spyrit.core.warp.DeformationField

class spyrit.core.warp.DeformationField(field: tensor)[source]

Bases: Module

Stores a discrete deformation field \(u\) of shape \((n\_frames,h,w,2)\).

The deformations has \(n\_frames\) is the number of frames, and its height and width are denoted by \(h\) and \(w\). The last dimension contains the x and y coordinates of the deformation field w.r.t the reference time \(t_0\).

\[f(t, x, y) = f(t_0, u(t, x, y))\]

where \(f(t_0, x, y)\) is the reference image and \(u(t, x, y)\) is the deformation field.

Forward call generates a video warping the input image according to the deformation field \(u\).

Important

The coordinates are given in the range [-1, 1]. When referring to a pixel, its position is the position of its center. The position (-1, -1) corresponds to the center of the top-left pixel.

Args:

field (torch.tensor): Deformation field \(u\) of shape \((n\_frames,h,w,2)\), where \(n\_frames\) is the number of deformation frames, \(h\) and \(w\) are the height and width of the deformation field. For accuracy reasons, it is recommended the dtype to be torch.float64.

Attributes:

self.field (torch.tensor): Deformation field \(u\) of shape \((n\_frames,h,w,2)\).

self.n_frames (int): Number of frames in the video.

self.img_shape (tuple): Shape of the image to be warped, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the image respectively.

img_h (int): Height of the image to be warped in pixels.

img_w (int): Width of the image to be warped in pixels.

self.align_corners (bool): Always True. This argument is passed to the functions torch.nn.functional.grid_sample() and torch.nn.functional.affine_grid() to ensure the corners of the image are aligned with the corners of the grid.

Example:

Storing a 90 degrees counter-clockwise rotation for 2x2 image.

>>> import torch
>>> from spyrit.core.warp import DeformationField
>>>
>>> u = torch.tensor([[[[1, -1], [1, 1]], [[-1, -1], [-1, 1]]]])
>>> field = DeformationField(u)
>>> print(field.field)
tensor([[[[ 1, -1],
          [ 1,  1]],

         [[-1, -1],
          [-1,  1]]]])
>>> print(field.field.shape)
torch.Size([1, 2, 2, 2])

Methods

forward(img[, n0, n1, mode])

Generates a video from a batch of 2D images according to the deformation field \(u\).