06. Denoised Completion Network (DCNet)

This tutorial shows how to perform image reconstruction using the denoised completion network (DCNet) with a trainable image denoiser. In the next tutorial, we will plug a denoiser into a DCNet, which requires no training.

Reconstruction and neural network denoising architecture sketch using split measurements

Note

As in the previous tutorials, we consider a split Hadamard operator and measurements corrupted by Poisson noise (see Tutorial 5).

Load a batch of images

Update search path

# sphinx_gallery_thumbnail_path = 'fig/tuto6.png'
import os

import torch
import torchvision
import matplotlib.pyplot as plt

import spyrit.core.torch as spytorch
from spyrit.misc.disp import imagesc
from spyrit.misc.statistics import transform_gray_norm

spyritPath = os.getcwd()
imgs_path = os.path.join(spyritPath, "images/")

Images \(x\) for training neural networks expect values in [-1,1]. The images are normalized and resized using the transform_gray_norm() function.

h = 64  # image is resized to h x h
transform = transform_gray_norm(img_size=h)

Create a data loader from some dataset (images must be in the folder images/test/)

dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)

x, _ = next(iter(dataloader))
print(f"Shape of input images: {x.shape}")
Shape of input images: torch.Size([7, 1, 64, 64])

Select the i-th image in the batch

i = 1  # Image index (modify to change the image)
x = x[i : i + 1, :, :, :]
x = x.detach().clone()
print(f"Shape of selected image: {x.shape}")
b, c, h, w = x.shape
Shape of selected image: torch.Size([1, 1, 64, 64])

Plot the selected image

imagesc(x[0, 0, :, :], r"$x$ in [-1, 1]")
$x$ in [-1, 1]
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  x = np.array(x, subok=True, copy=copy)

Forward operators for split measurements

We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to Tutorial 5).

First, we download the covariance matrix from our warehouse.

import girder_client
from spyrit.misc.load_data import download_girder

# Get covariance matrix
url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
dataId = "672207cbf03a54733161e95d"
data_folder = "./stat/"
cov_name = "Cov_64x64.pt"
# download
file_abs_path = download_girder(url, dataId, data_folder, cov_name)

try:
    Cov = torch.load(file_abs_path, weights_only=True)
    print(f"Cov matrix {cov_name} loaded")
except:
    Cov = torch.eye(h * h)
    print(f"Cov matrix {cov_name} not found! Set to the identity")
File already exists at ./stat/Cov_64x64.pt
Cov matrix Cov_64x64.pt loaded

We define the measurement, noise and preprocessing operators and then simulate a measurement vector corrupted by Poisson noise. As in the previous tutorials, we simulate an accelerated acquisition by subsampling the measurement matrix by retaining only the first rows of a Hadamard matrix that is permuted looking at the diagonal of the covariance matrix.

from spyrit.core.meas import HadamSplit
from spyrit.core.noise import Poisson
from spyrit.core.prep import SplitPoisson

# Measurement parameters
M = h**2 // 4  # Number of measurements (here, 1/4 of the pixels)
alpha = 100.0  # number of photons

# Measurement and noise operators
Ord = spytorch.Cov2Var(Cov)
meas_op = HadamSplit(M, h, Ord)
noise_op = Poisson(meas_op, alpha)
prep_op = SplitPoisson(alpha, meas_op)

print(f"Shape of image: {x.shape}")

# Measurements
y = noise_op(x)  # a noisy measurement vector
m = prep_op(y)  # preprocessed measurement vector

m_plot = spytorch.meas2img(m, Ord)
imagesc(m_plot[0, 0, :, :], r"Measurements $m$")
Measurements $m$
Shape of image: torch.Size([1, 1, 64, 64])
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  x = np.array(x, subok=True, copy=copy)

Pseudo inverse solution

We compute the pseudo inverse solution using spyrit.core.recon.PinvNet class as in the previous tutorial.

# Instantiate a PinvNet (with no denoising by default)
from spyrit.core.recon import PinvNet

pinvnet = PinvNet(noise_op, prep_op)

# Use GPU, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)
pinvnet = pinvnet.to(device)
y = y.to(device)

# Reconstruction
with torch.no_grad():
    z_invnet = pinvnet.reconstruct(y)
Using device:  cpu

Denoised completion network (DCNet)

Sketch of the DCNet architecture

The DCNet is based on four sequential steps:

  1. Denoising in the measurement domain.

  2. Estimation of the missing measurements from the denoised ones.

  3. Image-domain mapping.

  4. (Learned) Denoising in the image domain.

