.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_04_pseudoinverse_cnn_linear.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_04_pseudoinverse_cnn_linear.py: 04.a. Pseudoinverse + CNN (reconstruction) ========================================== .. _tuto_04_pseudoinverse_cnn_linear: This tutorial shows how to simulate measurements and perform image reconstruction using the :class:`spyrit.core.recon.PinvNet` class of the :mod:`spyrit.core.recon` submodule. .. image:: ../fig/tuto4_pinvnet.png :width: 600 :align: center :alt: Reconstruction architecture sketch | .. GENERATED FROM PYTHON SOURCE LINES 17-19 Load a batch of images ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 21-24 We load a batch of images from the :attr:`/images/` folder. Using the :func:`spyrit.misc.statistics.transform_gray_norm` function with the :attr:`normalize=False` argument returns images with values in (0,1). .. GENERATED FROM PYTHON SOURCE LINES 24-42 .. code-block:: Python import os import torchvision import torch.nn from spyrit.misc.statistics import transform_gray_norm spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") # Grayscale images of size 64 x 64, no normalization to keep values in (0,1) transform = transform_gray_norm(img_size=64, normalize=False) # Create dataset and loader (expects class folder 'images/test/') dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Ground-truth images: {x.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Ground-truth images: torch.Size([7, 1, 64, 64]) .. GENERATED FROM PYTHON SOURCE LINES 43-44 We plot the second image in the batch .. GENERATED FROM PYTHON SOURCE LINES 44-48 .. code-block:: Python from spyrit.misc.disp import imagesc imagesc(x[1, 0, :, :], "x[1, 0, :, :]") .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_001.png :alt: x[1, 0, :, :] :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 49-51 Linear measurements (no noise) ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 53-55 We choose the acquisition matrix as the positive component of a Hadamard matrix in "2D". This is a (0,1) matrix with shape of (64*64, 64*64). .. GENERATED FROM PYTHON SOURCE LINES 55-63 .. code-block:: Python from spyrit.core.torch import walsh_matrix_2d H = walsh_matrix_2d(64) H = torch.where(H > 0, 1.0, 0.0) print(f"Acquisition matrix: {H.shape}", end=" ") print(rf"with values in {{{H.min()}, {H.max()}}}") .. rst-class:: sphx-glr-script-out .. code-block:: none Acquisition matrix: torch.Size([4096, 4096]) with values in {0.0, 1.0} .. GENERATED FROM PYTHON SOURCE LINES 64-66 We subsample the measurement operator by a factor four, keeping only the low-frequency components .. GENERATED FROM PYTHON SOURCE LINES 66-72 .. code-block:: Python Sampling_square = torch.zeros(64, 64) Sampling_square[:32, :32] = 1 imagesc(Sampling_square, "Sampling map") .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_002.png :alt: Sampling map :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 73-75 We use spyrit.core.torch.sort_by_significance() to permutate the rows of H. Then, we keep the first 1024 rows. .. GENERATED FROM PYTHON SOURCE LINES 75-83 .. code-block:: Python from spyrit.core.torch import sort_by_significance H = sort_by_significance(H, Sampling_square, "rows", False) H = H[: 32 * 32, :] print(f"Shape of the measurement matrix: {H.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of the measurement matrix: torch.Size([1024, 4096]) .. GENERATED FROM PYTHON SOURCE LINES 84-87 We instantiate a :class:`spyrit.core.meas.Linear` operator. To indicate that the operator works in 2D, on images with shape (64, 64), we use the :attr:`meas_shape` argument. .. GENERATED FROM PYTHON SOURCE LINES 87-92 .. code-block:: Python from spyrit.core.meas import Linear meas_op = Linear(H, (64, 64)) .. GENERATED FROM PYTHON SOURCE LINES 93-94 We simulate the measurement vectors, which has a shape of (7, 1, 1024). .. GENERATED FROM PYTHON SOURCE LINES 94-98 .. code-block:: Python y = meas_op(x) print(f"Measurement vectors: {y.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Measurement vectors: torch.Size([7, 1, 1024]) .. GENERATED FROM PYTHON SOURCE LINES 99-101 To display the subsampled measurement vector as an image in the transformed domain, we use the :func:`spyrit.core.torch.meas2img` function .. GENERATED FROM PYTHON SOURCE LINES 101-111 .. code-block:: Python # plot from spyrit.core.torch import meas2img m_plot = meas2img(y, Sampling_square) print(f"Shape of the preprocessed measurement image: {m_plot.shape}") imagesc(m_plot[0, 0, :, :], "Measurements (reshaped)") .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_003.png :alt: Measurements (reshaped) :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of the preprocessed measurement image: torch.Size([7, 1, 64, 64]) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 112-114 Pseudo inverse solution with PinvNet ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 116-119 The :class:`spyrit.core.recon.PinvNet` class reconstructs an image by computing the pseudoinverse solution. By default, the torch.linalg.lstsq solver is used .. GENERATED FROM PYTHON SOURCE LINES 119-124 .. code-block:: Python from spyrit.core.recon import PinvNet pinv_net = PinvNet(meas_op) .. GENERATED FROM PYTHON SOURCE LINES 125-127 We use the :func:`~spyrit.core.recon.PinvNet.reconstruct` method to reconstruct the images from the measurement vectors :attr:`y` .. GENERATED FROM PYTHON SOURCE LINES 127-132 .. code-block:: Python x_rec = pinv_net.reconstruct(y) imagesc(x_rec[1, 0, :, :], "Pseudo Inverse") .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_004.png :alt: Pseudo Inverse :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 133-137 Alternatively, the pseudo-inverse of the acquition matrix is computed and stored. This option becomes efficient when a large number of reconstructions are performed (e.g., during training). To do so, we used set 'store_H_pinv' to 'True'. .. GENERATED FROM PYTHON SOURCE LINES 137-143 .. code-block:: Python pinv_net_2 = PinvNet(meas_op, store_H_pinv=True) x_rec_2 = pinv_net.reconstruct(y) imagesc(x_rec_2[1, 0, :, :], "Pseudo Inverse") .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_005.png :alt: Pseudo Inverse :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 144-146 Contrary to pinv_net, pinv_net_2 stores the pseudo inverse matrix with shape (4096,1024) .. GENERATED FROM PYTHON SOURCE LINES 146-150 .. code-block:: Python print(f"pinv_net: {hasattr(pinv_net.pinv, 'pinv')}") print(f"pinv_net_2: {hasattr(pinv_net_2.pinv, 'pinv')}") print(f"Shape: {pinv_net_2.pinv.pinv.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none pinv_net: False pinv_net_2: True Shape: torch.Size([4096, 1024]) .. GENERATED FROM PYTHON SOURCE LINES 151-153 CNN post processing with PinvNet ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 155-161 Reconstruction artefacts can be removed by post processing the pseudo inverse solution using a denoising neural network. In the following, we select a small CNN using the :class:`spyrit.core.nnet.ConvNet` class, but it can be replaced by any other neural network (e.g., a UNet from :class:`spyrit.core.nnet.Unet`). .. GENERATED FROM PYTHON SOURCE LINES 163-164 We download a ConvNet that has been trained using STL-10 dataset. .. GENERATED FROM PYTHON SOURCE LINES 164-173 .. code-block:: Python from spyrit.misc.load_data import download_girder url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataID = "68639a2af39e1d2884b09abf" # unique ID of the file model_folder = "./model/" model_cnn_path = download_girder(url, dataID, model_folder) .. rst-class:: sphx-glr-script-out .. code-block:: none Local folder not found, creating it... done. Downloading tuto_4b.pth... Downloading tuto_4b.pth... done. .. GENERATED FROM PYTHON SOURCE LINES 174-176 The CNN should be placed in an ordered dictionary and passed to a :class:`nn.Sequential`. .. GENERATED FROM PYTHON SOURCE LINES 176-182 .. code-block:: Python from typing import OrderedDict from spyrit.core.nnet import ConvNet denoiser = torch.nn.Sequential(OrderedDict({"denoi": ConvNet()})) .. GENERATED FROM PYTHON SOURCE LINES 183-184 We load the denoiser and send it to GPU, if available. .. GENERATED FROM PYTHON SOURCE LINES 184-191 .. code-block:: Python from spyrit.core.train import load_net device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") load_net(model_cnn_path, denoiser, device, False) .. rst-class:: sphx-glr-script-out .. code-block:: none Model Loaded: /home/docs/checkouts/readthedocs.org/user_builds/spyrit/checkouts/3.1.1/tutorial/model/tuto_4b.pth .. GENERATED FROM PYTHON SOURCE LINES 192-193 We create a PinvNet with a postprocessing denoising step .. GENERATED FROM PYTHON SOURCE LINES 193-196 .. code-block:: Python pinv_net = PinvNet(meas_op, denoi=denoiser, device=device) .. GENERATED FROM PYTHON SOURCE LINES 197-198 We reconstruct the image using PinvNet .. GENERATED FROM PYTHON SOURCE LINES 198-205 .. code-block:: Python pinv_net.eval() y = y.to(device) with torch.no_grad(): x_rec_cnn = pinv_net.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 206-207 We finally plot the results .. GENERATED FROM PYTHON SOURCE LINES 207-230 .. code-block:: Python import matplotlib.pyplot as plt from spyrit.misc.disp import add_colorbar, noaxis f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) im1 = ax1.imshow(x[1, 0, :, :], cmap="gray") ax1.set_title("Ground-truth", fontsize=20) noaxis(ax1) add_colorbar(im1, "bottom", size="20%") im2 = ax2.imshow(x_rec[1, 0, :, :].cpu(), cmap="gray") ax2.set_title("Pinv", fontsize=20) noaxis(ax2) add_colorbar(im2, "bottom", size="20%") im3 = ax3.imshow(x_rec_cnn.cpu()[1, 0, :, :], cmap="gray") ax3.set_title("Pinv + CNN", fontsize=20) noaxis(ax3) add_colorbar(im3, "bottom", size="20%") plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_006.png :alt: Ground-truth, Pinv, Pinv + CNN :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_006.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 231-233 We show the best result again (tutorial thumbnail purpose) sphinx_gallery_thumbnail_number = 7 .. GENERATED FROM PYTHON SOURCE LINES 233-236 .. code-block:: Python imagesc(x_rec_cnn.cpu()[1, 0, :, :], "Pinv + CNN", title_fontsize=20) .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_007.png :alt: Pinv + CNN :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 237-241 .. note:: In the :ref:`next tutorial `, we will show how to train PinvNet + CNN denoiser. .. GENERATED FROM PYTHON SOURCE LINES 243-245 Compatibility between spyrit 2 and spyrit 3 ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 247-250 SPyRiT 2.4 trains neural networks for images with values in the range (-1,1), while SPyRiT 3 assumes images with values in the range (0,1). This can be compensated for using :class:`spyrit.core.prep.Rerange`. .. GENERATED FROM PYTHON SOURCE LINES 250-260 .. code-block:: Python from spyrit.core.prep import Rerange rerange = Rerange((0, 1), (-1, 1)) denoiser = OrderedDict( {"rerange": rerange, "denoi": ConvNet(), "rerange_inv": rerange.inverse()} ) denoiser = torch.nn.Sequential(denoiser) .. GENERATED FROM PYTHON SOURCE LINES 261-262 We load a spyrit 2.4 denoiser and show the reconstruction .. GENERATED FROM PYTHON SOURCE LINES 262-273 .. code-block:: Python dataID = "67221889f03a54733161e963" # unique ID of the file model_cnn_path = download_girder(url, dataID, model_folder) load_net(model_cnn_path, denoiser, device, False) pinv_net = PinvNet(meas_op, denoi=denoiser, device=device) with torch.no_grad(): x_rec_cnn = pinv_net.reconstruct(y) imagesc(x_rec_cnn.cpu()[1, 0, :, :], "Pinv + CNN (v2.4)", title_fontsize=20) .. image-sg:: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_008.png :alt: Pinv + CNN (v2.4) :srcset: /gallery/images/sphx_glr_tuto_04_pseudoinverse_cnn_linear_008.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth... Downloading tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth... done. Model Loaded: /home/docs/checkouts/readthedocs.org/user_builds/spyrit/checkouts/3.1.1/tutorial/model/tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 12.960 seconds) .. _sphx_glr_download_gallery_tuto_04_pseudoinverse_cnn_linear.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_04_pseudoinverse_cnn_linear.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_04_pseudoinverse_cnn_linear.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_04_pseudoinverse_cnn_linear.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_