spyrit.core.recon.TikhoNet
- class spyrit.core.recon.TikhoNet(acqu: Linear, prep, sigma: tensor, denoi=Identity(), *, device: device = device(type='cpu'), **tikho_kwargs)[source]
Bases:
_PrebuiltFullNetPre-built
FullNetthat uses aTikhonovreconstruction operator.As a
FullNet, this network has two modules: one for measurements and one for reconstruction.The measurement module only contains the acquisition operator. The reconstruction module contains a preprocessing operator, a Tikhonov inverse operator, and a denoising operator.
This is a two-step reconstruction method [1]. The first step estimates the signal \(\tilde{x}\) from preprocessed measurements \(y\) by minimizing
\[\| y - Ax \|^2_{\Gamma^{-1}} + \|x\|^2_{\Sigma^{-1}}\]where \(A\) is the measurement matrix, \(\Gamma\) is the covariance of the noise, and \(\Sigma\) is the signal covariance prior. The solution is computed as
\[\tilde{x} = \Sigma A^\top (A \Sigma A^\top + \Gamma)^{-1} y\]The second step applies a learnable neural network \(\mathcal{G}_\theta\) to the output of the first step:
\[\hat{x} = \mathcal{G}_\theta(\tilde{x})\]where \(\theta\) are the learnable parameters of the neural network.
The optional keyword arguments passed at initialization are fed in the
Tikhonovoperator. This way, the regularization can be controlled directly from theTikhoNetconstructor.- References:
- Args:
acqu(spyrit.core.meas): Acquisition operator (seemeas)prep(spyrit.core.prep): Preprocessing operator (seeprep)sigma(torch.tensor): Image-domain covariance prior (for details, see theTikhonov()class)denoi(torch.nn.Module, optional): Image denoising operator (seennet). DefaultIdentitykwargs(dict): Optional keyword arguments passed to theTikhonov()constructor. May contain the following keys:approx(bool): If True, the Tikhonov inversion step is approximated
using a diagonal matrix. Default is False.
reshape_output(bool): If True, the output of the Tikhonov
inversion step is reshaped to match the acquisition operator input shape. Default is True.
- Attributes:
acqu: Acquisition operator initialized asacqu.acqu_modules(nn.Sequential): Measurement modules. Only contains the acquisition operator.prep: Preprocessing operator initialized asprep.tikho: Data consistency layer initialized asTikhonov(noise.meas_op, sigma).denoi: Image denoising operator initialized asdenoi.recon_modules(nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the pseudo inverse operator, and the denoising operator.- Input / Output:
input(torch.tensor): Ground-truth images with shape \((b,c,h,w)\).output(torch.tensor): Reconstructed images with shape \((b,c,h,w)\).- Example 1:
>>> import spyrit >>> noise = spyrit.core.noise.Poisson(100) >>> acqu = spyrit.core.meas.HadamSplit2d(8, noise_model=noise) >>> prep = spyrit.core.prep.UnsplitRescale() >>> sigma = torch.ones(64, 64) >>> tikho = TikhoNet(acqu, prep, sigma, device=torch.device("cpu")) >>> x = torch.rand(10, 1, 8, 8) >>> z = tikho(x) >>> print(z.shape) torch.Size([10, 1, 8, 8])
- Example 2:
>>> noise = spyrit.core.noise.Gaussian(1.0) >>> acqu = spyrit.core.meas.HadamSplit2d(8, noise_model=noise) >>> prep = spyrit.core.prep.UnsplitRescale() >>> sigma = torch.ones(64, 64) >>> tikho = TikhoNet(acqu, prep, sigma, approx=True, reshape_output=False) >>> x = torch.rand(10, 1, 8, 8) >>> z = tikho(x) >>> print(z.shape) torch.Size([10, 1, 64])
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Reconstruct an image from measurements.
Reconstruct an image from measurements without denoising.