spyrit.core.train

Training functions for deep learning models.

Functions

attr_removal(old_key, old_name)

attr_transformation(old_key, old_name, new_name)

boxplot(model1, model2, model3, criterion, ...)

boxplotconsist(model1, model2, model3, ...)

checkpoint(root, epoch, model)

Saves the dictionaries of a given pytorch model for the right epoch

compare_model(model1, model2, model3, ...[, ...])

Compare three models

count_memory(model)

count_param(model)

count_trainable_param(model)

images_norm(images)

imshow(img[, title])

load_net(title, model[, device, strict])

Loads net defined by title

multiplot(train_info1, train_info2, train_info3)

read_param(path)

remove_model_attributes(source, old_name[, ...])

Remove some attributes of a saved model (nn.module)

rename_model_attributes(source, old_name, ...)

Rename the name of the attributes of a saved model (nn.module)

save_net(title, model)

Saves dictionaries of a given pytorch model in the place defined by title

tb_profiler(path_prof, model, criterion, ...)

Tensorboard profiler: Profile code execution

tb_writer_add_image(writer, name_metric, ...)

Tensorboard writer: Add an image)

tb_writer_add_scalar(writer, name_metric, ...)

Tensorboard writer: Add a scalar (loss)

tb_writer_init(tb_path[, samples])

Tensorboard log for torch

train_model(model, criterion, optimizer, ...)

Trains the pytorch model

train_model_supervised(model, criterion, ...)

Trains the pytorch model in a supervised way

visualize_conv_layers(conv_layer[, ...])

Displays the 8 first filters of the convolution layer conv_layer

visualize_model(model, dataloaders, device)

Takes 8 images from the dataloader and shows side by side the input image and the reconstructed image

Classes

Train_par(batch_size, learning_rate, img_size)

Weight_Decay_Loss(loss)