spyrit.core.recon.PinvNet

class spyrit.core.recon.PinvNet(acqu: Linear, prep=Identity(), denoi=Identity(), *, device: device = device(type='cpu'), **pinv_kwargs)[source]

Bases: _PrebuiltFullNet

A FullNet with a pseudo inverse-based reconstruction module.

It simulates noisy measurements

\[y =\mathcal{N}\left(Ax\right),\]

where \(\mathcal{N}\) represents a noise operator (e.g., Gaussian), \(A\) is the acquisition matrix, \(x\) is the signal of interest.

It estimates the signal from the noisy measurements in three steps [1]:

  1. Preprocessing of the measurements:

\[\tilde{m} = By,\]

where \(B\) represents a preprocessing step.

  1. Pseudo-inverse reconstruction

\[x^\dagger = H^\dagger \tilde{m},\]

where \(H^\dagger\) denotes the Moore-Penrose pseudo-inverse of \(H=BA\).

  1. Denoising/artefact correction

\[\hat{x} = \mathcal{G}_\theta(x^\dagger)\]

where \(\mathcal{G}_\theta\) is a neural network with learnable parameters \(\theta\).

Args:

acqu (spyrit.core.meas): Acquisition operator \(\mathcal{N}\circ A\).

prep (spyrit.core.prep): Preprocessing operator \(B\). Defaults to no preprocessing (i.e., spyrit.core.prep.Identity).

denoi (torch.nn.Module, optional): Image denoising operator \(\mathcal{G}_\theta\). Defaults to no denoising (i.e., to torch.nn.Identity).

**pinv_kwargs: Optional keyword arguments passed to the pseudo inverse operator (see spyrit.core.inverse.PseudoInverse).

Attributes:

acqu (spyrit.core.meas): Acquisition operator \(\mathcal{N}\circ A\).

acqu_modules (torch.nn.Sequential): Acquisition modules. Contains only acqu.

prep (spyrit.core.prep): Preprocessing operator \(B\).

pinv (spyrit.core.inverse.PseudoInverse): Pseudo inverse operator \(H^\dagger\).

denoi (torch.nn.Module): Image denoising operator \(\mathcal{G}_\theta\).

recon_modules (torch.nn.Sequential): Reconstruction module. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator, i.e., \(\mathcal{G}_\theta\circ H^\dagger \circ B\).

pinv_kwargs (dict): Optional keyword arguments passed to the pseudo inverse operator.

Input / Output:

input: Ground-truth images with shape \((b,c,h,w)\), with \(b\) being the batch size, \(c\) the number of channels, and \(h\) and \(w\) the height and width of the images.

output: Reconstructed images with shape \((b,c,h,w)\).

References:
Example:
>>> import spyrit.core.meas as meas
>>> import spyrit.core.recon as recon
>>> acqu = meas.HadamSplit2d(32)
>>> pinv = recon.PinvNet(acqu)
>>> x = torch.rand(10, 1, 32, 32)
>>> y = pinv.acquire(x)
>>> z = pinv.reconstruct(y)
>>> print(y.shape)
torch.Size([10, 1, 2048])
>>> print(z.shape)
torch.Size([10, 1, 32, 32])

Same as above with preprocessing (unsplitting). Note the arguments that are passed to the pseudo inverse operator to work on the H matrix and reshape the output.

>>> import spyrit.core.meas as meas
>>> import spyrit.core.recon as recon
>>> import spyrit.core.prep as sprep
>>> acqu = meas.HadamSplit2d(32)
>>> prep = sprep.Unsplit()
>>> pinv = recon.PinvNet(acqu, prep, use_fast_pinv=True, reshape_output=True)
>>> x = torch.rand(10, 1, 32, 32)
>>> y = pinv.acquire(x)
>>> z = pinv.reconstruct(y)
>>> print(y.shape)
torch.Size([10, 1, 2048])
>>> print(z.shape)
torch.Size([10, 1, 32, 32])

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Apply the reconstruction modules to the input measurements.

reconstruct_pinv(y)

Reconstructs measurement vectors without denoising.