a00. Connect to deepinverse (HadamSplit2d)

This tutorial shows how to use DeepInverse (https://github.com/deepinv/deepinv) algorithms with a HadamSplit2d linear model. It used the spyrit.core.meas.HadamSplit2d class of the spyrit.core.meas submodule.

Reconstruction architecture sketch

Loads images

We load a batch of images from the /images/ folder with values in (0,1).

import os
import torchvision
import torch.nn

import matplotlib.pyplot as plt

from spyrit.misc.disp import imagesc
from spyrit.misc.statistics import transform_gray_norm

import deepinv as dinv

spyritPath = os.getcwd()
imgs_path = os.path.join(spyritPath, "images/")

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

# Grayscale images of size (32, 32), no normalization to keep values in (0,1)
transform = transform_gray_norm(img_size=32, normalize=False)

# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)

x, _ = next(iter(dataloader))
print(f"Ground-truth images: {x.shape}")
/Users/tbaudier/spyrit/deepinv/deepinv/__about__.py:8: DeprecationWarning: Implicit None on return values is deprecated and will raise KeyErrors.
  __license__ = metadata["License"]
Ground-truth images: torch.Size([7, 1, 32, 32])

We select the second image in the batch and plot it.

i_plot = 1
imagesc(x[i_plot, 0, :, :], r"$32\times 32$ image $X$")
$32\times 32$ image $X$

Basic example

We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. Reshape output is necesary for deepinv. We also add Poisson noise.

from spyrit.core.meas import HadamSplit2d
import spyrit.core.noise as noise
from spyrit.core.prep import UnsplitRescale

meas_spyrit = HadamSplit2d(32, 512, device=device, reshape_output=True)
alpha = 50  # image intensity
meas_spyrit.noise_model = noise.Poisson(alpha)
y = meas_spyrit(x)

# preprocess
prep = UnsplitRescale(alpha)
m_spyrit = prep(y)

print(y.shape)
torch.Size([7, 1, 1024])

The norm has to be computed to be passed to deepinv. We need to use the max singular value of the linear operator.

norm = torch.linalg.norm(meas_spyrit.H, ord=2)
print(norm)
tensor(32.0000)

Forward operator

You can direcly give the forward operator to deepinv. You can also add noise using deepinv model or spyrit model.

meas_deepinv = dinv.physics.LinearPhysics(
    lambda y: meas_spyrit.measure_H(y) / norm,
    A_adjoint=lambda y: meas_spyrit.unvectorize(meas_spyrit.adjoint_H(y) / norm),
)
# meas_deepinv.noise_model = dinv.physics.GaussianNoise(sigma=0.01)
m_deepinv = meas_deepinv(x)
print("diff:", torch.linalg.norm(m_spyrit / norm - m_deepinv))
diff: tensor(5.6969)

Reconstruction with deepinverse

First, use the adjoint and dagger (pseudo-inverse) operators to reconstruct the image.

x_adj = meas_deepinv.A_adjoint(m_spyrit / norm)
imagesc(x_adj[1, 0, :, :].cpu(), "Adjoint")

x_pinv = meas_deepinv.A_dagger(m_spyrit / norm)
imagesc(x_pinv[1, 0, :, :].cpu(), "Pinv")
  • Adjoint
  • Pinv

You can also use optimization-based methods from deepinv. Here, we use Total Variation (TV) regularization with a projected gradient descent (PGD) algorithm. You can note the use of the custom_init parameter to initialize the algorithm with the dagger operator.

model_tv = dinv.optim.optim_builder(
    iteration="PGD",
    prior=dinv.optim.TVPrior(),
    data_fidelity=dinv.optim.L2(),
    params_algo={"stepsize": 1, "lambda": 5e-2},
    max_iter=10,
    custom_init=lambda y, Physics: {"est": (Physics.A_dagger(y),)},
)

x_tv, metrics_TV = model_tv(m_spyrit / norm, meas_deepinv, compute_metrics=True, x_gt=x)
dinv.utils.plot_curves(metrics_TV)
imagesc(x_tv[1, 0, :, :].cpu(), "TV recon")
  • PSNR, F, residual
  • TV recon

Deep Plug and Play (DPIR) algorithm can also be used with a pretrained denoiser. Here, we use the DRUNet denoiser.

denoiser = dinv.models.DRUNet(in_channels=1, out_channels=1, device=device)
model_dpir = dinv.optim.DPIR(sigma=1e-1, device=device, denoiser=denoiser)
model_dpir.custom_init = lambda y, Physics: {"est": (Physics.A_dagger(y),)}
with torch.no_grad():
    x_dpir = model_dpir(m_spyrit / norm, meas_deepinv)
imagesc(x_dpir[1, 0, :, :].cpu(), "DIPR recon")
DIPR recon

Reconstruct Anything Model (RAM) can also be used.

model_ram = dinv.models.RAM(pretrained=True, device=device)
model_ram.sigma_threshold = 1e-1
with torch.no_grad():
    x_ram = model_ram(m_spyrit / norm, meas_deepinv)
imagesc(x_ram[1, 0, :, :].cpu(), "RAM recon")
RAM recon

Total running time of the script: (0 minutes 11.085 seconds)