spyrit.core.warp.AffineDeformationField.forward
- AffineDeformationField.forward(img: tensor, n0: int = 0, n1: int = None, mode: str = 'bilinear') tensor
Warps a vectorized image or batch of vectorized images with the stored inverse deformation field \(u\).
Deforms the vectorized image according to the inverse deformation field \(u\) contained in the attribute
field, sliced between the frames \(n0\) (included) and \(n1\) (excluded). \(u\) is the field that maps the pixels of the deformed image to the pixels of the original image.This method assumes the vectorized image has the same number of pixels as the deformation field.
- Args:
img(torch.tensor): The vectorized image to deform of shape \((c, h*w)\), where \(c\) is the number of channels (usually 1 or 3), and \(h\) and \(w\) are the number of pixels along the height and width of the image respectively.n0(int, optional): The index of the first frame to use in the inverse deformation field. Defaults to 0.n1(int, optional): The index of the first frame to exclude in the inverse deformation field. If None, the last available frame is used. Defaults to None.mode(str, optional): The interpolation mode to use. It is directly passed to the functiontorch.nn.functional.grid_sample(). It must be one of the following: ‘nearest’, ‘bilinear’, ‘bicubic’. Defaults to ‘bilinear’.
Note
If \(n0 < n1\),
fieldis sliced as follows:field[n0:n1, :, :, :]Note
If \(n0 > n1\),
fieldis sliced “backwards”. The first frame of the warped animation corresponds to the index \(n0\), and the last frame corresponds to the index \(n1+1\). This behavior is identical to slicing a list with a step of -1.- Returns:
output(torch.tensor): The deformed batch of images of shape \((|n1-n0|,c,h,w)\), where each image in the batch is deformed according to the inverse deformation field \(u\) contained in the attributefield.- Shape:
img: \((c,h,w)\), where \(c\) is the number of channels, and \(h\) and \(w\) are the number of pixels along the heigth and width of the image respectively.output: \((|n1-n0|,c,h,w)\)
Example 1: Rotating a 2x2 B&W image by 90 degrees counter-clockwise, using one frame
>>> v = torch.tensor([[[[ 1., -1.], [ 1., 1.]], [[-1., -1.], [-1., 1.]]]]) >>> field = DeformationField(v) >>> image = torch.tensor([[[0. , 0.3], [0.7, 1. ]]]) >>> deformed_image = field(image, 0, 1) >>> print(deformed_image) tensor([[[[0.3000, 1.0000], [0.0000, 0.7000]]]])