Source code for spyrit.core.warp

"""
Stores deformation fields and warps images.

Contains :class:`DeformationField` and :class:`AffineDeformationField`, a
subclass of the former. These classes are used to warp images according to
a deformation field that is stored as as class attribute. They can be fed
an image (called "*original image*") and will return the warped image
("*deformed image*").

The function that maps the *original image* pixel coordinates to the *deformed
image* pixel coordinates is called the "*deformation field*" and is noted
:math:`v`. The function that maps the pixels of the *deformed image* to the
pixels of the *original image* is called the "*inverse deformation field*" and
is noted :math:`u`. The *deformation field* and the *inverse deformation field*
are related by the equation :math:`v = u^{-1}`.

Here, the two classes use and store the *inverse deformation field* :math:`u`
as a class attribute.
"""

import warnings

import math
import torch
import torch.nn as nn
from torchvision.transforms import v2


# =============================================================================
[docs] class DeformationField(nn.Module): # ========================================================================= r""" Stores a discrete deformation field as a :math:`(b,h,w,2)` tensor. Warps a single image or batch of images according to an *inverse deformation field* :math:`u`, i.e. the field that maps the *deformed image* pixel coordinates to the *original image* pixel coordinates. It is constructed from a tensor of shape :math:`(n\_frames,h,w,2)`, where :math:`n\_frames` is the number of frames in the animation, :math:`h` and :math:`w` are the number of pixels along the height and width of the image respectively. The last dimension contains the x and y coordinates of the original image pixel that is displayed in the warped image. .. important:: The coordinates are given in the range [-1;1]. When referring to a pixel, its position is the position of its center. The position [-1;-1] corresponds to the center of the top-left pixel. Args: :attr:`field` (torch.tensor): *Inverse deformation field* :math:`u` of shape :math:`(n\_frames,H,W,2)`, where :math:`n\_frames` is the number of frames in the animation, and :math:`H` and :math:`W` are the height and width of the image to be warped. For accuracy reasons, it is recommended the dtype to be `torch.float64`. Attributes: :attr:`self.field` (torch.tensor): *Inverse deformation field* :math:`u` of shape :math:`(n\_frames,h,w,2)`. :attr:`self.n_frames` (int): Number of frames in the animation. :attr:`self.img_shape` (tuple): Shape of the image to be warped, i.e. :math:`(h,w)`, where :math:`h` and :math:`w` are the height and width of the image respectively. :attr:`img_h` (int): Height of the image to be warped in pixels. :attr:`img_w` (int): Width of the image to be warped in pixels. :attr:`self.align_corners` (bool): Always True. This argument is passed to the functions :func:`torch.nn.functional.grid_sample` and :func:`torch.nn.functional.affine_grid` to ensure the corners of the image are aligned with the corners of the grid. **Example 1:** Rotating a 2x2 B&W image by 90 degrees counter-clockwise, using one frame >>> u = torch.tensor([[[[ 1, -1], [ 1, 1]], [[-1, -1], [-1, 1]]]]) >>> field = DeformationField(u) >>> print(field.field) tensor([[[[ 1, -1], [ 1, 1]], [[-1, -1], [-1, 1]]]]) >>> print(field.field.shape) torch.Size([1, 2, 2, 2]) **Example 2:** Rotating a 2x2 B&W image by 90 degrees clockwise, using one frame >>> u = torch.tensor([[[[-1, 1], [-1, -1]], [[ 1, 1], [ 1, -1]]]]) >>> field = DeformationField(u) >>> print(field.field) tensor([[[[-1, 1], [-1, -1]], [[ 1, 1], [ 1, -1]]]) """ def __init__(self, field: torch.tensor): super().__init__() self._align_corners = True self._device_tracker = nn.Parameter(torch.tensor([0.0]), requires_grad=False) # field is None if AffineDeformationField is used if type(self) is DeformationField: # store as nn.Parameter self._field = nn.Parameter(field, requires_grad=False) self.warn_range = False # warn the user if the field goes beyond +/-2 # self._warn_field() @property def align_corners(self) -> bool: return self._align_corners @property def n_frames(self) -> int: return self._field.shape[0] @property def img_shape(self) -> tuple: return self._field.shape[1:3] @property def img_h(self) -> int: return self._field.shape[1] @property def img_w(self) -> int: return self._field.shape[2] @property def field(self) -> torch.tensor: return self._field.data @property def device(self) -> torch.device: return self._device_tracker.device
[docs] def forward( self, img: torch.tensor, n0: int = 0, n1: int = None, mode: str = "bilinear", ) -> torch.tensor: r"""Warps a batch of 2D images with the stored *inverse deformation field* :math:`u`. Deforms the batch of 2D images according to the *inverse deformation field* :math:`u` contained in the attribute :attr:`field`, sliced between the frames :math:`n0` (included) and :math:`n1` (excluded). :math:`u` is the field that maps the pixels of the *deformed image* to the pixels of the *original image*. Args: :attr:`img` (torch.tensor): The batch of 2D images to deform of shape :math:`(c, h, w)` or :math:`(b, c, h, w)`, where :math:`b` is the number of images in the batch, :math:`c` is the number of channels (usually 1 or 3), and :math:`h` and :math:`w` are the number of pixels along the height and width of the image respectively. :attr:`n0` (int, optional): The index of the first frame to use in the *inverse deformation field*. Defaults to 0. :attr:`n1` (int, optional): The index of the first frame to exclude in the *inverse deformation field*. If None, the last available frame is used. Defaults to None. :attr:`mode` (str, optional): The interpolation mode to use. It must be one of the following: 'nearest', 'bilinear', 'bicubic', 'biquintic'. If either `nearest`, `bilinear`, or `bicubic`, it is directly passed to the function :func:`torch.nn.functional.grid_sample`. if `biquintic`, it is passed to the package scikit-image, which requires skimage and numpy. Defaults to 'bilinear'. .. note:: If using mode='bicubic' or mode='biquintic', the warped image may contain values outside the original range. Please use the function or method :func:`torch.clamp` to ensure the values are in the correct range. .. note:: If :math:`n0 < n1`, :attr:`field` is sliced as follows: ``field[n0:n1, :, :, :]`` .. note:: If :math:`n0 > n1`, :attr:`field` is sliced "backwards". The first frame of the warped animation corresponds to the index :math:`n0`, and the last frame corresponds to the index :math:`n1+1`. This behavior is identical to slicing a list with a step of -1. .. note:: If the number of pixels is different in the image and the field, the torch function :func:`torch.nn.functional.grid_sample` will still work, and it will interpolate the field to match the image size. Returns: :attr:`output` (torch.tensor): The deformed batch of 2D images of shape :math:`(|n1-n0|, c, h, w)` or :math:`(b, |n1-n0|, c, h, w)` depending on the input shape, where each image in the batch is deformed according to the *inverse deformation field* :math:`u` contained in the attribute :attr:`field`. Shape: :attr:`img`: :math:`(b, c, h, w)`, where :math:`b` is the number of images in the batch, :math:`c` is the number of channels (usually 1 or 3), and :math:`h` and :math:`w` are the number of pixels along the height and width of the image respectively. :attr:`output`: :math:`(b, |n1-n0|, c, h, w)` Example 1: Rotating a 2x2 B&W image by 90 degrees counter-clockwise, using one frame >>> v = torch.tensor([[[[ 1., -1.], [ 1., 1.]], [[-1., -1.], [-1., 1.]]]]) >>> field = DeformationField(v) >>> image = torch.tensor([0., 0.3, 0.7, 1.]).view(1, 1, 2, 2) >>> deformed_image = field(image, 0, 1) >>> print(deformed_image) tensor([[[[0.3000, 1.0000], [0.0000, 0.7000]]]]) """ if img.ndim == 3: img = img.unsqueeze(0) no_batch = True else: no_batch = False # check that the input is shaped (b, c, h, w) b, c, h, w = img.shape if n1 is None: n1 = self.n_frames # get the right slice of the inverse deformation field n_frames = abs(n1 - n0) if n1 < n0: sel_inv_grid_frames = torch.flip(self.field[n1 + 1 : n0 + 1, :, :, :], [0]) else: sel_inv_grid_frames = self.field[n0:n1, :, :, :] # img has current shape (b, c, h, w), make it (n_frames, b*c, h, w) # because grid_sample will create the frames in the batch dimension img_frames = img.reshape(1, b * c, h, w).expand(n_frames, -1, -1, -1) warped_frames = self.grid_sample( img_frames.to(sel_inv_grid_frames.dtype), sel_inv_grid_frames, mode ).to(img.dtype) # has shape (n_frames, b*c, h, w), make it (b, n_frames, c, h, w) warped_frames = warped_frames.reshape(n_frames, b, c, h, w).moveaxis(0, 1) if no_batch: return warped_frames.squeeze(0) return warped_frames
[docs] def grid_sample(self, img_frames, inverse_grid_frames, mode): """Used to warp frames of 2D images with a deformation field. Each image of the collection will get a different deformation. This function matches the behavior of nn.functional.grid_sample. Inputs: :attr:`img_frames` (torch.tensor): batch of 2D images of shape `(n_frames, c, h, w)`, where `n_frames` is the number of frames in the animation, `c` is the number of channels, and `h` and `w` are the height and width of the image respectively. :attr:`inverse_grid_frames` (torch.tensor): batch of inverse deformation fields of shape `(n_frames, h, w, 2)`, indicating the pixel coordinates of the original image that are displayed in the warped image. :attr:`mode` (str): The interpolation mode to use. It must be one of the following: 'nearest', 'bilinear', 'bicubic', 'biquintic'. If either `nearest`, `bilinear`, or `bicubic`, it is directly passed to the function :func:`torch.nn.functional.grid_sample`. if `biquintic`, it is passed to the package scikit-image, which requires skimage and numpy. Returns: :attr:`out` (torch.tensor): The deformed batch of 2D images of shape `(n_frames, c, h, w)`. Each image in the batch is deformed according to the inverse deformation field :math:`u` contained in the attribute :attr:`field`. """ if mode == "biquintic": import skimage import numpy as np n_frames, c, h, w = img_frames.shape out = np.empty((n_frames, c, h, w)) # use scikit-image's order 5 warp. This implies: # putting the origin pixel coordinate (x,y) at dimension 0, not 3 # using numpy instead of pytorch inverse_grid_frames = inverse_grid_frames.moveaxis(-1, 0).cpu().numpy() # changing from 'xy' notation to 'ij' inverse_grid_frames = inverse_grid_frames[::-1, :, :, :] # rescaling from [-1, 1] to [0, height-1] (same for width) inverse_grid_frames = ( (inverse_grid_frames + 1) / 2 * np.array([self.img_h, self.img_w]).reshape(2, 1, 1, 1) ) # use 2 for loops, faster than 5D warp (because 5D interpolation) for frame in range(n_frames): inverse_grid = inverse_grid_frames[:, frame, :, :] for channel in range(c): out[frame, channel, :, :] = skimage.transform.warp( img_frames[frame, channel, :, :].numpy(), inverse_grid, order=5, clip=False, ) return torch.from_numpy(out).to(img_frames.device).to(img_frames.dtype) else: out = nn.functional.grid_sample( img_frames, inverse_grid_frames, mode=mode, padding_mode="zeros", align_corners=self.align_corners, ).to(img_frames.dtype) return out # has shape (n_frames, c, h, w)
def _warn_field(self): # using float64 is preferred for accuracy if self.field.dtype == torch.float32: if self.__class__ == DeformationField: msg = "Consider using float64 when storing the deformation field for greater accuracy." if self.__class__ == AffineDeformationField: msg = "Consider using float64 when defining the output type of the affine transformation matrix :attr:`func` for greater accuracy." warnings.warn(msg, UserWarning) # if the field goes bayond +/-2, warn the user if self.warn_range and (self.field.abs() > 2).any(): msg = "The deformation field goes beyond the range [-2;2], everything mapped outside [-1;1] will not be visible. Suppress this warning by setting self.warn_range = False." warnings.warn(msg, UserWarning) def _attributeslist(self): a = [ ("field shape", self.field.shape), ("field shape", self.field.shape), ("n_frames", self.n_frames), ("img_shape", self.img_shape), ] return a def __repr__(self): s_begin = f"{self.__class__.__name__}(\n " s_fill = "\n ".join([f"({k}): {v}" for k, v in self._attributeslist()]) s_end = "\n )" return s_begin + s_fill + s_end def __eq__(self, other) -> bool: if isinstance(other, DeformationField): return bool((self.field == other.field).all()) return False def __hash__(self) -> int: return hash(self.field)
# =============================================================================
[docs] class AffineDeformationField(DeformationField): # ========================================================================= r""" Stores an affine deformation field as a 3x3 matrix. Warps a batch of images according to an *inverse affine deformation field* :math:`u`, i.e. the field that maps the *deformed image* pixel coordinates to the *original image* pixel coordinates. It is constructed from a function of one parameter (time) that returns a tensor of shape :math:`(3,3)` representing a 2D affine homogeneous transformation matrix. The homogeneous transformation matrix corresponds to the *inverse deformation field* :math:`u`, i.e. the field that maps the pixels of the *deformed image* to the pixels of the *original image*. To warp an image, the affine transformation matrix is evaluated at each time corresponding to the frames of the animation. The *inverse deformation field* :math:`u` is then computed from the affine transformation matrix, and the image is warped according to the *inverse deformation field* :math:`u`. The image size is requested upon construction, but the warping can be done with images of different sizes. The grid is simply interpolated to match the image size. It is also possible to change the image size after construction by setting the attribute :attr:`img_shape`, or the attributes :attr:`img_h` and :attr:`img_w`. .. important:: The coordinates are given in the range [-1;1]. When referring to a pixel, its position is the position of its center. The position [-1;-1] corresponds to the center of the top-left pixel. Args: :attr:`func` (Callable: float -> torch.tensor): Function of one parameter (time) that returns a tensor of shape :math:`(3,3)` representing a 2D affine homogeneous transformation matrix, the *inverse deformation field* :math:`u`, i.e. the field that maps the pixels of the *deformed image* to the pixels of the *original image*. Attributes: :attr:`self.func` (function of one parameter): Function of one parameter (time) that returns a tensor of shape :math:`(3,3)` representing a 2D affine homogeneous transformation matrix. :attr:`self.field` (torch.tensor): *Inverse deformation field* :math:`u` of shape :math:`(n\_frames,h,w,2)`. :attr:`time_vector` (torch.tensor): List of the times at which the function is evaluated to generate the inverse deformation field. :attr:`self.n_frames` (int): Number of frames in the animation. :attr:`self.img_shape` (tuple): Shape of the image to be warped, i.e. :math:`(h,w)`, where :math:`h` and :math:`w` are the height and width of the image respectively. This attribute can be set to change the image size. :attr:`img_h` (int): Height of the image to be warped in pixels. This attribute can be set to change the image size. :attr:`img_w` (int): Width of the image to be warped in pixels. This attribute can be set to change the image size. Example 1: Progressive zooming **in** >>> def u(t): ... return torch.tensor([[1-t/10, 0, 0], [0, 1-t/10, 0], [0, 0, 1]]) >>> field = AffineDeformationField(u) Example 2: Rotation of an image **counter-clockwise**, at a frequency of 1Hz >>> import numpy as np >>> def s(t): ... return np.sin(2*np.pi*t) >>> def c(t): ... return np.cos(2*np.pi*t) >>> def u(t): ... return torch.tensor([[c(t), s(t), 0], [-s(t), c(t), 0], [0, 0, 1]]) >>> field = AffineDeformationField(u) """ def __init__( self, func, time_vector: torch.tensor, img_shape: tuple, ) -> None: self._func = func self._time_vector = time_vector self._img_shape = img_shape super().__init__(None) # override the field property to generate it on the fly @property def field(self) -> torch.tensor: return self._generate_inv_grid_frames(self.img_shape).to(self.device) # override inherited properties @property def n_frames(self) -> int: return len(self.time_vector) # add a setter for the img_shape @property def img_shape(self) -> tuple: return self._img_shape @img_shape.setter def img_shape(self, img_shape: tuple): self._img_shape = img_shape @property def img_h(self) -> int: return self.img_shape[0] @img_h.setter def img_h(self, img_h: int): self._img_shape = (img_h, self.img_shape[1]) @property def img_w(self) -> int: return self.img_shape[1] @img_w.setter def img_w(self, img_w: int): self._img_shape = (self.img_shape[0], img_w) # new property @property def func(self): return self._func # new property @property def time_vector(self): return self._time_vector def _generate_inv_grid_frames( self, grid_shape: tuple, ) -> torch.tensor: r"""Generates the inverse deformation field as a tensor of shape :math:`(n\_frames, h, w, 2)`. This function is called by the constructor to generate the inverse deformation field from the affine transformation matrix at the desired time points. It is not meant to be called directly. Args: t0 (float): First time at which the inverse deformation field is computed. t1 (float): Last time at which the inverse deformation field is computed. n_frames (int): Number of frames in the animation. grid_shape (tuple): shape of the 2D grid to be generated. Must be a tuple of the form (h, w), where h and w are respectively the height and width of the image to be warped. Returns: torch.tensor: The inverse deformation field as a tensor of shape :math:`(n\_frames, h, w, 2)`. """ # get a batch of matrices of shape (n_frames, 2, 3) inv_mat_frames = torch.stack( [ self.func(t.item())[:2, :] # need only the first 2 rows for t in self.time_vector ] ) # use them to generate the grid inv_grid_frames = nn.functional.affine_grid( inv_mat_frames, torch.Size( (len(self.time_vector), 1, *grid_shape) ), # n_channels has no effect align_corners=self.align_corners, ) return inv_grid_frames.to(self.device)
# =============================================================================
[docs] class ElasticDeformation(DeformationField): """ """ def __init__(self, alpha, sigma, img_shape, n_frames, n_interpolation): """_summary_ Args: alpha (float): Magnitude of displacements sigma (float): Smoothness of displacements in the spatial domain n_interpolation (int): Number of frames in the output animation between two consecutive input frames. 1 results in no interpolation sigma_time (float): Smoothness of displacements in the frequency domain """ super().__init__(None) # self.sigma_time = sigma_time self.alpha = alpha * 10 self._img_shape = img_shape self._n_frames = n_frames self.n_interpolation = n_interpolation self.ElasticTransform = v2.ElasticTransform(self.alpha, sigma) self._field = nn.Parameter( self._generate_inv_grid_frames(), requires_grad=False ) @property def field(self): return self._field.data @field.setter def field(self, field): self._field = nn.Parameter(field, requires_grad=False) self.n_frames = field.shape[0] @property def img_shape(self): return self._img_shape @img_shape.setter def img_shape(self, img_shape): self._img_shape = img_shape @property def n_frames(self): return self._n_frames @n_frames.setter def n_frames(self, n_frames): self._n_frames = n_frames def _generate_inv_grid_frames(self): """ """ # create base frame between -1 and 1 base_frame_i = torch.linspace(-1, 1, self.img_shape[0]) base_frame_j = torch.linspace(-1, 1, self.img_shape[1]) # shape (h, w, 2) base_frame = torch.stack( torch.meshgrid(base_frame_i, base_frame_j, indexing="ij"), dim=-1 ) window_width = self.n_interpolation * 3 elastic_frames_to_generate = 1 + int( math.ceil(self.n_frames / self.n_interpolation) ) total_frames_after_conv = ( 1 + (elastic_frames_to_generate - 1) * self.n_interpolation ) # account for the window width total_frames_to_generate = total_frames_after_conv + window_width - 1 grid = base_frame.repeat(total_frames_to_generate, 1, 1, 1) for i in range(total_frames_to_generate // self.n_interpolation): # generate a random field grid[i * self.n_interpolation] += self.ElasticTransform._get_params( torch.empty([1, *self.img_shape]) )["displacement"][0, :, :, :] # Define Gaussian convolution operator Conv = nn.Conv1d(1, 1, window_width, bias=False, padding=0) gaussian_window = torch.signal.windows.gaussian( window_width, std=window_width / 4 ) # , std=self.sigma_time) gaussian_window /= gaussian_window.sum() Conv.weight = nn.Parameter(gaussian_window.view(1, 1, -1), requires_grad=False) # reshape, convolute, reshape back grid = grid.permute(1, 2, 3, 0) # put time in the last dimension grid = grid.reshape(-1, 1, total_frames_to_generate) # (h*w*2, 1, n_frames) grid = Conv(grid) grid = grid.reshape(*self.img_shape, 2, total_frames_after_conv) grid = grid.permute(3, 0, 1, 2) # (n_frames, h, w, 2) # truncate to the desired number of frames return grid[: self.n_frames, ...]