spyrit.core.meas.LinearSplit
- class spyrit.core.meas.LinearSplit(H: tensor, pinv: bool = False, rtol: float = None, Ord: tensor = None, meas_shape: tuple = None)[source]
Bases:
LinearSimulates splitted measurements \(y = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}x\).
Computes linear measurements from incoming images: \(y = Px\), where \(P\) is a linear operator (matrix) and \(x\) is a vectorized image or batch of vectorized images.
The matrix \(P\) contains only positive values and is obtained by splitting a measurement matrix \(H\) such that \(P\) has a shape of \((2M, N)\) and P[0::2, :] = H_{+} and P[1::2, :] = H_{-}, where \(H_{+} = \max(0,H)\) and \(H_{-} = \max(0,-H)\).
The class is constructed from the \(M\) by \(N\) matrix \(H\), where \(N\) represents the number of pixels in the image and \(M\) the number of measurements. Therefore, the shape of \(P\) is \((2M, N)\).
- Args:
H(torch.tensor): measurement matrix (linear operator) with shape \((M, N)\).pinv(bool): Whether to store the pseudo inverse of the measurement matrix \(H\). If True, the pseudo inverse is initialized as \(H^\dagger\) and stored in the attributeH_pinv. It is alwats possible to compute and store the pseudo inverse later using the methodbuild_H_pinv(). Defaults to False.rtol(float, optional): Cutoff for small singular values (seetorch.linalg.pinv). Only relevant whenpinvis True.Ord(torch.tensor, optional): Order matrix used to reorder the rows of the measurement matrix \(H\). The first new row of \(H\) will correspond to the highest value in \(Ord\). Must contain \(M\) values. If some values repeat, the order is kept. Defaults to None.meas_shape(tuple, optional): Shape of the measurement patterns. Must be a tuple of two integers representing the height and width of the patterns. If not specified, the shape is suppposed to be a square image. If not, an error is raised. Defaults to None.- Attributes:
H(torch.tensor): The learnable measurement matrix of shape \((M, N)\) initialized as \(H\).H_static(torch.tensor): alias forH.P(torch.tensor): The splitted measurement matrix of shape \((2M, N)\).H_pinv(torch.tensor, optional): The learnable pseudo inverse measurement matrix \(H^\dagger\) of shape \((N, M)\).M(int): Number of measurements performed by the linear operator.N(int): Number of pixels in the image.h(int): Measurement pattern height.w(int): Measurement pattern width.meas_shape(tuple): Shape of the measurement patterns (height, width). Is equal to (self.h, self.w).indices(torch.tensor): Indices used to sort the rows of H. It is used by the methodreindex().Ord(torch.tensor): Order matrix used to sort the rows of H. It is used bysort_by_significance().
Note
If you know the pseudo inverse of \(H\) and want to store it, it is best to initialize the class with
pinvset to False and then callbuild_H_pinv()to store the pseudo inverse.Note
\(H = H_{+} - H_{-}\)
- Example:
>>> H = torch.randn(400, 1600) >>> meas_op = LinearSplit(H, False) >>> print(meas_op) LinearSplit( (M): 400 (N): 1600 (H.shape): torch.Size([400, 1600]) (meas_shape): (40, 40) (H_pinv): False (P.shape): torch.Size([800, 1600]) )
Methods
adjoint(x)Applies adjoint transform to incoming measurements \(y = H^{T}x\)
build_H_pinv([reg, eta])Used to set the pseudo inverse of the measurement matrix \(H\) using torch.linalg.pinv.
forward(x)Applies linear transform to incoming images: \(y = Px\).
forward_H(x)Applies linear transform to incoming images: \(m = Hx\).
get_H()Deprecated method.
pinv(x[, reg, eta])Computes the pseudo inverse solution \(y = H^\dagger x\).
reindex(x[, axis, inverse_permutation])Sorts a tensor along a specified axis using the indices tensor.