.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_06_dcnet_split_measurements.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_06_dcnet_split_measurements.py: ========================================= 06. Denoised Completion Network (DCNet) ========================================= .. _tuto_dcnet_split_measurements: This tutorial shows how to perform image reconstruction using the denoised completion network (DCNet) with a trainable image denoiser. In the next tutorial, we will plug a denoiser into a DCNet, which requires no training. .. figure:: ../fig/tuto3.png :width: 600 :align: center :alt: Reconstruction and neural network denoising architecture sketch using split measurements .. GENERATED FROM PYTHON SOURCE LINES 17-20 .. note:: As in the previous tutorials, we consider a split Hadamard operator and measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `). .. GENERATED FROM PYTHON SOURCE LINES 22-24 Load a batch of images ========================================= .. GENERATED FROM PYTHON SOURCE LINES 26-27 Update search path .. GENERATED FROM PYTHON SOURCE LINES 27-42 .. code-block:: Python # sphinx_gallery_thumbnail_path = 'fig/tuto6.png' import os import torch import torchvision import matplotlib.pyplot as plt import spyrit.core.torch as spytorch from spyrit.misc.disp import imagesc from spyrit.misc.statistics import transform_gray_norm spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") .. GENERATED FROM PYTHON SOURCE LINES 43-44 Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. .. GENERATED FROM PYTHON SOURCE LINES 44-48 .. code-block:: Python h = 64 # image is resized to h x h transform = transform_gray_norm(img_size=h) .. GENERATED FROM PYTHON SOURCE LINES 49-50 Create a data loader from some dataset (images must be in the folder `images/test/`) .. GENERATED FROM PYTHON SOURCE LINES 50-57 .. code-block:: Python dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Shape of input images: {x.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of input images: torch.Size([7, 1, 64, 64]) .. GENERATED FROM PYTHON SOURCE LINES 58-59 Select the `i`-th image in the batch .. GENERATED FROM PYTHON SOURCE LINES 59-65 .. code-block:: Python i = 1 # Image index (modify to change the image) x = x[i : i + 1, :, :, :] x = x.detach().clone() print(f"Shape of selected image: {x.shape}") b, c, h, w = x.shape .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of selected image: torch.Size([1, 1, 64, 64]) .. GENERATED FROM PYTHON SOURCE LINES 66-67 Plot the selected image .. GENERATED FROM PYTHON SOURCE LINES 67-70 .. code-block:: Python imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]") .. image-sg:: /gallery/images/sphx_glr_tuto_06_dcnet_split_measurements_001.png :alt: $x$ in [-1, 1] :srcset: /gallery/images/sphx_glr_tuto_06_dcnet_split_measurements_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 71-73 Forward operators for split measurements ========================================= .. GENERATED FROM PYTHON SOURCE LINES 75-76 We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `). .. GENERATED FROM PYTHON SOURCE LINES 78-79 First, we download the covariance matrix from our warehouse. .. GENERATED FROM PYTHON SOURCE LINES 79-98 .. code-block:: Python import girder_client from spyrit.misc.load_data import download_girder # Get covariance matrix url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataId = "672207cbf03a54733161e95d" data_folder = "./stat/" cov_name = "Cov_64x64.pt" # download file_abs_path = download_girder(url, dataId, data_folder, cov_name) try: Cov = torch.load(file_abs_path, weights_only=True) print(f"Cov matrix {cov_name} loaded") except: Cov = torch.eye(h * h) print(f"Cov matrix {cov_name} not found! Set to the identity") .. rst-class:: sphx-glr-script-out .. code-block:: none File already exists at ./stat/Cov_64x64.pt Cov matrix Cov_64x64.pt loaded .. GENERATED FROM PYTHON SOURCE LINES 99-104 We define the measurement, noise and preprocessing operators and then simulate a measurement vector corrupted by Poisson noise. As in the previous tutorials, we simulate an accelerated acquisition by subsampling the measurement matrix by retaining only the first rows of a Hadamard matrix that is permuted looking at the diagonal of the covariance matrix. .. GENERATED FROM PYTHON SOURCE LINES 104-128 .. code-block:: Python from spyrit.core.meas import HadamSplit from spyrit.core.noise import Poisson from spyrit.core.prep import SplitPoisson # Measurement parameters M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels) alpha = 100.0 # number of photons # Measurement and noise operators Ord = spytorch.Cov2Var(Cov) meas_op = HadamSplit(M, h, Ord) noise_op = Poisson(meas_op, alpha) prep_op = SplitPoisson(alpha, meas_op) print(f"Shape of image: {x.shape}") # Measurements y = noise_op(x) # a noisy measurement vector m = prep_op(y) # preprocessed measurement vector m_plot = spytorch.meas2img(m, Ord) imagesc(m_plot[0, 0, :, :], r"Measurements $m$") .. image-sg:: /gallery/images/sphx_glr_tuto_06_dcnet_split_measurements_002.png :alt: Measurements $m$ :srcset: /gallery/images/sphx_glr_tuto_06_dcnet_split_measurements_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Shape of image: torch.Size([1, 1, 64, 64]) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 129-131 Pseudo inverse solution ========================================= .. GENERATED FROM PYTHON SOURCE LINES 133-134 We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet` class as in the previous tutorial. .. GENERATED FROM PYTHON SOURCE LINES 134-150 .. code-block:: Python # Instantiate a PinvNet (with no denoising by default) from spyrit.core.recon import PinvNet pinvnet = PinvNet(noise_op, prep_op) # Use GPU, if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device: ", device) pinvnet = pinvnet.to(device) y = y.to(device) # Reconstruction with torch.no_grad(): z_invnet = pinvnet.reconstruct(y) .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 151-153 Denoised completion network (DCNet) ========================================= .. GENERATED FROM PYTHON SOURCE LINES 155-159 .. image:: ../fig/dcnet.png :width: 400 :align: center :alt: Sketch of the DCNet architecture .. GENERATED FROM PYTHON SOURCE LINES 161-172 The DCNet is based on four sequential steps: i) Denoising in the measurement domain. ii) Estimation of the missing measurements from the denoised ones. iii) Image-domain mapping. iv) (Learned) Denoising in the image domain. Typically, only the last step involves learnable parameters. .. GENERATED FROM PYTHON SOURCE LINES 175-177 Denoised completion ========================================= .. GENERATED FROM PYTHON SOURCE LINES 179-185 The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing .. math:: \| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}}, where :math:`\Sigma` is a covariance prior and :math:`\Sigma_\alpha` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details). .. GENERATED FROM PYTHON SOURCE LINES 187-188 In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior. .. GENERATED FROM PYTHON SOURCE LINES 188-200 .. code-block:: Python from spyrit.core.recon import DCNet dcnet = DCNet(noise_op, prep_op, Cov) # Use GPU, if available dcnet = dcnet.to(device) y = y.to(device) with torch.no_grad(): z_dcnet = dcnet.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 201-203 .. note:: In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 206-208 (Learned) Denoising in the image domain ========================================= .. GENERATED FROM PYTHON SOURCE LINES 210-211 To implement denoising in the image domain, we provide a :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`. .. GENERATED FROM PYTHON SOURCE LINES 211-218 .. code-block:: Python from spyrit.core.nnet import Unet denoi = Unet() dcnet_unet = DCNet(noise_op, prep_op, Cov, denoi) dcnet_unet = dcnet_unet.to(device) # Use GPU, if available .. GENERATED FROM PYTHON SOURCE LINES 219-220 We load pretrained weights for the UNet .. GENERATED FROM PYTHON SOURCE LINES 220-253 .. code-block:: Python from spyrit.core.train import load_net local_folder = "./model/" # Create model folder if os.path.exists(local_folder): print(f"{local_folder} found") else: os.mkdir(local_folder) print(f"Created {local_folder}") # Load pretrained model url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" dataID = "67221559f03a54733161e960" # unique ID of the file data_name = "tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth" model_unet_path = os.path.join(local_folder, data_name) if os.path.exists(model_unet_path): print(f"Model found : {data_name}") else: print(f"Model not found : {data_name}") print(f"Downloading model... ", end="") try: gc = girder_client.GirderClient(apiUrl=url) gc.downloadFile(dataID, model_unet_path) print("Done") except Exception as e: print("Failed with error: ", e) # Load pretrained model load_net(model_unet_path, dcnet_unet, device, False) .. rst-class:: sphx-glr-script-out .. code-block:: none ./model/ found Model not found : tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth Downloading model... Done Model Loaded: ./model/tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth .. GENERATED FROM PYTHON SOURCE LINES 254-255 We reconstruct the image .. GENERATED FROM PYTHON SOURCE LINES 255-258 .. code-block:: Python with torch.no_grad(): z_dcnet_unet = dcnet_unet.reconstruct(y) .. GENERATED FROM PYTHON SOURCE LINES 259-261 Results ========================================= .. GENERATED FROM PYTHON SOURCE LINES 261-292 .. code-block:: Python from spyrit.misc.disp import add_colorbar, noaxis f, axs = plt.subplots(2, 2, figsize=(10, 10)) # Plot the ground-truth image im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray") axs[0, 0].set_title("Ground-truth image", fontsize=16) noaxis(axs[0, 0]) add_colorbar(im1, "bottom") # Plot the pseudo inverse solution im2 = axs[0, 1].imshow(z_invnet.cpu()[0, 0, :, :], cmap="gray") axs[0, 1].set_title("Pseudo inverse", fontsize=16) noaxis(axs[0, 1]) add_colorbar(im2, "bottom") # Plot the solution obtained from denoised completion im3 = axs[1, 0].imshow(z_dcnet.cpu()[0, 0, :, :], cmap="gray") axs[1, 0].set_title(f"Denoised completion", fontsize=16) noaxis(axs[1, 0]) add_colorbar(im3, "bottom") # Plot the solution obtained from denoised completion with UNet denoising im4 = axs[1, 1].imshow(z_dcnet_unet.cpu()[0, 0, :, :], cmap="gray") axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16) noaxis(axs[1, 1]) add_colorbar(im4, "bottom") plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_06_dcnet_split_measurements_003.png :alt: Ground-truth image, Pseudo inverse, Denoised completion, Denoised completion with UNet denoising :srcset: /gallery/images/sphx_glr_tuto_06_dcnet_split_measurements_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) /home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword x = np.array(x, subok=True, copy=copy) .. GENERATED FROM PYTHON SOURCE LINES 293-295 .. note:: While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction. .. GENERATED FROM PYTHON SOURCE LINES 297-299 .. note:: We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.727 seconds) .. _sphx_glr_download_gallery_tuto_06_dcnet_split_measurements.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_06_dcnet_split_measurements.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_06_dcnet_split_measurements.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_06_dcnet_split_measurements.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_