spyrit.core.recon.PinvNet

class spyrit.core.recon.PinvNet(acqu: Linear, prep=Identity(), denoi=Identity(), *, device: device = device(type='cpu'), **pinv_kwargs)[source]

Bases: _PrebuiltFullNet

Pre-built FullNet that uses a pseudo inverse.

As a FullNet, this network has two modules: one for measurements and one for reconstruction.

The measurement module only contains the acquisition operator. The reconstruction module contains a preprocessing operator, a pseudo inverse operator, and a denoising operator.

The optional keyword arguments passed at initialization are fed in the pseudo inverse operator. This way, the regularization can be controlled directly from the PinvNet constructor.

Args:

acqu (spyrit.core.meas): Acquisition operator

prep (spyrit.core.prep): Preprocessing operator. Defaults to no preprocesing (i.e., Identity).

denoi (torch.nn.Module, optional): Image denoising operator. Defaults to no denoising (i.e., to Identity).

**pinv_kwargs: Optional keyword arguments passed to the pseudo inverse operator (see spyrit.core.inverse.PseudoInverse).

Attributes:

acqu (spyrit.core.meas): Acquisition operator.

acqu_modules (torch.nn.Sequential): Measurement modules. Only contains the acquisition operator.

prep (spyrit.core.prep): Preprocessing operator.

pinv (spyrit.core.inverse.PseudoInverse): Pseudo inverse operator.

denoi (torch.nn.Module): Image denoising operator.

recon_modules (torch.nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator.

pinv_kwargs (dict): Optional keyword arguments passed to the pseudo inverse operator.

Input / Output:

input: Ground-truth images with shape \((b,c,h,w)\), with \(b\) being the batch size, \(c\) the number of channels, and \(h\) and \(w\) the height and width of the images.

output: Reconstructed images with shape \((b,c,h,w)\).

Example:
>>> import spyrit.core.meas as meas
>>> import spyrit.core.recon as recon
>>> acqu = meas.HadamSplit2d(32)
>>> pinv = recon.PinvNet(acqu)
Example with a regularized pseudo inverse:
>>> import spyrit.core.meas as meas
>>> import spyrit.core.recon as recon
>>> acqu = meas.HadamSplit2d(32)
>>> pinv = recon.PinvNet(acqu, use_fast_pinv=False, store_H_pinv=True, regularization='H1', eta=1e-6, img_shape=(32, 32))

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Apply the reconstruction modules to the input measurements.

reconstruct_pinv(y)

Reconstructs measurement vectors without denoising.