spyrit.core.meas.LinearSplit.forward

LinearSplit.forward(x: tensor) tensor[source]

Applies linear transform to incoming images: \(y = Px\).

This is equivalent to computing \(x \cdot P^T\). The input images must be unvectorized. The matrix \(P\) is obtained by splitting the measurement matrix \(H\) such that \(P\) has a shape of \((2M, N)\) and P[0::2, :] = H_{+} and P[1::2, :] = H_{-}, where \(H_{+} = \max(0,H)\) and \(H_{-} = \max(0,-H)\).

Warning

This method uses the splitted measurement matrix \(P\) to compute the linear measurements from incoming images. If you want to apply the operator \(H\) directly, use the method forward_H().

Args:

\(x\) (torch.tensor): Batch of images of shape \((*, h, w)\). * can have any number of dimensions, for instance (b, c) where b is the batch size and c the number of channels. h and w are the height and width of the images.

Output:

torch.tensor: The linear measurements of the input images. It has shape \((*, 2M)\) where * denotes any number of dimensions and M the number of measurements as defined by the parameter M, which is equal to the number of rows in the measurement matrix \(H\) defined at initialization.

Shape:

\(x\): \((*, N)\) where * denotes the batch size and N the total number of pixels in the image.

Output: \((*, 2M)\) where * denotes the batch size and M the number of measurements as defined by the parameter M, which is equal to the number of rows in the measurement matrix \(H\) defined at initialization.

Example:
>>> H = torch.randn(400, 1600)
>>> meas_op = LinearSplit(H)
>>> x = torch.randn(10, 40, 40)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([10, 800])