spyrit.core.recon.PinvNet
- class spyrit.core.recon.PinvNet(acqu: Linear, prep=Identity(), denoi=Identity(), *, device: device = device(type='cpu'), **pinv_kwargs)[source]
Bases:
_PrebuiltFullNetA
FullNetwith a pseudo inverse-based reconstruction module.It simulates noisy measurements
\[y =\mathcal{N}\left(Ax\right),\]where \(\mathcal{N}\) represents a noise operator (e.g., Gaussian), \(A\) is the acquisition matrix, \(x\) is the signal of interest.
It estimates the signal from the noisy measurements in three steps [1]:
Preprocessing of the measurements:
\[\tilde{m} = By,\]where \(B\) represents a preprocessing step.
Pseudo-inverse reconstruction
\[x^\dagger = H^\dagger \tilde{m},\]where \(H^\dagger\) denotes the Moore-Penrose pseudo-inverse of \(H=BA\).
Denoising/artefact correction
\[\hat{x} = \mathcal{G}_\theta(x^\dagger)\]where \(\mathcal{G}_\theta\) is a neural network with learnable parameters \(\theta\).
- Args:
acqu(spyrit.core.meas): Acquisition operator \(\mathcal{N}\circ A\).prep(spyrit.core.prep): Preprocessing operator \(B\). Defaults to no preprocessing (i.e.,spyrit.core.prep.Identity).denoi(torch.nn.Module, optional): Image denoising operator \(\mathcal{G}_\theta\). Defaults to no denoising (i.e., totorch.nn.Identity).**pinv_kwargs: Optional keyword arguments passed to the pseudo inverse operator (seespyrit.core.inverse.PseudoInverse).- Attributes:
acqu(spyrit.core.meas): Acquisition operator \(\mathcal{N}\circ A\).acqu_modules(torch.nn.Sequential): Acquisition modules. Contains onlyacqu.prep(spyrit.core.prep): Preprocessing operator \(B\).pinv(spyrit.core.inverse.PseudoInverse): Pseudo inverse operator \(H^\dagger\).denoi(torch.nn.Module): Image denoising operator \(\mathcal{G}_\theta\).recon_modules(torch.nn.Sequential): Reconstruction module. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator, i.e., \(\mathcal{G}_\theta\circ H^\dagger \circ B\).pinv_kwargs(dict): Optional keyword arguments passed to the pseudo inverse operator.- Input / Output:
input: Ground-truth images with shape \((b,c,h,w)\), with \(b\) being the batch size, \(c\) the number of channels, and \(h\) and \(w\) the height and width of the images.output: Reconstructed images with shape \((b,c,h,w)\).- References:
- Example:
>>> import spyrit.core.meas as meas >>> import spyrit.core.recon as recon >>> acqu = meas.HadamSplit2d(32) >>> pinv = recon.PinvNet(acqu) >>> x = torch.rand(10, 1, 32, 32) >>> y = pinv.acquire(x) >>> z = pinv.reconstruct(y) >>> print(y.shape) torch.Size([10, 1, 2048]) >>> print(z.shape) torch.Size([10, 1, 32, 32])
Same as above with preprocessing (unsplitting). Note the arguments that are passed to the pseudo inverse operator to work on the H matrix and reshape the output.
>>> import spyrit.core.meas as meas >>> import spyrit.core.recon as recon >>> import spyrit.core.prep as sprep >>> acqu = meas.HadamSplit2d(32) >>> prep = sprep.Unsplit() >>> pinv = recon.PinvNet(acqu, prep, use_fast_pinv=True, reshape_output=True) >>> x = torch.rand(10, 1, 32, 32) >>> y = pinv.acquire(x) >>> z = pinv.reconstruct(y) >>> print(y.shape) torch.Size([10, 1, 2048]) >>> print(z.shape) torch.Size([10, 1, 32, 32])
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Apply the reconstruction modules to the input measurements.
Reconstructs measurement vectors without denoising.