spyrit.core.recon.DCNet

class spyrit.core.recon.DCNet(acqu: HadamSplit2d, prep: Rescale | RescaleEstim, sigma: tensor, denoi=Identity(), *, device: device = device(type='cpu'))[source]

Bases: _PrebuiltFullNet

Pre-built FullNet that uses a TikhonovMeasurementPriorDiag reconstruction operator.

As a FullNet, this network has two modules: one for the measurements and one for the reconstruction.

The measurement module only contains the acquisition operator. The acquisition operator must be a spyrit.core.meas.HadamSplit2d operator. The reconstruction module contains a preprocessing operator, a Tikhonov regularization spyrit.core.inverse.TikhonovMeasurementPriorDiag reconstruction operator, and a denoising operator.

Args:

acqu: Acquisition operator (see HadamSplit2d)

prep: Preprocessing operator (see prep)

sigma: Measurement covariance prior (for details, see the TikhonovMeasurementPriorDiag() class)

denoi (optional): Image denoising operator (see nnet). Default Identity

Attributes:

acqu: Acquisition operator initialized as acqu

acqu_modules (nn.Sequential): Measurement modules. Only contains the acquisition operator.

prep: Preprocessing operator initialized as prep

tikho: Tikhonv regularization operator initialized as a TikhonovMeasurementPriorDiag operator.

denoi: Image denoising operator initialized as denoi

recon_modules (nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the Tikhonov regularizaiton operator, and the denoising operator.

Input / Output:

input: Ground-truth images with shape \((b,c,h,w)\), with \(b\) being the batch size, \(c\) the number of channels, and \(h\) and \(w\) the height and width of the images.

output: Reconstructed images with shape \((b,c,h,w)\).

Example:
>>> from spyrit.core.meas import HadamSplit2d
>>> from spyrit.core.prep import UnsplitRescale
>>> from spyrit.core.recon import DCNet
>>> import torch
>>> acqu = HadamSplit2d(32)
>>> prep = UnsplitRescale()
>>> sigma = torch.eye(32*32,32*32)
>>> dcnet = DCNet(acqu, prep, sigma)
>>> y = torch.randn(10, 1, 2048)
>>> z = dcnet.reconstruct(y)
>>> print(z.shape)
torch.Size([10, 1, 32, 32])

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Reconstruct an image from measurements.

reconstruct_pinv(y)

Reconstruct an image from measurements without denoising.