spyrit.core.warp.DeformationField.forward

DeformationField.forward(img: tensor, n0: int = 0, n1: int = None, mode: str = 'bilinear') tensor[source]

Generates a video from a batch of 2D images according to the deformation field \(u\).

The deformation is taken between the frames \(n0\) (included) and \(n1\) (excluded).

Args:

img (torch.tensor): Batch of 2D images to deform of shape \((c, h, w)\) or \((b, c, h, w)\), where:math:b is the number of images in the batch, \(c\) is the number of channels, and \(h\) and \(w\) are the height and width of the images.

n0 (int, optional): The index of the first frame to use in the deformation field. Defaults to 0.

n1 (int, optional): The index of the first frame to exclude in the deformation field. If None, the last available frame is used. Defaults to None.

mode (str, optional): The interpolation mode to use. It must be one of the following: ‘nearest’, ‘bilinear’, ‘bicubic’, ‘biquintic’. The nearest, bilinear, and bicubic modes are directly supported by the function torch.nn.functional.grid_sample(). The biquintic mode relies on scikit-image. Defaults to ‘bilinear’.

Note

If using mode=’bicubic’ or mode=’biquintic’, the warped image may contain values outside the original range.

Note

If \(n0 < n1\), field is sliced as follows: field[n0:n1, :, :, :]

Note

If \(n0 > n1\), field is sliced”backwards”. The first frame of the warped animation corresponds to the index \(n0\), and the last frame corresponds to the index \(n1+1\). This behavior is identical to slicing a list with a step of -1.

Note

If the number of pixels is different in the image and the field, the field is interpolated to match the image size (see the behavior of torch.nn.functional.grid_sample()).

Note

If the input image three dimensional, a batch dimension is added in the first dimension.

Returns:

output (torch.tensor): The deformed batch of 2D images of shape \((|n1-n0|, c, h, w)\) or \((b, |n1-n0|, c, h, w)\) depending on the input shape, where each image in the batch is deformed according to the deformation field \(u\).

Example:

Rotating a 2x2 grayscale image by 90 degrees counter-clockwise, using one frame:

>>> import torch
>>> from spyrit.core.warp import DeformationField
>>>
>>> u = torch.tensor([[[[ 1., -1.], [ 1., 1.]], [[-1., -1.], [-1., 1.]]]])
>>> field = DeformationField(u)
>>> image = torch.tensor([0., 0.3, 0.7, 1.]).view(1, 1, 2, 2)
>>> print(image)
tensor([[[[0.0000, 0.3000],
          [0.7000, 1.0000]]]])
>>> deformed_image = field(image, 0, 1)
>>> print(deformed_image)
tensor([[[[[0.3000, 1.0000],
           [0.0000, 0.7000]]]]])