spyrit.core.warp.DeformationField
- class spyrit.core.warp.DeformationField(field: tensor)[source]
Bases:
ModuleStores a discrete deformation field \(u\) of shape \((n\_frames,h,w,2)\).
The deformations has \(n\_frames\) is the number of frames, and its height and width are denoted by \(h\) and \(w\). The last dimension contains the x and y coordinates of the deformation field w.r.t the reference time \(t_0\).
\[f(t, x, y) = f(t_0, u(t, x, y))\]where \(f(t_0, x, y)\) is the reference image and \(u(t, x, y)\) is the deformation field.
Forward call generates a video warping the input image according to the deformation field \(u\).
Important
The coordinates are given in the range [-1, 1]. When referring to a pixel, its position is the position of its center. The position (-1, -1) corresponds to the center of the top-left pixel.
- Args:
field(torch.tensor): Deformation field \(u\) of shape \((n\_frames,h,w,2)\), where \(n\_frames\) is the number of deformation frames, \(h\) and \(w\) are the height and width of the deformation field. For accuracy reasons, it is recommended the dtype to be torch.float64.- Attributes:
self.field(torch.tensor): Deformation field \(u\) of shape \((n\_frames,h,w,2)\).self.n_frames(int): Number of frames in the video.self.img_shape(tuple): Shape of the image to be warped, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the image respectively.img_h(int): Height of the image to be warped in pixels.img_w(int): Width of the image to be warped in pixels.self.align_corners(bool): Always True. This argument is passed to the functionstorch.nn.functional.grid_sample()andtorch.nn.functional.affine_grid()to ensure the corners of the image are aligned with the corners of the grid.- Example:
Storing a 90 degrees counter-clockwise rotation for 2x2 image.
>>> import torch >>> from spyrit.core.warp import DeformationField >>> >>> u = torch.tensor([[[[1, -1], [1, 1]], [[-1, -1], [-1, 1]]]]) >>> field = DeformationField(u) >>> print(field.field) tensor([[[[ 1, -1], [ 1, 1]], [[-1, -1], [-1, 1]]]]) >>> print(field.field.shape) torch.Size([1, 2, 2, 2])
Methods
forward(img[, n0, n1, mode])Generates a video from a batch of 2D images according to the deformation field \(u\).