spyrit.core.warp.ElasticDeformation

class spyrit.core.warp.ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation, dtype=torch.float32, device=device(type='cpu'))[source]

Bases: DeformationField

Generates and stores a random elastic deformation where each pixel is sampled from an uniform distribution and then smoothed in space and time.

This class inherits from the random generation of TorchVision’s torchvision.transforms.v2.ElasticTransform. It will generate several frames of static elastic deformation using the torchvision class, and then smooth out these frames in the time domain to create a continuous motion. The deformation field is generated at instanciation and stored as a class attribute.

The spatial magnitude of the displacements is controlled by the parameter alpha, the spatial smoothness of the displacements is controlled by the parameter sigma, and the time-domain smoothness is controlled by the parameter n_interpolation.

Note

The spatial and temporal smoothing are done after the displacements of magnitude alpha are generated. This means that the actual spatial displacement magnitude can be significantly lower than then one specified by alpha. To get the actual standard deviation of the deformation field, call the compute_field_std() method.

Note

The parameters alpha, sigma, and n_interpolation are defined at initialization and cannot be changed after instanciation.

Args:

alpha (float): Magnitude of displacements. This argument is passed to the constructor of torchvision.transforms.v2.ElasticTransform.

sigma (float): Smoothness of displacements in the spatial domain. This argument is passed to the constructor of torchvision.transforms.v2.ElasticTransform.

img_shape (tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.

n_frames (int): Number of frames in the video.

n_interpolation (int): Period in frames of the time-domain interpolation. Every n_interpolation frames, a 2D elastic transform is randomly generated. Between these frames, the deformation field is equal to the identity. A truncated gaussian smoothing of length equal to 3 times n_interpolation (to capture a real-looking movement between 3 points in 2D space) and with a standard deviation of \(\frac{3}{4}\) n_interpolation is applied to the deformation field.

dtype (torch.dtype): Data type of the tensors. Default is torch.float32.

device (torch.device): Device on which the tensors are stored. Default is CPU.

Attributes:

field (torch.tensor): The deformation field as a tensor of shape \((n\_frames,h,w,2)\).

img_shape (tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.

n_frames (int): Number of frames in the animation.

alpha (float): Magnitude of displacements.

sigma (float): Smoothness of displacements in the spatial domain.

n_interpolation (int): Period in frames of the time-domain interpolation.

ElasticTransform (torchvision.transforms.v2.ElasticTransform): The random generator of static elastic deformation, with parameters alpha and sigma.

Example:
>>> import torch
>>> from spyrit.core.warp import ElasticDeformation
>>>
>>> alpha = 100
>>> sigma = 5
>>> img_shape = (64, 64)
>>> n_frames = 30
>>> n_interpolation = 5
>>> def_field = ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation)
>>> print(def_field.field.shape)
torch.Size([30, 64, 64, 2])

Methods

compute_field_std()

Computes the theoretical standard deviation (in pixels) of the deformation field.

forward(img[, n0, n1, mode])

Generates a video from a batch of 2D images according to the deformation field \(u\).