.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_a00_connect_deepinv.py" .. LINE NUMBERS ARE GIVEN BELOW. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_a00_connect_deepinv.py: a00. Connect to deepinverse (HadamSplit2d) ==================================================== .. _tuto_connect_deepinv: This tutorial shows how to use DeepInverse (https://github.com/deepinv/deepinv) algorithms with a HadamSplit2d linear model. It used the :class:`spyrit.core.meas.HadamSplit2d` class of the :mod:`spyrit.core.meas` submodule. .. image:: https://github.com/deepinv/deepinv/raw/main/docs/source/figures/deepinv_logolarge.png :width: 600 :align: center :alt: Reconstruction architecture sketch | .. GENERATED FROM PYTHON SOURCE LINES 19-21 Loads images ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 23-24 We load a batch of images from the :attr:`/images/` folder with values in (0,1). .. GENERATED FROM PYTHON SOURCE LINES 24-50 .. code-block:: Python import os import torchvision import torch.nn import matplotlib.pyplot as plt from spyrit.misc.disp import imagesc from spyrit.misc.statistics import transform_gray_norm import deepinv as dinv spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" # Grayscale images of size (32, 32), no normalization to keep values in (0,1) transform = transform_gray_norm(img_size=32, normalize=False) # Create dataset and loader (expects class folder 'images/test/') dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Ground-truth images: {x.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none /Users/tbaudier/spyrit/deepinv/deepinv/__about__.py:8: DeprecationWarning: Implicit None on return values is deprecated and will raise KeyErrors. __license__ = metadata["License"] Ground-truth images: torch.Size([7, 1, 32, 32]) .. GENERATED FROM PYTHON SOURCE LINES 51-52 We select the second image in the batch and plot it. .. GENERATED FROM PYTHON SOURCE LINES 52-56 .. code-block:: Python i_plot = 1 imagesc(x[i_plot, 0, :, :], r"$32\times 32$ image $X$") .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_001.png :alt: $32\times 32$ image $X$ :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 57-59 Basic example ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 61-62 We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. Reshape output is necesary for deepinv. We also add Poisson noise. .. GENERATED FROM PYTHON SOURCE LINES 62-78 .. code-block:: Python from spyrit.core.meas import HadamSplit2d import spyrit.core.noise as noise from spyrit.core.prep import UnsplitRescale meas_spyrit = HadamSplit2d(32, 512, device=device, reshape_output=True) alpha = 50 # image intensity meas_spyrit.noise_model = noise.Poisson(alpha) y = meas_spyrit(x) # preprocess prep = UnsplitRescale(alpha) m_spyrit = prep(y) print(y.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([7, 1, 1024]) .. GENERATED FROM PYTHON SOURCE LINES 79-80 The norm has to be computed to be passed to deepinv. We need to use the max singular value of the linear operator. .. GENERATED FROM PYTHON SOURCE LINES 80-84 .. code-block:: Python norm = torch.linalg.norm(meas_spyrit.H, ord=2) print(norm) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(32.0000) .. GENERATED FROM PYTHON SOURCE LINES 85-87 Forward operator ---------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 89-90 You can direcly give the forward operator to deepinv. You can also add noise using deepinv model or spyrit model. .. GENERATED FROM PYTHON SOURCE LINES 90-99 .. code-block:: Python meas_deepinv = dinv.physics.LinearPhysics( lambda y: meas_spyrit.measure_H(y) / norm, A_adjoint=lambda y: meas_spyrit.unvectorize(meas_spyrit.adjoint_H(y) / norm), ) # meas_deepinv.noise_model = dinv.physics.GaussianNoise(sigma=0.01) m_deepinv = meas_deepinv(x) print("diff:", torch.linalg.norm(m_spyrit / norm - m_deepinv)) .. rst-class:: sphx-glr-script-out .. code-block:: none diff: tensor(5.6969) .. GENERATED FROM PYTHON SOURCE LINES 100-102 Reconstruction with deepinverse ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 104-105 First, use the adjoint and dagger (pseudo-inverse) operators to reconstruct the image. .. GENERATED FROM PYTHON SOURCE LINES 105-112 .. code-block:: Python x_adj = meas_deepinv.A_adjoint(m_spyrit / norm) imagesc(x_adj[1, 0, :, :].cpu(), "Adjoint") x_pinv = meas_deepinv.A_dagger(m_spyrit / norm) imagesc(x_pinv[1, 0, :, :].cpu(), "Pinv") .. rst-class:: sphx-glr-horizontal * .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_002.png :alt: Adjoint :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_002.png :class: sphx-glr-multi-img * .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_003.png :alt: Pinv :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_003.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 113-114 You can also use optimization-based methods from deepinv. Here, we use Total Variation (TV) regularization with a projected gradient descent (PGD) algorithm. You can note the use of the custom_init parameter to initialize the algorithm with the dagger operator. .. GENERATED FROM PYTHON SOURCE LINES 114-127 .. code-block:: Python model_tv = dinv.optim.optim_builder( iteration="PGD", prior=dinv.optim.TVPrior(), data_fidelity=dinv.optim.L2(), params_algo={"stepsize": 1, "lambda": 5e-2}, max_iter=10, custom_init=lambda y, Physics: {"est": (Physics.A_dagger(y),)}, ) x_tv, metrics_TV = model_tv(m_spyrit / norm, meas_deepinv, compute_metrics=True, x_gt=x) dinv.utils.plot_curves(metrics_TV) imagesc(x_tv[1, 0, :, :].cpu(), "TV recon") .. rst-class:: sphx-glr-horizontal * .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_004.png :alt: PSNR, F, residual :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_004.png :class: sphx-glr-multi-img * .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_005.png :alt: TV recon :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 128-129 Deep Plug and Play (DPIR) algorithm can also be used with a pretrained denoiser. Here, we use the DRUNet denoiser. .. GENERATED FROM PYTHON SOURCE LINES 129-136 .. code-block:: Python denoiser = dinv.models.DRUNet(in_channels=1, out_channels=1, device=device) model_dpir = dinv.optim.DPIR(sigma=1e-1, device=device, denoiser=denoiser) model_dpir.custom_init = lambda y, Physics: {"est": (Physics.A_dagger(y),)} with torch.no_grad(): x_dpir = model_dpir(m_spyrit / norm, meas_deepinv) imagesc(x_dpir[1, 0, :, :].cpu(), "DIPR recon") .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_006.png :alt: DIPR recon :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 137-138 Reconstruct Anything Model (RAM) can also be used. .. GENERATED FROM PYTHON SOURCE LINES 138-143 .. code-block:: Python model_ram = dinv.models.RAM(pretrained=True, device=device) model_ram.sigma_threshold = 1e-1 with torch.no_grad(): x_ram = model_ram(m_spyrit / norm, meas_deepinv) imagesc(x_ram[1, 0, :, :].cpu(), "RAM recon") .. image-sg:: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_007.png :alt: RAM recon :srcset: /gallery/images/sphx_glr_tuto_a00_connect_deepinv_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 11.085 seconds)