spyrit.core.warp.AffineDeformationField.forward
- AffineDeformationField.forward(img: tensor, n0: int = 0, n1: int = None, mode: str = 'bilinear') tensor[source]
Generates a video from a batch of 2D images according to the deformation field \(u\).
The deformation is taken between the frames \(n0\) (included) and \(n1\) (excluded).
- Args:
img(torch.tensor): Batch of 2D images to deform of shape \((c, h, w)\) or \((b, c, h, w)\), where \(b\) is the number of images in the batch, \(c\) is the number of channels, and \(h\) and \(w\) are the height and width of the images.n0(int, optional): The index of the first frame to use in the deformation field. Defaults to 0.n1(int, optional): The index of the first frame to exclude in the deformation field. If None, the last available frame is used. Defaults to None.mode(str, optional): The interpolation mode to use. It must be one of the following: ‘nearest’, ‘bilinear’, ‘bicubic’, ‘biquintic’. The nearest, bilinear, and bicubic modes are directly supported by the functiontorch.nn.functional.grid_sample(). The biquintic mode relies on scikit-image. Defaults to ‘bilinear’.
Note
If using mode=’bicubic’ or mode=’biquintic’, the warped image may contain values outside the original range.
Note
If \(n0 < n1\),
fieldis sliced as follows:field[n0:n1, :, :, :]Note
If \(n0 > n1\),
fieldis sliced “backwards”. The first frame of the warped animation corresponds to the index \(n0\), and the last frame corresponds to the index \(n1+1\). This behavior is identical to slicing a list with a step of -1.Note
If the number of pixels is different in the image and the field, the field is interpolated to match the image size (see the behavior of
torch.nn.functional.grid_sample()).- Returns:
output(torch.tensor): The deformed batch of 2D images of shape \((|n1-n0|, c, h, w)\) or \((b, |n1-n0|, c, h, w)\) depending on the input shape, where each image in the batch is deformed according to the deformation field \(u\).- Example 1: Progressive scaling
>>> import torch >>> from spyrit.core.warp import AffineDeformationField >>> >>> def scaling(t): ... scale = 1 - t/10 ... return torch.tensor([[scale, 0, 0], [0, scale, 0], [0, 0, 1]]) >>> >>> time_vector = torch.linspace(0, 1, 10) >>> def_field = AffineDeformationField(scaling, time_vector, (64, 64)) >>> images = torch.randn(16, 1, 64, 64) # Batch of 16 grayscale image >>> scaled_video = def_field(images) >>> print(scaled_video.shape) torch.Size([16, 10, 1, 64, 64])
- Example 2: Rotation counter-clockwise at 1Hz frequency
>>> import torch
>>> import math >>> from spyrit.core.warp import AffineDeformationField >>> >>> def rotation(t): ... angle = 2 * math.pi * t # 1Hz rotation ... c, s = math.cos(angle), math.sin(angle) ... return torch.tensor([[c, s, 0], [-s, c, 0], [0, 0, 1]], dtype=torch.float64) >>> >>> time_vector = torch.linspace(0, 1, 30) # 30 frames for 1 second >>> def_field = AffineDeformationField(rotation, time_vector, (128, 128), dtype=torch.float64) >>> image = torch.randn(3, 128, 128).to(dtype=torch.float64) # a single RGB image >>> rotated_video = def_field(image) >>> print(rotated_video.shape) torch.Size([30, 3, 128, 128])