Source code for spyrit.core.inverse

"""
Inverse methods for inverse problems.
"""

from typing import Union

import torch
import torch.nn as nn

import spyrit.core.meas as meas
import spyrit.core.torch as spytorch


# =============================================================================
[docs] class PseudoInverse(nn.Module): r"""Pseudoinverse. Solves the linear problem :math:`Ax = B`, either by calling a linear solver or by computing the pseudo-inverse matrix of :math:`A`. This behavior is defined by the attribute :attr:`store_H_pinv`. This class allows for regularization. Available regularizations are `rcond` (truncation of the singular values), `L2` and `H1`. .. note:: When :attr:`store_H_pinv` is `True`, additional parameters (such as regularization parameters) can be passed as keyword arguments to the class constructor. .. note:: When :attr:`store_pinv` is `False`, additional parameters (such as regularization parameters) can be passed as keyword arguments to the forward method. Args: :attr:`meas_op` (:mod:`spyrit.core.meas`): Measurement operator. :attr:`regularization` (str): Regularization methods among 'rcond', 'L2', or 'H1'. Defaults to 'rcond'. Attributes: :attr:`store_H_pinv` (bool): If False, the least squares solution is computed at each forward pass using the function :func:`torch.linalg.lstsq`. If True, computes and stores at initialization the pseudo-inverse of the measurement matrix using the function :func:`torch.linalg.pinv`. Default: False :attr:`use_fast_pinv` (bool): If True, uses a fast computation of either the measurement matrix pseudo-inverse or the least squares solution. This only works if the measurement operator has a fast pseudo-inverse method. Default: True. :attr:`reshape_output` (bool): If True, reshapes the output to the shape of the image using :meth:`meas_op.unvectorize`. Default: True. :attr:`reg_kwargs`: Additional keyword arguments that are passed to: - :func:`spyrit.core.torch.regularized_pinv` when :attr:`store_pinv` is True. - :func:`spyrit.core.torch.resularized_lstsq` when :attr:`store_pinv` is False. Here, 'driver' is set to 'gels' by default (see :func:`torch.linalg.lstsq`). Attributes: :attr:`meas_op`: Measurement operator initialized as :attr:`meas_op`. :attr:`regularization`: Regularization method initialized as :attr:`regularization`. :attr:`store_H_pinv`: Indicates if the pseudo-inverse is stored. :attr:`use_fast_pinv`: Indicates if the fast pseudo-inverse is used. :attr:`reshape_output`: Indicates if the output is reshaped. :attr:`reg_kwargs`: Additional keyword arguments passed to either :func:`spyrit.core.torch.regularized_pinv` or :func:`torch.linalg.lstsq`. :attr:`pinv`: The pseudo-inverse of the measurement matrix. Computed only if :attr:`store_H_pinv` is True. Example 1: >>> from spyrit.core.meas import Linear >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = Linear(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 2: LinearSplit, pseudo-inverse of H (default) >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op.measure_H(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 3: LinearSplit, pseudo-inverse of A >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> meas_op.set_matrix_to_inverse('A') >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) """ def __init__( self, meas_op: Union[meas.Linear, meas.DynamicLinear], regularization: str = "rcond", *, store_H_pinv: bool = False, use_fast_pinv: bool = True, reshape_output: bool = True, **reg_kwargs, ) -> None: super().__init__() self.meas_op = meas_op self.regularization = regularization self.store_H_pinv = store_H_pinv self.use_fast_pinv = use_fast_pinv self.reshape_output = reshape_output # only for torch.linalg.lstsq if not self.store_H_pinv: defaultKwargs = {"driver": "gels"} reg_kwargs = defaultKwargs | reg_kwargs self.reg_kwargs = reg_kwargs if self.store_H_pinv: # do we have a fast pseudo-inverse computation available? if self.use_fast_pinv and hasattr(self.meas_op, "fast_H_pinv"): self.pinv = meas_op.fast_H_pinv() else: self.pinv = spytorch.regularized_pinv( self.meas_op.get_matrix_to_inverse, regularization, **reg_kwargs ) if type(meas_op) is meas.HadamSplit2d: self.reshape_output = False
[docs] def forward(self, y: torch.tensor) -> torch.tensor: r"""Computes pseudo-inverse of measurements. If :attr:`self.store_H_pinv` is True, computes the product of the stored pseudo-inverse and the measurements. If :attr:`self.store_H_pinv` is False, computes the least squares solution of the measurements. In this case, any additional keyword arguments passed to the :class:`PseudoInverse` constructor (and store in :attr:`self.reg_kwargs` are used here. These can include: - :attr:`rcond` (float): Cutoff for small singular values. It is used only when :attr:`regularization` is 'rcond'. This parameter is fed directly to :func:`torch.linalg.pinv`. - Any other keyword arguments that are passed to :func:`torch.linalg.lstsq`. Used only when :attr:`regularization` is 'rcond'. - :attr:`eta` (float): Regularization parameter. It is used only when :attr:`regularization` is 'L2' or 'H1'. This parameter determines the amount of regularization applied to the pseudo-inverse. Args: :attr:`y` (torch.tensor): Batch of measurement vectors of shape :math:`(*, M)`, where :math:`*` is any number of dimensions and :math:`M` is the number of measurements of the measurement operator (:attr:`meas_op.M`). Returns: :attr:`output` (torch.tensor): Batch of reconstructed images of shape :math:`(*, N)` or the image shape as defined in the measurement operator (in `meas_op.meas_shape`) depending on the value of :attr:`self.reshape_output`. Example 1: >>> from spyrit.core.meas import Linear >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = Linear(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 2: LinearSplit, pseudo-inverse of H (default) >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op.measure_H(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) Example 3: LinearSplit, pseudo-inverse of A >>> from spyrit.core.meas import LinearSplit >>> from spyrit.core.inverse import PseudoInverse >>> H = torch.randn(10, 15) >>> meas_op = LinearSplit(H) >>> meas_op.set_matrix_to_inverse('A') >>> pinv_op = PseudoInverse(meas_op) >>> x = torch.randn(3, 4, 15) >>> y = meas_op(x) >>> x = pinv_op(y) >>> print(x.shape) torch.Size([3, 4, 15]) """ if self.store_H_pinv: # Expand the pseudo-inverse to the batch size of y pinv = self.pinv.expand(*y.shape[:-1], *self.pinv.shape) y = y.unsqueeze(-1) y = torch.matmul(pinv, y) y = y.squeeze(-1) else: if self.use_fast_pinv and hasattr(self.meas_op, "fast_pinv"): y = self.meas_op.fast_pinv(y) else: y = spytorch.regularized_lstsq( self.meas_op.get_matrix_to_inverse, y, self.regularization, **self.reg_kwargs, ) if self.reshape_output: y = self.meas_op.unvectorize(y) return y
# =============================================================================
[docs] class Tikhonov(nn.Module): r"""Tikhonov regularization. It estimates the signal :math:`x\in\mathbb{R}^{N}` the from linear measurements :math:`y = Ax\in\mathbb{R}^{M}` corrupted by noise by solving .. math:: \| y - Ax \|^2_{\Gamma{-1}} + \|x\|^2_{\Sigma^{-1}}, where :math:`\Gamma` is covariance of the noise, and :math:`\Sigma` is the signal covariance. In the case :math:`M\le N`, the solution can be computed as .. math:: \hat{x} = \Sigma A^\top (A \Sigma A^\top + \Gamma)^{-1} y, where we assume that both covariance matrices are positive definite. The class is constructed from :math:`A` and :math:`\Sigma`, while :math:`\Gamma` is passed as an argument to :meth:`forward()`. Passing :math:`\Gamma` to :meth:`forward()` is useful in the presence of signal- dependent noise. .. note:: * :math:`x` can be a 1d signal or a vectorized image/volume. This can be specified by setting the :attr:`meas_shape` attribute of the measurement operator. * The above formulation assumes that the signal :math:`x` has zero mean. Args: :attr:`meas_op` : Measurement operator (see :class:`~spyrit.core.meas`). :attr:`sigma` : Signal (image) covariance prior, of shape :math:`(N, N)`. :attr:`approx` : A boolean indicating whether to set the non-diagonal elements of :math:`A \Sigma A^T` to zero. Default is False. If True, this speeds up the computation of the inverse :math:`(A \Sigma A^T + \Sigma_\alpha)^{-1}`. :attr:`reshape_output` : A boolean indicating whether to reshape the output to the shape of the image. Default is True. Attributes: :attr:`meas_op` : Measurement operator initialized as :attr:`meas_op`. :attr:`sigma` : Signal covariance prior initialized as :attr:`sigma`. :attr:`approx` : Indicates if the diagonal approximation is used. :attr:`reshape_output` : Indicates if the output is reshaped. :attr:`img_shape` : Shape of the image, initialized as :attr:`meas_op.img_shape`. :attr:`sigma_meas` : Measurement covariance prior initialized as :math:`A \Sigma A^T`. If :attr:`approx` is True, the non-diagonal elements are set to zero. It is pre-computed at initialization to speed up future computations. :attr:`sigma_A_T` : Covariance of the missing measurements initialized as :math:`\Sigma A^T`. It is computed at initialization to speed up future computations. Example: >>> from spyrit.core.meas import Linear >>> B, H, M, N = 85, 17, 32, 64 >>> sigma = torch.rand(N, N) >>> gamma = torch.rand(M, M) >>> A = torch.rand([M,N]) >>> meas = Linear(A, meas_shape=(1,N)) >>> recon = Tikhonov(meas, sigma) >>> y = torch.rand(B,H,M) >>> x = recon(y, gamma) >>> print(y.shape) torch.Size([85, 17, 32]) >>> print(x.shape) torch.Size([85, 17, 1, 64]) """ def __init__( self, meas_op: meas.Linear, sigma: torch.tensor, approx=False, reshape_output: bool = True, ): super().__init__() self.meas_op = meas_op self.sigma = sigma self.approx = approx self.reshape_output = reshape_output self.img_shape = meas_op.meas_shape A = meas_op.get_matrix_to_inverse # get H or A # *measurement* covariance if approx: # store onle the diagonal sigma_meas = torch.einsum("ij,jk,ik->i", A, sigma, A) else: sigma_meas = A @ sigma @ A.T self.register_buffer("sigma_meas", sigma_meas) # estimation of the missing measurements sigma_A_T = torch.mm(sigma, A.mT) self.register_buffer("sigma_A_T", sigma_A_T)
[docs] def divide(self, y: torch.tensor, gamma: torch.tensor) -> torch.tensor: r"""Computes the division :math:`y \cdot (\Sigma \alpha + (A \Sigma A^T))^{-1}`. Measurements `y` are divided by the sum of the measurement covariance. If :attr:`self.approx` is True, the inverse is approximated as a diagonal matrix, speeding up the computation. Otherwise, the inverse is computed with the whole matrix. Args: :attr:`y` (torch.tensor): Input measurement tensor. Shape :math:`(*, M)`. :attr:`gamma` (torch.tensor): Noise covariance tensor. Shape :math:`(*, M, M)`. Returns: torch.tensor: The divided tensor. Shape :math:`(*, M)`. Example: >>> from spyrit.core.meas import Linear >>> M, N = 32, 64 >>> meas_op = Linear(torch.rand([M,N]), meas_shape=(1,N)) >>> sigma = torch.rand(N,N) >>> tikho = Tikhonov(meas_op, sigma, approx=False, reshape_output=True) >>> y = torch.rand(85, 3, M) >>> gamma = torch.eye(M).expand(85, 3, M, M) >>> print(tikho.divide(y, gamma).shape) torch.Size([85, 3, 32]) """ if self.approx: return y / (self.sigma_meas + torch.diagonal(gamma, dim1=-2, dim2=-1)) else: # we need to expand the matrices for the solve batch_shape = y.shape[:-1] expand_shape = batch_shape + (self.sigma_meas.shape) y = y.unsqueeze(-1) # add a dimension to y for batch matrix multiplications y = torch.linalg.solve((self.sigma_meas + gamma).expand(expand_shape), y) return y.squeeze(-1)
[docs] def forward(self, y: torch.tensor, gamma: torch.tensor) -> torch.tensor: r"""Reconstructs the signal from measurements and noise covariance. The Tikhonov solution is computed as .. math:: \hat{x} = B^\top (C + \Gamma)^{-1} y with :math:`B = \Sigma A^\top` and :math:`C = A \Sigma A^\top`. When :attr:`self.approx` is True, it is approximated as .. math:: \hat{x} = B^\top \frac{y}{\text{diag}(C + \Gamma)} Args: :attr:`y` (torch.tensor): A batch of measurement vectors :math:`y` of shape :math:`(*, M)`. :attr:`gamma` (torch.tensor): A batch of noise covariance :math:`\Gamma` of shape :math:`(*, M, M)`. Returns: (torch.tensor): A batch of reconstructed images of shape :math:`(*, N)` or the :math:`meas\_op.unvectorize` version of the image shape. Example 1: With reshape_output = True >>> from spyrit.core.meas import Linear >>> M, N = 32, 64 >>> b, c, h, w = 85, 3, 8, 8 >>> meas_op = Linear(torch.rand([M,N]), meas_shape=(1,N)) >>> x = torch.randn(b,c,h,w) >>> y = meas_op(x) >>> sigma = torch.rand(N,N) >>> tikho = Tikhonov(meas_op, sigma, approx=False, reshape_output=True) >>> gamma = torch.eye(M).expand(b, c, M, M) >>> print(tikho(y, gamma).shape) torch.Size([85, 3, 1, 64]) Example 2: With reshape_output = False >>> from spyrit.core.meas import Linear >>> M, N = 32, 64 >>> b, c, h, w = 85, 3, 8, 8 >>> meas_op = Linear(torch.rand([M,N]), meas_shape=(1,N)) >>> x = torch.randn(b,c,h,w) >>> y = meas_op(x) >>> sigma = torch.rand(N,N) >>> tikho = Tikhonov(meas_op, sigma, approx=False, reshape_output=False) >>> gamma = torch.eye(M).expand(b, c, M, M) >>> print(tikho(y, gamma).shape) torch.Size([85, 3, 64]) """ y = self.divide(y, gamma) y = torch.matmul(self.sigma_A_T, y.unsqueeze(-1)).squeeze(-1) if self.reshape_output: y = self.meas_op.unvectorize(y) return y
# =============================================================================
[docs] class TikhonovMeasurementPriorDiag(nn.Module): r"""Tikhonov regularisation with prior in the measurement domain. Considering linear measurements :math:`m = Hx \in\mathbb{R}^M`, where :math:`H = GF` is the measurement matrix and :math:`x\in\mathbb{R}^N` is a vectorized image, it estimates :math:`x` from :math:`m` by approximately minimizing .. math:: \|m - GFx \|^2_{\Sigma^{-1}_\alpha} + \|F(x - x_0)\|^2_{\Sigma^{-1}} where :math:`x_0\in\mathbb{R}^N` is a mean image prior, :math:`\Sigma\in\mathbb{R}^{N\times N}` is a covariance prior, and :math:`\Sigma_\alpha\in\mathbb{R}^{M\times M}` is the measurement noise covariance. The matrix :math:`G\in\mathbb{R}^{M\times N}` is a subsampling matrix. .. note:: The class is instantiated from :math:`\Sigma`, which represents the covariance of :math:`Fx`. Args: :attr:`meas_op`: A Hadamard measurement operator (see :mod:`spyrit.core.meas.HadamSplit2d`). :attr:`sigma`: Measurement covariance prior with shape :math:`N` x :math:`N`. Attributes: :attr:`meas_op`: Measurement operator initialized as :attr:`meas_op`. :attr:`denoise_weights`: The learnable denoising layer initialized from :math:`\Sigma_1`. This layer is a :class:`nn.Parameter`. :attr:`comp`: The learnable matrix initialized from :math:`\Sigma_{21}`. This matrix is a :class:`nn.Parameter`. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> x0 = torch.rand([10, 3, 32, 32]) >>> var = torch.rand([10, 3, 400]) >>> x = recon_op(y, x0, var) >>> print(x.shape) torch.Size([10, 3, 32, 32]) """ def __init__( self, meas_op: meas.HadamSplit2d, sigma: torch.tensor, reshape_output: bool = False, ): super().__init__() self.meas_op = meas_op self.reshape_output = reshape_output M = self.meas_op.M var_prior = sigma.diag()[:M] self.denoise_weights = nn.Parameter(torch.sqrt(var_prior), requires_grad=False) Sigma1 = sigma[:M, :M] Sigma21 = sigma[M:, :M] W = torch.linalg.solve(Sigma1.T, Sigma21.T).T self.comp = nn.Parameter(W, requires_grad=False)
[docs] def wiener_denoise(self, x: torch.tensor, var: torch.tensor) -> torch.tensor: r"""Returns a denoised version of the input tensor using the variance prior. This uses the attribute self.denoise_weights, which is a learnable parameter. Inputs: x (:class:`torch.tensor`): The input tensor to be denoised. var (:class:`torch.tensor`): The variance prior. Returns: :class:`torch.tensor`: The denoised tensor. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> var = torch.rand([10, 3, 400]) >>> print(recon_op.wiener_denoise(y, var).shape) torch.Size([10, 3, 400]) """ weights_squared = self.denoise_weights**2 return torch.mul((weights_squared / (weights_squared + var)), x)
[docs] def forward_no_prior(self, x, var): r"""Computes the Tikhonov regularization with prior in the measurement domain. We approximate the solution as: .. math:: \hat{x} = x_0 + F^{-1} \begin{bmatrix} m_1 \\ m_2\end{bmatrix} with :math:`m_1 = D_1(D_1 + \Sigma_\alpha)^{-1} (m - GF x_0)` and :math:`m_2 = \Sigma_1 \Sigma_{21}^{-1} m_1`, where :math:`\Sigma = \begin{bmatrix} \Sigma_1 & \Sigma_{21}^\top \\ \Sigma_{21} & \Sigma_2\end{bmatrix}` and :math:`D_1 =\textrm{Diag}(\Sigma_1)`. Assuming the noise covariance :math:`\Sigma_\alpha` is diagonal, the matrix inversion involved in the computation of :math:`m_1` is straightforward. This is an approximation to the exact solution .. math:: \hat{x} &= x_0 + F^{-1}\begin{bmatrix}\Sigma_1 \\ \Sigma_{21} \end{bmatrix} [\Sigma_1 + \Sigma_\alpha]^{-1} (m - GF x_0). See Lemma B.0.5 of the PhD dissertation of A. Lorente Mur (2021): https://theses.hal.science/tel-03670825v1/file/these.pdf Args: :attr:`x` (:class:`torch.tensor`): A batch of measurement vectors :math:`m` of shape :math:`(*, M)`. :attr:`var` (:class:`torch.tensor`): A batch of measurement noise variances :math:`\Sigma_\alpha` of shape :math:`(*, M)`. Returns: :class:`torch.tensor`: Batch of reconstructed image of shape :math:`(*, \sqrt{N}, \sqrt{N})`. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> var = torch.rand([10, 3, 400]) >>> x = recon_op.forward_no_prior(y, var) >>> print(x.shape) torch.Size([10, 3, 32, 32]) """ y1 = self.wiener_denoise(x, var) y2 = y1 @ self.comp.T y = torch.cat((y1, y2), -1) y = self.meas_op.fast_pinv(y) # if self.reshape_output: # y = self.meas_op.unvectorize(y) return y
[docs] def forward( self, x: torch.tensor, x_0: torch.tensor, var: torch.tensor, ) -> torch.tensor: r"""Computes the Tikhonov regularization with prior in the measurement domain. This method, unlike the :meth:`forward_no_prior` method, allows for a non-zero mean image prior :math:`x_0`. We approximate the solution as: .. math:: \hat{x} = x_0 + F^{-1} \begin{bmatrix} m_1 \\ m_2\end{bmatrix} with :math:`m_1 = D_1(D_1 + \Sigma_\alpha)^{-1} (m - GF x_0)` and :math:`m_2 = \Sigma_1 \Sigma_{21}^{-1} m_1`, where :math:`\Sigma = \begin{bmatrix} \Sigma_1 & \Sigma_{21}^\top \\ \Sigma_{21} & \Sigma_2\end{bmatrix}` and :math:`D_1 =\textrm{Diag}(\Sigma_1)`. Assuming the noise covariance :math:`\Sigma_\alpha` is diagonal, the matrix inversion involved in the computation of :math:`m_1` is straightforward. This is an approximation to the exact solution .. math:: \hat{x} &= x_0 + F^{-1}\begin{bmatrix}\Sigma_1 \\ \Sigma_{21} \end{bmatrix} [\Sigma_1 + \Sigma_\alpha]^{-1} (m - GF x_0) See Lemma B.0.5 of the PhD dissertation of A. Lorente Mur (2021): https://theses.hal.science/tel-03670825v1/file/these.pdf Args: :attr:`x` (:class:`torch.tensor`): A batch of measurement vectors :math:`m` with shape :math:`(*, M)`. :attr:`x_0` (:class:`torch.tensor`): A batch of prior images :math:`x_0` with shape :math:`(*, \sqrt{N}, \sqrt{N})`. :attr:`var` (:class:`torch.tensor`): A batch of measurement noise variances :math:`\Sigma_\alpha` with shape :math:`(*, M)`. :attr:`meas_op` (:class:`torch.tensor`): A measurement operator that provides :math:`GF` and :math:`F^{-1}`. Output: (:class:`torch.tensor`): A batch of images with shape :math:`(*, \sqrt{N}, \sqrt{N})`. Example: >>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.inverse import TikhonovMeasurementPriorDiag >>> import torch >>> acqu = HadamSplit2d(32, 400) >>> sigma = torch.rand([32*32, 32*32]) >>> recon_op = TikhonovMeasurementPriorDiag(acqu, sigma) >>> y = torch.rand([10, 3, 400]) >>> x0 = torch.rand([10, 3, 32, 32]) >>> var = torch.rand([10, 3, 400]) >>> x = recon_op(y, x0, var) >>> print(x.shape) torch.Size([10, 3, 32, 32]) """ x = x - self.meas_op.forward_H(x_0) x = self.forward_no_prior(x, var) x += x_0 return x