.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_03_pseudoinverse_cnn_linear.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_03_pseudoinverse_cnn_linear.py: 03. Pseudoinverse solution + CNN denoising ========================================== .. _tuto_pseudoinverse_cnn_linear: This tutorial shows how to simulate measurements and perform image reconstruction using PinvNet (pseudoinverse linear network) with CNN denoising as a last layer. This tutorial is a continuation of the :ref:`Pseudoinverse solution tutorial ` but uses a CNN denoiser instead of the identity operator in order to remove artefacts. The measurement operator is chosen as a Hadamard matrix with positive coefficients, which can be replaced by any matrix. .. image:: ../fig/tuto3.png :width: 600 :align: center :alt: Reconstruction and neural network denoising architecture sketch These tutorials load image samples from `/images/`. .. GENERATED FROM PYTHON SOURCE LINES 23-25 Load a batch of images ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 27-29 Images :math:`x` for training expect values in [-1,1]. The images are normalized using the :func:`transform_gray_norm` function. .. GENERATED FROM PYTHON SOURCE LINES 29-67 .. code-block:: Python import os import torch import torchvision import matplotlib.pyplot as plt import spyrit.core.torch as spytorch from spyrit.misc.disp import imagesc from spyrit.misc.statistics import transform_gray_norm # sphinx_gallery_thumbnail_path = 'fig/tuto3.png' h = 64 # image size hxh i = 1 # Image index (modify to change the image) spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") # Create a transform for natural images to normalized grayscale image tensors transform = transform_gray_norm(img_size=h) # Create dataset and loader (expects class folder 'images/test/') dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Shape of input images: {x.shape}") # Select image x = x[i : i + 1, :, :, :] x = x.detach().clone() print(f"Shape of selected image: {x.shape}") b, c, h, w = x.shape # plot imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") .. image-sg:: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_001.png :alt: $x$ in [-1, 1] :srcset: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of input images: torch.Size([7, 1, 64, 64]) Shape of selected image: torch.Size([1, 1, 64, 64]) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 68-70 Define a measurement operator ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 72-76 We consider the case where the measurement matrix is the positive component of a Hadamard matrix and the sampling operator preserves only the first :attr:`M` low-frequency coefficients (see :ref:`Positive Hadamard matrix ` for full explantion). .. GENERATED FROM PYTHON SOURCE LINES 76-91 .. code-block:: Python import math F = spytorch.walsh2_matrix(h) F = torch.max(F, torch.zeros_like(F)) und = 4 # undersampling factor M = h**2 // und # number of measurements (undersampling factor = 4) Sampling_map = torch.zeros(h, h) M_xy = math.ceil(M**0.5) Sampling_map[:M_xy, :M_xy] = 1 imagesc(Sampling_map, "low-frequency sampling map") .. image-sg:: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_002.png :alt: low-frequency sampling map :srcset: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 92-94 After permutation of the full Hadamard matrix, we keep only its first :attr:`M` rows .. GENERATED FROM PYTHON SOURCE LINES 94-100 .. code-block:: Python F = spytorch.sort_by_significance(F, Sampling_map, "rows", False) H = F[:M, :] print(f"Shape of the measurement matrix: {H.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of the measurement matrix: torch.Size([1024, 4096]) .. GENERATED FROM PYTHON SOURCE LINES 101-102 Then, we instantiate a :class:`spyrit.core.meas.Linear` measurement operator .. GENERATED FROM PYTHON SOURCE LINES 102-107 .. code-block:: Python from spyrit.core.meas import Linear meas_op = Linear(H, pinv=True) .. GENERATED FROM PYTHON SOURCE LINES 108-110 Noiseless case ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 112-114 In the noiseless case, we consider the :class:`spyrit.core.noise.NoNoise` noise operator .. GENERATED FROM PYTHON SOURCE LINES 114-124 .. code-block:: Python from spyrit.core.noise import NoNoise N0 = 1.0 # Noise level (noiseless) noise = NoNoise(meas_op) # Simulate measurements y = noise(x) print(f"Shape of raw measurements: {y.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of raw measurements: torch.Size([1, 1, 1024]) .. GENERATED FROM PYTHON SOURCE LINES 125-127 We now compute and plot the preprocessed measurements corresponding to an image in [-1,1] .. GENERATED FROM PYTHON SOURCE LINES 127-135 .. code-block:: Python from spyrit.core.prep import DirectPoisson prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator m = prep(y) print(f"Shape of the preprocessed measurements: {m.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of the preprocessed measurements: torch.Size([1, 1, 1024]) .. GENERATED FROM PYTHON SOURCE LINES 136-138 To display the subsampled measurement vector as an image in the transformed domain, we use the :func:`spyrit.core.torch.meas2img` function .. GENERATED FROM PYTHON SOURCE LINES 138-145 .. code-block:: Python # plot m_plot = spytorch.meas2img(m, Sampling_map) print(f"Shape of the preprocessed measurement image: {m_plot.shape}") imagesc(m_plot[0, 0, :, :], "Preprocessed measurements (no noise)") .. image-sg:: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_003.png :alt: Preprocessed measurements (no noise) :srcset: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of the preprocessed measurement image: torch.Size([1, 1, 64, 64]) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 146-148 PinvNet Network ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 150-154 We consider the :class:`spyrit.core.recon.PinvNet` class that reconstructs an image by computing the pseudoinverse solution, which is fed to a neural network denoiser. To compute the pseudoinverse solution only, the denoiser can be set to the identity operator .. GENERATED FROM PYTHON SOURCE LINES 154-159 .. code-block:: Python from spyrit.core.recon import PinvNet pinv_net = PinvNet(noise, prep, denoi=torch.nn.Identity()) .. GENERATED FROM PYTHON SOURCE LINES 160-161 or equivalently .. GENERATED FROM PYTHON SOURCE LINES 161-163 .. code-block:: Python pinv_net = PinvNet(noise, prep) .. GENERATED FROM PYTHON SOURCE LINES 164-166 Then, we reconstruct the image from the measurement vector :attr:`y` using the :func:`~spyrit.core.recon.PinvNet.reconstruct` method .. GENERATED FROM PYTHON SOURCE LINES 166-169 .. code-block:: Python x_rec = pinv_net.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 170-172 Removing artefacts with a CNN ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 174-178 Artefacts can be removed by selecting a neural network denoiser (last layer of PinvNet). We select a simple CNN using the :class:`spyrit.core.nnet.ConvNet` class, but this can be replaced by any neural network (eg. UNet from :class:`spyrit.core.nnet.Unet`). .. GENERATED FROM PYTHON SOURCE LINES 180-184 .. image:: ../fig/pinvnet_cnn.png :width: 400 :align: center :alt: Sketch of the PinvNet with CNN architecture .. GENERATED FROM PYTHON SOURCE LINES 184-197 .. code-block:: Python from spyrit.core.nnet import ConvNet from spyrit.core.train import load_net # Define PInvNet with ConvNet denoising layer denoi = ConvNet() pinv_net_cnn = PinvNet(noise, prep, denoi) # Send to GPU if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device:", device) pinv_net_cnn = pinv_net_cnn.to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 198-200 As an example, we use a simple ConvNet that has been pretrained using STL-10 dataset. We download the pretrained weights and load them into the network. .. GENERATED FROM PYTHON SOURCE LINES 200-215 .. code-block:: Python from spyrit.misc.load_data import download_girder # Load pretrained model url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataID = "67221889f03a54733161e963" # unique ID of the file local_folder = "./model/" data_name = "tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth" # download the model and save it in the local folder model_cnn_path = download_girder(url, dataID, local_folder, data_name) # Load model weights load_net(model_cnn_path, pinv_net_cnn, device, False) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/girder_client/__init__.py:1: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html from pkg_resources import DistributionNotFound, get_distribution /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/pkg_resources/__init__.py:3149: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`. Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages declare_namespace(pkg) Local folder not found, creating it... done. Downloading tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth... Downloading tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth... done. Model Loaded: /home/docs/checkouts/readthedocs.org/user_builds/spyrit/checkouts/2.4.0/tutorial/model/tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth .. GENERATED FROM PYTHON SOURCE LINES 216-218 We now reconstruct the image using PinvNet with pretrained CNN denoising and plot results side by side with the PinvNet without denoising .. GENERATED FROM PYTHON SOURCE LINES 218-245 .. code-block:: Python from spyrit.misc.disp import add_colorbar, noaxis with torch.no_grad(): x_rec_cnn = pinv_net_cnn.reconstruct(y.to(device)) x_rec_cnn = pinv_net_cnn(x.to(device)) # plot f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) im1 = ax1.imshow(x[0, 0, :, :], cmap="gray") ax1.set_title("Ground-truth image", fontsize=20) noaxis(ax1) add_colorbar(im1, "bottom", size="20%") im2 = ax2.imshow(x_rec[0, 0, :, :], cmap="gray") ax2.set_title("Pinv reconstruction", fontsize=20) noaxis(ax2) add_colorbar(im2, "bottom", size="20%") im3 = ax3.imshow(x_rec_cnn.cpu()[0, 0, :, :], cmap="gray") ax3.set_title(f"Pinv + CNN (trained 30 epochs", fontsize=20) noaxis(ax3) add_colorbar(im3, "bottom", size="20%") plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_004.png :alt: Ground-truth image, Pinv reconstruction, Pinv + CNN (trained 30 epochs :srcset: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 246-247 We show the best result again (tutorial thumbnail purpose) .. GENERATED FROM PYTHON SOURCE LINES 247-253 .. code-block:: Python # Plot imagesc( x_rec_cnn.cpu()[0, 0, :, :], f"Pinv + CNN (trained 30 epochs", title_fontsize=20 ) .. image-sg:: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_005.png :alt: Pinv + CNN (trained 30 epochs :srcset: /gallery/images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 254-255 In the next tutorial, we will show how to train PinvNet + CNN denoiser. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.515 seconds) .. _sphx_glr_download_gallery_tuto_03_pseudoinverse_cnn_linear.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_03_pseudoinverse_cnn_linear.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_03_pseudoinverse_cnn_linear.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_03_pseudoinverse_cnn_linear.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_