spyrit.core.recon.LearnedPGD
- class spyrit.core.recon.LearnedPGD(acqu: LinearSplit, prep: UnsplitRescale, denoi=Identity(), *, iter_stop=3, x0=0.5, step=None, step_estimation=False, step_grad=False, step_decay=1, wls=False, gt=None, log_fidelity=False, res_learn=False, **pinv_kwargs)[source]
Bases:
ModuleLearned Proximal Gradient Descent reconstruction network.
Iterative algorithm that alternates between a gradient step and a proximal step, where the proximal operator is replaced by a learned denoiser. The update rule is given by
\[x_{k+1} = \texttt{denoi}\left(x_k - \gamma \, H^T (Hx_k - m)\right)\]where \(x_k\in\mathbb{R}^N\) is the current estimate, \(\gamma\in\mathbb{R}\) is the step size, \(H\in\mathbb{R}^{M\times N}\) is the forward model, and \(m\in\mathbb{R}^{M}\) are the measurements.
- Args:
acqu: Acquisition operator (seemeas)prep: Preprocessing operator (seeprep)denoi(optional): Image denoising operator (seennet). DefaultIdentityiter_stop(int): Number of iterations of the LPGD algorithm (commonly 3 to 10, trade-off between accuracy and speed). Default 3 (for speed and with higher accuracy than post-processing denoising)step(float): Step size of the LPGD algorithm. Default is None, and it is estimated as the inverse of the Lipschitz constant of the gradient of the data fidelity term.If
meas_op.Nis available, the step size is estimated as \(\gamma=1/N\) which is true for Hadamard operators.If not, the step size is estimated by computing the Lipschitz constant as the largest singular value of the Hessian \(L=\lambda_{\max}(H^TH)\). If this fails, the step size is set to 1e-4.
step_estimation(bool): Default False. Seestepfor details.step_grad(bool): Default False. If True, the step size is learned as a parameter of the network. Not tested yet.wls(bool): Default False. If True, the data fidelity term is modified to be the weighted least squares (WLS) term, which approximates the Poisson likelihood. In this case, the data fidelity term is \(\|Hx-y\|^2_{C^{-1}}\), where \(C\) is the covariance matrix. We assume that \(C\) is diagonal, and the diagonal elements are the measurement noise variances, estimated fromsigma.gt(torch.tensor): Ground-truth images. If available, the mean squared error (MSE) is computed and logged. Default None.log_fidelity(bool): Default False. If True, the data fidelity term is logged for each iteration of the LPGD algorithm.- Input / Output:
input: Ground-truth images with shape \((B,C,H,W)\)output: Reconstructed images with shape \((B,C,H,W)\)- Attributes:
acqu: Acquisition operator initialized asacquprep: Preprocessing operator initialized aspreppinv: Analytical reconstruction operator initialized asPseudoInverse()denoi: Image denoising operator initialized asdenoi- Example:
>>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.prep import UnsplitRescale >>> from spyrit.core.recon import LearnedPGD >>> import torch
>>> acqu = HadamSplit2d(32, M=400) >>> prep = UnsplitRescale() >>> recnet = LearnedPGD(acqu, prep) >>> x = torch.FloatTensor(10,1,32,32).uniform_(-1, 1) >>> z = recnet(x) >>> print(z.shape) torch.Size([10, 1, 32, 32])
>>> y = torch.randn(10, 1, 800) >>> z = recnet.reconstruct(y) >>> print(z.shape) torch.Size([10, 1, 32, 32])
Methods
acquire(x)Simulate data acquisition
cost_fun(x, y)forward(x)Full pipeline of reconstruction network
mse_fun(x, x_gt)reconstruct(x)Reconstruction step of a reconstruction network
Reconstruction step of a reconstruction network
set_stepsize(step)step_schedule(step)