spyrit.core.inverse.PseudoInverse.forward

PseudoInverse.forward(y: tensor) tensor[source]

Computes pseudo-inverse of measurements.

If self.store_H_pinv is True, computes the product of the stored pseudo-inverse and the measurements.

If self.store_H_pinv is False, computes the least squares solution of the measurements. In this case, any additional keyword arguments passed to the PseudoInverse constructor (and store in self.reg_kwargs are used here. These can include:

  • rcond (float): Cutoff for small singular values. It is

used only when regularization is ‘rcond’. This parameter is fed directly to torch.linalg.pinv(). - Any other keyword arguments that are passed to torch.linalg.lstsq(). Used only when regularization is ‘rcond’. - eta (float): Regularization parameter. It is used only when regularization is ‘L2’ or ‘H1’. This parameter determines the amount of regularization applied to the pseudo-inverse.

Args:

y (torch.tensor): Batch of measurement vectors of shape \((*, M)\), where \(*\) is any number of dimensions and \(M\) is the number of measurements of the measurement operator (meas_op.M).

Returns:

output (torch.tensor): Batch of reconstructed images of shape \((*, N)\) or the image shape as defined in the measurement operator (in meas_op.meas_shape) depending on the value of self.reshape_output.

Example 1:
>>> from spyrit.core.meas import Linear
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = Linear(H)
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 2: LinearSplit, pseudo-inverse of H (default)
>>> from spyrit.core.meas import LinearSplit
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = LinearSplit(H)
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.measure_H(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 3: LinearSplit, pseudo-inverse of A
>>> from spyrit.core.meas import LinearSplit
>>> from spyrit.core.inverse import PseudoInverse
>>> H = torch.randn(10, 15)
>>> meas_op = LinearSplit(H)
>>> meas_op.set_matrix_to_inverse('A')
>>> pinv_op = PseudoInverse(meas_op)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> x = pinv_op(y)
>>> print(x.shape)
torch.Size([3, 4, 15])