spyrit.core.recon.LearnedPGD

class spyrit.core.recon.LearnedPGD(noise, prep, denoi=Identity(), iter_stop=3, x0=0, step=None, step_estimation=False, step_grad=False, step_decay=1, wls=False, gt=None, log_fidelity=False, res_learn=False)[source]

Bases: Module

Learned Proximal Gradient Descent reconstruction network. Iterative algorithm that alternates between a gradient step and a proximal step, where the proximal operator is learned denoiser. The update rule is given by:

\(x_{k+1} = prox(\hat{x_k} - step * H^T (Hx_k - y))= denoi(\hat{x_k} - step * H^T (Hx_k - y))\)

Args:

noise: Acquisition operator (see noise)

prep: Preprocessing operator (see prep)

denoi (optional): Image denoising operator (see nnet). Default Identity

iter_stop (int): Number of iterations of the LPGD algorithm (commonly 3 to 10, trade-off between accuracy and speed). Default 3 (for speed and with higher accuracy than post-processing denoising)

step (float): Step size of the LPGD algorithm. Default is None, and it is estimated as the inverse of the Lipschitz constant of the gradient of the data fidelity term.

  • If \(meas_op.N\) is available, the step size is estimated as

\(step=1/L=1/\text{meas_op.N}\), true for Hadamard operators. - If not, the step size is estimated from by computing the Lipschitz constant as the largest singular value of the Hessians, \(L=\lambda_{\max}(H^TH)\). If this fails, the step size is set to 1e-4.

step_estimation (bool): Default False. See step for details.

step_grad (bool): Default False. If True, the step size is learned as a parameter of the network. Not tested yet.

wls (bool): Default False. If True, the data fidelity term is modified to be the weighted least squares (WLS) term, which approximates the Poisson likelihood. In this case, the data fidelity term is \(\|Hx-y\|^2_{C^{-1}}\), where \(C\) is the covariance matrix. We assume that \(C\) is diagonal, and the diagonal elements are the measurement noise variances, estimated from sigma.

gt (torch.tensor): Ground-truth images. If available, the mean squared error (MSE) is computed and logged. Default None.

log_fidelity (bool): Default False. If True, the data fidelity term is logged for each iteration of the LPGD algorithm.

Input / Output:

input: Ground-truth images with shape \((B,C,H,W)\)

output: Reconstructed images with shape \((B,C,H,W)\)

Attributes:

Acq: Acquisition operator initialized as noise

prep: Preprocessing operator initialized as prep

pinv: Analytical reconstruction operator initialized as PseudoInverse()

Denoi: Image denoising operator initialized as denoi

Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = LearnedPGD(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
tensor(5.8912e-06)

Methods

acquire(x)

Simulate data acquisition

cost_fun(x, y)

forward(x)

Full pipeline of reconstruction network

hessian_sv()

mse_fun(x, x_gt)

reconstruct(x)

Reconstruction step of a reconstruction network

reconstruct_expe(x)

Reconstruction step of a reconstruction network

set_stepsize(step)

step_schedule(step)

stepsize_gd()