Source code for spyrit.misc.dual_arm

"""
Module designed for the dual-arm single-pixel camera.

This module contains several classes:
    * :class:`KeyPoints`: determines the key points between a CMOS camera and a single-pixel camera.
    * :class:`ComputeHomography`: computes the homography matrix between the two camera views.
    * :class:`MotionFieldProjector`: reads the CMOS motion fields from Nifti files and projects them to the single-pixel camera point of view using the computed homography.

Practical examples of usage can be found in the `spyrit-examples <https://github.com/openspyrit/spyrit-examples/tree/dynamic_tip>`_ repository.

In particular, the following scripts in the ``2025_dynamic_TIP`` folder treat experimental data acquired with the dual-arm single-pixel camera and use the classes
from this module for calibration and motion estimation: ``fig_07.py``, ``fig_08.py``, ``fig_09_10.py``, ``fig_11_ablation_channels.py``, ``fig_11_spectra.py``, and
``fig_12.py``.

References:
    [Maitre2024_1]_ Maitre, T., Bretin, E., Mahieu-Williame, L., Sdika, M., & Ducros, N. (2024, May).
    Hybrid single-pixel camera for dynamic hyperspectral imaging. In 2024 IEEE International Symposium
    on Biomedical Imaging (ISBI) (pp. 1-5). IEEE. DOI:10.1109/ISBI56570.2024.10635884

    [Maitre2026]_ (Submitted to TIP) Maitre, T., Bretin, E., Mahieu-Williame, L., Phan, R., Sdika, M., & Ducros, N. (2025).
    Dual-arm motion-compensated single-pixel imaging. HAL Id: hal-05068181
"""

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import warnings
from dataclasses import dataclass
from typing import Tuple, Optional, Union
from pathlib import Path
import torch.nn as nn
import time
import math

try:
    import cv2
except ImportError:
    warnings.warn(
        "Please install OpenCV to use the dual-arm module (necessary for defining keypoints), e.g. via 'pip install opencv-python'."
    )

from spyrit.misc.disp import torch2numpy
from spyrit.core.meas import HadamSplit2d
from spyrit.misc.statistics import Cov2Var
from spyrit.core.warp import DeformationField

from spyrit.misc.load_data import read_acquisition
from spyrit.misc.disp import get_frame


@dataclass
class _MouseState:
    """State container for mouse interactions."""

    x: int = 0
    y: int = 0
    img: Optional[np.ndarray] = None


# Global state for mouse callbacks (necessary for OpenCV callback system)
_cmos_state = _MouseState()
_sp_state = _MouseState()


def _draw_circle(event: int, x: int, y: int, flags: int, param) -> None:
    """Mouse callback for CMOS image interaction."""
    global _cmos_state
    if event == cv2.EVENT_LBUTTONDBLCLK and _cmos_state.img is not None:
        cv2.circle(_cmos_state.img, (x, y), 2, (255, 0, 0), -1)
        _cmos_state.x, _cmos_state.y = x, y


def _draw_circle_2(event: int, x: int, y: int, flags: int, param) -> None:
    """Mouse callback for single-pixel camera image interaction."""
    global _sp_state
    if event == cv2.EVENT_LBUTTONDBLCLK and _sp_state.img is not None:
        cv2.circle(_sp_state.img, (x, y), 1, (255, 0, 0), -1)
        _sp_state.x, _sp_state.y = x, y


