spyrit.core.warp.AffineDeformationField
- class spyrit.core.warp.AffineDeformationField(func, time_vector: tensor, img_shape: tuple, dtype: dtype = torch.float32, device: device = device(type='cpu'))[source]
Bases:
DeformationFieldStores and applies affine deformation fields defined by transformation matrices.
This class generates video sequences by warping images according to time-varying affine transformations. It constructs a discrete deformation field \(u\) from a user-defined function that returns 3x3 affine transformation matrices at different time points.
The forward call generates a video warping the input image according to the deformation field \(u = v^{-1}\).
\[f(t, x, y) = f(t_0, u(t, x, y))\]where \(f(t_0, x, y)\) is the reference image and \(u(t, x, y)\) is the deformation field.
Important
The coordinates are given in the range [-1, 1]. When referring to a pixel, its position is the position of its center. The position (-1, -1) corresponds to the center of the top-left pixel.
Note
The image size is requested upon construction, but the warping can be done with images of different sizes. The grid is simply interpolated to match the image size. It is also possible to change the image size after construction by setting the attribute
img_shape, or the attributesimg_handimg_w.- Args:
func(Callable[[float], torch.tensor]): Function of one parameter (time) that returns a tensor of shape \((3,3)\) representing an affine homogeneous transformation matrix. This matrix corresponds to the deformation field \(u\).time_vector(torch.tensor): Vector of time points at which the transformation function is evaluated to generate the deformation field. Shape \((n\_frames,)\).img_shape(tuple): Shape of the image to be warped, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the image respectively.dtype(torch.dtype, optional): Data type of the deformation field tensor. For accuracy reasons, it is recommended to use torch.float64. Defaults to torch.float32.device(torch.device, optional): Device on which the deformation field tensor is stored. Defaults to torch.device(‘cpu’).- Attributes:
self.func(Callable[[float], torch.tensor]): Function of one parameter (time) that returns a tensor of shape \((3,3)\) representing an affine homogeneous transformation matrix.self.field(torch.tensor):Deformation field \(u\) of shape \((n\_frames,h,w,2)\).self.time_vector(torch.tensor): Vector of time points at which the function is evaluated to generate the deformation field.self.n_frames(int): Number of frames in the video.self.img_shape(tuple): Shape of the image to be warped, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the image respectively.self.img_h(int): Height of the image to be warped in pixels.self.img_w(int): Width of the image to be warped in pixels.self.align_corners(bool): Always True. This argument is passed to the functionstorch.nn.functional.grid_sample()andtorch.nn.functional.affine_grid()to ensure the corners of the image are aligned with the corners of the grid.- Example 1: Progressive scaling
>>> import torch >>> from spyrit.core.warp import AffineDeformationField >>> >>> def scaling(t): ... scale = 1 - t/10 ... return torch.tensor([[scale, 0, 0], [0, scale, 0], [0, 0, 1]]) >>> >>> time_vector = torch.linspace(0, 1, 10) >>> def_field = AffineDeformationField(scaling, time_vector, (64, 64)) >>> print(def_field.n_frames) 10
- Example 2: Rotation counter-clockwise at 1Hz frequency
>>> import torch >>> import math >>> from spyrit.core.warp import AffineDeformationField >>> >>> def rotation(t): ... angle = 2 * math.pi * t # 1Hz rotation ... c, s = math.cos(angle), math.sin(angle) ... return torch.tensor([[c, s, 0], [-s, c, 0], [0, 0, 1]], dtype=torch.float64) >>> >>> time_vector = torch.linspace(0, 1, 30) # 30 frames for 1 second >>> def_field = AffineDeformationField(rotation, time_vector, (128, 128)) >>> print(def_field.field.shape) torch.Size([30, 128, 128, 2])
Methods
forward(img[, n0, n1, mode])Generates a video from a batch of 2D images according to the deformation field \(u\).