spyrit.core.meas.FreeformLinear.apply_mask
- FreeformLinear.apply_mask(x: tensor) tensor[source]
Appplies the saved mask to the input tensor, where the masked dimensions are collapsed into one.
This method first selects the elements from the input tensor at the specified dimensions self.meas_dims and based on the mask. The selected elements are then flattened into a single dimension which is the last dimension of the output tensor.
- Args:
x (
torch.tensor): The input tensor to select the mask from. The dimensions indexed by self.meas_dims should match the measurement shape self.meas_shape.- Returns:
torch.tensor: A tensor of shape (*, self.N) where * denotes all the dimensions of the input tensor not included in self.meas_dims.- Example: Select one every second point on the diagonal of a batch of images
# >>> images = torch.rand(17, 3, 40, 40) # b, c, h, w # >>> # create a (2,20) mask # >>> mask = torch.tensor([[i, i] for i in range(0,40,2)]).T # >>> H = torch.randn(13, 20) # >>> meas_op = FreeformLinear(H, mask, meas_shape=(40,40), dim=(-1,-2)) # >>> y = meas_op.apply_mask(images) # >>> print(y.shape) # torch.Size([17, 3, 20])