Note
Go to the end to download the full example code.
08. Learned proximal gradient descent (LPGD) for split measurements
This tutorial shows how to perform image reconstruction with unrolled Learned Proximal Gradient Descent (LPGD) for split measurements.
Unfortunately, it has a large memory consumption so it cannot be run interactively. If you want to run it yourself, please remove all the “if False:” statements at the beginning of each code block. The figures displayed are the ones that would be generated if the code was run.
LPGD is a unrolled method, which can be explained as a recurrent network where each block corresponds to un unrolled iteration of the proximal gradient descent. At each iteration, the network performs a gradient step and a denoising step.
The updated rule for the LPGD network is given by:
where \(x^{(k)}\) is the image estimate at iteration \(k\), \(H\) is the forward operator, \(\gamma\) is the step size, and \(\mathcal{G}_{\theta}\) is a denoising network with learnable parameters \(\theta\).
Load a batch of images
Images \(x\) for training neural networks expect values in [-1,1]. The images are normalized
using the transform_gray_norm() function.
# sphinx_gallery_thumbnail_path = 'fig/lpgd.png'
if False:
import os
import torch
import torchvision
import matplotlib.pyplot as plt
import spyrit.core.torch as spytorch
from spyrit.misc.disp import imagesc
from spyrit.misc.statistics import transform_gray_norm
spyritPath = os.getcwd()
imgs_path = os.path.join(spyritPath, "images/")
######################################################################
# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function.
h = 128 # image is resized to h x h
transform = transform_gray_norm(img_size=h)
######################################################################
# Create a data loader from some dataset (images must be in the folder `images/test/`)
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
x, _ = next(iter(dataloader))
print(f"Shape of input images: {x.shape}")
######################################################################
# Select the `i`-th image in the batch
i = 1 # Image index (modify to change the image)
x = x[i : i + 1, :, :, :]
x = x.detach().clone()
print(f"Shape of selected image: {x.shape}")
b, c, h, w = x.shape
######################################################################
# Plot the selected image
imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")
%% Forward operators for split measurements —————————————————————————–
We consider noisy split measurements for a Hadamard operator and a simple rectangular subsampling” strategy (for more details, refer to Acquisition - split measurements).
We define the measurement, noise and preprocessing operators and then simulate a measurement vector \(y\) corrupted by Poisson noise. As in the previous tutorial, we simulate an accelerated acquisition by subsampling the measurement matrix by retaining only the first rows of a Hadamard matrix.
if False:
import math
from spyrit.core.meas import HadamSplit
from spyrit.core.noise import Poisson
from spyrit.core.prep import SplitPoisson
# Measurement parameters
M = h**2 // 4 # Number of measurements (here, 1/4 of the pixels)
alpha = 10.0 # number of photons
# Sampling: rectangular matrix
Ord_rec = torch.zeros(h, h)
n_sub = math.ceil(M**0.5)
Ord_rec[:n_sub, :n_sub] = 1
# Measurement and noise operators
meas_op = HadamSplit(M, h, Ord_rec)
noise_op = Poisson(meas_op, alpha)
prep_op = SplitPoisson(alpha, meas_op)
print(f"Shape of image: {x.shape}")
# Measurements
y = noise_op(x) # a noisy measurement vector
m = prep_op(y) # preprocessed measurement vector
m_plot = spytorch.meas2img(m, Ord_rec)
imagesc(m_plot[0, 0, :, :], r"Measurements $m$")
We define the LearnedPGD network by providing the measurement, noise and preprocessing operators,
the denoiser and other optional parameters to the class spyrit.core.recon.LearnedPGD.
The optional parameters include the number of unrolled iterations (iter_stop)
and the step size decay factor (step_decay).
We choose Unet as the denoiser, as in previous tutorials.
For the optional parameters, we use three iterations and a step size decay
factor of 0.9, which worked well on this data (this should match the parameters
used during training).
if False:
from spyrit.core.nnet import Unet
from spyrit.core.recon import LearnedPGD
# use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Define UNet denoiser
denoi = Unet()
# Define the LearnedPGD model
lpgd_net = LearnedPGD(noise_op, prep_op, denoi, iter_stop=3, step_decay=0.9)
Now, we download the pretrained weights and load them into the LPGD network. Unfortunately, the pretrained weights are too heavy (2GB) to be downloaded here. The last figure is nonetheless displayed to show the results.
if False:
from spyrit.core.train import load_net
from spyrit.misc.load_data import download_girder
# Download parameters
url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
dataID = "67221f60f03a54733161e96c" # unique ID of the file
local_folder = "./model/"
data_name = "tuto8_model_lpgd_light.pth"
# Download from Girder
model_abs_path = download_girder(url, dataID, local_folder, data_name)
# Load pretrained weights to the model
load_net(model_abs_path, lpgd_net, device, strict=False)
lpgd_net.eval()
lpgd_net.to(device)
We reconstruct by calling the reconstruct method as in previous tutorials and display the results.
if False:
from spyrit.misc.disp import add_colorbar, noaxis
with torch.no_grad():
z_lpgd = lpgd_net.reconstruct(y.to(device))
# Plot results
f, axs = plt.subplots(2, 1, figsize=(10, 10))
im1 = axs[0].imshow(x.cpu()[0, 0, :, :], cmap="gray")
axs[0].set_title("Ground-truth image", fontsize=16)
noaxis(axs[0])
add_colorbar(im1, "bottom")
im2 = axs[1].imshow(z_lpgd.cpu()[0, 0, :, :], cmap="gray")
axs[1].set_title("LPGD", fontsize=16)
noaxis(axs[1])
add_colorbar(im2, "bottom")
plt.show()
Total running time of the script: (0 minutes 0.003 seconds)