spyrit.core.warp.ElasticDeformation.grid_sample

ElasticDeformation.grid_sample(img_frames, inverse_grid_frames, mode)

Used to warp frames of 2D images with a deformation field. Each image of the collection will get a different deformation. This function matches the behavior of nn.functional.grid_sample.

Inputs:

img_frames (torch.tensor): batch of 2D images of shape (n_frames, c, h, w), where n_frames is the number of frames in the animation, c is the number of channels, and h and w are the height and width of the image respectively.

inverse_grid_frames (torch.tensor): batch of inverse deformation fields of shape (n_frames, h, w, 2), indicating the pixel coordinates of the original image that are displayed in the warped image.

mode (str): The interpolation mode to use. It must be one of the following: ‘nearest’, ‘bilinear’, ‘bicubic’, ‘biquintic’. If either nearest, bilinear, or bicubic, it is directly passed to the function torch.nn.functional.grid_sample(). if biquintic, it is passed to the package scikit-image, which requires skimage and numpy.

Returns:

out (torch.tensor): The deformed batch of 2D images of shape (n_frames, c, h, w). Each image in the batch is deformed according to the inverse deformation field \(u\) contained in the attribute field.