.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery/tuto_06_a_warp.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_tuto_06_a_warp.py: 06.a. Deformation fields ======================== .. _tuto_06a_warp: This tutorial demonstrates how to create and apply deformation fields to simulate motion in images using the SpyRIT library. It based on the :mod:`spyrit.core.warp` submodule. .. image:: ../fig/tuto06a_network.png :width: 600 :align: center :alt: Overview of the dynamic pipeline Given a reference image :math:`x` and a deformation field :math:`u(t, :, :)`, it computes the motion video x(t, :, :) by applying the deformation field to the reference image: .. math:: x(t, :, :) = x(t_0, u(t, :, :)). Topics covered: - Creating affine deformation fields (translation, rotation, scaling) - Creating elastic deformation fields for realistic motion - Visualizing deformed image sequences .. GENERATED FROM PYTHON SOURCE LINES 32-44 .. code-block:: Python import torch import torchvision import matplotlib.pyplot as plt import math from pathlib import Path from spyrit.misc.statistics import transform_norm from spyrit.misc.load_data import download_girder from spyrit.core.warp import AffineDeformationField, ElasticDeformation .. GENERATED FROM PYTHON SOURCE LINES 45-46 Set parameters: .. GENERATED FROM PYTHON SOURCE LINES 46-65 .. code-block:: Python thumbnail = True # True for displaying the motion as a thumbnail, False for a video visualization n = 64 # size of the FOV side in pixels img_size = 88 # full image side's size in pixels n_frames = 50 # number of frames in the dynamic sequence device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) dtype = torch.float64 simu_interp = "bilinear" # interpolation order for motion simulation time_dim = 1 # time dimension index in tensors fov_shape = (n, n) img_shape = (img_size, img_size) amp_max = (img_shape[0] - fov_shape[0]) // 2 .. rst-class:: sphx-glr-script-out .. code-block:: none Using device: cpu .. GENERATED FROM PYTHON SOURCE LINES 66-67 Load an image from Tomoradio's warehouse. .. GENERATED FROM PYTHON SOURCE LINES 67-95 .. code-block:: Python # Download an RGB brain surface image. url_tomoradio = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1" data_root = Path("../data/data_online/2025_dynamic") # local path to data imgs_path = data_root / Path("images/") id_files = ["69248e3204d23f6e964b16b7"] # brain_surface_colorized.png try: download_girder(url_tomoradio, id_files, imgs_path) except Exception as e: print("Unable to download from the Tomoradio warehouse") print(e) # Create a transform for natural images to normalized image tensors transform = transform_norm(img_size=img_size) batch_size = 1 # Create dataset and loader dataset = torchvision.datasets.ImageFolder(root=data_root, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) img, _ = dataloader.dataset[0] x = img.unsqueeze(0).to(dtype=dtype, device=device) print(f"Shape of input images: {x.shape}") x = (x - x.min()) / (x.max() - x.min()) # normalize to [0, 1] n_wav = x.shape[1] .. rst-class:: sphx-glr-script-out .. code-block:: none Local folder not found, creating it... done. Downloading brain_surface_colorized.png... Downloading brain_surface_colorized.png... done. Shape of input images: torch.Size([1, 3, 88, 88]) .. GENERATED FROM PYTHON SOURCE LINES 96-97 Plot the reference image .. GENERATED FROM PYTHON SOURCE LINES 97-108 .. code-block:: Python x_plot = x.moveaxis(1, -1).squeeze().cpu().numpy() plt.imshow(x_plot) if n_wav == 1: plt.colorbar(fraction=0.046, pad=0.04) plt.title("Reference image") plt.axis("off") plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_001.png :alt: Reference image :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 109-122 Affine deformation ################## Affine deformation examples: 1. Translation (diagonal motion) 2. Rotation (spinning motion) 3. Surface-preserving scaling (pulsating motion) .. important:: SpyRIT uses normalized coordinates [-1, 1]. To convert pixels to normalized: normalized = 2 * pixels / image_size .. GENERATED FROM PYTHON SOURCE LINES 125-127 1. Translation (diagonal motion) -------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 127-158 .. code-block:: Python T = 1000 # time of a period time_vector = torch.linspace(0, 2 * T, n_frames) def translation(t): """Translation transformation - diagonal movement.""" d_pix_tot = 10 # amplitude of translation in pixels assert d_pix_tot < amp_max, "Translation amplitude too large for image size!" d_normalized = 2 * d_pix_tot / img_size # Convert to normalized coordinates d_pix_unit = d_normalized / ( 2 * T ) # normalized amplitude per time unit (for a time vector of length 2T) tx = d_pix_unit * t ty = -d_pix_unit * t return torch.tensor( [ [1, 0, tx], [0, 1, ty], [0, 0, 1], ], dtype=dtype, ) def_field = AffineDeformationField( translation, time_vector, img_shape, dtype=dtype, device=device ) .. GENERATED FROM PYTHON SOURCE LINES 159-161 Simulate motion ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 161-168 .. code-block:: Python x_motion = def_field(x, 0, n_frames, mode=simu_interp) x_motion = x_motion.moveaxis(time_dim, 1) print("x_motion.shape:", x_motion.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none x_motion.shape: torch.Size([1, 50, 3, 88, 88]) .. GENERATED FROM PYTHON SOURCE LINES 169-171 Display deformation within the FOV ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 171-232 .. code-block:: Python if thumbnail: # plot few frames as thumbnails n_frames_display = 15 n_rows, n_cols = 1, 4 plt.figure(figsize=(12, 3)) for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames or frame >= n_rows * n_cols: break plt.subplot(n_rows, n_cols, frame + 1) plt.imshow( x_motion[ 0, n_frame, :, amp_max : img_size - amp_max, amp_max : img_size - amp_max, ] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy(), cmap="gray", ) # in X plt.title("frame %d" % (n_frame), fontsize=18) plt.axis("off") plt.tight_layout() plt.show() else: # show motion as video with IPython display from IPython.display import clear_output n_frames_display = 5 x_min, x_max = x_motion.min().item(), x_motion.max().item() for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames: break plt.close() plt.imshow( x_motion[ 0, n_frame, :, amp_max : img_size - amp_max, amp_max : img_size - amp_max, ] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy(), cmap="gray", vmin=x_min, vmax=x_max, ) # in X plt.suptitle("frame %d" % (n_frame), fontsize=16) plt.pause(0.1) clear_output(wait=True) .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_002.png :alt: frame 0, frame 15, frame 30, frame 45 :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 233-235 2. Rotation (spinning motion) ----------------------------- .. GENERATED FROM PYTHON SOURCE LINES 235-257 .. code-block:: Python T = 1000 # time of a period time_vector = torch.linspace(0, 2 * T, n_frames) def rotation(t): """Rotation transformation - spinning motion.""" theta = 2 * math.pi * t / T # One full rotation per period T return torch.tensor( [ [math.cos(theta), -math.sin(theta), 0], [math.sin(theta), math.cos(theta), 0], [0, 0, 1], ], dtype=dtype, ) def_field = AffineDeformationField( rotation, time_vector, img_shape, dtype=dtype, device=device ) .. GENERATED FROM PYTHON SOURCE LINES 258-260 Simulate motion ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 260-267 .. code-block:: Python x_motion = def_field(x, 0, n_frames, mode=simu_interp) x_motion = x_motion.moveaxis(time_dim, 1) print("x_motion.shape:", x_motion.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none x_motion.shape: torch.Size([1, 50, 3, 88, 88]) .. GENERATED FROM PYTHON SOURCE LINES 268-270 Display deformation within the FOV ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 270-332 .. code-block:: Python if thumbnail: # plot few frames as thumbnails n_frames_display = 5 n_rows, n_cols = 1, 4 plt.figure(figsize=(12, 3)) for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames or frame >= n_rows * n_cols: break plt.subplot(n_rows, n_cols, frame + 1) plt.imshow( x_motion[ 0, n_frame, :, amp_max : img_size - amp_max, amp_max : img_size - amp_max, ] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy(), cmap="gray", ) # in X plt.title("frame %d" % (n_frame), fontsize=18) plt.axis("off") plt.tight_layout() plt.show() else: # show motion as video with IPython display from IPython.display import clear_output n_frames_display = 5 x_min, x_max = x_motion.min().item(), x_motion.max().item() for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames: break plt.close() plt.imshow( x_motion[ 0, n_frame, :, amp_max : img_size - amp_max, amp_max : img_size - amp_max, ] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy(), cmap="gray", vmin=x_min, vmax=x_max, ) # in X plt.suptitle("frame %d" % (n_frame), fontsize=16) plt.pause(0.1) clear_output(wait=True) .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_003.png :alt: frame 0, frame 5, frame 10, frame 15 :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 333-335 3. Surface-preserving (pulsating motion) ---------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 335-361 .. code-block:: Python T = 1000 # time of a period time_vector = torch.linspace(0, 2 * T, n_frames) def s(t): a = 0.2 # amplitude in normalized coordinates return 1 + a * math.sin(t * 2 * math.pi / T) def pulsation(t): """Surface-preserving transformation - pulsating motion.""" return torch.tensor( [ [1 / s(t), 0, 0], [0, s(t), 0], [0, 0, 1], ], dtype=dtype, ) def_field = AffineDeformationField( pulsation, time_vector, img_shape, dtype=dtype, device=device ) .. GENERATED FROM PYTHON SOURCE LINES 362-364 Simulate motion ~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 364-371 .. code-block:: Python x_motion = def_field(x, 0, n_frames, mode=simu_interp) x_motion = x_motion.moveaxis(time_dim, 1) print("x_motion.shape:", x_motion.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none x_motion.shape: torch.Size([1, 50, 3, 88, 88]) .. GENERATED FROM PYTHON SOURCE LINES 372-374 Display deformation within the FOV ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 374-436 .. code-block:: Python if thumbnail: # plot few frames as thumbnails n_frames_display = 5 n_rows, n_cols = 1, 4 plt.figure(figsize=(12, 3)) for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames or frame >= n_rows * n_cols: break plt.subplot(n_rows, n_cols, frame + 1) plt.imshow( x_motion[ 0, n_frame, :, amp_max : img_size - amp_max, amp_max : img_size - amp_max, ] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy(), cmap="gray", ) # in X plt.title("frame %d" % (n_frame), fontsize=18) plt.axis("off") plt.tight_layout() plt.show() else: # show motion as video with IPython display from IPython.display import clear_output n_frames_display = 5 x_min, x_max = x_motion.min().item(), x_motion.max().item() for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames: break plt.close() plt.imshow( x_motion[ 0, n_frame, :, amp_max : img_size - amp_max, amp_max : img_size - amp_max, ] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy(), cmap="gray", vmin=x_min, vmax=x_max, ) # in X plt.suptitle("frame %d" % (n_frame), fontsize=16) plt.pause(0.1) clear_output(wait=True) .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_004.png :alt: frame 0, frame 5, frame 10, frame 15 :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 437-446 Random elastic deformation ########################## Elastic deformation creates a non-parametric motion that can simulate tissue deformation or fluid motion. Parameters: - :attr:`magnitude_amp`: Controls magnitude of deformations (in pixels) - :attr:`smoothness`: Controls spatial correlation (higher = smoother) - :attr:`n_interpolation`: Number of keyframes for temporal interpolation .. GENERATED FROM PYTHON SOURCE LINES 446-465 .. code-block:: Python magnitude_amp = 500 # Magnitude in pixels smoothness = 5 # Spatial smoothness parameter n_interpolation = 3 # Temporal interpolation points def_field = ElasticDeformation( magnitude_amp, smoothness, img_shape, n_frames, n_interpolation, dtype=dtype, device=device, ) elastic_std = def_field.compute_field_std() print( f"Generated random elastic deformation field has an std of {elastic_std:.2f} pixels." ) .. rst-class:: sphx-glr-script-out .. code-block:: none Generated random elastic deformation field has an std of 5.77 pixels. .. GENERATED FROM PYTHON SOURCE LINES 466-467 ~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 467-473 .. code-block:: Python x_motion = def_field(x, 0, n_frames, mode=simu_interp) x_motion = x_motion.moveaxis(time_dim, 1) print("x_motion.shape:", x_motion.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none x_motion.shape: torch.Size([1, 50, 3, 88, 88]) .. GENERATED FROM PYTHON SOURCE LINES 474-476 Display deformation within the FOV ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 476-523 .. code-block:: Python n_frames_display = 5 if thumbnail: # plot few frames as thumbnails n_rows, n_cols = 1, 4 plt.figure(figsize=(12, 3)) for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames or frame >= n_rows * n_cols: break plt.subplot(n_rows, n_cols, frame + 1) x_frame = ( x_motion[0, n_frame, :, amp_max : n + amp_max, amp_max : n + amp_max] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy() ) plt.imshow(x_frame, cmap="gray") # in X plt.title("frame %d" % (n_frame), fontsize=18) plt.axis("off") plt.tight_layout() plt.show() else: # show motion as video with IPython display from IPython.display import clear_output x_min, x_max = x_motion.min().item(), x_motion.max().item() for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames: break plt.close() x_frame = ( x_motion[0, n_frame, :, amp_max : n + amp_max, amp_max : n + amp_max] .moveaxis(0, -1) .view(*fov_shape, n_wav) .cpu() .numpy() ) plt.imshow(x_frame, cmap="gray", vmin=x_min, vmax=x_max) # in X plt.suptitle("frame %d" % (n_frame), fontsize=16) plt.colorbar(fraction=0.046, pad=0.04) plt.pause(0.01) clear_output(wait=True) .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_005.png :alt: frame 0, frame 5, frame 10, frame 15 :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 524-526 Visualize the deformation field with quiver arrows ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 526-590 .. code-block:: Python interval = torch.linspace(0, img_size - 1, img_size, dtype=torch.float64) x1, x2 = torch.meshgrid(interval, interval, indexing="xy") x1, x2 = x1 / img_size * 2 - 1, x2 / img_size * 2 - 1 x1, x2 = x1.cpu().numpy(), x2.cpu().numpy() field = def_field.field.cpu().numpy() n_frames_display = 5 if thumbnail: # plot few frames plt.figure(figsize=(12, 3)) n_rows, n_cols = 1, 4 for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames or frame >= n_rows * n_cols: break plt.subplot(n_rows, n_cols, frame + 1) step = 6 # change this to plot fewer or more arrows plt.quiver( x1[::step, ::step], -x2[::step, ::step], (field[n_frame, ::step, ::step, 0] - x1[::step, ::step]), -(field[n_frame, ::step, ::step, 1] - x2[::step, ::step]), angles="xy", scale_units="xy", scale=1, ) plt.title("frame %d" % (n_frame), fontsize=18) # Make axes square so quiver arrows reflect image aspect ratio ax = plt.gca() ax.set_aspect("equal", adjustable="box") ax.set_xlim([-1, 1]) ax.set_ylim([-1, 1]) ax.set_xticks([-1, 0, 1]) ax.set_yticks([-1, 0, 1]) plt.tight_layout() plt.show() else: # show motion as video with IPython display from IPython.display import clear_output for frame in range(n_frames): n_frame = n_frames_display * frame if n_frame >= n_frames: break plt.figure(figsize=(6, 6)) step = 6 # change this to plot fewer or more arrows plt.quiver( x1[::step, ::step], -x2[::step, ::step], (field[n_frame, ::step, ::step, 0] - x1[::step, ::step]), -(field[n_frame, ::step, ::step, 1] - x2[::step, ::step]), angles="xy", scale_units="xy", scale=1, ) plt.suptitle("frame %d" % n_frame, fontsize=16) # Make axes square so quiver arrows reflect image aspect ratio plt.pause(0.01) clear_output(wait=True) .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_006.png :alt: frame 0, frame 5, frame 10, frame 15 :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 591-593 Plot a frame of the deformation field for thumbnail sphinx_gallery_thumbnail_number = 7 .. GENERATED FROM PYTHON SOURCE LINES 593-613 .. code-block:: Python n_frame = 30 plt.figure(figsize=(2, 2)) step = 6 # change this to plot fewer or more arrows plt.quiver( x1[::step, ::step], -x2[::step, ::step], (field[n_frame, ::step, ::step, 0] - x1[::step, ::step]), -(field[n_frame, ::step, ::step, 1] - x2[::step, ::step]), angles="xy", scale_units="xy", scale=1, ) # Make axes square so quiver arrows reflect image aspect ratio ax = plt.gca() ax.set_aspect("equal", adjustable="box") ax.set_xlim([-1, 1]) ax.set_ylim([-1, 1]) ax.set_xticks([-1, 0, 1]) ax.set_yticks([-1, 0, 1]) plt.show() .. image-sg:: /gallery/images/sphx_glr_tuto_06_a_warp_007.png :alt: tuto 06 a warp :srcset: /gallery/images/sphx_glr_tuto_06_a_warp_007.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.554 seconds) .. _sphx_glr_download_gallery_tuto_06_a_warp.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tuto_06_a_warp.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tuto_06_a_warp.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tuto_06_a_warp.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_