.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_04_b_train_pseudoinverse_cnn_linear.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_04_b_train_pseudoinverse_cnn_linear.py: 04.b. Pseudoinverse + CNN (training) ================================================ .. _tuto_4b_train_pseudoinverse_cnn_linear: This tutorial trains a post processing CNN used by a :class:`spyrit.core.recon.PinvNet` (see the :ref:`previous tutorial `). .. image:: ../fig/tuto4_pinvnet.png :width: 600 :align: center :alt: Reconstruction architecture sketch | For post-processing, we consider a small CNN; however, it be replaced by any other network (e.g., a Unet). Training is performed on the STL-10 dataset, but any other database can be considered. You can use Tensorboard for Pytorch for experiment tracking and for visualizing the training process: losses, network weights, and intermediate results (reconstructed images at different epochs). .. GENERATED FROM PYTHON SOURCE LINES 25-27 Measurement operator ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 29-32 We choose the acquisition matrix as the positive component of a Hadamard matrix in "2D". We subsample it by a factor four, keeping only the low-frequency components (see :ref:`Tutorial 4 ` for details). .. GENERATED FROM PYTHON SOURCE LINES 34-35 Positive component of a Hadamard matrix in "2D". .. GENERATED FROM PYTHON SOURCE LINES 35-41 .. code-block:: Python import torch from spyrit.core.torch import walsh_matrix_2d H = walsh_matrix_2d(64) H = torch.where(H > 0, 1.0, 0.0) .. GENERATED FROM PYTHON SOURCE LINES 42-43 Subsampling map .. GENERATED FROM PYTHON SOURCE LINES 43-47 .. code-block:: Python Sampling_square = torch.zeros(64, 64) Sampling_square[:32, :32] = 1 .. GENERATED FROM PYTHON SOURCE LINES 48-49 Permutation of the rows and subsampling .. GENERATED FROM PYTHON SOURCE LINES 49-55 .. code-block:: Python from spyrit.core.torch import sort_by_significance H = sort_by_significance(H, Sampling_square, "rows", False) H = H[: 32 * 32, :] .. GENERATED FROM PYTHON SOURCE LINES 56-57 Associated :class:`spyrit.core.meas.Linear` operator .. GENERATED FROM PYTHON SOURCE LINES 57-65 .. code-block:: Python from spyrit.core.meas import Linear # Send to GPU if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") meas_op = Linear(H, (64, 64), device=device) .. GENERATED FROM PYTHON SOURCE LINES 66-70 .. note:: The linear measurement operator is chosen as the positive part of a subsampled Hadamard matrix, but any other matrix can be used. .. GENERATED FROM PYTHON SOURCE LINES 73-75 Pseudo inverse solution followed by a CNN ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 77-82 We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an image by computing the pseudoinverse solution and applies a nonlinear network denoiser. First, we must define the denoiser. As an example, we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class. Then, we define the PinvNet network by passing the noise and preprocessing operators and the denoiser. .. GENERATED FROM PYTHON SOURCE LINES 82-89 .. code-block:: Python from typing import OrderedDict from spyrit.core.nnet import ConvNet denoiser = torch.nn.Sequential(OrderedDict({"denoi": ConvNet()})) .. GENERATED FROM PYTHON SOURCE LINES 90-94 .. note:: Here, we consider a small CNN; however, it be replaced by any other network (e.g., a Unet). .. GENERATED FROM PYTHON SOURCE LINES 97-99 We instantiate a :class:`spyrit.core.recon.PinvNet` with the CNN as an image-domain post processing .. GENERATED FROM PYTHON SOURCE LINES 99-105 .. code-block:: Python from spyrit.core.recon import PinvNet pinv_net = PinvNet(meas_op, denoi=denoiser, device=device, store_H_pinv=True) .. GENERATED FROM PYTHON SOURCE LINES 106-111 .. important:: We use :attr:`store_H_pinv=True` to compute and store the pseudo inverse matrix. This will be *much* faster that using a solver (default option) when a large number of pseudoinverse solutions will have to be computed during training. .. GENERATED FROM PYTHON SOURCE LINES 113-120 Dataloader for training ----------------------------------------------------------------------------- We now consider the STL10 dataset and use the the :attr:`normalize=False` argument to keep images with values in (0,1). Set :attr:`mode_run=True` in the the script below to download the STL10 dataset and train the CNN. Otherwise, the CNN paramameters will be downloaded. .. GENERATED FROM PYTHON SOURCE LINES 120-143 .. code-block:: Python # import torch.nn from spyrit.misc.statistics import data_loaders_stl10 from pathlib import Path # Parameters h = 64 # image size hxh data_root = Path("./data/") # path to data folder (where the dataset is stored) batch_size = 700 # Dataloader for STL-10 dataset mode_run = False if mode_run: dataloaders = data_loaders_stl10( data_root, img_size=h, batch_size=batch_size, seed=7, shuffle=True, download=True, normalize=False, ) .. GENERATED FROM PYTHON SOURCE LINES 144-148 .. note:: Here, training is performed on the STL-10 dataset, but any other database can be considered. .. GENERATED FROM PYTHON SOURCE LINES 150-152 Optimizer ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 154-157 We define a loss function (mean squared error), an optimizer (Adam) and a scheduler. The scheduler decreases the learning rate by a factor of :attr:`gamma` every :attr:`step_size` epochs. .. GENERATED FROM PYTHON SOURCE LINES 157-170 .. code-block:: Python from spyrit.core.train import Weight_Decay_Loss # Parameters lr = 1e-3 step_size = 10 gamma = 0.5 loss = torch.nn.MSELoss() criterion = Weight_Decay_Loss(loss) optimizer = torch.optim.Adam(pinv_net.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) .. GENERATED FROM PYTHON SOURCE LINES 171-173 Training ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 175-180 We use the :func:`spyrit.core.train.train_model` function, which iterates through the dataloader, feeds the STL10 images to the full network and optimizes the parameters of the CNN. In addition, it computes the loss and desired metrics on the training and validation sets at each iteration. The training process can be monitored using Tensorboard. .. GENERATED FROM PYTHON SOURCE LINES 183-185 Set :attr:`mode_run=True` to train the CNN (e.g., around 60 min for 20 epochs on my laptop equipped with a NVIDIA Quadro P1000). Otherwise, download the CNN parameters. .. GENERATED FROM PYTHON SOURCE LINES 185-221 .. code-block:: Python from spyrit.core.train import train_model from datetime import datetime # Parameters model_root = Path("./model") # path to model saving files num_epochs = 20 # number of training epochs (num_epochs = 30) checkpoint_interval = 0 # interval between saving model checkpoints tb_freq = ( 50 # interval between logging to Tensorboard (iterations through the dataloader) ) # Path for Tensorboard experiment tracking logs name_run = "stl10_hadam_positive" now = datetime.now().strftime("%Y-%m-%d_%H-%M") tb_path = f"runs/runs_{name_run}_nonoise_m{meas_op.M}/{now}" # Train the network if mode_run: pinv_net, train_info = train_model( pinv_net, criterion, optimizer, scheduler, dataloaders, device, model_root, num_epochs=num_epochs, disp=True, do_checkpoint=checkpoint_interval, tb_path=tb_path, tb_freq=tb_freq, ) else: train_info = {} .. GENERATED FROM PYTHON SOURCE LINES 222-231 .. note:: To launch Tensorboard type in a new console: tensorboard --logdir runs and open the provided link in a browser. The training process can be monitored in real time in the "Scalars" tab. The "Images" tab allows to visualize the reconstructed images at different iterations :attr:`tb_freq`. .. GENERATED FROM PYTHON SOURCE LINES 233-235 Training history ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 237-239 We save the model so that it can later be utilized. We save the network's architecture, the training parameters and the training history. .. GENERATED FROM PYTHON SOURCE LINES 239-285 .. code-block:: Python from spyrit.core.train import save_net title = "tuto_4b" Path(model_root).mkdir(parents=True, exist_ok=True) model_path = model_root / (title + ".pth") train_path = model_root / (title + ".pkl") if checkpoint_interval: Path(model_path).mkdir(parents=True, exist_ok=True) save_net(model_path, pinv_net.denoi) # save_net(model_root/(title+"_cnn.pth"), pinv_net.denoi.denoi) # Save training history import pickle if mode_run: from spyrit.core.train import Train_par reg = 1e-7 # Default value params = Train_par(batch_size, lr, h, reg=reg) params.set_loss(train_info) train_path = model_root / (title + ".pkl") with open(train_path, "wb") as param_file: pickle.dump(params, param_file) torch.cuda.empty_cache() else: from spyrit.misc.load_data import download_girder url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataID = "68639a2af39e1d2884b09abc" # unique ID of the file download_girder(url, dataID, model_root) with open(train_path, "rb") as param_file: params = pickle.load(param_file) train_info["train"] = params.train_loss train_info["val"] = params.val_loss .. rst-class:: sphx-glr-script-out .. code-block:: none model/tuto_4b.pth Model Saved Downloading tuto_4b.pkl... Downloading tuto_4b.pkl... done. .. GENERATED FROM PYTHON SOURCE LINES 286-288 Validation and training losses ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 290-291 We plot the training loss and validation loss .. GENERATED FROM PYTHON SOURCE LINES 291-305 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np epoch = np.arange(1, num_epochs + 1) fig = plt.figure() plt.semilogy(epoch, train_info["train"], label="train") plt.semilogy(epoch, train_info["val"], label="val") plt.xticks([5, 10, 15, 20]) plt.xlabel("Epochs", fontsize=20) plt.ylabel("Loss", fontsize=20) plt.legend(fontsize=20) plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_04_b_train_pseudoinverse_cnn_linear_001.png :alt: tuto 04 b train pseudoinverse cnn linear :srcset: /gallery/images/sphx_glr_tuto_04_b_train_pseudoinverse_cnn_linear_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.871 seconds) .. _sphx_glr_download_gallery_tuto_04_b_train_pseudoinverse_cnn_linear.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_04_b_train_pseudoinverse_cnn_linear.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_04_b_train_pseudoinverse_cnn_linear.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_04_b_train_pseudoinverse_cnn_linear.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_