spyrit.core.warp.ElasticDeformation

class spyrit.core.warp.ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation, dtype=torch.float32)[source]

Bases: DeformationField

Defines and stores a moving elastic deformation producing a flag-like effect.

This class inherits from the random generation of TorchVision’s torchvision.transforms.v2.ElasticTransform. It will generate several frames of static elastic deformation using the torchvision class, and then smooth out these frames in the time domain to create a continuous animation. The deformation field is generated at instantiation and stored as a class attribute.

The spatial magnitude of the displacements is controlled by the parameter alpha, the spatial smoothness of the displacements is controlled by the parameter sigma, and the time-domain smoothness is controlled by the parameter sigma_time.

Note

The spatial smoothing and time-domain smoothing are done after the displacements of magnitude alpha are generated. This means that the actual spatial displacement magnitude might be significantly lower than then one specified by alpha.

Note

The parameters alpha, sigma, and n_interpolation are defined at initialization and cannot be changed after instantiation.

Args:

alpha (float): Magnitude of displacements. This argument is passed to the constructor of torchvision.transforms.v2.ElasticTransform.

sigma (float): Smoothness of displacements in the spatial domain. This argument is passed to the constructor of torchvision.transforms.v2.ElasticTransform.

img_shape (tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.

n_frames (int): Number of frames in the animation.

n_interpolation (int): Period in frames of the time-domain interpolation. Every n_interpolation frames, a 2D elastic transform is randomly generated. Between these frames, the deformation field is equal to the identity. A truncated gaussian smoothing of length equal to 3 times n_interpolation (to capture a real-looking movement between 3 points in 2D space) and with a standard deviation of \(\frac{3}{4}\) n_interpolation is applied to the deformation field.

dtype (torch.dtype): Data type of the tensors. Default is torch.float32.

Attributes:

field (torch.tensor): The deformation field as a tensor of shape \((n\_frames,h,w,2)\).

img_shape (tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.

n_frames (int): Number of frames in the animation.

alpha (float): Magnitude of displacements.

sigma (float): Smoothness of displacements in the spatial domain.

n_interpolation (int): Period in frames of the time-domain interpolation.

ElasticTransform (torchvision.transforms.v2.ElasticTransform): The random generator of static elastic deformation, with parameters alpha and sigma.

Methods

det()

forward(img[, n0, n1, mode])

Warps a batch of 2D images with the stored inverse deformation field \(u\).

grid_sample(img_frames, inverse_grid_frames, ...)

Used to warp frames of 2D images with a deformation field.