spyrit.core.warp.ElasticDeformation.forward
- ElasticDeformation.forward(img: tensor, n0: int = 0, n1: int = None, mode: str = 'bilinear') tensor
Generates a video from a batch of 2D images according to the deformation field \(u\).
The deformation is taken between the frames \(n0\) (included) and \(n1\) (excluded).
- Args:
img(torch.tensor): Batch of 2D images to deform of shape \((c, h, w)\) or \((b, c, h, w)\), where:math:b is the number of images in the batch, \(c\) is the number of channels, and \(h\) and \(w\) are the height and width of the images.n0(int, optional): The index of the first frame to use in the deformation field. Defaults to 0.n1(int, optional): The index of the first frame to exclude in the deformation field. If None, the last available frame is used. Defaults to None.mode(str, optional): The interpolation mode to use. It must be one of the following: ‘nearest’, ‘bilinear’, ‘bicubic’, ‘biquintic’. The nearest, bilinear, and bicubic modes are directly supported by the functiontorch.nn.functional.grid_sample(). The biquintic mode relies on scikit-image. Defaults to ‘bilinear’.
Note
If using mode=’bicubic’ or mode=’biquintic’, the warped image may contain values outside the original range.
Note
If \(n0 < n1\),
fieldis sliced as follows:field[n0:n1, :, :, :]Note
If \(n0 > n1\),
fieldis sliced”backwards”. The first frame of the warped animation corresponds to the index \(n0\), and the last frame corresponds to the index \(n1+1\). This behavior is identical to slicing a list with a step of -1.Note
If the number of pixels is different in the image and the field, the field is interpolated to match the image size (see the behavior of
torch.nn.functional.grid_sample()).Note
If the input image three dimensional, a batch dimension is added in the first dimension.
- Returns:
output(torch.tensor): The deformed batch of 2D images of shape \((|n1-n0|, c, h, w)\) or \((b, |n1-n0|, c, h, w)\) depending on the input shape, where each image in the batch is deformed according to the deformation field \(u\).- Example:
Rotating a 2x2 grayscale image by 90 degrees counter-clockwise, using one frame:
>>> import torch >>> from spyrit.core.warp import DeformationField >>> >>> u = torch.tensor([[[[ 1., -1.], [ 1., 1.]], [[-1., -1.], [-1., 1.]]]]) >>> field = DeformationField(u) >>> image = torch.tensor([0., 0.3, 0.7, 1.]).view(1, 1, 2, 2) >>> print(image) tensor([[[[0.0000, 0.3000], [0.7000, 1.0000]]]]) >>> deformed_image = field(image, 0, 1) >>> print(deformed_image) tensor([[[[[0.3000, 1.0000], [0.0000, 0.7000]]]]])