Source code for spyrit.core.warp
"""
This module contains classes that are used to warp images according to
a deformation field. Let :math:`t_0 \in \mathbb{R}_+`,
:math:`f(t_0, x, y): \mathbb{R}^2 \mapsto \mathbb{R}^2` be a reference scene
and :math:`u(t, x, y): \mathbb{R}^3 \mapsto \mathbb{R}^2` be a deformation field.
These classes compute the moving scene:
.. math::
f(t, x, y) = f(t_0, u(t, x, y))
.. note::
These classes rely on backward mapping to perform the warping.
.. note::
These classes stores the deformation field :math:`u`.
"""
import warnings
import math
import torch
import torch.nn as nn
from torchvision.transforms import v2
# =============================================================================
[docs]
class DeformationField(nn.Module):
# =========================================================================
r"""Stores a discrete deformation field :math:`u` of shape :math:`(n\_frames,h,w,2)`.
The deformations has :math:`n\_frames` is the number of frames, and its height and width
are denoted by :math:`h` and :math:`w`. The last dimension contains the x
and y coordinates of the deformation field w.r.t the reference time :math:`t_0`.
.. math::
f(t, x, y) = f(t_0, u(t, x, y))
where :math:`f(t_0, x, y)` is the reference image and :math:`u(t, x, y)` is the *deformation field*.
Forward call generates a video warping the input image according to the deformation field :math:`u`.
.. important::
The coordinates are given in the range [-1, 1]. When referring to a
pixel, its position is the position of its center. The position
(-1, -1) corresponds to the center of the top-left pixel.
Args:
:attr:`field` (torch.tensor): Deformation field :math:`u` of shape :math:`(n\_frames,h,w,2)`,
where :math:`n\_frames` is the number of deformation frames, :math:`h` and :math:`w`
are the height and width of the deformation field.
For accuracy reasons, it is recommended the dtype to be `torch.float64`.
Attributes:
:attr:`self.field` (torch.tensor): Deformation field :math:`u` of shape :math:`(n\_frames,h,w,2)`.
:attr:`self.n_frames` (int): Number of frames in the video.
:attr:`self.img_shape` (tuple): Shape of the image to be warped, i.e.
:math:`(h,w)`, where :math:`h` and :math:`w` are the height and width
of the image respectively.
:attr:`img_h` (int): Height of the image to be warped in pixels.
:attr:`img_w` (int): Width of the image to be warped in pixels.
:attr:`self.align_corners` (bool): Always True. This argument is passed
to the functions :func:`torch.nn.functional.grid_sample` and
:func:`torch.nn.functional.affine_grid` to ensure the corners of the
image are aligned with the corners of the grid.
Example:
Storing a 90 degrees counter-clockwise rotation for 2x2 image.
>>> import torch
>>> from spyrit.core.warp import DeformationField
>>>
>>> u = torch.tensor([[[[1, -1], [1, 1]], [[-1, -1], [-1, 1]]]])
>>> field = DeformationField(u)
>>> print(field.field)
tensor([[[[ 1, -1],
[ 1, 1]],
<BLANKLINE>
[[-1, -1],
[-1, 1]]]])
>>> print(field.field.shape)
torch.Size([1, 2, 2, 2])
"""
def __init__(self, field: torch.tensor):
super().__init__()
self._align_corners = True
self._device_tracker = nn.Parameter(torch.tensor([0.0]), requires_grad=False)
# field is None if AffineDeformationField is used
if field is not None:
# store as nn.Parameter and ensure proper device/dtype handling
self._field = nn.Parameter(field.detach().clone(), requires_grad=False)
# Move device tracker to same device as field
self._device_tracker = nn.Parameter(
torch.tensor([0.0], device=field.device, dtype=field.dtype),
requires_grad=False,
)
self.warn_range = False # warn the user if the field goes beyond +/-2
# self._warn_field()
@property
def align_corners(self) -> bool:
return self._align_corners
@property
def n_frames(self) -> int:
return self._field.shape[0]
@property
def img_shape(self) -> tuple:
return self._field.shape[1:3]
@property
def img_h(self) -> int:
return self._field.shape[1]
@property
def img_w(self) -> int:
return self._field.shape[2]
@property
def field(self) -> torch.tensor:
return self._field.data
@property
def device(self) -> torch.device:
return self._device_tracker.device
@property
def dtype(self) -> torch.dtype:
"""Get the dtype of the deformation field."""
return (
self._field.dtype if hasattr(self, "_field") else self._device_tracker.dtype
)
[docs]
def forward(
self,
img: torch.tensor,
n0: int = 0,
n1: int = None,
mode: str = "bilinear",
) -> torch.tensor:
r"""
Generates a video from a batch of 2D images according to the
*deformation field* :math:`u`.
The deformation is taken between the frames :math:`n0` (included)
and :math:`n1` (excluded).
Args:
:attr:`img` (torch.tensor): Batch of 2D images to deform of shape :math:`(c, h, w)`
or :math:`(b, c, h, w)`, where:math:`b` is the number of images in the batch,
:math:`c` is the number of channels, and :math:`h` and :math:`w` are the height
and width of the images.
:attr:`n0` (int, optional): The index of the first frame to use in the *deformation
field*. Defaults to 0.
:attr:`n1` (int, optional): The index of the first frame to exclude in the *deformation
field*. If None, the last available frame is used. Defaults to None.
:attr:`mode` (str, optional): The interpolation mode to use. It must be one of the following:
'nearest', 'bilinear', 'bicubic', 'biquintic'. The `nearest`,
`bilinear`, and `bicubic` modes are directly supported by the
function :func:`torch.nn.functional.grid_sample`. The `biquintic` mode
relies on scikit-image. Defaults to 'bilinear'.
.. note::
If using mode='bicubic' or mode='biquintic', the warped image may
contain values outside the original range.
.. note::
If :math:`n0 < n1`, :attr:`field` is sliced as follows: ``field[n0:n1, :, :, :]``
.. note::
If :math:`n0 > n1`, :attr:`field` is sliced"backwards".
The first frame of the warped animation corresponds to
the index :math:`n0`, and the last frame corresponds to the index
:math:`n1+1`. This behavior is identical to slicing a list with a
step of -1.
.. note::
If the number of pixels is different in the image and the field,
the field is interpolated to match the image size (see the
behavior of :func:`torch.nn.functional.grid_sample`).
.. note::
If the input image three dimensional, a batch dimension is added in the first dimension.
Returns:
:attr:`output` (torch.tensor):
The deformed batch of 2D images of shape :math:`(|n1-n0|, c, h, w)`
or :math:`(b, |n1-n0|, c, h, w)` depending on the input shape, where each
image in the batch is deformed according to the *deformation
field* :math:`u`.
Example:
Rotating a 2x2 grayscale image by 90 degrees counter-clockwise, using one frame:
>>> import torch
>>> from spyrit.core.warp import DeformationField
>>>
>>> u = torch.tensor([[[[ 1., -1.], [ 1., 1.]], [[-1., -1.], [-1., 1.]]]])
>>> field = DeformationField(u)
>>> image = torch.tensor([0., 0.3, 0.7, 1.]).view(1, 1, 2, 2)
>>> print(image)
tensor([[[[0.0000, 0.3000],
[0.7000, 1.0000]]]])
>>> deformed_image = field(image, 0, 1)
>>> print(deformed_image)
tensor([[[[[0.3000, 1.0000],
[0.0000, 0.7000]]]]])
"""
if img.ndim == 3:
img = img.unsqueeze(0)
no_batch = True
else:
no_batch = False
# check that the input is shaped (b, c, h, w)
b, c, h, w = img.shape
if n1 is None:
n1 = self.n_frames
# get the right slice of the deformation field
n_frames = abs(n1 - n0)
if n1 < n0:
sel_grid_frames = torch.flip(self.field[n1 + 1 : n0 + 1, :, :, :], [0])
else:
sel_grid_frames = self.field[n0:n1, :, :, :]
# Ensure the grid is on the same device as the input image
sel_grid_frames = sel_grid_frames.to(device=img.device)
# img has current shape (b, c, h, w), make it (n_frames, b*c, h, w)
# because grid_sample will create the frames in the batch dimension
img_frames = img.reshape(1, b * c, h, w).expand(n_frames, -1, -1, -1)
# Ensure dtype compatibility for grid_sample
warped_frames = self._grid_sample(img_frames, sel_grid_frames, mode)
# has shape (n_frames, b*c, h, w), make it (b, n_frames, c, h, w)
warped_frames = warped_frames.reshape(n_frames, b, c, h, w).moveaxis(0, 1)
if no_batch:
return warped_frames.squeeze(0)
return warped_frames
def _grid_sample(self, img_frames, grid_frames, mode):
"""Warp frames of 2D images with a deformation field.
Each image of the collection will get a different deformation.
This function matches the behavior of nn.functional.grid_sample.
Inputs:
:attr:`img_frames` (torch.tensor): batch of 2D images of shape
`(n_frames, c, h, w)`, where `n_frames` is the number of frames in
the animation, `c` is the number of channels, and `h` and `w` are
the height and width of the image respectively.
:attr:`grid_frames` (torch.tensor): batch of deformation fields of
shape `(n_frames, h, w, 2)`, indicating the pixel coordinates of the
original image that are displayed in the warped image.
:attr:`mode` (str): The interpolation mode to use. It must be one of
the following: 'nearest', 'bilinear', 'bicubic', 'biquintic'. If either
`nearest`, `bilinear`, or `bicubic`, it is directly passed to the
function :func:`torch.nn.functional.grid_sample`. If `biquintic`,
it is passed to the package scikit-image, which requires skimage and numpy.
Returns:
:attr:`out` (torch.tensor): The deformed batch of 2D images of shape
`(n_frames, c, h, w)`. Each image in the batch is deformed according
to the *deformation field* :math:`u` contained in the attribute
:attr:`field`.
"""
if mode == "biquintic":
import skimage
import numpy as np
n_frames, c, h, w = img_frames.shape
out = np.empty((n_frames, c, h, w))
# use scikit-image's order 5 warp. This implies:
# putting the origin pixel coordinate (x,y) at dimension 0, not 3
# using numpy instead of pytorch
grid_frames = grid_frames.moveaxis(-1, 0).cpu().numpy()
# changing from 'xy' notation to 'ij'
grid_frames = grid_frames[::-1, :, :, :]
# rescaling from [-1, 1] to [0, height-1] (same for width)
grid_frames = (
(grid_frames + 1)
/ 2
* np.array([self.img_h, self.img_w]).reshape(2, 1, 1, 1)
)
# use 2 for loops, faster than 5D warp (because 5D interpolation)
for frame in range(n_frames):
grid = grid_frames[:, frame, :, :]
for channel in range(c):
out[frame, channel, :, :] = skimage.transform.warp(
img_frames[frame, channel, :, :].cpu().numpy(),
grid,
order=5,
clip=False,
)
return torch.from_numpy(out).to(
device=img_frames.device, dtype=img_frames.dtype
)
else:
# Ensure both tensors are on the same device and compatible dtypes
if img_frames.device != grid_frames.device:
grid_frames = grid_frames.to(device=img_frames.device)
# For grid_sample, we need to ensure the grid is float32 or float64
if grid_frames.dtype not in [torch.float32, torch.float64]:
grid_frames = grid_frames.float()
out = nn.functional.grid_sample(
img_frames,
grid_frames,
mode=mode,
padding_mode="zeros",
align_corners=self.align_corners,
)
return out # has shape (n_frames, c, h, w)
def _warn_field(self):
# using float64 is preferred for accuracy
if self.field.dtype == torch.float32:
if self.__class__ == DeformationField:
msg = "Consider using float64 when storing the deformation field for greater accuracy."
if self.__class__ == AffineDeformationField:
msg = "Consider using float64 when defining the output type of the affine transformation matrix :attr:`func` for greater accuracy."
warnings.warn(msg, UserWarning)
# if the field goes bayond +/-2, warn the user
if self.warn_range and (self.field.abs() > 2).any():
msg = "The deformation field goes beyond the range [-2;2], everything mapped outside [-1;1] will not be visible. Suppress this warning by setting self.warn_range = False."
warnings.warn(msg, UserWarning)
def _attributeslist(self):
a = [
("field shape", self.field.shape),
("n_frames", self.n_frames),
("img_shape", self.img_shape),
]
return a
def __repr__(self):
s_begin = f"{self.__class__.__name__}(\n "
s_fill = "\n ".join([f"({k}): {v}" for k, v in self._attributeslist()])
s_end = "\n )"
return s_begin + s_fill + s_end
def __eq__(self, other) -> bool:
if isinstance(other, DeformationField):
return bool((self.field == other.field).all())
return False
def __hash__(self) -> int:
return hash(self.field)
# =============================================================================
[docs]
class AffineDeformationField(DeformationField):
# =========================================================================
r"""Stores and applies affine deformation fields defined by transformation matrices.
This class generates video sequences by warping images according to time-varying affine
transformations. It constructs a discrete *deformation field* :math:`u` from
a user-defined function that returns 3x3 affine transformation matrices at different time points.
The forward call generates a video warping the input image according to the deformation field :math:`u = v^{-1}`.
.. math::
f(t, x, y) = f(t_0, u(t, x, y))
where :math:`f(t_0, x, y)` is the reference image and :math:`u(t, x, y)` is the *deformation field*.
.. important::
The coordinates are given in the range [-1, 1]. When referring to a
pixel, its position is the position of its center. The position
(-1, -1) corresponds to the center of the top-left pixel.
.. note::
The image size is requested upon construction, but the warping can be done
with images of different sizes. The grid is simply interpolated to match
the image size. It is also possible to change the image size after
construction by setting the attribute :attr:`img_shape`, or the attributes
:attr:`img_h` and :attr:`img_w`.
Args:
:attr:`func` (Callable[[float], torch.tensor]): Function of one parameter (time) that returns a tensor
of shape :math:`(3,3)` representing an affine homogeneous transformation matrix. This matrix corresponds
to the *deformation field* :math:`u`.
:attr:`time_vector` (torch.tensor): Vector of time points at which the transformation function is evaluated
to generate the deformation field. Shape :math:`(n\_frames,)`.
:attr:`img_shape` (tuple): Shape of the image to be warped, i.e. :math:`(h,w)`, where :math:`h` and :math:`w`
are the height and width of the image respectively.
:attr:`dtype` (torch.dtype, optional): Data type of the deformation field tensor.
For accuracy reasons, it is recommended to use `torch.float64`.
Defaults to `torch.float32`.
:attr:`device` (torch.device, optional): Device on which the deformation field tensor is stored.
Defaults to `torch.device('cpu')`.
Attributes:
:attr:`self.func` (Callable[[float], torch.tensor]): Function of one parameter (time) that returns a tensor of shape
:math:`(3,3)` representing an affine homogeneous transformation matrix.
:attr:`self.field` (torch.tensor):*Deformation field* :math:`u` of shape :math:`(n\_frames,h,w,2)`.
:attr:`self.time_vector` (torch.tensor): Vector of time points at which the function is evaluated to generate
the deformation field.
:attr:`self.n_frames` (int): Number of frames in the video.
:attr:`self.img_shape` (tuple): Shape of the image to be warped, i.e. :math:`(h,w)`, where :math:`h` and :math:`w`
are the height and width of the image respectively.
:attr:`self.img_h` (int): Height of the image to be warped in pixels.
:attr:`self.img_w` (int): Width of the image to be warped in pixels.
:attr:`self.align_corners` (bool): Always True. This argument is passed to the functions
:func:`torch.nn.functional.grid_sample` and :func:`torch.nn.functional.affine_grid`
to ensure the corners of the image are aligned with the corners of the grid.
Example 1: Progressive scaling
>>> import torch
>>> from spyrit.core.warp import AffineDeformationField
>>>
>>> def scaling(t):
... scale = 1 - t/10
... return torch.tensor([[scale, 0, 0], [0, scale, 0], [0, 0, 1]])
>>>
>>> time_vector = torch.linspace(0, 1, 10)
>>> def_field = AffineDeformationField(scaling, time_vector, (64, 64))
>>> print(def_field.n_frames)
10
Example 2: Rotation counter-clockwise at 1Hz frequency
>>> import torch
>>> import math
>>> from spyrit.core.warp import AffineDeformationField
>>>
>>> def rotation(t):
... angle = 2 * math.pi * t # 1Hz rotation
... c, s = math.cos(angle), math.sin(angle)
... return torch.tensor([[c, s, 0], [-s, c, 0], [0, 0, 1]], dtype=torch.float64)
>>>
>>> time_vector = torch.linspace(0, 1, 30) # 30 frames for 1 second
>>> def_field = AffineDeformationField(rotation, time_vector, (128, 128))
>>> print(def_field.field.shape)
torch.Size([30, 128, 128, 2])
"""
def __init__(
self,
func,
time_vector: torch.tensor,
img_shape: tuple,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> None:
self._align_corners = True
self.func = func
self.time_vector = time_vector
field = self._generate_grid_frames(
img_shape, time_vector, func, dtype=dtype, device=device
)
super().__init__(field)
def _generate_grid_frames(
self,
grid_shape: tuple,
time_vector: torch.tensor,
func,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> torch.tensor:
r"""Generates the deformation field
This function is called by the constructor to generate the deformation field as
a tensor of shape:math:`(n\_frames, h, w, 2)` from the affine transformation
matrix at the desired time points. It is not meant to be called directly.
Args:
t0 (float): First time at which the deformation field is
computed.
t1 (float): Last time at which the deformation field is
computed.
n_frames (int): Number of frames in the video.
grid_shape (tuple): shape of the 2D grid to be generated. Must be a
tuple of the form (h, w), where h and w are respectively the height
and width of the image to be warped.
Returns:
torch.tensor: The deformation field :math:`u` as a tensor of shape
:math:`(n\_frames, h, w, 2)`.
"""
# get a batch of matrices of shape (n_frames, 2, 3)
mat_frames = torch.stack(
[func(t.item())[:2, :] for t in time_vector] # need only the first 2 rows
)
# Ensure matrices are on the correct device and dtype
mat_frames = mat_frames.to(dtype=dtype, device=device)
# use them to generate the grid
grid_frames = nn.functional.affine_grid(
mat_frames,
torch.Size((len(time_vector), 1, *grid_shape)), # n_channels has no effect
align_corners=self._align_corners,
)
return grid_frames
[docs]
def forward(
self,
img: torch.tensor,
n0: int = 0,
n1: int = None,
mode: str = "bilinear",
) -> torch.tensor:
r"""Generates a video from a batch of 2D images according to the *deformation field* :math:`u`.
The deformation is taken between the frames :math:`n0` (included)
and :math:`n1` (excluded).
Args:
:attr:`img` (torch.tensor):
Batch of 2D images to deform of shape :math:`(c, h, w)` or :math:`(b, c, h, w)`, where
:math:`b` is the number of images in the batch, :math:`c` is the
number of channels, and :math:`h` and :math:`w`
are the height and width of the images.
:attr:`n0` (int, optional):
The index of the first frame to use in the *deformation
field*. Defaults to 0.
:attr:`n1` (int, optional):
The index of the first frame to exclude in the *deformation
field*. If None, the last available frame is used. Defaults to None.
:attr:`mode` (str, optional):
The interpolation mode to use. It must be one of the following:
'nearest', 'bilinear', 'bicubic', 'biquintic'. The `nearest`,
`bilinear`, and `bicubic` modes are directly supported by the
function :func:`torch.nn.functional.grid_sample`. The `biquintic` mode
relies on scikit-image. Defaults to 'bilinear'.
.. note::
If using mode='bicubic' or mode='biquintic', the warped image may
contain values outside the original range.
.. note::
If :math:`n0 < n1`, :attr:`field` is sliced
as follows: ``field[n0:n1, :, :, :]``
.. note::
If :math:`n0 > n1`, :attr:`field` is sliced
"backwards". The first frame of the warped animation corresponds to
the index :math:`n0`, and the last frame corresponds to the index
:math:`n1+1`. This behavior is identical to slicing a list with a
step of -1.
.. note::
If the number of pixels is different in the image and the field,
the field is interpolated to match the image size (see the
behavior of :func:`torch.nn.functional.grid_sample`).
Returns:
:attr:`output` (torch.tensor):
The deformed batch of 2D images of shape :math:`(|n1-n0|, c, h, w)`
or :math:`(b, |n1-n0|, c, h, w)` depending on the input shape, where each
image in the batch is deformed according to the *deformation
field* :math:`u`.
**Example 1:** Progressive scaling
>>> import torch
>>> from spyrit.core.warp import AffineDeformationField
>>>
>>> def scaling(t):
... scale = 1 - t/10
... return torch.tensor([[scale, 0, 0], [0, scale, 0], [0, 0, 1]])
>>>
>>> time_vector = torch.linspace(0, 1, 10)
>>> def_field = AffineDeformationField(scaling, time_vector, (64, 64))
>>> images = torch.randn(16, 1, 64, 64) # Batch of 16 grayscale image
>>> scaled_video = def_field(images)
>>> print(scaled_video.shape)
torch.Size([16, 10, 1, 64, 64])
**Example 2:** Rotation counter-clockwise at 1Hz frequency
>>> import torch
>>> import math
>>> from spyrit.core.warp import AffineDeformationField
>>>
>>> def rotation(t):
... angle = 2 * math.pi * t # 1Hz rotation
... c, s = math.cos(angle), math.sin(angle)
... return torch.tensor([[c, s, 0], [-s, c, 0], [0, 0, 1]], dtype=torch.float64)
>>>
>>> time_vector = torch.linspace(0, 1, 30) # 30 frames for 1 second
>>> def_field = AffineDeformationField(rotation, time_vector, (128, 128), dtype=torch.float64)
>>> image = torch.randn(3, 128, 128).to(dtype=torch.float64) # a single RGB image
>>> rotated_video = def_field(image)
>>> print(rotated_video.shape)
torch.Size([30, 3, 128, 128])
"""
return super().forward(img, n0, n1, mode)
# =============================================================================
[docs]
class ElasticDeformation(DeformationField):
r"""Generates and stores a random elastic deformation where each pixel is
sampled from an uniform distribution and then smoothed in space and time.
This class inherits from the random generation of TorchVision's
:class:`torchvision.transforms.v2.ElasticTransform`. It will generate several
frames of static elastic deformation using the torchvision class, and then
smooth out these frames in the time domain to create a continuous motion.
The deformation field is generated at instanciation and stored as a class
attribute.
The spatial magnitude of the displacements is controlled by the parameter
:attr:`alpha`, the spatial smoothness of the displacements is controlled by
the parameter :attr:`sigma`, and the time-domain smoothness is controlled by
the parameter :attr:`n_interpolation`.
.. note::
The spatial and temporal smoothing are done **after** the
displacements of magnitude :attr:`alpha` are generated. This means that
the actual spatial displacement magnitude can be significantly lower
than then one specified by :attr:`alpha`. To get the actual standard
deviation of the deformation field, call the :meth:`compute_field_std` method.
.. note::
The parameters :attr:`alpha`, :attr:`sigma`, and :attr:`n_interpolation`
are defined at initialization and cannot be changed after instanciation.
Args:
alpha (float): Magnitude of displacements. This argument is passed to
the constructor of :class:`torchvision.transforms.v2.ElasticTransform`.
sigma (float): Smoothness of displacements in the spatial domain. This
argument is passed to the constructor of
:class:`torchvision.transforms.v2.ElasticTransform`.
img_shape (tuple): Shape of the deformation field, i.e. :math:`(h,w)`,
where :math:`h` and :math:`w` are the height and width of the field
respectively.
n_frames (int): Number of frames in the video.
n_interpolation (int): Period in frames of the time-domain interpolation.
Every :attr:`n_interpolation` frames, a 2D elastic transform is randomly
generated. Between these frames, the deformation field is equal to the
identity. A truncated gaussian smoothing of length equal to 3 times
:attr:`n_interpolation` (to capture a real-looking movement between 3
points in 2D space) and with a standard deviation of :math:`\frac{3}{4}`
:attr:`n_interpolation` is applied to the deformation field.
dtype (torch.dtype): Data type of the tensors. Default is torch.float32.
device (torch.device): Device on which the tensors are stored. Default is CPU.
Attributes:
:attr:`field` (torch.tensor): The deformation field as a tensor of shape
:math:`(n\_frames,h,w,2)`.
:attr:`img_shape` (tuple): Shape of the deformation field, i.e. :math:`(h,w)`,
where :math:`h` and :math:`w` are the height and width of the field
respectively.
:attr:`n_frames` (int): Number of frames in the animation.
:attr:`alpha` (float): Magnitude of displacements.
:attr:`sigma` (float): Smoothness of displacements in the spatial domain.
:attr:`n_interpolation` (int): Period in frames of the time-domain interpolation.
:attr:`ElasticTransform` (torchvision.transforms.v2.ElasticTransform): The
random generator of static elastic deformation, with parameters :attr:`alpha`
and :attr:`sigma`.
Example:
>>> import torch
>>> from spyrit.core.warp import ElasticDeformation
>>>
>>> alpha = 100
>>> sigma = 5
>>> img_shape = (64, 64)
>>> n_frames = 30
>>> n_interpolation = 5
>>> def_field = ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation)
>>> print(def_field.field.shape)
torch.Size([30, 64, 64, 2])
"""
def __init__(
self,
alpha,
sigma,
img_shape,
n_frames,
n_interpolation,
dtype=torch.float32,
device=torch.device("cpu"),
):
field = self._generate_grid_frames(
img_shape, n_frames, n_interpolation, alpha, sigma, dtype, device
)
field = field.to(dtype=dtype, device=device)
super().__init__(field)
# Set additional attributes (after init)
self.alpha = alpha
self.sigma = sigma
self.n_interpolation = n_interpolation
self.ElasticTransform = v2.ElasticTransform(alpha, sigma)
[docs]
def compute_field_std(self):
r"""Computes the theoretical standard deviation (in pixels) of the deformation field."""
sigma_t = 3 * self.n_interpolation / 4
var_dz = 1 / 3
var_gdz = var_dz / (4 * math.pi * self.sigma**2)
std = self.alpha * (var_gdz / (2 * math.pi**0.5 * sigma_t)) ** 0.5
return std
def _generate_grid_frames(
self, img_shape, n_frames, n_interpolation, alpha, sigma, dtype, device
):
r"""Generates the frames of the elastic deformation field of shape :math:`(n_frames, h, w, 2)`."""
# create base frame between -1 and 1
base_frame_i = torch.linspace(-1, 1, img_shape[0], dtype=dtype, device=device)
base_frame_j = torch.linspace(-1, 1, img_shape[1], dtype=dtype, device=device)
# shape (h, w, 2)
base_frame = torch.stack(
torch.meshgrid(base_frame_i, base_frame_j, indexing="xy"), dim=-1
)
window_width = n_interpolation * 3
elastic_frames_to_generate = 1 + int(math.ceil(n_frames / n_interpolation))
total_frames_after_conv = 1 + (elastic_frames_to_generate - 1) * n_interpolation
# account for the window width
total_frames_to_generate = total_frames_after_conv + window_width - 1
grid = base_frame.repeat(total_frames_to_generate, 1, 1, 1)
for i in range(total_frames_to_generate // n_interpolation):
# generate a random field - create tensor on correct device
dummy_input = torch.empty([1, *img_shape], device=device, dtype=dtype)
# Note: ElasticTransform needs to be created with the correct parameters
elastic_transform = v2.ElasticTransform(alpha, sigma)
# displacement = elastic_transform._get_params(dummy_input)["displacement"][0, :, :, :] # old
displacement = elastic_transform.make_params([dummy_input])["displacement"][
0, :, :, :
]
grid[i * n_interpolation] += displacement.to(device=device, dtype=dtype)
# Define Gaussian convolution operator
Conv = nn.Conv1d(1, 1, window_width, bias=False, padding=0)
gaussian_window = torch.signal.windows.gaussian(
window_width, std=window_width / 4, dtype=dtype, device=device
) # , std=self.sigma_time)
gaussian_window /= gaussian_window.sum()
Conv.weight = nn.Parameter(gaussian_window.view(1, 1, -1), requires_grad=False)
# Move Conv to correct device
Conv = Conv.to(device=device)
# reshape, convolute, reshape back
grid = grid.permute(1, 2, 3, 0) # put time in the last dimension
grid = grid.reshape(-1, 1, total_frames_to_generate) # (h*w*2, 1, n_frames)
grid = Conv(grid)
grid = grid.reshape(*img_shape, 2, total_frames_after_conv)
grid = grid.permute(3, 0, 1, 2) # (n_frames, h, w, 2)
# truncate to the desired number of frames
return grid[:n_frames, ...]