spyrit.core.warp.ElasticDeformation
- class spyrit.core.warp.ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation, dtype=torch.float32)[source]
Bases:
DeformationFieldDefines and stores a moving elastic deformation producing a flag-like effect.
This class inherits from the random generation of TorchVision’s
torchvision.transforms.v2.ElasticTransform. It will generate several frames of static elastic deformation using the torchvision class, and then smooth out these frames in the time domain to create a continuous animation. The deformation field is generated at instantiation and stored as a class attribute.The spatial magnitude of the displacements is controlled by the parameter
alpha, the spatial smoothness of the displacements is controlled by the parametersigma, and the time-domain smoothness is controlled by the parametersigma_time.Note
The spatial smoothing and time-domain smoothing are done after the displacements of magnitude
alphaare generated. This means that the actual spatial displacement magnitude might be significantly lower than then one specified byalpha.Note
The parameters
alpha,sigma, andn_interpolationare defined at initialization and cannot be changed after instantiation.- Args:
alpha (float): Magnitude of displacements. This argument is passed to the constructor of
torchvision.transforms.v2.ElasticTransform.sigma (float): Smoothness of displacements in the spatial domain. This argument is passed to the constructor of
torchvision.transforms.v2.ElasticTransform.img_shape (tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.
n_frames (int): Number of frames in the animation.
n_interpolation (int): Period in frames of the time-domain interpolation. Every
n_interpolationframes, a 2D elastic transform is randomly generated. Between these frames, the deformation field is equal to the identity. A truncated gaussian smoothing of length equal to 3 timesn_interpolation(to capture a real-looking movement between 3 points in 2D space) and with a standard deviation of \(\frac{3}{4}\)n_interpolationis applied to the deformation field.dtype (torch.dtype): Data type of the tensors. Default is torch.float32.
- Attributes:
field(torch.tensor): The deformation field as a tensor of shape \((n\_frames,h,w,2)\).img_shape(tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.n_frames(int): Number of frames in the animation.alpha(float): Magnitude of displacements.sigma(float): Smoothness of displacements in the spatial domain.n_interpolation(int): Period in frames of the time-domain interpolation.ElasticTransform(torchvision.transforms.v2.ElasticTransform): The random generator of static elastic deformation, with parametersalphaandsigma.
Methods
det()forward(img[, n0, n1, mode])Warps a batch of 2D images with the stored inverse deformation field \(u\).
grid_sample(img_frames, inverse_grid_frames, ...)Used to warp frames of 2D images with a deformation field.