[docs] class KeyPoints(nn.Module): """ Detects and manages keypoints between two camera views. This class provides multiple methods for keypoint detection and matching between a source image (CMOS camera) and a destination image (Single-Pixel Camera). It supports both automatic detection methods (SIFT) and manual placement. Detected keypoints are essential for computing the homography matrix that relates the two camera coordinate systems (see :class:`ComputeHomography`). Args: :attr:`src_img` Source image array from CMOS camera :attr:`dest_img` Destination image array from Single-Pixel Camera :attr:`homo_folder` (optional) Folder where keypoint data and homography matrices are stored """ def __init__( self, src_img: np.ndarray, dest_img: np.ndarray, homo_folder: str = "" ): super().__init__() self.src_img = src_img self.dest_img = dest_img self.homo_folder = homo_folder
[docs] def place_hand_keypoints( self, win_up_factor: int = 10 ) -> Tuple[np.ndarray, np.ndarray]: """ Manually place keypoints on both images using mouse interaction. Args: :attr:`win_up_factor` (optional): Window upscaling factor for the single-pixel image display. Returns: Tuple of (src_points, dest_points) as numpy arrays. """ src_points, dest_points = [], [] global _cmos_state, _sp_state _cmos_state.img = self.src_img.copy() _sp_state.img = np.rot90(self.dest_img, 2).copy() n = self.dest_img.shape[0] print(f"Image size: {n}, upscaling factor: {win_up_factor}") # Setup OpenCV windows cv2.namedWindow("CMOS", cv2.WINDOW_NORMAL) cv2.setMouseCallback("CMOS", _draw_circle) cv2.namedWindow("SPC", cv2.WINDOW_NORMAL) cv2.resizeWindow("SPC", win_up_factor * n, win_up_factor * n) cv2.setMouseCallback("SPC", _draw_circle_2) print("DOUBLE-CLICK and press 'a' to place a point on the CMOS image") print("DOUBLE-CLICK and press 'z' to place a point on the SPC image") print("Press 'q' to quit") while True: cv2.imshow("CMOS", _cmos_state.img) cv2.imshow("SPC", _sp_state.img) key = cv2.waitKey(0) & 0xFF if key == ord("q"): break elif key == ord("a"): point = (_cmos_state.x, _cmos_state.y) print(f"CMOS point: {point}") src_points.append(point) print(f"Current src_points: {src_points}") elif key == ord("z"): point = (_sp_state.x, _sp_state.y) print(f"SPC point: {point}") dest_points.append(point) print(f"Current dest_points: {dest_points}") cv2.destroyAllWindows() # Transform coordinates for SP image (account for rotation) dest_points = np.array(dest_points) dest_points = np.array([n - 1, n - 1]) - dest_points src_points = np.array(src_points) # Save keypoints data_path = self.homo_folder data_path.mkdir(parents=True, exist_ok=True) np.save(data_path / "handmade_dest_kp.npy", dest_points) np.save(data_path / "handmade_src_kp.npy", src_points) return src_points, dest_points
[docs] def find_sift_keypoints(self) -> Tuple[np.ndarray, np.ndarray]: """ Find matching keypoints using SIFT feature detection. Returns: Tuple of (src_points, dest_points) as numpy arrays. """ # Convert to 8-bit images img1 = (255 * self.src_img).astype(np.uint8) img2 = (255 * self.dest_img).astype(np.uint8) ## CODE FROM OPENCV EXAMPLE coord_p1, coord_p2 = [], [] # Initialize SIFT detector sift = cv2.SIFT_create() # Find keypoints and descriptors kp1, des1 = sift.detectAndCompute(img1, None) kp2, des2 = sift.detectAndCompute(img2, None) # BFMatcher with default params bf = cv2.BFMatcher() matches = bf.knnMatch(des1, des2, k=2) # Apply ratio test good = [] for m, n in matches: if m.distance < 0.6 * n.distance: p1 = kp1[m.queryIdx].pt p2 = kp2[m.trainIdx].pt if p1 not in coord_p1 and p2 not in coord_p2: coord_p1.append(p1), coord_p2.append(p2) good.append([m]) # cv.drawMatchesKnn expects list of lists as matches. img3 = cv2.drawMatchesKnn( img1, kp1, img2, kp2, good, None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS, ) img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2GRAY) plt.imshow(img3, cmap="gray"), plt.show() return np.array(coord_p1), np.array(coord_p2)
[docs] def find_shi_tomasi_keypoints( self, max_corners: int = 20, quality_level: float = 0.1, min_distance: int = 2 ) -> Tuple[np.ndarray, np.ndarray]: """ Find corners using Shi-Tomasi corner detector. .. warning:: This method doesn't match keypoints between images yet. Feel free to contribute if you need this functionality. Alternatively, you can use SIFT or manual placement for now. Args: :attr:`max_corners`: Maximum number of corners to detect. :attr:`quality_level`: Quality level for corner detection. :attr:`min_distance`: Minimum distance between corners. Returns: Tuple of (src_points, dest_points) as numpy arrays. """ warnings.warn( "Shi-Tomasi keypoints are not matched between images yet. " "Feel free to contribute if you need this functionality." ) img1 = (255 * self.src_img).astype(np.uint8) img2 = (255 * self.dest_img).astype(np.uint8) corners_cmos = cv2.goodFeaturesToTrack( img1, maxCorners=max_corners, qualityLevel=quality_level, minDistance=min_distance, useHarrisDetector=False, ) corners_spc = cv2.goodFeaturesToTrack( img2, maxCorners=max_corners, qualityLevel=quality_level, minDistance=min_distance, useHarrisDetector=False, ) if corners_cmos is None or corners_spc is None: return np.array([]), np.array([]) src_points = corners_cmos[:, 0, :] dest_points = corners_spc[:, 0, :] return src_points, dest_points
[docs] def external_keypoints(self) -> Tuple[np.ndarray, np.ndarray]: """ Load keypoints from external files. Returns: Tuple of (:attr:`src_points`, :attr:`dest_points`) as numpy arrays. """ data_path = Path("../data/exp_data") / self.homo_folder src_file = data_path / "external_src_kp.npy" dest_file = data_path / "external_dest_kp.npy" if not src_file.exists() or not dest_file.exists(): raise FileNotFoundError( f"External keypoint files not found: {src_file} or {dest_file}" ) src_points = np.load(src_file) dest_points = np.load(dest_file) return src_points, dest_points
[docs] def forward( self, kp_method: str, read_hand_kp: bool = False ) -> Tuple[np.ndarray, np.ndarray]: """ Main method to find keypoints using specified method. Args: :attr:`kp_method`: Method to use ('hand', 'sift', 'shi-tomasi', 'external'). :attr:`read_hand_kp`: Whether to read existing hand-placed keypoints. Returns: Tuple of (:attr:`src_points`, :attr:`dest_points`) as numpy arrays. """ if kp_method == "hand": if read_hand_kp: data_path = self.homo_folder / kp_method src_points = np.load(data_path / "handmade_src_kp.npy") dest_points = np.load(data_path / "handmade_dest_kp.npy") else: src_points, dest_points = self.place_hand_keypoints() elif kp_method == "sift": src_points, dest_points = self.find_sift_keypoints() # print(dest_points[4], src_points[4]) # src_points, dest_points = np.delete(src_points, 4, axis=0), np.delete(dest_points, 4, axis=0) #point aberrant pour le chat # src_points, dest_points = np.delete(src_points, 2, axis=0), np.delete(dest_points, 2, axis=0) elif kp_method == "shi-tomasi": src_points, dest_points = self.find_shi_tomasi_keypoints() elif kp_method == "external": src_points, dest_points = self.external_keypoints() else: raise ValueError( f"Unknown keypoint method: {kp_method}. " f"Supported methods: 'hand', 'sift', 'shi-tomasi', 'external'" ) return src_points, dest_points
[docs] def recalibrate( X: torch.Tensor, shape: Tuple[int, int], homography_inv: torch.Tensor, amp_max: int = 0, ) -> torch.Tensor: """ Recalibrate tensor X using inverse homography transformation. Args: :attr:`X`: Input tensor of shape (batch_size, n_wav, height, width). :attr:`shape`: Target shape (n, m). :attr:`homography_inv`: Inverse homography matrix (3x3). :attr:`amp_max`: Maximum amplitude offset. Returns: Recalibrated tensor of shape (batch_size, n_wav, n, m). """ n, m = shape batch_size, n_wav, height, width = X.shape dtype = X.dtype device = X.device # Create coordinate meshgrid y_coords = torch.linspace(0, n - 1, n, dtype=dtype, device=device) x_coords = torch.linspace(0, m - 1, m, dtype=dtype, device=device) x_grid, y_grid = torch.meshgrid(x_coords, y_coords, indexing="xy") # Apply amplitude offset x_grid = x_grid - amp_max y_grid = y_grid - amp_max # Create homogeneous coordinates ones = torch.ones_like(x_grid) coords_homogeneous = torch.stack([x_grid, y_grid, ones], dim=2) # (n, m, 3) # Apply inverse homography transformation # Reshape for batch matrix multiplication coords_flat = coords_homogeneous.view(-1, 3, 1) # (n*m, 3, 1) # Transform coordinates transformed_coords = torch.bmm( homography_inv.unsqueeze(0).expand(coords_flat.shape[0], -1, -1), coords_flat ) # (n*m, 3, 1) # Convert from homogeneous coordinates transformed_coords = transformed_coords.squeeze(-1) # (n*m, 3) w = transformed_coords[:, 2] x_new = transformed_coords[:, 0] / w y_new = transformed_coords[:, 1] / w # Reshape back to grid format x_new = x_new.view(n, m) y_new = y_new.view(n, m) # Normalize coordinates to [-1, 1] for grid_sample x_new_norm = x_new / (width - 1) * 2 - 1 y_new_norm = y_new / (height - 1) * 2 - 1 # Create grid for grid_sample (note: grid_sample expects (x, y) order) grid = torch.stack((x_new_norm, y_new_norm), dim=2) # (n, m, 2) grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1) # (batch_size, n, m, 2) # Apply interpolation X_calibrated = nn.functional.grid_sample( X, grid, mode="bilinear", padding_mode="zeros", align_corners=True ) return X_calibrated
[docs] class ComputeHomography(nn.Module): """ Computes the homography between the two arms of the hybrid single-pixel camera using a Direct Linear Transform (DLT) [Maitre2024_1]_. .. note:: By convention, we refer to the CMOS image as the "source" and the single-pixel camera reconstruction as the "destination". Args: :attr:`data_root`: Root directory of the data. :attr:`data_folder`: Folder containing the data. :attr:`data_file_prefix`: Prefix of the data files. :attr:`n`: Size of the reconstructed image. :attr:`n_acq`: Size of the acquisition. """ def __init__( self, data_root: Path, data_folder: str, data_file_prefix: str, n: int, n_acq: int, ): super().__init__() self.data_root = Path(data_root) self.data_folder = data_folder self.data_file_prefix = data_file_prefix self.n = n self.n_acq = n_acq
[docs] def DLT(self, points_source: np.ndarray, points_target: np.ndarray) -> np.ndarray: """ Computes homography using Direct Linear Transform (DLT) method. Requires at least 4 corresponding point pairs. Args: points_source: Source image keypoints (Nx2). points_target: Target image keypoints (Nx2). Returns: 3x3 homography matrix. """ if len(points_source) < 4 or len(points_target) < 4: raise ValueError("At least 4 point correspondences are required for DLT") A = self._construct_A(points_source, points_target) # Solve using SVD _, _, vh = np.linalg.svd(A, full_matrices=True) # Solution is the last column of V (last row of V^T) homography = vh[-1].reshape((3, 3)) return homography / homography[2, 2]
def _construct_A( self, points_source: np.ndarray, points_target: np.ndarray ) -> np.ndarray: """ Construct matrix A for DLT algorithm. Args: :attr:`points_source`: Source image keypoints. :attr:`points_target`: Target image keypoints. Returns: A matrix for SVD decomposition. """ assert ( points_source.shape == points_target.shape ), "Source and target points must have the same shape" num_points = points_source.shape[0] matrices = [] for i in range(num_points): partial_A = self._construct_A_partial(points_source[i], points_target[i]) matrices.append(partial_A) return np.concatenate(matrices, axis=0) def _construct_A_partial(self, point_source, point_target): x, y, z = point_source[0], point_source[1], 1 x_t, y_t, z_t = point_target[0], point_target[1], 1 A_partial = np.array( [ [0, 0, 0, -z_t * x, -z_t * y, -z_t * z, y_t * x, y_t * y, y_t * z], [z_t * x, z_t * y, z_t * z, 0, 0, 0, -x_t * x, -x_t * y, -x_t * z], ] ) return A_partial
[docs] def forward( self, kp_method: str, homo_folder: str = "", read_homography: bool = False, save_homography: bool = True, read_hand_kp: bool = False, snapshot: bool = True, show_calib: bool = False, ) -> torch.Tensor: """ Compute the homography between the CMOS and single pixel cameras. Args: :attr:`kp_method`: Keypoint detection method ('sift', 'hand' or 'external'). :attr:`homo_folder`: Folder for homography data. :attr:`read_homography`: Whether to load existing homography. :attr:`save_homography`: Whether to save computed homography. :attr:`read_hand_kp`: Whether to read existing hand-placed keypoints. :attr:`snapshot`: Whether to use a snapshot or a video for the CMOS data. :attr:`show_calib`: Whether to show calibration visualization. Returns: Computed homography matrix as torch tensor. """ # Create output directory output_dir = self.data_root / homo_folder / kp_method output_dir.mkdir(parents=True, exist_ok=True) # Load CMOS image g_frame0 = self._load_cmos(snapshot) # Load single pixel camera reconstruction f_stat_np = self._load_spc() # Normalize images g_frame0 = (g_frame0 - g_frame0.min()) / (g_frame0.max() - g_frame0.min()) f_stat_np = (f_stat_np - f_stat_np.min()) / (f_stat_np.max() - f_stat_np.min()) # Load or compute homography if read_homography: homography_np = self._load_homography(output_dir) homography = torch.from_numpy(homography_np) else: homography = self._compute_homography( g_frame0, f_stat_np, kp_method, homo_folder, read_hand_kp, save_homography, output_dir, ) # Visualize calibration if requested if show_calib: self._visualize_calibration(g_frame0, f_stat_np, homography) return homography
def _load_cmos(self, snapshot: bool) -> np.ndarray: """Load CMOS camera data from snapshot or video.""" if snapshot: file_path = ( self.data_root / self.data_folder / f"{self.data_file_prefix}_IDScam_before_acq.npy" ) if not file_path.exists(): raise FileNotFoundError(f"CMOS snapshot file not found: {file_path}") return np.load(file_path) else: video_path = ( self.data_root / self.data_folder / f"{self.data_file_prefix}_video.avi" ) if not video_path.exists(): raise FileNotFoundError(f"Video file not found: {video_path}") return get_frame(str(video_path), 0) def _load_spc(self) -> np.ndarray: """Load single pixel camera raw data and reconstruct an image using orthogonality of Hadamard patterns.""" # Load measurement data _, meas = read_acquisition( self.data_root, self.data_folder, self.data_file_prefix ) # Load covariance matrix stat_folder = Path("./stats/") cov_file = stat_folder / f"Cov_{self.n_acq}x{self.n_acq}.npy" if not cov_file.exists(): raise FileNotFoundError(f"Covariance file not found: {cov_file}") Cov_acq = np.load(cov_file) Ord_acq = Cov2Var(Cov_acq) Ord = torch.from_numpy(Ord_acq) # Create measurement operator if self.n < self.n_acq: Ord = Ord[: self.n, : self.n] meas_op_stat = HadamSplit2d( M=self.n**2, h=self.n, order=Ord, dtype=torch.float64 ) meas = meas[: 2 * self.n**2] else: meas_op_stat = HadamSplit2d( M=self.n**2, h=self.n_acq, order=Ord, dtype=torch.float64 ) # Process measurements m = meas[::2, :] - meas[1::2, :] m_pan = np.mean(m, axis=1) m_pan = torch.from_numpy(m_pan.reshape((1, -1))) # Reconstruct image f_stat = meas_op_stat.fast_pinv(m_pan) return torch2numpy(f_stat).reshape((self.n, self.n)) def _load_homography(self, output_dir: Path) -> np.ndarray: """Load existing homography matrix.""" homo_file = output_dir / "homography.npy" if not homo_file.exists(): raise FileNotFoundError(f"Homography file not found: {homo_file}") return np.load(homo_file) def _compute_homography( self, g_frame0: np.ndarray, f_stat_np: np.ndarray, kp_method: str, homo_folder: str, read_hand_kp: bool, save_homography: bool, output_dir: Path, ) -> torch.Tensor: """Compute homography using keypoint detection.""" # Find keypoints kp_finder = KeyPoints(g_frame0, f_stat_np, self.data_root / homo_folder) src_points, dest_points = kp_finder(kp_method, read_hand_kp=read_hand_kp) if len(src_points) == 0 or len(dest_points) == 0: raise ValueError("No keypoints found. Cannot compute homography.") # Visualize keypoints self._visualize_keypoints(g_frame0, f_stat_np, src_points, dest_points) # Compute homography using DLT homography_np = self.DLT(src_points, dest_points) homography = torch.from_numpy(homography_np) # Save homography if requested if save_homography: np.save(output_dir / "homography.npy", homography_np) return homography def _visualize_keypoints( self, g_frame0: np.ndarray, f_stat_np: np.ndarray, src_points: np.ndarray, dest_points: np.ndarray, ) -> None: """Visualize detected keypoints on both images.""" colors = np.random.rand(len(src_points)) # CMOS image keypoints plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.imshow(g_frame0, cmap="gray") plt.scatter(src_points[:, 0], src_points[:, 1], c=colors, s=50, marker="o") plt.title("Keypoints on CMOS image", fontsize=16) plt.axis("off") # SP image keypoints plt.subplot(1, 2, 2) plt.imshow(f_stat_np, cmap="gray") plt.scatter(dest_points[:, 0], dest_points[:, 1], c=colors, s=50, marker="o") plt.title("Keypoints on SP image", fontsize=16) plt.axis("off") plt.tight_layout() plt.show() def _visualize_calibration( self, g_frame0: np.ndarray, f_stat_np: np.ndarray, homography: torch.Tensor ) -> None: """Visualize calibration results.""" homography_inv = torch.linalg.inv(homography) # Calibrate CMOS image g_frame0_tensor = torch.from_numpy(g_frame0).unsqueeze(0).unsqueeze(0) img_cmos_calibrated = recalibrate( g_frame0_tensor, (f_stat_np.shape[0], f_stat_np.shape[0]), homography_inv ) # img_cmos_calibrated = tensor2img(img_cmos_calibrated)[:, :, 0] img_cmos_calibrated = torch2numpy(img_cmos_calibrated.moveaxis(1, -1))[ 0, :, :, 0 ] # Create visualization fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(21, 7)) axs[0].imshow(img_cmos_calibrated, cmap="gray") axs[0].set_title("CMOS camera calibrated", fontsize=20) axs[0].axis("off") axs[1].imshow(f_stat_np, cmap="gray") axs[1].set_title("SP reconstruction", fontsize=20) axs[1].axis("off") diff_img = f_stat_np - img_cmos_calibrated im = axs[2].imshow(diff_img, cmap="Spectral") axs[2].set_title("Difference: SP - CMOS", fontsize=20) axs[2].axis("off") fig.colorbar(im, ax=axs[2], fraction=0.046, pad=0.04) plt.tight_layout() plt.show()
@dataclass class _MotionConfig: """Configuration for motion estimation parameters.""" n: int # Pattern size M: int # Number of illumination patterns n_ppg: int # Number of patterns per gate T: float # Total acquisition time frame_ref: int = 0 # Reference frame index dtype: torch.dtype = torch.float64
[docs] class MotionFieldProjector(nn.Module): """ Projects the motion fields from the CMOS camera perspective to the single-pixel camera (SPC) perspective. This class loads pre-computed motion fields from NIfTI files and performs: - Geometric transformation via homography (CMOS to SPC coordinate mapping) - Temporal interpolation to match the SPC acquisition timing for each illumination pattern - Reference frame definition The result is a motion field suitable for dynamic single-pixel imaging applications. Args: :attr:`deform_path`: Path to deformation field files. :attr:`deform_prefix`: Prefix for deformation files. :attr:`n`: Pattern size. :attr:`M`: Number of illumination patterns. :attr:`n_ppg`: Number of patterns per gate. A gate is defined as the set of patterns between two CMOS frames. :attr:`T`: Total acquisition time. :attr:`frame_ref` (optional): Reference frame index. :attr:`homography` (optional): 3x3 homography transformation matrix. :attr:`translation` (optional): Translation offset (x, y). :attr:`dtype` (optional): Data type for computations (torch.float32, torch.float64, etc.). :attr:`device` (optional): Device to use for computations ('cpu', 'cuda', etc.). Raises: FileNotFoundError: If deformation path doesn't exist. ValueError: If homography matrix is not 3x3. """ def __init__( self, deform_path: Union[str, Path], deform_prefix: str, n: int, M: int, n_ppg: int, T: float, frame_ref: int = 0, homography: torch.Tensor = torch.eye(3), translation: Tuple[float, float] = (0.0, 0.0), dtype: Optional[torch.dtype] = torch.float64, device: Optional[Union[str, torch.device]] = torch.device("cpu"), ): super().__init__() self.deform_path = Path(deform_path) if not self.deform_path.exists(): raise FileNotFoundError(f"Deformation path not found: {deform_path}") self.deform_prefix = deform_prefix self.config = _MotionConfig( n=n, M=M, n_ppg=n_ppg, T=T, frame_ref=frame_ref, dtype=dtype ) # Setup device self.dtype = dtype self.device = device # Validate homography matrix if homography.shape != (3, 3): raise ValueError(f"Homography must be 3x3 matrix, got {homography.shape}") self.homography = ( homography.clone().detach().to(dtype=self.config.dtype, device=self.device) ) self.translation = translation # Apply translation to homography self.homography[0, 2] += translation[0] self.homography[1, 2] += translation[1] # Precompute inverse homography self.homography_inv = torch.linalg.inv(self.homography) # Initialize storage for motion fields self.def_field_cmos: Optional[torch.Tensor] = None self.def_field_spc: Optional[torch.Tensor] = None self.u_cmos: Optional[torch.Tensor] = None def _load_deformation_movies( self, warping: str ) -> Tuple[torch.Tensor, int, int, int]: """ Load deformation field movies from NIfTI files. Args: :attr:`warping`: 'pattern' or 'image'. Matches the warping mode used in the Dynamic classes from :mod:`spyrit.core.meas`. Returns: Tuple of (combined_motion_data, width, height, n_frames). Raises: FileNotFoundError: If deformation files are not found. ValueError: If file dimensions are inconsistent. """ try: import nibabel as nib except ImportError: raise ImportError( "nibabel is required to load NIfTI files. Please install it (e.g. via 'pip install nibabel')." ) if warping == "image": mode = "inverse" elif warping == "pattern": mode = "direct" else: raise ValueError( f"Invalid warping mode: {warping}. Use 'image' or 'pattern'." ) movies = [] for i in range(2): file_path = self.deform_path / f"{self.deform_prefix}_{mode}_{i+1}.img" if not file_path.exists(): raise FileNotFoundError(f"Deformation file not found: {file_path}") print(f"Loading {file_path}") movies.append(nib.load(file_path)) # Extract dimensions from header hdr = movies[0].header dims = hdr["dim"] _, width, height, _, n_frames, _, _, _ = dims width, height, n_frames = int(width), int(height), int(n_frames) print(f"Loaded {n_frames} frames of size {width}x{height}") # Load and combine motion data # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # print(f'Using device: {device} for stacking operation') ti = time.time() u1_cmos = torch.from_numpy(movies[0].get_fdata().transpose()).to(self.device) u2_cmos = torch.from_numpy(movies[1].get_fdata().transpose()).to(self.device) u_cmos = torch.stack([u1_cmos, u2_cmos], dim=1).reshape( u1_cmos.shape[0], -1, u1_cmos.shape[2], u1_cmos.shape[3] ) tf = time.time() print(f"Time to load deformation movies: {tf - ti:.2f} seconds") return u_cmos.to(self.config.dtype), width, height, n_frames def _create_coordinate_grids( self, l: int, amp_max: int, width: int, height: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create coordinate grids for spatial transformations. Args: :attr:`l`: Grid size for SP coordinates. :attr:`amp_max`: Amplitude offset for the extended field of view. :attr:`width`: CMOS image width. :attr:`height`: CMOS image height. Returns: Tuple of (sp_grid, cmos_grid). """ # SP coordinate grid interval = torch.linspace( 0, l - 1, l, dtype=self.config.dtype, device=self.device ) x1_sp, x2_sp = torch.meshgrid(interval, interval, indexing="xy") x1_sp, x2_sp = x1_sp - amp_max, x2_sp - amp_max # CMOS coordinate grid interval_1 = torch.linspace( 0, width - 1, width, dtype=self.config.dtype, device=self.device ) interval_2 = torch.linspace( 0, height - 1, height, dtype=self.config.dtype, device=self.device ) x1_cmos, x2_cmos = torch.meshgrid(interval_1, interval_2, indexing="xy") return (x1_sp, x2_sp), (x1_cmos, x2_cmos) def _apply_homography_vectorized( self, x1_sp: torch.Tensor, x2_sp: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply homography transformation using vectorized operations. Args: :attr:`x1_sp`: X coordinates in SP space. :attr:`x2_sp`: Y coordinates in SP space. Returns: Tuple of transformed coordinates (x1_new, x2_new). """ # Create homogeneous coordinates ones = torch.ones_like(x1_sp) coords_homogeneous = torch.stack([x1_sp, x2_sp, ones], dim=0) # Reshape for efficient batch matrix multiplication coords_flat = coords_homogeneous.view(3, -1) # (3, l*l) # Transform coordinates transformed_coords = self.homography_inv @ coords_flat # (3, l*l) # Convert from homogeneous coordinates w = transformed_coords[2, :] x1_new = (transformed_coords[0, :] / w).view_as(x1_sp) x2_new = (transformed_coords[1, :] / w).view_as(x2_sp) return x1_new, x2_new
[docs] def estim_motion_from_CMOS(self, warping: str, amp_max: int = 0) -> None: """ Estimate motion field from CMOS camera data. Args: :attr:`warping`: 'pattern' or 'image'. Matches the warping mode used in the Dynamic classes from :mod:`spyrit.core.meas`. :attr:`amp_max`: Amplitude for the extended field of view. Raises: FileNotFoundError: If required files are not found. """ l = self.config.n + 2 * amp_max # Load deformation movies u_cmos, width, height, n_frames = self._load_deformation_movies(warping) self.u_cmos = u_cmos # Create coordinate grids (x1_sp, x2_sp), (x1_cmos, x2_cmos) = self._create_coordinate_grids( l, amp_max, width, height ) # Apply homography transformation (vectorized) x1_new, x2_new = self._apply_homography_vectorized(x1_sp, x2_sp) # Normalize coordinates for grid_sample [-1, 1] x1_new_norm = x1_new / (width - 1) * 2 - 1 x2_new_norm = x2_new / (height - 1) * 2 - 1 # Create sampling grid efficiently (avoid .repeat()) grid = torch.stack((x1_new_norm, x2_new_norm), dim=2) # (l, l, 2) grid = grid.unsqueeze(0).expand( n_frames, -1, -1, -1 ) # More memory efficient than repeat # Create CMOS grid for displacement calculation grid_cmos = torch.stack((x1_cmos, x2_cmos), dim=0) # (2, width, height) grid_cmos = grid_cmos.unsqueeze(0).expand( n_frames, -1, -1, -1 ) # More memory efficient # Calculate displacement and apply interpolation du_cmos = u_cmos - grid_cmos # Move data to GPU for the compute-intensive grid_sample operation device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device} for grid_sample operation") t1 = time.time() if device.type == "cuda": du_cmos_gpu = du_cmos.to(device) grid_gpu = grid.to(device) du = nn.functional.grid_sample( du_cmos_gpu, grid_gpu, mode="bilinear", padding_mode="border", align_corners=True, ).cpu() # Move result back to CPU del du_cmos_gpu, grid_gpu torch.cuda.empty_cache() else: du = nn.functional.grid_sample( du_cmos, grid, mode="bilinear", padding_mode="border", align_corners=True, ) t2 = time.time() print(f"Time to apply grid_sample: {t2 - t1:.2f} seconds") # Efficient grid creation for final coordinates grid2 = torch.stack((x1_new, x2_new), dim=0) grid2 = grid2.unsqueeze(0).expand(n_frames, -1, -1, -1) u = du + grid2 # Apply homography transformation to get SP coordinates (old operator A : x_cmos -> x_sp) u1_sp, u2_sp = self._apply_homography_to_motion(u) # Combine results res = torch.stack((u1_sp, u2_sp), dim=1) # Add identity transformation for first frame frame0 = torch.stack((x1_sp, x2_sp), dim=0).unsqueeze(0) res = torch.cat((frame0, res), dim=0) # cat is more efficient than concatenate # Normalize coordinates to [-1, 1] range res = (res + amp_max) / (l - 1) * 2 - 1 self.def_field_cmos = res
def _apply_homography_to_motion( self, u: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply homography transformation to motion vectors. Args: :attr:`u`: Motion vectors of shape (n_frames, 2, height, width). Returns: Tuple of transformed motion components (u1_sp, u2_sp). """ H = self.homography # Vectorized homography application u1_sp = (H[0, 0] * u[:, 0, :, :] + H[0, 1] * u[:, 1, :, :] + H[0, 2]) / ( H[2, 0] * u[:, 0, :, :] + H[2, 1] * u[:, 1, :, :] + H[2, 2] ) u2_sp = (H[1, 0] * u[:, 0, :, :] + H[1, 1] * u[:, 1, :, :] + H[1, 2]) / ( H[2, 0] * u[:, 0, :, :] + H[2, 1] * u[:, 1, :, :] + H[2, 2] ) return u1_sp, u2_sp
[docs] def def_reference(self) -> None: """ Define reference frame by subtracting reference deformation from all frames. This method normalizes the deformation field so that the reference frame has zero deformation, making all other deformations relative to it. Raises: RuntimeError: If CMOS deformation field is not computed yet. """ if self.def_field_cmos is None: raise RuntimeError( "CMOS deformation field not computed. Call estim_motion_from_CMOS first." ) n_frames, _, l, _ = self.def_field_cmos.shape # Validate reference frame index if self.config.frame_ref >= n_frames: raise ValueError( f"Reference frame {self.config.frame_ref} >= number of frames {n_frames}" ) # Create identity grid for reference on the same device/dtype as the deformation field device = self.def_field_cmos.device interval = torch.linspace(0, l - 1, l, dtype=self.config.dtype, device=device) x1, x2 = torch.meshgrid(interval, interval, indexing="xy") x1_norm = x1 / (l - 1) * 2 - 1 x2_norm = x2 / (l - 1) * 2 - 1 # Calculate reference deformation dx1_ref = self.def_field_cmos[self.config.frame_ref, 0, :, :] - x1_norm dx2_ref = self.def_field_cmos[self.config.frame_ref, 1, :, :] - x2_norm # Subtract reference deformation from all frames self.def_field_cmos[:, 0, :, :] = self.def_field_cmos[:, 0, :, :] - dx1_ref self.def_field_cmos[:, 1, :, :] = self.def_field_cmos[:, 1, :, :] - dx2_ref
[docs] def interpolate_between_frames(self) -> None: """ Interpolate deformation field between frames for SPC acquisition timing. This method creates a temporally dense deformation field that matches the SPC acquisition pattern timing. Raises: RuntimeError: If CMOS deformation field is not computed yet. """ if self.def_field_cmos is None: raise RuntimeError( "CMOS deformation field not computed. Call estim_motion_from_CMOS first." ) n_frames, _, l, _ = self.def_field_cmos.shape # Calculate timing parameters n_hppg = math.ceil( 2 * self.config.M / n_frames ) # Hadamard patterns per gate period n_last_pat = 2 * self.config.M - n_hppg * ( n_frames - 1 ) # Patterns in last frame n_wppg = self.config.n_ppg - n_hppg # White patterns per gate period print( f"Interpolation parameters: n_hppg={n_hppg}, n_last_pat={n_last_pat}, n_wppg={n_wppg}" ) # Calculate total pattern count and timing n_patterns = self.config.n_ppg * (n_frames - 1) + n_wppg + n_last_pat dt = self.config.T / n_patterns t_acq_cmos = n_wppg * dt # Initialize SPC deformation field self.def_field_spc = torch.zeros( (2 * self.config.M, 2, l, l), dtype=self.config.dtype, device=self.device ) print(f"Creating SPC deformation field with {2 * self.config.M} patterns") # Interpolate for each frame for f in range(n_frames): t_f, t_fp1, u_f, u_fp1 = self._get_frame_timing_and_deformation( f, n_frames, dt, t_acq_cmos ) g_beg, g_end = self._get_pattern_indices(f, n_frames, n_wppg, n_last_pat) # Interpolate patterns within this frame interval for k in range(g_beg, g_end): t_k = ( k * dt + dt / 2 ) # add dt / 2 to be at the middle of the pattern exposure # Linear interpolation between frames alpha = (t_k - t_f) / (t_fp1 - t_f) if t_fp1 != t_f else 0.0 interpolated_def = u_f + alpha * (u_fp1 - u_f) pattern_idx = k - (f + 1) * n_wppg if 0 <= pattern_idx < 2 * self.config.M: self.def_field_spc[pattern_idx, :, :, :] = interpolated_def
def _get_frame_timing_and_deformation( self, f: int, n_frames: int, dt: float, t_acq_cmos: float ) -> Tuple[float, float, torch.Tensor, torch.Tensor]: """ Get timing and deformation data for frame interpolation. Args: :attr:`f`: Current frame index. :attr:`n_frames`: Total number of frames. :attr:`dt`: Time step. :attr:`t_acq_cmos`: CMOS acquisition time offset. Returns: Tuple of (t_f, t_fp1, u_f, u_fp1) - timing and deformation data. """ # Calculate frame timings t_f = f * self.config.n_ppg * dt + t_acq_cmos / 2 t_fp1 = (f + 1) * self.config.n_ppg * dt + t_acq_cmos / 2 # Get deformation fields u_f = self.def_field_cmos[f, :, :, :] if f != n_frames - 1: u_fp1 = self.def_field_cmos[f + 1, :, :, :] else: u_fp1 = u_f # Use same deformation for last frame return t_f, t_fp1, u_f, u_fp1 def _get_pattern_indices( self, f: int, n_frames: int, n_wppg: int, n_last_pat: int ) -> Tuple[int, int]: """ Get pattern indices for current frame. Args: :attr:`f`: Current frame index. :attr:`n_frames`: Total number of frames. :attr:`n_wppg`: White patterns per group. :attr:`n_last_pat`: Patterns in last frame. Returns: Tuple of (g_beg, g_end) - start and end of Hadamard (SPC) pattern indices. """ g_beg = f * self.config.n_ppg + n_wppg if f != n_frames - 1: g_end = (f + 1) * self.config.n_ppg else: g_end = g_beg + n_last_pat return g_beg, g_end
[docs] def forward(self, warping: str, amp_max: int = 0) -> torch.Tensor: """ Complete forward pass for motion estimation. Args: :attr:`warping`: 'pattern' or 'image'. Matches the warping mode used in the Dynamic classes from :mod:`spyrit.core.meas`. :attr:`amp_max`: Amplitude of the extended field of view. Returns: SPC deformation field of shape (2*M, l, l, 2). Raises: RuntimeError: If any step fails during processing. """ try: # Step 1: Convert motion from CMOS perspective to SPC print("Step 1: Convert motion from CMOS perspective to SPC...") self.estim_motion_from_CMOS(warping, amp_max=amp_max) # Step 2: Define reference frame print("Step 2: Defining reference frame...") self.def_reference() # Step 3: Interpolate between frames print("Step 3: Interpolating between frames...") self.interpolate_between_frames() print("Motion estimation completed successfully!") return DeformationField(self.def_field_spc.moveaxis(1, -1)) except Exception as e: raise RuntimeError(f"Motion estimation failed: {e}") from e