spyrit.core.recon.LearnedPGD

class spyrit.core.recon.LearnedPGD(acqu: LinearSplit, prep: UnsplitRescale, denoi=Identity(), *, iter_stop=3, x0=0.5, step=None, step_estimation=False, step_grad=False, step_decay=1, wls=False, gt=None, log_fidelity=False, res_learn=False, **pinv_kwargs)[source]

Bases: Module

Learned Proximal Gradient Descent reconstruction network.

Iterative algorithm that alternates between a gradient step and a proximal step, where the proximal operator is replaced by a learned denoiser. The update rule is given by

\[x_{k+1} = \texttt{denoi}\left(x_k - \gamma \, H^T (Hx_k - m)\right)\]

where \(x_k\in\mathbb{R}^N\) is the current estimate, \(\gamma\in\mathbb{R}\) is the step size, \(H\in\mathbb{R}^{M\times N}\) is the forward model, and \(m\in\mathbb{R}^{M}\) are the measurements.

Args:

acqu: Acquisition operator (see meas)

prep: Preprocessing operator (see prep)

denoi (optional): Image denoising operator (see nnet). Default Identity

iter_stop (int): Number of iterations of the LPGD algorithm (commonly 3 to 10, trade-off between accuracy and speed). Default 3 (for speed and with higher accuracy than post-processing denoising)

step (float): Step size of the LPGD algorithm. Default is None, and it is estimated as the inverse of the Lipschitz constant of the gradient of the data fidelity term.

  • If meas_op.N is available, the step size is estimated as \(\gamma=1/N\) which is true for Hadamard operators.

  • If not, the step size is estimated by computing the Lipschitz constant as the largest singular value of the Hessian \(L=\lambda_{\max}(H^TH)\). If this fails, the step size is set to 1e-4.

step_estimation (bool): Default False. See step for details.

step_grad (bool): Default False. If True, the step size is learned as a parameter of the network. Not tested yet.

wls (bool): Default False. If True, the data fidelity term is modified to be the weighted least squares (WLS) term, which approximates the Poisson likelihood. In this case, the data fidelity term is \(\|Hx-y\|^2_{C^{-1}}\), where \(C\) is the covariance matrix. We assume that \(C\) is diagonal, and the diagonal elements are the measurement noise variances, estimated from sigma.

gt (torch.tensor): Ground-truth images. If available, the mean squared error (MSE) is computed and logged. Default None.

log_fidelity (bool): Default False. If True, the data fidelity term is logged for each iteration of the LPGD algorithm.

Input / Output:

input: Ground-truth images with shape \((B,C,H,W)\)

output: Reconstructed images with shape \((B,C,H,W)\)

Attributes:

acqu: Acquisition operator initialized as acqu

prep: Preprocessing operator initialized as prep

pinv: Analytical reconstruction operator initialized as PseudoInverse()

denoi: Image denoising operator initialized as denoi

Example:
>>> from spyrit.core.meas import HadamSplit2d
>>> from spyrit.core.prep import UnsplitRescale
>>> from spyrit.core.recon import LearnedPGD
>>> import torch
>>> acqu = HadamSplit2d(32, M=400)
>>> prep = UnsplitRescale()
>>> recnet = LearnedPGD(acqu, prep)
>>> x = torch.FloatTensor(10,1,32,32).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 32, 32])
>>> y = torch.randn(10, 1, 800)
>>> z = recnet.reconstruct(y)
>>> print(z.shape)
torch.Size([10, 1, 32, 32])

Methods

acquire(x)

Simulate data acquisition

cost_fun(x, y)

forward(x)

Full pipeline of reconstruction network

hessian_sv()

mse_fun(x, x_gt)

reconstruct(x)

Reconstruction step of a reconstruction network

reconstruct_expe(x)

Reconstruction step of a reconstruction network

set_stepsize(step)

step_schedule(step)

stepsize_gd()