spyrit.core.recon.TikhoNet
- class spyrit.core.recon.TikhoNet(acqu: Linear, prep, sigma: tensor, denoi=Identity(), *, device: device = device(type='cpu'), **tikho_kwargs)[source]
Bases:
_PrebuiltFullNetPre-built
FullNetthat uses aTikhonovreconstruction operator.As a
FullNet, this network has two modules: one for measurements and one for reconstruction.The measurement module only contains the acquisition operator. The reconstruction module contains a preprocessing operator, a Tikhonov inverse operator, and a denoising operator.
The optional keyword arguments passed at initialization are fed in the
Tikhonovoperator. This way, the regularization can be controlled directly from theTikhoNetconstructor.- Args:
acqu(spyrit.core.meas): Acquisition operator (seemeas)prep(spyrit.core.prep): Preprocessing operator (seeprep)sigma(torch.tensor): Image-domain covariance prior (for details, see theTikhonov()class)denoi(torch.nn.Module, optional): Image denoising operator (seennet). DefaultIdentitykwargs(dict): Optional keyword arguments passed to theTikhonov()constructor. May contain the following keys:approx (bool): If True, the Tikhonov inversion step is approximated
using a diagonal matrix. Default is False. - reshape_output (bool): If True, the output of the Tikhonov inversion step is reshaped to match the acquisition operator input shape. Default is True.
- Attributes:
acqu: Acquisition operator initialized asacqu.acqu_modules(nn.Sequential): Measurement modules. Only contains the acquisition operator.prep: Preprocessing operator initialized asprep.tikho: Data consistency layer initialized asTikhonov(noise.meas_op, sigma).denoi: Image denoising operator initialized asdenoi.recon_modules(nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator.- Input / Output:
input(torch.tensor): Ground-truth images with shape \((b,c,h,w)\).output(torch.tensor): Reconstructed images with shape \((b,c,h,w)\).- Example 1:
# >>> noise = spyrit.core.noise.Poisson(100) # >>> acqu = spyrit.core.meas.HadamSplit2d(64, noise_model=noise) # >>> prep = spyrit.core.prep.Rescale(100) # >>> sigma = torch.ones(64, 64) # >>> tikho = TikhoNet(acqu, prep, sigma, device=torch.device(“cuda”)) # >>> x = torch.rand(10, 1, 64, 64) # >>> z = tikho(x) # >>> print(z.shape) # torch.Size([10, 1, 64, 64])
- Example 2:
# >>> noise = spyrit.core.noise.Gaussian(1.0) # >>> acqu = spyrit.core.meas.HadamSplit2d(64, noise_model=noise) # >>> prep = spyrit.core.prep.Rescale(1.0) # >>> sigma = torch.ones(64, 64) # >>> tikho = TikhoNet(acqu, prep, sigma, approx=True, reshape_output=False) # >>> x = torch.rand(10, 1, 64, 64) # >>> z = tikho(x) # >>> print(z.shape) # torch.Size([10, 1, 4096])
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Reconstruct an image from measurements.
Reconstruct an image from measurements without denoising.