.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_05_dcnet.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_05_dcnet.py: ========================================= 05. Denoised Completion Network (DC-Net) ========================================= .. _tuto_dcnet_split_measurements: This tutorial shows how to perform image reconstruction using a denoised completion network (DC-Net) [1]_ with a trainable image denoiser. .. figure:: ../fig/tuto5_dcnet.png :width: 600 :align: center :alt: Reconstruction and neural network denoising architecture sketch using split measurements | .. [1] A Lorente Mur, P Leclerc, F Peyrin, and N Ducros, "Single-pixel image reconstruction from experimental data using neural networks," *Opt. Express*, Vol. 29, Issue 11, 17097-17110 (2021). `DOI `_. .. GENERATED FROM PYTHON SOURCE LINES 21-23 Load a batch of images ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 25-28 We load a batch of images from the :attr:`/images/` folder. Using the :func:`spyrit.misc.statistics.transform_gray_norm` function with the :attr:`normalize=False` argument returns images with values in (0,1). .. GENERATED FROM PYTHON SOURCE LINES 28-46 .. code-block:: Python import os import torchvision import torch.nn from spyrit.misc.statistics import transform_gray_norm spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") # Grayscale images of size 64 x 64, no normalization to keep values in (0,1) transform = transform_gray_norm(img_size=64, normalize=False) # Create dataset and loader (expects class folder 'images/test/') dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Ground-truth images: {x.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Ground-truth images: torch.Size([7, 1, 64, 64]) .. GENERATED FROM PYTHON SOURCE LINES 47-48 We display the second image in the batch .. GENERATED FROM PYTHON SOURCE LINES 48-52 .. code-block:: Python from spyrit.misc.disp import imagesc imagesc(x[1, 0, :, :], "x[1, 0, :, :]") .. image-sg:: /gallery/images/sphx_glr_tuto_05_dcnet_001.png :alt: x[1, 0, :, :] :srcset: /gallery/images/sphx_glr_tuto_05_dcnet_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 53-55 Forward operators for split measurements ========================================= .. GENERATED FROM PYTHON SOURCE LINES 57-65 We consider Poisson noise, i.e., a noisy measurement vector given by .. math:: y \sim \mathcal{P}(\alpha A x), where :math:`\alpha` is a scalar value that represents the maximum image intensity (in photons), :math:`A \colon\, \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains the DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal. The larger :math:`\alpha`, the higher the signal-to-noise ratio of the measurements. .. GENERATED FROM PYTHON SOURCE LINES 67-68 The acquisition matrix :math:`A` is chosen as a split Hadamard matrix. It is subsampled by a factor of four by retaining the rows that give, statistically, the coefficients with the largest variance. This is achieved by the :class:`~spyrit.core.meas.HadamSplit` class (see :ref:`Tutorial 1.c ` for details). .. GENERATED FROM PYTHON SOURCE LINES 71-72 First, we download a covariance matrix (for subsampling). .. GENERATED FROM PYTHON SOURCE LINES 72-82 .. code-block:: Python from spyrit.misc.load_data import download_girder # Get covariance matrix url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataId = "672207cbf03a54733161e95d" data_folder = "./stat/" file_abs_path = download_girder(url, dataId, data_folder) Cov = torch.load(file_abs_path, weights_only=True) .. rst-class:: sphx-glr-script-out .. code-block:: none File already exists at ./stat/Cov_64x64.pt .. GENERATED FROM PYTHON SOURCE LINES 83-84 Then, we choose a subsampling factor of four and specify the subsampling strategy using the :attr:`order` attribute. Finally, we set the noise model using the :attr:`noise_model` attribute. We use the :class:`spyrit.core.noise.Poisson` class and set :math:`\alpha` to 100 photons. .. GENERATED FROM PYTHON SOURCE LINES 84-95 .. code-block:: Python from spyrit.core.torch import Cov2Var from spyrit.core.meas import HadamSplit2d from spyrit.core.noise import Poisson M = 64 * 64 // 4 alpha = 100.0 # image intensity Variance = Cov2Var(Cov) noise_model = Poisson(alpha) meas_op = HadamSplit2d(64, M=M, order=Variance, noise_model=noise_model) .. GENERATED FROM PYTHON SOURCE LINES 96-97 We simulate the measurements .. GENERATED FROM PYTHON SOURCE LINES 97-99 .. code-block:: Python y = meas_op(x) .. GENERATED FROM PYTHON SOURCE LINES 100-102 Pseudo inverse solution with preprocessing ========================================== .. GENERATED FROM PYTHON SOURCE LINES 104-109 We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet`, which can include a preprocessing step .. math:: m = \texttt{Prep}(y). .. GENERATED FROM PYTHON SOURCE LINES 111-122 We consider the :class:`spyrit.core.prep.UnsplitRescale` class that intends to "undo": * The splitting of an acquisition matrix (see :class:`spyrit.core.meas.LinearSplit`) * The scaling that controls the SNR of Poisson-corrupted measurements (see :class:`spyrit.core.noise.Poisson`). For this, we use the :class:`spyrit.core.prep.UnsplitRescale` class that computes .. math:: m = \frac{(y_+-y_-)}{\alpha}, where :math:`y_+ = H_+ x` and :math:`y_- = H_- x`. .. GENERATED FROM PYTHON SOURCE LINES 122-127 .. code-block:: Python from spyrit.core.recon import PinvNet from spyrit.core.prep import UnsplitRescale prep_op = UnsplitRescale(alpha) pinvnet = PinvNet(meas_op, prep_op) .. GENERATED FROM PYTHON SOURCE LINES 128-129 Use GPU, if available .. GENERATED FROM PYTHON SOURCE LINES 129-135 .. code-block:: Python device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device: ", device) pinvnet = pinvnet.to(device) y = y.to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 136-137 Reconstruction .. GENERATED FROM PYTHON SOURCE LINES 137-140 .. code-block:: Python with torch.no_grad(): x_pinv = pinvnet.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 141-142 We display the second image in the batch .. GENERATED FROM PYTHON SOURCE LINES 142-144 .. code-block:: Python imagesc(x_pinv[1, 0, :, :].cpu(), "pinv") .. image-sg:: /gallery/images/sphx_glr_tuto_05_dcnet_002.png :alt: pinv :srcset: /gallery/images/sphx_glr_tuto_05_dcnet_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 145-147 .. note:: Thanks to preprocessing, the reconstructed image has values in the range (0, 1), like the ground truth image. .. GENERATED FROM PYTHON SOURCE LINES 150-152 Denoised Completion Network (DC-Net) ========================================= .. GENERATED FROM PYTHON SOURCE LINES 154-165 A DC-Net is based on four sequential steps: 1. Denoising in the measurement domain. 2. Estimation of the missing measurements from the denoised ones. 3. Image-domain mapping. 4. (Learned) Denoising in the image domain. Typically, only the last step involves learnable parameters. .. GENERATED FROM PYTHON SOURCE LINES 167-169 Denoised Completion ========================================= .. GENERATED FROM PYTHON SOURCE LINES 171-177 The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`m = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing .. math:: \| m - Hx \|^2_{\Gamma^{-1}} + \|x\|^2_{\Sigma^{-1}}, where :math:`\Sigma` is a covariance prior and :math:`\Gamma` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details). .. GENERATED FROM PYTHON SOURCE LINES 179-180 In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior. .. GENERATED FROM PYTHON SOURCE LINES 180-184 .. code-block:: Python from spyrit.core.recon import DCNet dcnet = DCNet(meas_op, prep_op, Cov / 4, device=device) .. GENERATED FROM PYTHON SOURCE LINES 185-187 .. note:: We divide the covariance by four because it was computed using images with values in the range (-1, 1), whereas our images are in the range (0, 1). Therefore, the covariance is four times larger than expected. .. GENERATED FROM PYTHON SOURCE LINES 189-190 Reconstruction .. GENERATED FROM PYTHON SOURCE LINES 190-194 .. code-block:: Python dcnet = dcnet.to(device) with torch.no_grad(): x_dc = dcnet.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 195-196 We display the second image in the batch .. GENERATED FROM PYTHON SOURCE LINES 196-198 .. code-block:: Python imagesc(x_dc[1, 0, :, :].cpu(), "denoised completion") .. image-sg:: /gallery/images/sphx_glr_tuto_05_dcnet_003.png :alt: denoised completion :srcset: /gallery/images/sphx_glr_tuto_05_dcnet_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 199-201 .. note:: In this tutorial, the covariance matrix used to define subsampling is also used as prior for reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 203-205 (Learned) Denoising in the image domain ========================================= .. GENERATED FROM PYTHON SOURCE LINES 208-209 We download the parameters of a (spyrit 2.4) UNet denoiser .. GENERATED FROM PYTHON SOURCE LINES 209-216 .. code-block:: Python from spyrit.misc.load_data import download_girder url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" model_folder = "./model/" dataID = "67221559f03a54733161e960" # unique ID of the file model_cnn_path = download_girder(url, dataID, model_folder) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth... Downloading tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth... done. .. GENERATED FROM PYTHON SOURCE LINES 217-222 The UNet should be placed in an ordered dictionary and passed to a :class:`nn.Sequential`. SPyRiT 2.4 trains neural networks for images with values in the range (-1, 1), while SPyRiT 3 assumes images with values in the range (0, 1). This can be compensated for using :class:`spyrit.core.prep.Rerange`. .. GENERATED FROM PYTHON SOURCE LINES 222-234 .. code-block:: Python from spyrit.core.prep import Rerange from typing import OrderedDict from spyrit.core.nnet import Unet from spyrit.core.train import load_net rerange = Rerange((0, 1), (-1, 1)) denoiser = OrderedDict( {"rerange": rerange, "denoi": Unet(), "rerange_inv": rerange.inverse()} ) denoiser = torch.nn.Sequential(denoiser) load_net(model_cnn_path, denoiser, device, False) .. rst-class:: sphx-glr-script-out .. code-block:: none Model Loaded: /home/docs/checkouts/readthedocs.org/user_builds/spyrit/checkouts/3.1.1/tutorial/model/tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth .. GENERATED FROM PYTHON SOURCE LINES 235-236 To implement denoising in the image domain, we pass the :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`. .. GENERATED FROM PYTHON SOURCE LINES 236-239 .. code-block:: Python dcnet = DCNet(meas_op, prep_op, Cov, denoiser, device=device) dcnet = dcnet.to(device) # Use GPU, if available .. GENERATED FROM PYTHON SOURCE LINES 240-241 We reconstruct the image .. GENERATED FROM PYTHON SOURCE LINES 241-244 .. code-block:: Python with torch.no_grad(): x_dcnet = dcnet.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 245-247 We display the second image in the batch sphinx_gallery_thumbnail_number = 4 .. GENERATED FROM PYTHON SOURCE LINES 247-249 .. code-block:: Python im = imagesc(x_dcnet[1, 0, :, :].cpu(), "denoised completion") .. image-sg:: /gallery/images/sphx_glr_tuto_05_dcnet_004.png :alt: denoised completion :srcset: /gallery/images/sphx_glr_tuto_05_dcnet_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 250-252 Results ========================================= .. GENERATED FROM PYTHON SOURCE LINES 252-283 .. code-block:: Python import matplotlib.pyplot as plt from spyrit.misc.disp import add_colorbar, noaxis i_im = 1 f, axs = plt.subplots(2, 2, figsize=(10, 10)) # Plot the ground-truth image im1 = axs[0, 0].imshow(x[i_im, 0, :, :], cmap="gray") axs[0, 0].set_title("Ground-truth image", fontsize=16) noaxis(axs[0, 0]) add_colorbar(im1, "bottom") # Plot the pseudo inverse solution im2 = axs[0, 1].imshow(x_pinv.cpu()[i_im, 0, :, :], cmap="gray") axs[0, 1].set_title("Pseudoinverse", fontsize=16) noaxis(axs[0, 1]) add_colorbar(im2, "bottom") # Plot the solution obtained from denoised completion im3 = axs[1, 0].imshow(x_dc.cpu()[i_im, 0, :, :], cmap="gray") axs[1, 0].set_title("Denoised completion", fontsize=16) noaxis(axs[1, 0]) add_colorbar(im3, "bottom") # Plot the solution obtained from denoised completion with UNet denoising im4 = axs[1, 1].imshow(x_dcnet.cpu()[i_im, 0, :, :], cmap="gray") axs[1, 1].set_title("Denoised completion + UNet", fontsize=16) noaxis(axs[1, 1]) add_colorbar(im4, "bottom") .. image-sg:: /gallery/images/sphx_glr_tuto_05_dcnet_005.png :alt: Ground-truth image, Pseudoinverse, Denoised completion, Denoised completion + UNet :srcset: /gallery/images/sphx_glr_tuto_05_dcnet_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/3.1.1/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 284-285 While the pseudo inverse reconstruction is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet provides the best reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 287-289 .. note:: We refer to `spyrit-examples `_ for a comparison of several methods (e.g., pinvNet, DCNet, DRUNet). .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.486 seconds) .. _sphx_glr_download_gallery_tuto_05_dcnet.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_05_dcnet.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_05_dcnet.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_05_dcnet.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_