spyrit.core.warp.ElasticDeformation
- class spyrit.core.warp.ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation, dtype=torch.float32, device=device(type='cpu'))[source]
Bases:
DeformationFieldGenerates and stores a random elastic deformation where each pixel is sampled from an uniform distribution and then smoothed in space and time.
This class inherits from the random generation of TorchVision’s
torchvision.transforms.v2.ElasticTransform. It will generate several frames of static elastic deformation using the torchvision class, and then smooth out these frames in the time domain to create a continuous motion. The deformation field is generated at instanciation and stored as a class attribute.The spatial magnitude of the displacements is controlled by the parameter
alpha, the spatial smoothness of the displacements is controlled by the parametersigma, and the time-domain smoothness is controlled by the parametern_interpolation.Note
The spatial and temporal smoothing are done after the displacements of magnitude
alphaare generated. This means that the actual spatial displacement magnitude can be significantly lower than then one specified byalpha. To get the actual standard deviation of the deformation field, call thecompute_field_std()method.Note
The parameters
alpha,sigma, andn_interpolationare defined at initialization and cannot be changed after instanciation.- Args:
alpha (float): Magnitude of displacements. This argument is passed to the constructor of
torchvision.transforms.v2.ElasticTransform.sigma (float): Smoothness of displacements in the spatial domain. This argument is passed to the constructor of
torchvision.transforms.v2.ElasticTransform.img_shape (tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.
n_frames (int): Number of frames in the video.
n_interpolation (int): Period in frames of the time-domain interpolation. Every
n_interpolationframes, a 2D elastic transform is randomly generated. Between these frames, the deformation field is equal to the identity. A truncated gaussian smoothing of length equal to 3 timesn_interpolation(to capture a real-looking movement between 3 points in 2D space) and with a standard deviation of \(\frac{3}{4}\)n_interpolationis applied to the deformation field.dtype (torch.dtype): Data type of the tensors. Default is torch.float32.
device (torch.device): Device on which the tensors are stored. Default is CPU.
- Attributes:
field(torch.tensor): The deformation field as a tensor of shape \((n\_frames,h,w,2)\).img_shape(tuple): Shape of the deformation field, i.e. \((h,w)\), where \(h\) and \(w\) are the height and width of the field respectively.n_frames(int): Number of frames in the animation.alpha(float): Magnitude of displacements.sigma(float): Smoothness of displacements in the spatial domain.n_interpolation(int): Period in frames of the time-domain interpolation.ElasticTransform(torchvision.transforms.v2.ElasticTransform): The random generator of static elastic deformation, with parametersalphaandsigma.- Example:
>>> import torch >>> from spyrit.core.warp import ElasticDeformation >>> >>> alpha = 100 >>> sigma = 5 >>> img_shape = (64, 64) >>> n_frames = 30 >>> n_interpolation = 5 >>> def_field = ElasticDeformation(alpha, sigma, img_shape, n_frames, n_interpolation) >>> print(def_field.field.shape) torch.Size([30, 64, 64, 2])
Methods
Computes the theoretical standard deviation (in pixels) of the deformation field.
forward(img[, n0, n1, mode])Generates a video from a batch of 2D images according to the deformation field \(u\).