Source code for spyrit.misc.disp

# -----------------------------------------------------------------------------
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import torch
import math
import numbers
from pathlib import Path
from typing import Tuple, Union

from spyrit.misc.color import wavelength_to_colormap
from spyrit.core.warp import DeformationField


[docs] def display_vid(video, fps, title="", colormap=plt.cm.gray): """ video is a numpy array of shape [nb_frames, 1, nx, ny] """ plt.ion() nb_frames, channels, nx, ny = video.shape fig = plt.figure() ax = fig.add_subplot(1, 1, 1) for i in range(nb_frames): current_frame = video[i, 0, :, :] plt.imshow(current_frame, cmap=colormap) plt.title(title) divider = make_axes_locatable(ax) cax = plt.axes([0.85, 0.1, 0.075, 0.8]) plt.colorbar(cax=cax) plt.show() plt.pause(fps) plt.ioff()
[docs] def display_rgb_vid(video, fps, title=""): """ video is a numpy array of shape [nb_frames, 3, nx, ny] """ plt.ion() nb_frames, channels, nx, ny = video.shape fig = plt.figure() ax = fig.add_subplot(1, 1, 1) for i in range(nb_frames): current_frame = video[i, :, :, :] current_frame = np.moveaxis(current_frame, 0, -1) plt.imshow(current_frame) plt.title(title) plt.show() plt.pause(fps) plt.ioff()
[docs] def fitPlots(N, aspect=(16, 9)): width = aspect[0] height = aspect[1] area = width * height * 1.0 factor = (N / area) ** (1 / 2.0) cols = math.floor(width * factor) rows = math.floor(height * factor) rowFirst = width < height while rows * cols < N: if rowFirst: rows += 1 else: cols += 1 rowFirst = not (rowFirst) return rows, cols
[docs] def Multi_plots( img_list, title_list, shape, suptitle="", colormap=plt.cm.gray, axis_off=True, aspect=(16, 9), savefig="", fontsize=14, ): [rows, cols] = shape plt.figure() plt.suptitle(suptitle, fontsize=16) if (len(img_list) < rows * cols) or (len(title_list) < rows * cols): for k in range(max(rows * cols - len(img_list), rows * cols - len(title_list))): img_list.append(np.zeros((64, 64))) title_list.append("") for k in range(rows * cols): ax = plt.subplot(rows, cols, k + 1) ax.imshow(img_list[k], cmap=colormap) ax.set_title(title_list[k], fontsize=fontsize) if axis_off: plt.axis("off") if savefig: plt.savefig(savefig, bbox_inches="tight") plt.show()
[docs] def compare_video_frames( vid_list, nb_disp_frames, title_list, suptitle="", colormap=plt.cm.gray, aspect=(16, 9), savefig="", fontsize=14, ): rows = len(vid_list) cols = nb_disp_frames plt.figure(figsize=aspect) plt.suptitle(suptitle, fontsize=16) for i in range(rows): for j in range(cols): k = (j + 1) + (i) * (cols) i # print(k) ax = plt.subplot(rows, cols, k) # print("i = {}, j = {}".format(i,j)) ax.imshow(vid_list[i][0, j, 0, :, :], cmap=colormap) ax.set_title(title_list[i][j], fontsize=fontsize) plt.axis("off") if savefig: plt.savefig(savefig, bbox_inches="tight") plt.show()
[docs] def torch2numpy(torch_tensor): return torch_tensor.cpu().detach().numpy()
[docs] def uint8(dsp): x = (dsp - np.amin(dsp)) / (np.amax(dsp) - np.amin(dsp)) * 255 x = x.astype("uint8") return x
[docs] def imagesc( Img, title="", colormap=None, show=False, figsize=None, fig=None, ax=None, cbar_pos=None, title_fontsize=16, **kwargs, ): """ Display image data with scaled colors, a colormap, and a colorbar, similar to MATLAB's `imagesc` function. This function acts as a wrapper around `matplotlib.pyplot.imshow` with custom handling for the colormap and colorbar placement. Args: :attr:`Img` (array-like): The 2D array or image data to be displayed. :attr:`title` (str, optional): The title for the plot. Defaults to an empty string. :attr:`colormap` (str, int, or Colormap, optional): The colormap to use. - If **None** (default), uses Matplotlib's default **'gray'** colormap. - If **str**, it should be a valid Matplotlib colormap name (e.g., 'plasma', 'jet', 'viridis'). - If **int** or **float**, it is treated as a wavelength (in nm) and is passed to the function `wavelength_to_colormap(colormap, gamma=0.6)` from `spyrit.misc.color` to generate a custom colormap. - If a **Matplotlib Colormap object**, it is used directly. :attr:`show` (bool, optional): If **True** (default), calls `plt.show()` to display the plot. :attr:`figsize` (tuple, optional): A tuple (width, height) specifying the figure size in inches. Passed to `plt.figure()`. Defaults to None. :attr:`cbar_pos` (str, optional): Position of the colorbar. - If **"bottom"**, the colorbar is placed horizontally below the image. - If **None** (default) or any other value, the colorbar is placed vertically to the right of the image. :attr:`title_fontsize` (int, optional): Font size for the plot title. Defaults to 16. **kwargs: Additional keyword arguments. - :attr:`gamma` (float, optional): The gamma correction factor when `colormap` is a wavelength (numeric). Defaults to 0.6. Returns: None: The function primarily displays the plot via Matplotlib. """ if colormap is None: colormap = plt.cm.gray elif isinstance(colormap, numbers.Number): if "gamma" in kwargs: gamma = kwargs["gamma"] else: gamma = 0.6 colormap = wavelength_to_colormap(colormap, gamma=gamma) if fig is None: fig = plt.figure(figsize=figsize) if ax is None: ax = fig.add_subplot(1, 1, 1) pos = ax.imshow(Img, cmap=colormap) if cbar_pos == "bottom": fig.colorbar(pos, ax=ax, location="bottom", orientation="horizontal") else: fig.colorbar(pos, ax=ax, location="right", orientation="vertical") ax.set_title(title, fontsize=title_fontsize) # fig.tight_layout() # it raises warnings in some cases if show is True: plt.show()
[docs] def imagecomp( Img1, Img2, suptitle="", title1="", title2="", colormap1=plt.cm.gray, colormap2=plt.cm.gray, show=False, ): f, (ax1, ax2) = plt.subplots(1, 2) im1 = ax1.imshow(Img1, cmap=colormap1) ax1.set_title(title1) cax = plt.axes([0.43, 0.3, 0.025, 0.4]) plt.colorbar(im1, cax=cax) plt.suptitle(suptitle, fontsize=16) # im2 = ax2.imshow(Img2, cmap=colormap2) ax2.set_title(title2) cax = plt.axes([0.915, 0.3, 0.025, 0.4]) plt.colorbar(im2, cax=cax) plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) if show: plt.show()
[docs] def imagepanel( Img1, Img2, Img3, Img4, suptitle="", title1="", title2="", title3="", title4="", colormap1=plt.cm.gray, colormap2=plt.cm.gray, colormap3=plt.cm.gray, colormap4=plt.cm.gray, show=False, ): fig, axarr = plt.subplots(2, 2, figsize=(20, 10)) plt.suptitle(suptitle, fontsize=16) im1 = axarr[0, 0].imshow(Img1, cmap=colormap1) axarr[0, 0].set_title(title1) cax = plt.axes([0.4, 0.54, 0.025, 0.35]) plt.colorbar(im1, cax=cax) im2 = axarr[0, 1].imshow(Img2, cmap=colormap2) axarr[0, 1].set_title(title2) cax = plt.axes([0.90, 0.54, 0.025, 0.35]) plt.colorbar(im2, cax=cax) im3 = axarr[1, 0].imshow(Img3, cmap=colormap3) axarr[1, 0].set_title(title3) cax = plt.axes([0.4, 0.12, 0.025, 0.35]) plt.colorbar(im3, cax=cax) im4 = axarr[1, 1].imshow(Img4, cmap=colormap4) axarr[1, 1].set_title(title4) cax = plt.axes([0.9, 0.12, 0.025, 0.35]) plt.colorbar(im4, cax=cax) plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) if show: plt.show()
[docs] def plot(x, y, title="", xlabel="", ylabel="", color="black"): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) plt.plot(x, y, color=color) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.show()
[docs] def add_colorbar(mappable, position="right", size="5%"): """ Example: f, axs = plt.subplots(1, 2) im = axs[0].imshow(img1, cmap='gray') add_colorbar(im) im = axs[0].imshow(img2, cmap='gray') add_colorbar(im) """ if position == "bottom": orientation = "horizontal" else: orientation = "vertical" last_axes = plt.gca() ax = mappable.axes fig = ax.figure divider = make_axes_locatable(ax) cax = divider.append_axes(position, size="5%", pad=0.05) cbar = fig.colorbar(mappable, cax=cax, orientation=orientation) plt.sca(last_axes) return cbar
[docs] def noaxis(axs): if type(axs) is np.ndarray: for ax in axs: ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) else: axs.get_xaxis().set_visible(False) axs.get_yaxis().set_visible(False)
[docs] def string_mean_std(x, prec=3): return "{:.{p}f} +/- {:.{p}f}".format(np.mean(x), np.std(x), p=prec)
[docs] def histogram(s): count, bins, ignored = plt.hist(s, 30, density=True) plt.show()
[docs] def vid2batch(root, img_dim, start_frame, end_frame): from imutils.video import FPS import imutils try: import cv2 except ImportError: raise ImportError( "cv2 is required for vid2batch. Please install OpenCV (e.g., via 'pip install opencv-python')." ) stream = cv2.VideoCapture(root) fps = FPS().start() frame_nb = 0 output_batch = torch.zeros(1, end_frame - start_frame, 1, img_dim, img_dim) while True: grabbed, frame = stream.read() if not grabbed: break frame_nb += 1 if (frame_nb >= start_frame) & (frame_nb < end_frame): frame = cv2.resize(frame, (img_dim, img_dim)) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) output_batch[0, frame_nb - start_frame, 0, :, :] = torch.Tensor( frame[:, :, 1] ) return output_batch
[docs] def pre_process_video(video, crop_patch, kernel_size): try: import cv2 except ImportError: raise ImportError( "cv2 is required for pre_process_video. Please install OpenCV (e.g., via 'pip install opencv-python')." ) batch_size, seq_length, c, h, w = video.shape batched_frames = video.reshape(batch_size * seq_length * c, h, w) output_batch = torch.zeros(batched_frames.shape) for i in range(batch_size * seq_length * c): img = torch2numpy(batched_frames[i, :, :]) img[crop_patch] = 0 median_frame = cv2.medianBlur(img, kernel_size) output_batch[i, :, :] = torch.Tensor(median_frame) output_batch = output_batch.reshape(batch_size, seq_length, c, h, w) return output_batch
[docs] def contrib_map( H_dyn: np.ndarray, n: int, save_figs: bool = False, path_fig: Union[str, Path] = "", show_fig: bool = True, ) -> np.ndarray: """ Generate a contribution map showing measurement pattern coverage. Args: H_dyn: Dynamic measurement matrix. n: Image size (assuming square images). save_figs: Whether to save the figure. path_fig: Path to save the figure. show_fig: Whether to display the figure. Returns: Contribution map as numpy array. """ _, L = H_dyn.shape l = int(np.sqrt(L)) if l * l != L: raise ValueError(f"Matrix dimension L={L} is not a perfect square") # Create contribution map contrib = np.where(H_dyn != 0, 1, 0) contrib_image = np.sum(contrib, axis=0) contrib_image = contrib_image.reshape((l, l)) / (n**2) contrib_image = np.rot90(contrib_image, 2) # Account for rotation in SP POV # Create visualization plt.figure(figsize=(8, 6)) plt.imshow(contrib_image, cmap="hot") cbar = plt.colorbar(fraction=0.047, pad=0.01, format="%.1f") cbar.ax.tick_params(labelsize=15) plt.title("Contribution Map", fontsize=16) if save_figs and path_fig: plt.axis("off") plt.savefig(str(path_fig), bbox_inches="tight", dpi=300) if show_fig: plt.show() else: plt.close() return contrib_image
[docs] def error_map( img1: np.ndarray, img2: np.ndarray, save_figs: bool = False, path_fig: Union[str, Path] = "", show_fig: bool = True, title: str = "Error Map", ) -> np.ndarray: """ Compute and visualize error map between two images. Args: img1: First image. img2: Second image. save_figs: Whether to save the figure. path_fig: Path to save the figure. show_fig: Whether to display the figure. title: Title for the plot. Returns: Error map as numpy array. Raises: ValueError: If images have different shapes. """ if img1.shape != img2.shape: raise ValueError(f"Image shapes don't match: {img1.shape} vs {img2.shape}") # Normalize images to [0, 1] img1_normalized = ( (img1 - img1.min()) / (img1.max() - img1.min()) if img1.max() != img1.min() else np.zeros_like(img1) ) img2_normalized = ( (img2 - img2.min()) / (img2.max() - img2.min()) if img2.max() != img2.min() else np.zeros_like(img2) ) error = img1_normalized - img2_normalized plt.figure(figsize=(8, 6)) plt.imshow(error, cmap="Spectral") cbar = plt.colorbar(fraction=0.047, pad=0.01, format="%.2f") cbar.ax.tick_params(labelsize=15) plt.title(title, fontsize=16) plt.axis("off") if save_figs and path_fig: plt.savefig(str(path_fig), bbox_inches="tight", dpi=300) if show_fig: plt.show() else: plt.close() return error
[docs] def blue_box( f: np.ndarray, amp_max: int = 0, box_color: Tuple[int, int, int] = (0, 0, 255) ) -> np.ndarray: """ Add a colored box overlay to an image for visualization purposes. Args: f: Input grayscale or RGB image. amp_max: Offset from image borders for the box. box_color: RGB color for the box (default: blue). Returns: RGB image with colored box overlay. Raises: ValueError: If amp_max is too large for the image size. """ if amp_max * 2 >= min(f.shape[0], f.shape[1]): raise ValueError(f"amp_max={amp_max} is too large for image shape {f.shape}") # Normalize to 8-bit and create RGB image f_normalized = ( (f - f.min()) / (f.max() - f.min()) if f.max() != f.min() else np.zeros_like(f) ) f_255 = (f_normalized * 255).astype(np.uint8) if len(f.shape) == 2: f_rgb = np.stack([f_255] * 3, axis=2) elif len(f.shape) == 3: f_rgb = f_255 if amp_max == 0: return f_rgb r, g, b = box_color # Draw box borders # Top and bottom borders f_rgb[amp_max, amp_max:-amp_max] = [r, g, b] f_rgb[-amp_max - 1, amp_max:-amp_max] = [r, g, b] # Left and right borders f_rgb[amp_max:-amp_max, amp_max] = [r, g, b] f_rgb[amp_max:-amp_max, -amp_max - 1] = [r, g, b] return f_rgb
[docs] def get_frame(movie_path: Union[str, Path], frame_number: int = 0) -> np.ndarray: """ Extract a specific frame from a video file. Args: movie_path: Path to the video file. frame_number: Frame number to extract (0-indexed). Returns: Grayscale frame as numpy array. Raises: FileNotFoundError: If video file doesn't exist. ValueError: If frame cannot be read or video is invalid. """ try: import cv2 except ImportError: raise ImportError( "cv2 is required for get_frame. Please install OpenCV (e.g., via 'pip install opencv-python')." ) movie_path = Path(movie_path) if not movie_path.exists(): raise FileNotFoundError(f"Video file not found: {movie_path}") cap = cv2.VideoCapture(str(movie_path)) if not cap.isOpened(): raise ValueError(f"Error opening video file: {movie_path}") # Get total frame count for validation total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if frame_number >= total_frames: cap.release() raise ValueError( f"Frame {frame_number} not available. Video has {total_frames} frames." ) # Set frame position directly (more efficient than iterating) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) ret, frame = cap.read() cap.release() if not ret: raise ValueError(f"Could not read frame {frame_number} from video") # Convert to grayscale if len(frame.shape) == 3: gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) else: gray_frame = frame return gray_frame
[docs] def save_motion_video(x_motion, out_path, amp_max=0, fps=820): r""" Save :attr:`x_motion` of shape (1, n_frames, c, h, w) as a video file. """ try: import cv2 except ImportError: raise ImportError( "cv2 is required for save_motion_video. Please install OpenCV (e.g., via 'pip install opencv-python')." ) out_path = Path(out_path) out_path.parent.mkdir(parents=True, exist_ok=True) n_batch, n_frames, n_wav, h, w = x_motion.shape h_crop, w_crop = torch.tensor((h, w)) - 2 * amp_max h_crop, w_crop = int(h_crop.item()), int(w_crop.item()) if n_batch > 1: raise ValueError(f"save_motion_video expects a single batch, got {n_batch}") if n_wav != 1 and n_wav != 3: raise ValueError(f"save_motion_video expects 1 or 3 channels, got {n_wav}") fourcc = cv2.VideoWriter_fourcc(*"MJPG") writer = cv2.VideoWriter(str(out_path), fourcc, fps, (w_crop, h_crop), True) if not writer.isOpened(): raise RuntimeError("cv2 VideoWriter failed to open") for t in range(n_frames): frame_wide = x_motion[0, t].moveaxis(0, -1).cpu().numpy() mn, mx = frame_wide.min(), frame_wide.max() frame = frame_wide[amp_max : h_crop + amp_max, amp_max : w_crop + amp_max] if mx > mn: frame8 = ((frame - mn) / (mx - mn) * 255.0).astype("uint8") else: frame8 = (frame * 0).astype("uint8") if n_wav == 1: frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_GRAY2BGR) else: frame_bgr = cv2.cvtColor(frame8, cv2.COLOR_RGB2BGR) writer.write(frame_bgr) writer.release() print(f"Saved motion video to {out_path}")
[docs] def save_field_video( def_field, out_path, n_frames, step=5, fps=30, figsize=(6, 6), dpi=200, scale=1, fs=16, amp_max=0, box_color="blue", box_linewidth=2, ): r""" Save deformation field as a quiver video. Args: :attr:`def_field` (DeformationField): Deformation field object (see the :mod:`spyrit.core.warp` module). :attr:`out_path` (Path-like): Path-like to save the output video (mp4). :attr:`n_frames` (int): Number of frames in the output video. :attr:`step` (int): Step size for quiver grid (spatial subsampling). :attr:`fps` (int): Frames per second for the output video. :attr:`figsize` (tuple): Figure size for the quiver plot (width, height). :attr:`dpi` (int): Output DPI for the saved video/frames. :attr:`scale` (float): Quiver scale parameter for scaling the length of arrows. :attr:`amp_max` (int, optional): Number of pixels inset from each border to draw a blue box (e.g. the SPC FOV). If 0, no box is drawn. :attr:`box_color` (str, optional): matplotlib color for the box edge. :attr:`box_linewidth` (float, optional): width of the box edge line. """ out_path = Path(out_path) out_path.parent.mkdir(parents=True, exist_ok=True) # Basic validation if int(n_frames) < 1: raise ValueError("n_frames must be >= 1") if int(step) < 1: raise ValueError("step must be >= 1") if fps <= 0: raise ValueError("fps must be > 0") if not isinstance(def_field, DeformationField): raise ValueError("def_field must be a DeformationField object") # Load and validate field field_full = torch2numpy(def_field.field) n_patterns = def_field.n_frames l = int(def_field.img_shape[0]) n = l - 2 * amp_max # Validate amp_max if amp_max < 0: raise ValueError("amp_max must be >= 0") if amp_max * 2 >= l: raise ValueError(f"amp_max={amp_max} is too large for image size {l}") # Temporal indices: guarantee exactly n_frames samples indices = np.linspace(0, n_patterns - 1, int(n_frames)).round().astype(int) # Spatial subsampling using absolute pixel coordinates (0 .. l-1) interval = np.arange(0, l) x1_sp, x2_sp = np.meshgrid(interval, interval, indexing="xy") x1 = x1_sp[::step, ::step] x2 = x2_sp[::step, ::step] field = field_full[indices][:, ::step, ::step, :] # normalize to absolute pixel coordinates scale_factor = np.array(def_field.img_shape) - 1 field = (field + 1) / 2 * scale_factor # rotate field to 180 degrees to match our SP convention (np.rot90 only reorders samples) field[..., 0] = (l - 1) - field[..., 0] field[..., 1] = (l - 1) - field[..., 1] x1 = (l - 1) - x1 x2 = (l - 1) - x2 # Prepare figure and initial quiver (use absolute pixel coords) fig_q, ax_q = plt.subplots(figsize=figsize) ax_q.set_aspect("equal") ax_q.set_xlim([0, l - 1]) ax_q.set_ylim([0, l - 1]) ax_q.invert_yaxis() # displacement vectors: target - source (in pixel coordinates) U0 = field[0, ..., 0] - x1 V0 = field[0, ..., 1] - x2 Q = ax_q.quiver(x1, x2, U0, V0, angles="xy", scale_units="xy", scale=scale) # Draw blue box (if requested) in absolute pixel coordinates if amp_max and amp_max > 0: from matplotlib.patches import Rectangle x_min = amp_max x_max = l - amp_max - 1 y_min = amp_max y_max = l - amp_max - 1 width = x_max - x_min height = y_max - y_min rect = Rectangle( (x_min, y_min), width, height, fill=False, edgecolor=box_color, linewidth=box_linewidth, ) ax_q.add_patch(rect) # Try writing MP4 try: writer = animation.FFMpegWriter(fps=fps) with writer.saving(fig_q, str(out_path), dpi=dpi): for i, idx in enumerate(indices): U = field[i, ..., 0] - x1 V = field[i, ..., 1] - x2 Q.set_UVC(U, V) ax_q.set_title(f"frame {idx}", fontsize=fs) writer.grab_frame() plt.close(fig_q) except Exception as e: raise RuntimeError(f"FFMpegWriter failed: {e}")