Note
Go to the end to download the full example code.
03. Pseudoinverse solution + CNN denoising
This tutorial shows how to simulate measurements and perform image reconstruction using PinvNet (pseudoinverse linear network) with CNN denoising as a last layer. This tutorial is a continuation of the Pseudoinverse solution tutorial but uses a CNN denoiser instead of the identity operator in order to remove artefacts.
The measurement operator is chosen as a Hadamard matrix with positive coefficients, which can be replaced by any matrix.
These tutorials load image samples from /images/.
Load a batch of images
Images \(x\) for training expect values in [-1,1]. The images are normalized
using the transform_gray_norm() function.
import os
import girder_client
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from spyrit.misc.disp import imagesc
from spyrit.misc.statistics import transform_gray_norm
# sphinx_gallery_thumbnail_path = 'fig/tuto3.png'
h = 64 # image size hxh
i = 1 # Image index (modify to change the image)
spyritPath = os.getcwd()
imgs_path = os.path.join(spyritPath, "images/")
# Create a transform for natural images to normalized grayscale image tensors
transform = transform_gray_norm(img_size=h)
# Create dataset and loader (expects class folder 'images/test/')
dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
x, _ = next(iter(dataloader))
print(f"Shape of input images: {x.shape}")
# Select image
x = x[i : i + 1, :, :, :]
x = x.detach().clone()
b, c, h, w = x.shape
# plot
x_plot = x.view(-1, h, h).cpu().numpy()
imagesc(x_plot[0, :, :], r"$x$ in [-1, 1]")
![$x$ in [-1, 1]](../_images/sphx_glr_tuto_03_pseudoinverse_cnn_linear_001.png)
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.3.2/lib/python3.11/site-packages/girder_client/__init__.py:1: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
from pkg_resources import DistributionNotFound, get_distribution
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.3.2/lib/python3.11/site-packages/pkg_resources/__init__.py:3117: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
declare_namespace(pkg)
Shape of input images: torch.Size([7, 1, 64, 64])
Define a measurement operator
We consider the case where the measurement matrix is the positive
component of a Hadamard matrix and the sampling operator preserves only
the first M low-frequency coefficients
(see Positive Hadamard matrix for full explantion).
import math
from spyrit.misc.sampling import sort_by_significance
from spyrit.misc.walsh_hadamard import walsh2_matrix
F = walsh2_matrix(h)
F = np.where(F > 0, F, 0)
und = 4 # undersampling factor
M = h**2 // und # number of measurements (undersampling factor = 4)
Sampling_map = np.ones((h, h))
M_xy = math.ceil(M**0.5)
Sampling_map[:, M_xy:] = 0
Sampling_map[M_xy:, :] = 0
F = sort_by_significance(F, Sampling_map, "rows", False)
H = F[:M, :]
print(f"Shape of the measurement matrix: {H.shape}")
imagesc(Sampling_map, "low-frequency sampling map")

Shape of the measurement matrix: (1024, 4096)
Then, we instantiate a spyrit.core.meas.Linear measurement operator
Noiseless case
In the noiseless case, we consider the spyrit.core.noise.NoNoise noise
operator
Shape of raw measurements: torch.Size([1, 1024])
We now compute and plot the preprocessed measurements corresponding to an image in [-1,1]
from spyrit.core.prep import DirectPoisson
prep = DirectPoisson(N0, meas_op) # "Undo" the NoNoise operator
m = prep(y)
print(f"Shape of the preprocessed measurements: {m.shape}")
Shape of the preprocessed measurements: torch.Size([1, 1024])
To display the subsampled measurement vector as an image in the transformed
domain, we use the spyrit.misc.sampling.meas2img() function

Shape of the preprocessed measurement image: (64, 64)
PinvNet Network
We consider the spyrit.core.recon.PinvNet class that reconstructs an
image by computing the pseudoinverse solution, which is fed to a neural
network denoiser. To compute the pseudoinverse solution only, the denoiser
can be set to the identity operator
or equivalently
Then, we reconstruct the image from the measurement vector y using the
reconstruct() method
x_rec = pinv_net.reconstruct(y)
Removing artefacts with a CNN
Artefacts can be removed by selecting a neural network denoiser
(last layer of PinvNet). We select a simple CNN using the
spyrit.core.nnet.ConvNet class, but this can be replaced by any
neural network (eg. UNet from spyrit.core.nnet.Unet).
from spyrit.core.nnet import ConvNet, Unet
from spyrit.core.train import load_net
# Define PInvNet with ConvNet denoising layer
denoi = ConvNet()
pinv_net_cnn = PinvNet(noise, prep, denoi)
# Send to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pinv_net_cnn = pinv_net_cnn.to(device)
As an example, we use a simple ConvNet that has been pretrained using STL-10 dataset. We download the pretrained weights and load them into the network.
from spyrit.misc.load_data import download_girder
# Load pretrained model
url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
dataID = "668267b3baa5a9000705896a" # unique ID of the file
local_folder = "./model/"
data_name = "tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pth"
# download the model and save it in the local folder
model_cnn_path = download_girder(url, dataID, local_folder, data_name)
# Load model weights
load_net(model_cnn_path, pinv_net_cnn, device, False)
Local folder not found, creating it... done.
Downloading file... done.
Model Loaded: /home/docs/checkouts/readthedocs.org/user_builds/spyrit/checkouts/2.3.2/tutorial/model/tuto3_pinv-net_cnn_stl10_N0_1_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pth
We now reconstruct the image using PinvNet with pretrained CNN denoising and plot results side by side with the PinvNet without denoising
with torch.no_grad():
x_rec_cnn = pinv_net_cnn.reconstruct(y.to(device))
x_rec_cnn = pinv_net_cnn(x.to(device))
# plot
x_plot = x.squeeze().cpu().numpy()
x_plot2 = x_rec.squeeze().cpu().numpy()
x_plot3 = x_rec_cnn.squeeze().cpu().numpy()
from spyrit.misc.disp import add_colorbar, noaxis
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
im1 = ax1.imshow(x_plot, cmap="gray")
ax1.set_title("Ground-truth image", fontsize=20)
noaxis(ax1)
add_colorbar(im1, "bottom", size="20%")
im2 = ax2.imshow(x_plot2, cmap="gray")
ax2.set_title("Pinv reconstruction", fontsize=20)
noaxis(ax2)
add_colorbar(im2, "bottom", size="20%")
im3 = ax3.imshow(x_plot3, cmap="gray")
ax3.set_title(f"Pinv + CNN (trained 30 epochs", fontsize=20)
noaxis(ax3)
add_colorbar(im3, "bottom", size="20%")

<matplotlib.colorbar.Colorbar object at 0x7f911c5dcc50>
We show the best result again (tutorial thumbnail purpose)
# Plot
imagesc(x_plot3, f"Pinv + CNN (trained 30 epochs", title_fontsize=20)
plt.show()

In the next tutorial, we will show how to train PinvNet + CNN denoiser.
Total running time of the script: (0 minutes 7.180 seconds)