spyrit.core.recon.TikhoNet

class spyrit.core.recon.TikhoNet(acqu: Linear, prep, sigma: tensor, denoi=Identity(), *, device: device = device(type='cpu'), **tikho_kwargs)[source]

Bases: _PrebuiltFullNet

Pre-built FullNet that uses a Tikhonov reconstruction operator.

As a FullNet, this network has two modules: one for measurements and one for reconstruction.

The measurement module only contains the acquisition operator. The reconstruction module contains a preprocessing operator, a Tikhonov inverse operator, and a denoising operator.

This is a two-step reconstruction method [1]. The first step estimates the signal \(\tilde{x}\) from preprocessed measurements \(y\) by minimizing

\[\| y - Ax \|^2_{\Gamma^{-1}} + \|x\|^2_{\Sigma^{-1}}\]

where \(A\) is the measurement matrix, \(\Gamma\) is the covariance of the noise, and \(\Sigma\) is the signal covariance prior. The solution is computed as

\[\tilde{x} = \Sigma A^\top (A \Sigma A^\top + \Gamma)^{-1} y\]

The second step applies a learnable neural network \(\mathcal{G}_\theta\) to the output of the first step:

\[\hat{x} = \mathcal{G}_\theta(\tilde{x})\]

where \(\theta\) are the learnable parameters of the neural network.

The optional keyword arguments passed at initialization are fed in the Tikhonov operator. This way, the regularization can be controlled directly from the TikhoNet constructor.

References:
Args:

acqu (spyrit.core.meas): Acquisition operator (see meas)

prep (spyrit.core.prep): Preprocessing operator (see prep)

sigma (torch.tensor): Image-domain covariance prior (for details, see the Tikhonov() class)

denoi (torch.nn.Module, optional): Image denoising operator (see nnet). Default Identity

kwargs (dict): Optional keyword arguments passed to the Tikhonov() constructor. May contain the following keys:

  • approx (bool): If True, the Tikhonov inversion step is approximated

using a diagonal matrix. Default is False.

  • reshape_output (bool): If True, the output of the Tikhonov

inversion step is reshaped to match the acquisition operator input shape. Default is True.

Attributes:

acqu: Acquisition operator initialized as acqu.

acqu_modules (nn.Sequential): Measurement modules. Only contains the acquisition operator.

prep: Preprocessing operator initialized as prep.

tikho: Data consistency layer initialized as Tikhonov(noise.meas_op, sigma).

denoi: Image denoising operator initialized as denoi.

recon_modules (nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator.

Input / Output:

input (torch.tensor): Ground-truth images with shape \((b,c,h,w)\).

output (torch.tensor): Reconstructed images with shape \((b,c,h,w)\).

Example 1:
>>> import spyrit
>>> noise = spyrit.core.noise.Poisson(100)
>>> acqu = spyrit.core.meas.HadamSplit2d(8, noise_model=noise)
>>> prep = spyrit.core.prep.UnsplitRescale()
>>> sigma = torch.ones(64, 64)
>>> tikho = TikhoNet(acqu, prep, sigma, device=torch.device("cpu"))
>>> x = torch.rand(10, 1, 8, 8)
>>> z = tikho(x)
>>> print(z.shape)
torch.Size([10, 1, 8, 8])
Example 2:
>>> noise = spyrit.core.noise.Gaussian(1.0)
>>> acqu = spyrit.core.meas.HadamSplit2d(8, noise_model=noise)
>>> prep = spyrit.core.prep.UnsplitRescale()
>>> sigma = torch.ones(64, 64)
>>> tikho = TikhoNet(acqu, prep, sigma, approx=True, reshape_output=False)
>>> x = torch.rand(10, 1, 8, 8)
>>> z = tikho(x)
>>> print(z.shape)
torch.Size([10, 1, 64])

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Reconstruct an image from measurements.

reconstruct_pinv(y)

Reconstruct an image from measurements without denoising.