Source code for spyrit.core.inverse

"""
Inverse methods for inverse problems.
"""

from typing import Union
from typing import Union

import torch
import torch.nn as nn

import spyrit.core.meas as meas
import spyrit.core.torch as spytorch


# =============================================================================
[docs] class PseudoInverse(nn.Module): r"""Moore-Penrose pseudoinverse. This allows to solve the linear problem :math:`Ax = B`, by either computing the least-squares solution of the equation, or by computing the pseudo-inverse matrix of :math:`A`. This behavior is defined by the keyword parameter :attr:`store_H_pinv`. This class can also handle regularization in the computation of the least-squares solution or the matrix pseudo-inverse. The available regularization methods are `rcond` (which truncates the matrix's SVD below a certain threshold), `L2` and `H1`. .. note:: When :attr:`store_H_pinv` is `True`, additional parameters (such as regularization parameters) can be passed as keyword arguments to the class constructor. .. note:: When :attr:`store_pinv` is `False`, additional parameters (such as regularization parameters) can be passed as keyword arguments to the forward method of this class. Args: :attr:`meas_op`: Measurement operator. See :mod:`spyrit.core.meas`. :attr:`regularization` (str): Regularization method. Can be 'rcond', 'L2', or 'H1'. Default: 'rcond'. Keyword Args: :attr:`store_H_pinv` (bool): If False, the least squares solution is computed at each forward pass using the function :func:`torch.linalg.lstsq`. If True, computes and stores at initialization the pseudo-inverse of the measurement matrix using the function :func:`torch.linalg.pinv`. Default: False :attr:`use_fast_pinv` (bool): If True, uses a fast computation of either the measurement matrix pseudo-inverse or the least squares solution. This only works if the measurement operator has a fast pseudo-inverse method. Default: True. :attr:`reshape_output` (bool): If True, reshapes the output to the shape of the image using :meth:`meas_op.unvectorize`. Default: True. :attr:`reg_kwargs`: Additional keyword arguments that are passed to :func:`spyrit.core.torch.regularized_pinv` when :attr:`store_pinv` is True or to :func:`spyrit.core.torch.resularized_lstsq` when :attr:`store_pinv` is False. Attributes: :attr:`meas_op`: Measurement operator initialized as :attr:`meas_op`. :attr:`regularization`: Regularization method initialized as :attr:`regularization`. :attr:`store_H_pinv`: Indicates if the pseudo-inverse is stored. :attr:`use_fast_pinv`: Indicates if the fast pseudo-inverse is used. :attr:`reshape_output`: Indicates if the output is reshaped. :attr:`reg_kwargs`: Additional keyword arguments passed to the :func:`spyrit.core.torch.regularized_pinv` or :func:`torch.linalg.lstsq` functions. :attr:`pinv`: The pseudo-inverse of the measurement matrix. It is computed only if :attr:`store_H_pinv` is True. Example 1: >>> from spyrit.core.meas import Linear >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = Linear(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 2: LinearSplit, pseudo-inverse of H (default) >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op.measure_H(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 3: LinearSplit, pseudo-inverse of A >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> meas_op.set_matrix_to_inverse('A') >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) """ def __init__( self, meas_op: Union[meas.Linear, meas.DynamicLinear], regularization: str = "rcond", *, store_H_pinv: bool = False, use_fast_pinv: bool = True, reshape_output: bool = True, **reg_kwargs, ) -> None: super().__init__() self.meas_op = meas_op self.regularization = regularization self.store_H_pinv = store_H_pinv self.use_fast_pinv = use_fast_pinv self.reshape_output = reshape_output self.reg_kwargs = reg_kwargs if self.store_H_pinv: # do we have a fast pseudo-inverse computation available? if self.use_fast_pinv and hasattr(self.meas_op, "fast_H_pinv"): self.pinv = meas_op.fast_H_pinv() else: self.pinv = spytorch.regularized_pinv( self.meas_op.get_matrix_to_inverse, regularization, **reg_kwargs ) if type(meas_op) is meas.HadamSplit2d: self.reshape_output = False
[docs] def forward(self, y: torch.tensor) -> torch.tensor: r"""Computes pseudo-inverse of measurements. If :attr:`self.store_H_pinv` is True, computes the product of the stored pseudo-inverse and the measurements. If :attr:`self.store_H_pinv` is False, computes the least squares solution of the measurements. In this case, any additional keyword arguments passed to the :class:`PseudoInverse` constructor (and store in :attr:`self.reg_kwargs` are used here. These can include: - :attr:`rcond` (float): Cutoff for small singular values. It is used only when :attr:`regularization` is 'rcond'. This parameter is fed directly to :func:`torch.linalg.pinv`. - Any other keyword arguments that are passed to :func:`torch.linalg.lstsq`. Used only when :attr:`regularization` is 'rcond'. - :attr:`eta` (float): Regularization parameter. It is used only when :attr:`regularization` is 'L2' or 'H1'. This parameter determines the amount of regularization applied to the pseudo-inverse. Args: :attr:`y` (torch.tensor): Batch of measurement vectors of shape :math:`(*, M)`, where :math:`*` is any number of dimensions and :math:`M` is the number of measurements of the measurement operator (:attr:`meas_op.M`). Returns: :attr:`output` (torch.tensor): Batch of reconstructed images of shape :math:`(*, N)` or the image shape as defined in the measurement operator (in `meas_op.meas_shape`) depending on the value of :attr:`self.reshape_output`. Example 1: >>> from spyrit.core.meas import Linear >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = Linear(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 2: LinearSplit, pseudo-inverse of H (default) >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op.measure_H(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 3: LinearSplit, pseudo-inverse of A >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> meas_op.set_matrix_to_inverse('A') >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) """ if self.store_H_pinv: # Expand the pseudo-inverse to the batch size of y pinv = self.pinv.expand(*y.shape[:-1], *self.pinv.shape) y = y.unsqueeze(-1) y = torch.matmul(pinv, y) y = y.squeeze(-1) else: if self.use_fast_pinv and hasattr(self.meas_op, "fast_pinv"): y = self.meas_op.fast_pinv(y) else: y = spytorch.regularized_lstsq( self.meas_op.get_matrix_to_inverse, y, self.regularization, **self.reg_kwargs, ) if self.reshape_output: y = self.meas_op.unvectorize(y) return y
# =============================================================================
[docs] class Tikhonov(nn.Module): r"""Tikhonov regularization. It estimates the signal :math:`x\in\mathbb{R}^{N}` the from linear measurements :math:`y = Ax\in\mathbb{R}^{M}` corrupted by noise by solving .. math:: \| y - Ax \|^2_{\Gamma{-1}} + \|x\|^2_{\Sigma^{-1}}, where :math:`\Gamma` is covariance of the noise, and :math:`\Sigma` is the signal covariance. In the case :math:`M\le N`, the solution can be computed as .. math:: \hat{x} = \Sigma A^\top (A \Sigma A^\top + \Gamma)^{-1} y, where we assume that both covariance matrices are positive definite. The class is constructed from :math:`A` and :math:`\Sigma`, while :math:`\Gamma` is passed as an argument to :meth:`forward()`. Passing :math:`\Gamma` to :meth:`forward()` is useful in the presence of signal- dependent noise. .. note:: * :math:`x` can be a 1d signal or a vectorized image/volume. This can be specified by setting the :attr:`meas_shape` attribute of the measurement operator. * The above formulation assumes that the signal :math:`x` has zero mean. Args: - :attr:`meas_op` : Measurement operator (see :class:`~spyrit.core.meas`). - :attr:`sigma` : Signal (image) covariance prior, of shape :math:`(N, N)`. - :attr:`approx` : A boolean indicating whether to set the non-diagonal elements of :math:`A \Sigma A^T` to zero. Default is False. If True, this speeds up the computation of the inverse :math:`(A \Sigma A^T + \Sigma_\alpha)^{-1}`. - :attr:`reshape_output` : A boolean indicating whether to reshape the output to the shape of the image. Default is True. Attributes: - :attr:`meas_op` : Measurement operator initialized as :attr:`meas_op`. - :attr:`sigma` : Signal covariance prior initialized as :attr:`sigma`. - :attr:`approx` : Indicates if the diagonal approximation is used. - :attr:`reshape_output` : Indicates if the output is reshaped. - :attr:`img_shape` : Shape of the image, initialized as :attr:`meas_op.img_shape`. - :attr:`sigma_meas` : Measurement covariance prior initialized as :math:`A \Sigma A^T`. If :attr:`approx` is True, the non-diagonal elements are set to zero. It is pre-computed at initialization to speed up future computations. - :attr:`sigma_A_T` : Covariance of the missing measurements initialized as :math:`\Sigma A^T`. It is computed at initialization to speed up future computations. Example: >>> from spyrit.core.meas import Linear >>> B, H, M, N = 85, 17, 32, 64 >>> sigma = torch.rand(N, N) >>> gamma = torch.rand(M, M) >>> A = torch.rand([M,N]) >>> meas = Linear(A, meas_shape=(1,N)) >>> recon = Tikhonov(meas, sigma) >>> y = torch.rand(B,H,M) >>> x = recon(y, gamma) >>> print(y.shape) torch.Size([85, 17, 32]) >>> print(x.shape) torch.Size([85, 17, 1, 64]) """ def __init__( self, meas_op: meas.Linear, sigma: torch.tensor, approx=False, reshape_output: bool = True, ): super().__init__() self.meas_op = meas_op self.sigma = sigma self.approx = approx self.reshape_output = reshape_output self.img_shape = meas_op.meas_shape A = meas_op.get_matrix_to_inverse # get H or A # *measurement* covariance if approx: # store onle the diagonal sigma_meas = torch.einsum("ij,jk,ik->i", A, sigma, A) else: sigma_meas = A @ sigma @ A.T self.register_buffer("sigma_meas", sigma_meas) # estimation of the missing measurements sigma_A_T = torch.mm(sigma, A.mT) self.register_buffer("sigma_A_T", sigma_A_T)
[docs] def divide(self, y: torch.tensor, gamma: torch.tensor) -> torch.tensor: r"""Computes the division :math:`y \cdot (\Sigma \alpha + (A \Sigma A^T))^{-1}`. Measurements `y` are divided by the sum of the measurement covariance. If :attr:`self.approx` is True, the inverse is approximated as a diagonal matrix, speeding up the computation. Otherwise, the inverse is computed with the whole matrix. Args: y (torch.tensor): Input measurement tensor. Shape :math:`(*, M)`. gamma (torch.tensor): Noise covariance tensor. Shape :math:`(*, M, M)`. Returns: torch.tensor: The divided tensor. Shape :math:`(*, M)`. Example: >>> from spyrit.core.meas import Linear >>> M, N = 32, 64 >>> meas_op = Linear(torch.rand([M,N]), meas_shape=(1,N)) >>> sigma = torch.rand(N,N) >>> tikho = Tikhonov(meas_op, sigma, approx=False, reshape_output=True) >>> y = torch.rand(85, 3, M) >>> gamma = torch.eye(M).expand(85, 3, M, M) >>> print(tikho.divide(y, gamma).shape) torch.Size([85, 3, 32]) """ if self.approx: return y / (self.sigma_meas + torch.diagonal(gamma, dim1=-2, dim2=-1)) else: # we need to expand the matrices for the solve batch_shape = y.shape[:-1] expand_shape = batch_shape + (self.sigma_meas.shape) y = y.unsqueeze(-1) # add a dimension to y for batch matrix multiplications y = torch.linalg.solve((self.sigma_meas + gamma).expand(expand_shape), y) return y.squeeze(-1)
[docs] def forward(self, y: torch.tensor, gamma: torch.tensor) -> torch.tensor: r"""Reconstructs the signal from measurements and noise covariance. The Tikhonov solution is computed as .. math:: \hat{x} = B^\top (C + \Gamma)^{-1} y with :math:`B = \Sigma A^\top` and :math:`C = A \Sigma A^\top`. When :attr:`self.approx` is True, it is approximated as .. math:: \hat{x} = B^\top \frac{y}{\text{diag}(C + \Gamma)} Args: :attr:`y` (torch.tensor): A batch of measurement vectors :math:`y` of shape :math:`(*, M)`. :attr:`gamma` (torch.tensor): A batch of noise covariance :math:`\Gamma` of shape :math:`(*, M, M)`. Returns: (torch.tensor): A batch of reconstructed images of shape :math:`(*, N)` or the :meth:`meas_op.unvectorize`d version of the image shape. Example 1: With reshape_output = True >>> from spyrit.core.meas import Linear >>> M, N = 32, 64 >>> b, c, h, w = 85, 3, 8, 8 >>> meas_op = Linear(torch.rand([M,N]), meas_shape=(1,N)) >>> x = torch.randn(b,c,h,w) >>> y = meas_op(x) >>> sigma = torch.rand(N,N) >>> tikho = Tikhonov(meas_op, sigma, approx=False, reshape_output=True) >>> gamma = torch.eye(M).expand(b, c, M, M) >>> print(tikho(y, gamma).shape) torch.Size([85, 3, 1, 64]) Example 2: With reshape_output = False >>> from spyrit.core.meas import Linear >>> M, N = 32, 64 >>> b, c, h, w = 85, 3, 8, 8 >>> meas_op = Linear(torch.rand([M,N]), meas_shape=(1,N)) >>> x = torch.randn(b,c,h,w) >>> y = meas_op(x) >>> sigma = torch.rand(N,N) >>> tikho = Tikhonov(meas_op, sigma, approx=False, reshape_output=False) >>> gamma = torch.eye(M).expand(b, c, M, M) >>> print(tikho(y, gamma).shape) torch.Size([85, 3, 64]) """ y = self.divide(y, gamma) y = torch.matmul(self.sigma_A_T, y.unsqueeze(-1)).squeeze(-1) if self.reshape_output: y = self.meas_op.unvectorize(y) return y
# =============================================================================
[docs] class TikhonovMeasurementPriorDiag(nn.Module): r"""Tikhonov regularisation with prior in the measurement domain. Considering linear measurements :math:`m = Hx \in\mathbb{R}^M`, where :math:`H = GF` is the measurement matrix and :math:`x\in\mathbb{R}^N` is a vectorized image, it estimates :math:`x` from :math:`m` by approximately minimizing .. math:: \|m - GFx \|^2_{\Sigma^{-1}_\alpha} + \|F(x - x_0)\|^2_{\Sigma^{-1}} where :math:`x_0\in\mathbb{R}^N` is a mean image prior, :math:`\Sigma\in\mathbb{R}^{N\times N}` is a covariance prior, and :math:`\Sigma_\alpha\in\mathbb{R}^{M\times M}` is the measurement noise covariance. The matrix :math:`G\in\mathbb{R}^{M\times N}` is a subsampling matrix. .. note:: The class is instantiated from :math:`\Sigma`, which represents the covariance of :math:`Fx`. Args: :attr:`meas_op`: A Hadamard measurement operator (see :mod:`spyrit.core.meas.HadamSplit2d`). :attr:`sigma`: Measurement covariance prior with shape :math:`N` x :math:`N`. Attributes: :attr:`meas_op`: Measurement operator initialized as :attr:`meas_op`. :attr:`denoise_weights`: The learnable denoising layer initialized from :math:`\Sigma_1`. This layer is a :class:`nn.Parameter`. :attr:`comp`: The learnable matrix initialized from :math:`\Sigma_{21}`. This matrix is a :class:`nn.Parameter`. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> x0 = torch.rand([10, 3, 32, 32]) >>> var = torch.rand([10, 3, 400]) >>> x = recon_op(y, x0, var) >>> print(x.shape) torch.Size([10, 3, 32, 32]) """ def __init__( self, meas_op: meas.HadamSplit2d, sigma: torch.tensor, reshape_output: bool = False, ): super().__init__() self.meas_op = meas_op self.reshape_output = reshape_output M = self.meas_op.M var_prior = sigma.diag()[:M] self.denoise_weights = nn.Parameter(torch.sqrt(var_prior), requires_grad=False) Sigma1 = sigma[:M, :M] Sigma21 = sigma[M:, :M] W = torch.linalg.solve(Sigma1.T, Sigma21.T).T self.comp = nn.Parameter(W, requires_grad=False)
[docs] def wiener_denoise(self, x: torch.tensor, var: torch.tensor) -> torch.tensor: r"""Returns a denoised version of the input tensor using the variance prior. This uses the attribute self.denoise_weights, which is a learnable parameter. Inputs: x (:class:`torch.tensor`): The input tensor to be denoised. var (:class:`torch.tensor`): The variance prior. Returns: :class:`torch.tensor`: The denoised tensor. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> var = torch.rand([10, 3, 400]) >>> print(recon_op.wiener_denoise(y, var).shape) torch.Size([10, 3, 400]) """ weights_squared = self.denoise_weights**2 return torch.mul((weights_squared / (weights_squared + var)), x)
[docs] def forward_no_prior(self, x, var): r"""Computes the Tikhonov regularization with prior in the measurement domain. We approximate the solution as: .. math:: \hat{x} = x_0 + F^{-1} \begin{bmatrix} m_1 \\ m_2\end{bmatrix} with :math:`m_1 = D_1(D_1 + \Sigma_\alpha)^{-1} (m - GF x_0)` and :math:`m_2 = \Sigma_1 \Sigma_{21}^{-1} m_1`, where :math:`\Sigma = \begin{bmatrix} \Sigma_1 & \Sigma_{21}^\top \\ \Sigma_{21} & \Sigma_2\end{bmatrix}` and :math:`D_1 =\textrm{Diag}(\Sigma_1)`. Assuming the noise covariance :math:`\Sigma_\alpha` is diagonal, the matrix inversion involved in the computation of :math:`m_1` is straightforward. This is an approximation to the exact solution .. math:: \hat{x} &= x_0 + F^{-1}\begin{bmatrix}\Sigma_1 \\ \Sigma_{21} \end{bmatrix} [\Sigma_1 + \Sigma_\alpha]^{-1} (m - GF x_0). See Lemma B.0.5 of the PhD dissertation of A. Lorente Mur (2021): https://theses.hal.science/tel-03670825v1/file/these.pdf Args: :attr:`x` (:class:`torch.tensor`): A batch of measurement vectors :math:`m` of shape :math:`(*, M)`. :attr:`var` (:class:`torch.tensor`): A batch of measurement noise variances :math:`\Sigma_\alpha` of shape :math:`(*, M)`. Returns: :class:`torch.tensor`: Batch of reconstructed image of shape :math:`(*, \sqrt{N}, \sqrt{N})`. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> var = torch.rand([10, 3, 400]) >>> x = recon_op.forward_no_prior(y, var) >>> print(x.shape) torch.Size([10, 3, 32, 32]) """ y1 = self.wiener_denoise(x, var) y2 = y1 @ self.comp.T y = torch.cat((y1, y2), -1) y = self.meas_op.fast_pinv(y) # if self.reshape_output: # y = self.meas_op.unvectorize(y) return y
[docs] def forward( self, x: torch.tensor, x_0: torch.tensor, var: torch.tensor, ) -> torch.tensor: r"""Computes the Tikhonov regularization with prior in the measurement domain. This method, unlike the :meth:`forward_no_prior` method, allows for a non-zero mean image prior :math:`x_0`. We approximate the solution as: .. math:: \hat{x} = x_0 + F^{-1} \begin{bmatrix} m_1 \\ m_2\end{bmatrix} with :math:`m_1 = D_1(D_1 + \Sigma_\alpha)^{-1} (m - GF x_0)` and :math:`m_2 = \Sigma_1 \Sigma_{21}^{-1} m_1`, where :math:`\Sigma = \begin{bmatrix} \Sigma_1 & \Sigma_{21}^\top \\ \Sigma_{21} & \Sigma_2\end{bmatrix}` and :math:`D_1 =\textrm{Diag}(\Sigma_1)`. Assuming the noise covariance :math:`\Sigma_\alpha` is diagonal, the matrix inversion involved in the computation of :math:`m_1` is straightforward. This is an approximation to the exact solution .. math:: \hat{x} &= x_0 + F^{-1}\begin{bmatrix}\Sigma_1 \\ \Sigma_{21} \end{bmatrix} [\Sigma_1 + \Sigma_\alpha]^{-1} (m - GF x_0) See Lemma B.0.5 of the PhD dissertation of A. Lorente Mur (2021): https://theses.hal.science/tel-03670825v1/file/these.pdf Args: :attr:`x` (:class:`torch.tensor`): A batch of measurement vectors :math:`m` with shape :math:`(*, M)`. :attr:`x_0` (:class:`torch.tensor`): A batch of prior images :math:`x_0` with shape :math:`(*, \sqrt{N}, \sqrt{N})`. :attr:`var` (:class:`torch.tensor`): A batch of measurement noise variances :math:`\Sigma_\alpha` with shape :math:`(*, M)`. :attr:`meas_op` (:class:`torch.tensor`): A measurement operator that provides :math:`GF` and :math:`F^{-1}`. Output: (:class:`torch.tensor`): A batch of images with shape :math:`(*, \sqrt{N}, \sqrt{N})`. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> x0 = torch.rand([10, 3, 32, 32]) >>> var = torch.rand([10, 3, 400]) >>> x = recon_op(y, x0, var) >>> print(x.shape) torch.Size([10, 3, 32, 32]) """ x = x - self.meas_op.forward_H(x_0) x = self.forward_no_prior(x, var) x += x_0 return x