spyrit.core.warp.ElasticDeformation.forward

ElasticDeformation.forward(img: tensor, n0: int = 0, n1: int = None, mode: str = 'bilinear') tensor

Warps a batch of 2D images with the stored inverse deformation field \(u\).

Deforms the batch of 2D images according to the inverse deformation field \(u\) contained in the attribute field, sliced between the frames \(n0\) (included) and \(n1\) (excluded). \(u\) is the field that maps the pixels of the deformed image to the pixels of the original image.

Args:

img (torch.tensor): The batch of 2D images to deform of shape \((c, h, w)\) or \((b, c, h, w)\), where \(b\) is the number of images in the batch, \(c\) is the number of channels (usually 1 or 3), and \(h\) and \(w\) are the number of pixels along the height and width of the image respectively.

n0 (int, optional): The index of the first frame to use in the inverse deformation field. Defaults to 0.

n1 (int, optional): The index of the first frame to exclude in the inverse deformation field. If None, the last available frame is used. Defaults to None.

mode (str, optional): The interpolation mode to use. It must be one of the following: ‘nearest’, ‘bilinear’, ‘bicubic’, ‘biquintic’. If either nearest, bilinear, or bicubic, it is directly passed to the function torch.nn.functional.grid_sample(). if biquintic, it is passed to the package scikit-image, which requires skimage and numpy. Defaults to ‘bilinear’.

Note

If using mode=’bicubic’ or mode=’biquintic’, the warped image may contain values outside the original range. Please use the function or method torch.clamp() to ensure the values are in the correct range.

Note

If \(n0 < n1\), field is sliced as follows: field[n0:n1, :, :, :]

Note

If \(n0 > n1\), field is sliced “backwards”. The first frame of the warped animation corresponds to the index \(n0\), and the last frame corresponds to the index \(n1+1\). This behavior is identical to slicing a list with a step of -1.

Note

If the number of pixels is different in the image and the field, the torch function torch.nn.functional.grid_sample() will still work, and it will interpolate the field to match the image size.

Returns:

output (torch.tensor): The deformed batch of 2D images of shape \((|n1-n0|, c, h, w)\) or \((b, |n1-n0|, c, h, w)\) depending on the input shape, where each image in the batch is deformed according to the inverse deformation field \(u\) contained in the attribute field.

Shape:

img: \((b, c, h, w)\), where \(b\) is the number of images in the batch, \(c\) is the number of channels (usually 1 or 3), and \(h\) and \(w\) are the number of pixels along the height and width of the image respectively.

output: \((b, |n1-n0|, c, h, w)\)

Example 1: Rotating a 2x2 B&W image by 90 degrees counter-clockwise, using one frame

>>> v = torch.tensor([[[[ 1., -1.], [ 1., 1.]], [[-1., -1.], [-1., 1.]]]])
>>> field = DeformationField(v)
>>> image = torch.tensor([0., 0.3, 0.7, 1.]).view(1, 1, 2, 2)
>>> deformed_image = field(image, 0, 1)
>>> print(deformed_image)
tensor([[[[[0.3000, 1.0000],
           [0.0000, 0.7000]]]]])