spyrit.core.inverse.PseudoInverse

class spyrit.core.inverse.PseudoInverse(meas_op: Linear | DynamicLinear, regularization: str = 'rcond', *, store_H_pinv: bool = False, use_fast_pinv: bool = True, reshape_output: bool = True, **reg_kwargs)[source]

Bases: Module

Moore-Penrose pseudoinverse.

This allows to solve the linear problem \(Ax = B\), by either computing the least-squares solution of the equation, or by computing the pseudo-inverse matrix of \(A\). This behavior is defined by the keyword parameter store_H_pinv.

This class can also handle regularization in the computation of the least-squares solution or the matrix pseudo-inverse. The available regularization methods are rcond (which truncates the matrix’s SVD below a certain threshold), L2 and H1.

Note

When store_H_pinv is True, additional parameters (such as regularization parameters) can be passed as keyword arguments to the class constructor.

Note

When store_pinv is False, additional parameters (such as regularization parameters) can be passed as keyword arguments to the forward method of this class.

Args:

meas_op: Measurement operator. See spyrit.core.meas.

regularization (str): Regularization method. Can be ‘rcond’, ‘L2’, or ‘H1’. Default: ‘rcond’.

Keyword Args:

store_H_pinv (bool): If False, the least squares solution is computed at each forward pass using the function torch.linalg.lstsq(). If True, computes and stores at initialization the pseudo-inverse of the measurement matrix using the function torch.linalg.pinv(). Default: False

use_fast_pinv (bool): If True, uses a fast computation of either the measurement matrix pseudo-inverse or the least squares solution. This only works if the measurement operator has a fast pseudo-inverse method. Default: True.

reshape_output (bool): If True, reshapes the output to the shape of the image using meas_op.unvectorize(). Default: True.

reg_kwargs: Additional keyword arguments that are passed to spyrit.core.torch.regularized_pinv() when store_pinv is True or to spyrit.core.torch.resularized_lstsq() when store_pinv is False.

Attributes:

meas_op: Measurement operator initialized as meas_op.

regularization: Regularization method initialized as regularization.

store_H_pinv: Indicates if the pseudo-inverse is stored.

use_fast_pinv: Indicates if the fast pseudo-inverse is used.

reshape_output: Indicates if the output is reshaped.

reg_kwargs: Additional keyword arguments passed to the spyrit.core.torch.regularized_pinv() or torch.linalg.lstsq() functions.

pinv: The pseudo-inverse of the measurement matrix. It is computed only if store_H_pinv is True.

Example 1:
>>> from spyrit.core.meas import Linear
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = Linear(H)
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 2: LinearSplit, pseudo-inverse of H (default)
>>> from spyrit.core.meas import LinearSplit
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = LinearSplit(H)
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.measure_H(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 3: LinearSplit, pseudo-inverse of A
>>> from spyrit.core.meas import LinearSplit
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = LinearSplit(H)
>>> meas_op.set_matrix_to_inverse('A')
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])

Methods

forward(y)

Computes pseudo-inverse of measurements.