Typically, only the last step involves learnable parameters.

Denoised completion

The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements \(y = Hx\), where \(H\) is the measurement matrix and \(x\) is the unknown image, it estimates \(x\) from \(y\) by minimizing

\[\| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}},\]

where \(\Sigma\) is a covariance prior and \(\Sigma_\alpha\) is the noise covariance. Denoised completation can be performed using the TikhonovMeasurementPriorDiag class (see documentation for more details).

In practice, it is more convenient to use the spyrit.core.recon.DCNet class, which relies on a forward operator, a preprocessing operator, and a covariance prior.

from spyrit.core.recon import DCNet

dcnet = DCNet(noise_op, prep_op, Cov)

# Use GPU, if available
dcnet = dcnet.to(device)
y = y.to(device)

with torch.no_grad():
    z_dcnet = dcnet.reconstruct(y)

Note

In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction.

(Learned) Denoising in the image domain

To implement denoising in the image domain, we provide a spyrit.core.nnet.Unet denoiser to a spyrit.core.recon.DCNet.

from spyrit.core.nnet import Unet

denoi = Unet()
dcnet_unet = DCNet(noise_op, prep_op, Cov, denoi)
dcnet_unet = dcnet_unet.to(device)  # Use GPU, if available

We load pretrained weights for the UNet

from spyrit.core.train import load_net

local_folder = "./model/"
# Create model folder
if os.path.exists(local_folder):
    print(f"{local_folder} found")
else:
    os.mkdir(local_folder)
    print(f"Created {local_folder}")

# Load pretrained model
url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
dataID = "67221559f03a54733161e960"  # unique ID of the file
data_name = "tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth"
model_unet_path = os.path.join(local_folder, data_name)

if os.path.exists(model_unet_path):
    print(f"Model found : {data_name}")

else:
    print(f"Model not found : {data_name}")
    print(f"Downloading model... ", end="")
    try:
        gc = girder_client.GirderClient(apiUrl=url)
        gc.downloadFile(dataID, model_unet_path)
        print("Done")
    except Exception as e:
        print("Failed with error: ", e)

# Load pretrained model
load_net(model_unet_path, dcnet_unet, device, False)
./model/ found
Model not found : tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth
Downloading model... Done
Model Loaded: ./model/tuto6_dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_light.pth

We reconstruct the image

with torch.no_grad():
    z_dcnet_unet = dcnet_unet.reconstruct(y)

Results

from spyrit.misc.disp import add_colorbar, noaxis

f, axs = plt.subplots(2, 2, figsize=(10, 10))

# Plot the ground-truth image
im1 = axs[0, 0].imshow(x[0, 0, :, :], cmap="gray")
axs[0, 0].set_title("Ground-truth image", fontsize=16)
noaxis(axs[0, 0])
add_colorbar(im1, "bottom")

# Plot the pseudo inverse solution
im2 = axs[0, 1].imshow(z_invnet.cpu()[0, 0, :, :], cmap="gray")
axs[0, 1].set_title("Pseudo inverse", fontsize=16)
noaxis(axs[0, 1])
add_colorbar(im2, "bottom")

# Plot the solution obtained from denoised completion
im3 = axs[1, 0].imshow(z_dcnet.cpu()[0, 0, :, :], cmap="gray")
axs[1, 0].set_title(f"Denoised completion", fontsize=16)
noaxis(axs[1, 0])
add_colorbar(im3, "bottom")

# Plot the solution obtained from denoised completion with UNet denoising
im4 = axs[1, 1].imshow(z_dcnet_unet.cpu()[0, 0, :, :], cmap="gray")
axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16)
noaxis(axs[1, 1])
add_colorbar(im4, "bottom")

plt.show()
Ground-truth image, Pseudo inverse, Denoised completion, Denoised completion with UNet denoising
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  x = np.array(x, subok=True, copy=copy)
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  x = np.array(x, subok=True, copy=copy)
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  x = np.array(x, subok=True, copy=copy)
/home/docs/checkouts/readthedocs.org/user_builds/spyrit/envs/2.4.0/lib/python3.11/site-packages/matplotlib/cbook.py:684: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  x = np.array(x, subok=True, copy=copy)

Note

While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction.

Note

We refer to spyrit-examples tutorials for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab.

Total running time of the script: (0 minutes 4.727 seconds)

Gallery generated by Sphinx-Gallery