.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_04_train_pseudoinverse_cnn_linear.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_04_train_pseudoinverse_cnn_linear.py: 04. Train pseudoinverse solution + CNN denoising ================================================ .. _tuto_train_pseudoinverse_cnn_linear: This tutorial shows how to train PinvNet with a CNN denoiser for reconstruction of linear measurements (results shown in the :ref:`previous tutorial `). As an example, we use a small CNN, which can be replaced by any other network, for example Unet. Training is performed on the STL-10 dataset. You can use Tensorboard for Pytorch for experiment tracking and for visualizing the training process: losses, network weights, and intermediate results (reconstructed images at different epochs). The linear measurement operator is chosen as the positive part of a Hadamard matrix, but this matrix can be replaced by any desired matrix. These tutorials load image samples from `/images/`. .. GENERATED FROM PYTHON SOURCE LINES 23-25 Load a batch of images ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 27-28 First, we load an image :math:`x` and normalized it to [-1,1], as in previous examples. .. GENERATED FROM PYTHON SOURCE LINES 28-64 .. code-block:: Python import os import torch import torchvision import matplotlib.pyplot as plt import spyrit.core.torch as spytorch from spyrit.misc.disp import imagesc from spyrit.misc.statistics import transform_gray_norm h = 64 # image size hxh i = 1 # Image index (modify to change the image) spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") # Create a transform for natural images to normalized grayscale image tensors transform = transform_gray_norm(img_size=h) # Create dataset and loader (expects class folder 'images/test/') dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Shape of input images: {x.shape}") # Select image x = x[i : i + 1, :, :, :] x = x.detach().clone() print(f"Shape of selected image: {x.shape}") b, c, h, w = x.shape # plot imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") .. image-sg:: /gallery/images/sphx_glr_tuto_04_train_pseudoinverse_cnn_linear_001.png :alt: $x$ in [-1, 1] :srcset: /gallery/images/sphx_glr_tuto_04_train_pseudoinverse_cnn_linear_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of input images: torch.Size([7, 1, 64, 64]) Shape of selected image: torch.Size([1, 1, 64, 64]) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 65-75 Define a dataloader ----------------------------------------------------------------------------- We define a dataloader for STL-10 dataset using :func:`spyrit.misc.statistics.data_loaders_stl10`. This will download the dataset to the provided path if it is not already downloaded. It is based on pytorch pre-loaded dataset :class:`torchvision.datasets.STL10` and :class:`torch.utils.data.DataLoader`, which creates a generator that iterates through the dataset, returning a batch of images and labels at each iteration. Set :attr:`mode_run` to True in the script below to download the dataset and for training; otherwise, pretrained weights and results will be download for display. .. GENERATED FROM PYTHON SOURCE LINES 75-96 .. code-block:: Python from spyrit.misc.statistics import data_loaders_stl10 from pathlib import Path # Parameters h = 64 # image size hxh data_root = Path("./data") # path to data folder (where the dataset is stored) batch_size = 512 # Dataloader for STL-10 dataset mode_run = False if mode_run: dataloaders = data_loaders_stl10( data_root, img_size=h, batch_size=batch_size, seed=7, shuffle=True, download=True, ) .. GENERATED FROM PYTHON SOURCE LINES 97-99 Define a measurement operator ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 101-106 We consider the case where the measurement matrix is the positive component of a Hadamard matrix, which is often used in single-pixel imaging (see :ref:`Hadamard matrix `). Then, we simulate an accelerated acquisition by keeping only the first :attr:`M` low-frequency coefficients (see :ref:`low frequency sampling `). .. GENERATED FROM PYTHON SOURCE LINES 106-126 .. code-block:: Python import math und = 4 # undersampling factor M = h**2 // und # number of measurements (undersampling factor = 4) F = spytorch.walsh2_matrix(h) F = torch.max(F, torch.zeros_like(F)) Sampling_map = torch.zeros(h, h) M_xy = math.ceil(M**0.5) Sampling_map[:M_xy, :M_xy] = 1 # imagesc(Sampling_map, 'low-frequency sampling map') F = spytorch.sort_by_significance(F, Sampling_map, "rows", False) H = F[:M, :] print(f"Shape of the measurement matrix: {H.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of the measurement matrix: torch.Size([1024, 4096]) .. GENERATED FROM PYTHON SOURCE LINES 127-130 Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator, a :class:`spyrit.core.noise.NoNoise` noise operator for noiseless case, and a preprocessing measurements operator :class:`spyrit.core.prep.DirectPoisson`. .. GENERATED FROM PYTHON SOURCE LINES 130-140 .. code-block:: Python from spyrit.core.meas import Linear from spyrit.core.noise import NoNoise from spyrit.core.prep import DirectPoisson meas_op = Linear(H, pinv=True) noise = NoNoise(meas_op) N0 = 1.0 # Mean maximum total number of photons prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator .. GENERATED FROM PYTHON SOURCE LINES 141-143 PinvNet Network ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 145-151 We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an image by computing the pseudoinverse solution and applies a nonlinear network denoiser. First, we must define the denoiser. As an example, we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class. Then, we define the PinvNet network by passing the noise and preprocessing operators and the denoiser. .. GENERATED FROM PYTHON SOURCE LINES 151-169 .. code-block:: Python from spyrit.core.nnet import ConvNet from spyrit.core.recon import PinvNet denoiser = ConvNet() model = PinvNet(noise, prep, denoi=denoiser) # Send to GPU if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Use multiple GPUs if available if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) print("Using device:", device) model = model.to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 170-175 .. note:: In the example provided, we choose a small CNN using the :class:`spyrit.core.nnet.ConvNet` class. This can be replaced by any denoiser, for example the :class:`spyrit.core.nnet.Unet` class or a custom denoiser. .. GENERATED FROM PYTHON SOURCE LINES 178-180 Define a Loss function optimizer and scheduler ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 182-186 In order to train the network, we need to define a loss function, an optimizer and a scheduler. We use the Mean Square Error (MSE) loss function, weigh decay loss and the Adam optimizer. The scheduler decreases the learning rate by a factor of :attr:`gamma` every :attr:`step_size` epochs. .. GENERATED FROM PYTHON SOURCE LINES 186-202 .. code-block:: Python import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from spyrit.core.train import save_net, Weight_Decay_Loss # Parameters lr = 1e-3 step_size = 10 gamma = 0.5 loss = nn.MSELoss() criterion = Weight_Decay_Loss(loss) optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) .. GENERATED FROM PYTHON SOURCE LINES 203-205 Train the network ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 207-213 To train the network, we use the :func:`~spyrit.core.train.train_model` function, which handles the training process. It iterates through the dataloader, feeds the inputs to the network and optimizes the solution (by computing the loss and its gradients and updating the network weights at each iteration). In addition, it computes the loss and desired metrics on the training and validation sets at each iteration. The training process can be monitored using Tensorboard. .. GENERATED FROM PYTHON SOURCE LINES 216-225 .. note:: To launch Tensorboard type in a new console: tensorboard --logdir runs and open the provided link in a browser. The training process can be monitored in real time in the "Scalars" tab. The "Images" tab allows to visualize the reconstructed images at different iterations :attr:`tb_freq`. .. GENERATED FROM PYTHON SOURCE LINES 228-231 In order to train, you must set :attr:`mode_run` to True for training. It is set to False by default to download the pretrained weights and results for display, as training takes around 40 min for 30 epochs. .. GENERATED FROM PYTHON SOURCE LINES 231-269 .. code-block:: Python # We train for one epoch only to check that everything works fine. from spyrit.core.train import train_model from datetime import datetime # Parameters model_root = Path("./model") # path to model saving files num_epochs = 5 # number of training epochs (num_epochs = 30) checkpoint_interval = 2 # interval between saving model checkpoints tb_freq = ( 50 # interval between logging to Tensorboard (iterations through the dataloader) ) # Path for Tensorboard experiment tracking logs name_run = "stdl10_hadampos" now = datetime.now().strftime("%Y-%m-%d_%H-%M") tb_path = f"runs/runs_{name_run}_n{int(N0)}_m{M}/{now}" # Train the network if mode_run: model, train_info = train_model( model, criterion, optimizer, scheduler, dataloaders, device, model_root, num_epochs=num_epochs, disp=True, do_checkpoint=checkpoint_interval, tb_path=tb_path, tb_freq=tb_freq, ) else: train_info = {} .. GENERATED FROM PYTHON SOURCE LINES 270-272 Save the network and training history ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 274-276 We save the model so that it can later be utilized. We save the network's architecture, the training parameters and the training history. .. GENERATED FROM PYTHON SOURCE LINES 276-329 .. code-block:: Python from spyrit.core.train import save_net # Training parameters train_type = "N0_{:g}".format(N0) arch = "pinv-net" denoi = "cnn" data = "stl10" reg = 1e-7 # Default value suffix = "N_{}_M_{}_epo_{}_lr_{}_sss_{}_sdr_{}_bs_{}".format( h, M, num_epochs, lr, step_size, gamma, batch_size ) title = model_root / f"{arch}_{denoi}_{data}_{train_type}_{suffix}" print(title) Path(model_root).mkdir(parents=True, exist_ok=True) if checkpoint_interval: Path(title).mkdir(parents=True, exist_ok=True) save_net(str(title) + ".pth", model) # Save training history import pickle if mode_run: from spyrit.core.train import Train_par params = Train_par(batch_size, lr, h, reg=reg) params.set_loss(train_info) train_path = model_root / f"TRAIN_{arch}_{denoi}_{data}_{train_type}_{suffix}.pkl" with open(train_path, "wb") as param_file: pickle.dump(params, param_file) torch.cuda.empty_cache() else: from spyrit.misc.load_data import download_girder url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataID = "667ebfe4baa5a90007058964" # unique ID of the file data_name = "tuto4_TRAIN_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pkl" train_path = os.path.join(model_root, data_name) # download girder file download_girder(url, dataID, model_root, data_name) with open(train_path, "rb") as param_file: params = pickle.load(param_file) train_info["train"] = params.train_loss train_info["val"] = params.val_loss .. rst-class:: sphx-glr-script-out .. code-block:: none model/pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_5_lr_0.001_sss_10_sdr_0.5_bs_512 model/pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_5_lr_0.001_sss_10_sdr_0.5_bs_512.pth Model Saved Downloading tuto4_TRAIN_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pkl... Downloading tuto4_TRAIN_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pkl... done. .. GENERATED FROM PYTHON SOURCE LINES 330-331 We plot the training loss and validation loss .. GENERATED FROM PYTHON SOURCE LINES 331-343 .. code-block:: Python # Plot # sphinx_gallery_thumbnail_number = 2 fig = plt.figure() plt.plot(train_info["train"], label="train") plt.plot(train_info["val"], label="val") plt.xlabel("Epochs", fontsize=20) plt.ylabel("Loss", fontsize=20) plt.legend(fontsize=20) plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_04_train_pseudoinverse_cnn_linear_002.png :alt: tuto 04 train pseudoinverse cnn linear :srcset: /gallery/images/sphx_glr_tuto_04_train_pseudoinverse_cnn_linear_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 344-350 .. note:: See the googlecolab notebook `spyrit-examples/tutorial/tuto_train_lin_meas_colab.ipynb `_ for training a reconstruction network on GPU. It shows how to train using different architectures, denoisers and other hyperparameters from :func:`~spyrit.core.train.train_model` function. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.123 seconds) .. _sphx_glr_download_gallery_tuto_04_train_pseudoinverse_cnn_linear.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_04_train_pseudoinverse_cnn_linear.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_04_train_pseudoinverse_cnn_linear.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_04_train_pseudoinverse_cnn_linear.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_