spyrit.core.recon.TikhoNet

class spyrit.core.recon.TikhoNet(acqu: Linear, prep, sigma: tensor, denoi=Identity(), *, device: device = device(type='cpu'), **tikho_kwargs)[source]

Bases: _PrebuiltFullNet

Pre-built FullNet that uses a Tikhonov reconstruction operator.

As a FullNet, this network has two modules: one for measurements and one for reconstruction.

The measurement module only contains the acquisition operator. The reconstruction module contains a preprocessing operator, a Tikhonov inverse operator, and a denoising operator.

The optional keyword arguments passed at initialization are fed in the Tikhonov operator. This way, the regularization can be controlled directly from the TikhoNet constructor.

Args:

acqu (spyrit.core.meas): Acquisition operator (see meas)

prep (spyrit.core.prep): Preprocessing operator (see prep)

sigma (torch.tensor): Image-domain covariance prior (for details, see the Tikhonov() class)

denoi (torch.nn.Module, optional): Image denoising operator (see nnet). Default Identity

kwargs (dict): Optional keyword arguments passed to the Tikhonov() constructor. May contain the following keys:

  • approx (bool): If True, the Tikhonov inversion step is approximated

using a diagonal matrix. Default is False. - reshape_output (bool): If True, the output of the Tikhonov inversion step is reshaped to match the acquisition operator input shape. Default is True.

Attributes:

acqu: Acquisition operator initialized as acqu.

acqu_modules (nn.Sequential): Measurement modules. Only contains the acquisition operator.

prep: Preprocessing operator initialized as prep.

tikho: Data consistency layer initialized as Tikhonov(noise.meas_op, sigma).

denoi: Image denoising operator initialized as denoi.

recon_modules (nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator.

Input / Output:

input (torch.tensor): Ground-truth images with shape \((b,c,h,w)\).

output (torch.tensor): Reconstructed images with shape \((b,c,h,w)\).

Example 1:
>>> import spyrit
>>> noise = spyrit.core.noise.Poisson(100)
>>> acqu = spyrit.core.meas.HadamSplit2d(8, noise_model=noise)
>>> prep = spyrit.core.prep.UnsplitRescale()
>>> sigma = torch.ones(64, 64)
>>> tikho = TikhoNet(acqu, prep, sigma, device=torch.device("cpu"))
>>> x = torch.rand(10, 1, 8, 8)
>>> z = tikho(x)
>>> print(z.shape)
torch.Size([10, 1, 8, 8])
Example 2:
>>> noise = spyrit.core.noise.Gaussian(1.0)
>>> acqu = spyrit.core.meas.HadamSplit2d(8, noise_model=noise)
>>> prep = spyrit.core.prep.UnsplitRescale()
>>> sigma = torch.ones(64, 64)
>>> tikho = TikhoNet(acqu, prep, sigma, approx=True, reshape_output=False)
>>> x = torch.rand(10, 1, 8, 8)
>>> z = tikho(x)
>>> print(z.shape)
torch.Size([10, 1, 64])

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Reconstruct an image from measurements.

reconstruct_pinv(y)

Reconstruct an image from measurements without denoising.