spyrit.core.warp.DeformationField

class spyrit.core.warp.DeformationField(field: tensor)[source]

Bases: Module

Stores a discrete deformation field as a \((b,h,w,2)\) tensor.

Warps a single image according to an inverse deformation field \(u\), i.e. the field that maps the deformed image pixel coordinates to the original image pixel coordinates.

It is constructed from a tensor of shape \((n\_frames,h,w,2)\), where \(n\_frames\) is the number of frames in the animation, \(h\) and \(w\) are the number of pixels along the height and width of the image respectively. The last dimension contains the x and y coordinates of the original image pixel that is displayed in the warped image.

Important

The coordinates are given in the range [-1;1]. When referring to a pixel, its position is the position of its center. The position [-1;-1] corresponds to the center of the top-left pixel.

Args:

field (torch.tensor): Inverse deformation field \(u\) of shape \((n\_frames,H,W,2)\), where \(n\_frames\) is the number of frames in the animation, and \(H\) and \(W\) are the height and width of the image to be warped. For accuracy reasons, the dtype is converted to torch.float64.

Attributes:

self.field (torch.tensor): Inverse deformation field \(u\) of shape \((n\_frames,h,w,2)\).

self.n_frames (int): Number of frames in the animation.

self.img_shape (tuple): Shape of the image to be warped, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the image respectively.

img_h (int): Height of the image to be warped in pixels.

img_w (int): Width of the image to be warped in pixels.

Example 1: Rotating a 2x2 B&W image by 90 degrees counter-clockwise, using one frame
>>> u = torch.tensor([[[[ 0.5, -0.5], [ 0.5, 0.5]], [[-0.5, -0.5], [-0.5, 0.5]]]])
>>> field = DeformationField(u)
>>> print(field.field)
tensor([[[[ 0.5, -0.5], [ 0.5, 0.5]], [[-0.5, -0.5], [-0.5, 0.5]]]])
Example 2: Rotating a 2x2 B&W image by 90 degrees clockwise, using one frame
>>> u = torch.tensor([[[[-1, 1], [-1, -1]], [[ 1, 1], [ 1, -1]]]])
>>> field = DeformationField(u)
>>> print(field.field)
tensor([[[[-1, 1], [-1, -1]], [[ 1, 1], [ 1, -1]]])

Methods

forward(img[, n0, n1, mode])

Warps a vectorized image or batch of vectorized images with the stored inverse deformation field \(u\).