spyrit.core.meas.LinearSplit

class spyrit.core.meas.LinearSplit(H: tensor, pinv: bool = False, rtol: float = None, Ord: tensor = None, meas_shape: tuple = None)[source]

Bases: Linear

Simulates splitted measurements \(y = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}x\).

Computes linear measurements from incoming images: \(y = Px\), where \(P\) is a linear operator (matrix) and \(x\) is a vectorized image or batch of vectorized images.

The matrix \(P\) contains only positive values and is obtained by splitting a measurement matrix \(H\) such that \(P\) has a shape of \((2M, N)\) and P[0::2, :] = H_{+} and P[1::2, :] = H_{-}, where \(H_{+} = \max(0,H)\) and \(H_{-} = \max(0,-H)\).

The class is constructed from the \(M\) by \(N\) matrix \(H\), where \(N\) represents the number of pixels in the image and \(M\) the number of measurements. Therefore, the shape of \(P\) is \((2M, N)\).

Args:

H (torch.tensor): measurement matrix (linear operator) with shape \((M, N)\).

pinv (bool): Whether to store the pseudo inverse of the measurement matrix \(H\). If True, the pseudo inverse is initialized as \(H^\dagger\) and stored in the attribute H_pinv. It is alwats possible to compute and store the pseudo inverse later using the method set_H_pinv(). Defaults to False.

rtol (float, optional): Cutoff for small singular values (see torch.linalg.pinv). Only relevant when pinv is True.

Ord (torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix \(H\). The first new row of \(H\) will correspond to the highest value in \(Ord\). Must contain \(M\) values. If some values repeat, the order is kept. Defaults to None.

meas_shape (tuple, optional): Shape of the measurement patterns. Must be a tuple of two integers representing the height and width of the patterns. If not specified, the shape is suppposed to be a square image. If not, an error is raised. Defaults to None.

Attributes:

H (torch.tensor): The learnable measurement matrix of shape \((M, N)\) initialized as \(H\).

H_static (torch.tensor): alias for H.

P (torch.tensor): The splitted measurement matrix of shape \((2M, N)\).

H_pinv (torch.tensor, optional): The learnable pseudo inverse measurement matrix \(H^\dagger\) of shape \((N, M)\).

M (int): Number of measurements performed by the linear operator.

N (int): Number of pixels in the image.

h (int): Measurement pattern height.

w (int): Measurement pattern width.

meas_shape (tuple): Shape of the measurement patterns (height, width). Is equal to (self.h, self.w).

indices (torch.tensor): Indices used to sort the rows of H. It is used by the method reindex().

Ord (torch.tensor): Order matrix used to sort the rows of H. It is used by sort_by_significance().

Note

If you know the pseudo inverse of \(H\) and want to store it, it is best to initialize the class with pinv set to False and then call set_H_pinv() to store the pseudo inverse.

Note

\(H = H_{+} - H_{-}\)

Example:
>>> H = torch.randn(400, 1600)
>>> meas_op = LinearSplit(H, False)
>>> print(meas_op)
LinearSplit(
  (M): 400
  (N): 1600
  (H.shape): torch.Size([400, 1600])
  (meas_shape): (40, 40)
  (H_pinv): False
  (P.shape): torch.Size([800, 1600])
)

Methods

adjoint(x)

Applies adjoint transform to incoming measurements \(y = H^{T}x\)

forward(x)

Applies linear transform to incoming images: \(y = Px\).

forward_H(x)

Applies linear transform to incoming images: \(m = Hx\).

get_H()

Deprecated method.

pinv(x[, reg, eta])

Computes the pseudo inverse solution \(y = H^\dagger x\).

reindex(x[, axis, inverse_permutation])

Sorts a tensor along a specified axis using the indices tensor.

set_H_pinv([rtol])

Used to set the pseudo inverse of the measurement matrix \(H\) using torch.linalg.pinv.