spyrit.core.inverse.PseudoInverse

class spyrit.core.inverse.PseudoInverse(meas_op: Linear | DynamicLinear, regularization: str = 'rcond', *, store_H_pinv: bool = False, use_fast_pinv: bool = True, reshape_output: bool = True, **reg_kwargs)[source]

Bases: Module

Pseudoinverse.

Solves the linear problem \(Ax = B\), either by calling a linear solver or by computing the pseudo-inverse matrix of \(A\). This behavior is defined by the attribute store_H_pinv.

This class allows for regularization. Available regularizations are rcond (truncation of the singular values), L2 and H1.

Note

When store_H_pinv is True, additional parameters (such as regularization parameters) can be passed as keyword arguments to the class constructor.

Note

When store_pinv is False, additional parameters (such as regularization parameters) can be passed as keyword arguments to the forward method.

Args:

meas_op (spyrit.core.meas): Measurement operator.

regularization (str): Regularization methods among ‘rcond’, ‘L2’, or ‘H1’. Defaults to ‘rcond’.

Attributes:

store_H_pinv (bool): If False, the least squares solution is computed at each forward pass using the function torch.linalg.lstsq(). If True, computes and stores at initialization the pseudo-inverse of the measurement matrix using the function torch.linalg.pinv(). Default: False

use_fast_pinv (bool): If True, uses a fast computation of either the measurement matrix pseudo-inverse or the least squares solution. This only works if the measurement operator has a fast pseudo-inverse method. Default: True.

reshape_output (bool): If True, reshapes the output to the shape of the image using meas_op.unvectorize(). Default: True.

reg_kwargs: Additional keyword arguments that are passed to:

  • spyrit.core.torch.regularized_pinv() when store_pinv is True.

  • spyrit.core.torch.resularized_lstsq() when store_pinv is False. Here, ‘driver’ is set to ‘gels’ by default (see torch.linalg.lstsq()).

Attributes:

meas_op: Measurement operator initialized as meas_op.

regularization: Regularization method initialized as regularization.

store_H_pinv: Indicates if the pseudo-inverse is stored.

use_fast_pinv: Indicates if the fast pseudo-inverse is used.

reshape_output: Indicates if the output is reshaped.

reg_kwargs: Additional keyword arguments passed to either spyrit.core.torch.regularized_pinv() or torch.linalg.lstsq().

pinv: The pseudo-inverse of the measurement matrix. Computed only if store_H_pinv is True.

Example 1:
>>> from spyrit.core.meas import Linear
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = Linear(H)
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 2: LinearSplit, pseudo-inverse of H (default)
>>> from spyrit.core.meas import LinearSplit
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = LinearSplit(H)
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.measure_H(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 3: LinearSplit, pseudo-inverse of A
>>> from spyrit.core.meas import LinearSplit
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = LinearSplit(H)
>>> meas_op.set_matrix_to_inverse('A')
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])

Methods

forward(y)

Computes pseudo-inverse of measurements.