spyrit.core.recon.PinvNet
- class spyrit.core.recon.PinvNet(acqu: Linear, prep, denoi=Identity(), *, device: device = device(type='cpu'), **pinv_kwargs)[source]
Bases:
_PrebuiltFullNetPre-built
FullNetthat uses a pseudo inverse.As a
FullNet, this network has two modules: one for measurements and one for reconstruction.The measurement module only contains the acquisition operator. The reconstruction module contains a preprocessing operator, a pseudo inverse operator, and a denoising operator.
The optional keyword arguments passed at initialization are fed in the pseudo inverse operator. This way, the regularization can be controlled directly from the
PinvNetconstructor.- Args:
acqu (spyrit.core.meas): Acquisition operator (see
meas)prep (spyrit.core.prep): Preprocessing operator (see
prep)denoi (torch.nn.Module, optional): Image denoising operator. Default is
Identity.**pinv_kwargs: Optional keyword arguments passed to the pseudo inverse operator (see
PseudoInverse).- Attributes:
acqu(spyrit.core.meas): Acquisition operator.acqu_modules(nn.Sequential): Measurement modules. Only contains the acquisition operator.prep(spyrit.core.prep): Preprocessing operator.inv(spyrit.core.inverse.PseudoInverse): Pseudo inverse operator.denoi(torch.nn.Module): Image denoising operator.recon_modules(nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator.pinv_kwargs(dict): Optional keyword arguments passed to the pseudo inverse operator.- Input / Output:
input: Ground-truth images with shape \((b,c,h,w)\), with \(b\) being the batch size, \(c\) the number of channels, and \(h\) and \(w\) the height and width of the images.output: Reconstructed images with shape \((b,c,h,w)\).- Example:
>>> import spyrit >>> acqu = spyrit.core.meas.HadamSplit2d(32) >>> prep = spyrit.core.prep.Rescale(1.0) >>> pinv = PinvNet(acqu, prep, device=torch.device("cpu"))
- Example with a regularized pseudo inverse:
>>> import spyrit >>> noise_model = spyrit.core.noise.Poisson(100) >>> acqu = spyrit.core.meas.HadamSplit2d(32, noise_model=noise_model) >>> prep = spyrit.core.prep.Rescale(100) >>> pinv = PinvNet(acqu, prep, use_fast_pinv=False, store_H_pinv=True, regularization='H1', eta=1e-6, img_shape=(32, 32))
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Apply the reconstruction modules to the input measurements.
Reconstructs measurement vectors without denoising.