spyrit.core.warp.DeformationField

class spyrit.core.warp.DeformationField(field: tensor)[source]

Bases: Module

Stores a discrete deformation field as a \((b,h,w,2)\) tensor.

Warps a single image or batch of images according to an inverse deformation field \(u\), i.e. the field that maps the deformed image pixel coordinates to the original image pixel coordinates.

It is constructed from a tensor of shape \((n\_frames,h,w,2)\), where \(n\_frames\) is the number of frames in the animation, \(h\) and \(w\) are the number of pixels along the height and width of the image respectively. The last dimension contains the x and y coordinates of the original image pixel that is displayed in the warped image.

Important

The coordinates are given in the range [-1;1]. When referring to a pixel, its position is the position of its center. The position [-1;-1] corresponds to the center of the top-left pixel.

Args:

field (torch.tensor): Inverse deformation field \(u\) of shape \((n\_frames,H,W,2)\), where \(n\_frames\) is the number of frames in the animation, and \(H\) and \(W\) are the height and width of the image to be warped. For accuracy reasons, it is recommended the dtype to be torch.float64.

Attributes:

self.field (torch.tensor): Inverse deformation field \(u\) of shape \((n\_frames,h,w,2)\).

self.n_frames (int): Number of frames in the animation.

self.img_shape (tuple): Shape of the image to be warped, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the image respectively.

img_h (int): Height of the image to be warped in pixels.

img_w (int): Width of the image to be warped in pixels.

self.align_corners (bool): Always True. This argument is passed to the functions torch.nn.functional.grid_sample() and torch.nn.functional.affine_grid() to ensure the corners of the image are aligned with the corners of the grid.

Example 1: Rotating a 2x2 B&W image by 90 degrees counter-clockwise, using one frame
>>> u = torch.tensor([[[[ 1, -1], [ 1, 1]], [[-1, -1], [-1, 1]]]])
>>> field = DeformationField(u)
>>> print(field.field)
tensor([[[[ 1, -1], [ 1, 1]], [[-1, -1], [-1, 1]]]])
>>> print(field.field.shape)
torch.Size([1, 2, 2, 2])
Example 2: Rotating a 2x2 B&W image by 90 degrees clockwise, using one frame
>>> u = torch.tensor([[[[-1, 1], [-1, -1]], [[ 1, 1], [ 1, -1]]]])
>>> field = DeformationField(u)
>>> print(field.field)
tensor([[[[-1, 1], [-1, -1]], [[ 1, 1], [ 1, -1]]])

Methods

forward(img[, n0, n1, mode])

Warps a batch of 2D images with the stored inverse deformation field \(u\).

grid_sample(img_frames, inverse_grid_frames, ...)

Used to warp frames of 2D images with a deformation field.