Source code for spyrit.misc.disp

# -----------------------------------------------------------------------------
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from numpy import linalg as LA
import time
from scipy import signal
from scipy import misc
from scipy import sparse
import torch
import math
import numbers

from spyrit.misc.color import wavelength_to_colormap


[docs] def display_vid(video, fps, title="", colormap=plt.cm.gray): """ video is a numpy array of shape [nb_frames, 1, nx, ny] """ plt.ion() nb_frames, channels, nx, ny = video.shape fig = plt.figure() ax = fig.add_subplot(1, 1, 1) for i in range(nb_frames): current_frame = video[i, 0, :, :] plt.imshow(current_frame, cmap=colormap) plt.title(title) divider = make_axes_locatable(ax) cax = plt.axes([0.85, 0.1, 0.075, 0.8]) plt.colorbar(cax=cax) plt.show() plt.pause(fps) plt.ioff()
[docs] def display_rgb_vid(video, fps, title=""): """ video is a numpy array of shape [nb_frames, 3, nx, ny] """ plt.ion() nb_frames, channels, nx, ny = video.shape fig = plt.figure() ax = fig.add_subplot(1, 1, 1) for i in range(nb_frames): current_frame = video[i, :, :, :] current_frame = np.moveaxis(current_frame, 0, -1) plt.imshow(current_frame) plt.title(title) plt.show() plt.pause(fps) plt.ioff()
[docs] def fitPlots(N, aspect=(16, 9)): width = aspect[0] height = aspect[1] area = width * height * 1.0 factor = (N / area) ** (1 / 2.0) cols = math.floor(width * factor) rows = math.floor(height * factor) rowFirst = width < height while rows * cols < N: if rowFirst: rows += 1 else: cols += 1 rowFirst = not (rowFirst) return rows, cols
[docs] def Multi_plots( img_list, title_list, shape, suptitle="", colormap=plt.cm.gray, axis_off=True, aspect=(16, 9), savefig="", fontsize=14, ): [rows, cols] = shape plt.figure() plt.suptitle(suptitle, fontsize=16) if (len(img_list) < rows * cols) or (len(title_list) < rows * cols): for k in range(max(rows * cols - len(img_list), rows * cols - len(title_list))): img_list.append(np.zeros((64, 64))) title_list.append("") for k in range(rows * cols): ax = plt.subplot(rows, cols, k + 1) ax.imshow(img_list[k], cmap=colormap) ax.set_title(title_list[k], fontsize=fontsize) if axis_off: plt.axis("off") if savefig: plt.savefig(savefig, bbox_inches="tight") plt.show()
[docs] def compare_video_frames( vid_list, nb_disp_frames, title_list, suptitle="", colormap=plt.cm.gray, aspect=(16, 9), savefig="", fontsize=14, ): rows = len(vid_list) cols = nb_disp_frames plt.figure(figsize=aspect) plt.suptitle(suptitle, fontsize=16) for i in range(rows): for j in range(cols): k = (j + 1) + (i) * (cols) i # print(k) ax = plt.subplot(rows, cols, k) # print("i = {}, j = {}".format(i,j)) ax.imshow(vid_list[i][0, j, 0, :, :], cmap=colormap) ax.set_title(title_list[i][j], fontsize=fontsize) plt.axis("off") if savefig: plt.savefig(savefig, bbox_inches="tight") plt.show()
[docs] def torch2numpy(torch_tensor): return torch_tensor.cpu().detach().numpy()
[docs] def uint8(dsp): x = (dsp - np.amin(dsp)) / (np.amax(dsp) - np.amin(dsp)) * 255 x = x.astype("uint8") return x
[docs] def imagesc( Img, title="", colormap=None, show=False, figsize=None, fig=None, ax=None, cbar_pos=None, title_fontsize=16, **kwargs, ): """ Display image data with scaled colors, a colormap, and a colorbar, similar to MATLAB's `imagesc` function. This function acts as a wrapper around `matplotlib.pyplot.imshow` with custom handling for the colormap and colorbar placement. Args: Img (array-like): The 2D array or image data to be displayed. title (str, optional): The title for the plot. Defaults to an empty string. colormap (str, int, or Colormap, optional): The colormap to use. - If **None** (default), uses Matplotlib's default **'gray'** colormap. - If **str**, it should be a valid Matplotlib colormap name (e.g., 'plasma', 'jet', 'viridis'). - If **int** or **float**, it is treated as a wavelength (in nm) and is passed to the function `wavelength_to_colormap(colormap, gamma=0.6)` from `spyrit.misc.color` to generate a custom colormap. - If a **Matplotlib Colormap object**, it is used directly. show (bool, optional): If **True** (default), calls `plt.show()` to display the plot. figsize (tuple, optional): A tuple (width, height) specifying the figure size in inches. Passed to `plt.figure()`. Defaults to None. cbar_pos (str, optional): Position of the colorbar. - If **"bottom"**, the colorbar is placed horizontally below the image. - If **None** (default) or any other value, the colorbar is placed vertically to the right of the image. title_fontsize (int, optional): Font size for the plot title. Defaults to 16. **kwargs: Additional keyword arguments. - gamma (float, optional): The gamma correction factor when `colormap` is a wavelength (numeric). Defaults to 0.6. Returns: None: The function primarily displays the plot via Matplotlib. """ if colormap is None: colormap = plt.cm.gray elif isinstance(colormap, numbers.Number): if "gamma" in kwargs: gamma = kwargs["gamma"] else: gamma = 0.6 colormap = wavelength_to_colormap(colormap, gamma=gamma) if fig is None: fig = plt.figure(figsize=figsize) if ax is None: ax = fig.add_subplot(1, 1, 1) pos = ax.imshow(Img, cmap=colormap) if cbar_pos == "bottom": fig.colorbar(pos, ax=ax, location="bottom", orientation="horizontal") else: fig.colorbar(pos, ax=ax, location="right", orientation="vertical") ax.set_title(title, fontsize=title_fontsize) # fig.tight_layout() # it raises warnings in some cases if show is True: plt.show()
[docs] def imagecomp( Img1, Img2, suptitle="", title1="", title2="", colormap1=plt.cm.gray, colormap2=plt.cm.gray, show=False, ): f, (ax1, ax2) = plt.subplots(1, 2) im1 = ax1.imshow(Img1, cmap=colormap1) ax1.set_title(title1) cax = plt.axes([0.43, 0.3, 0.025, 0.4]) plt.colorbar(im1, cax=cax) plt.suptitle(suptitle, fontsize=16) # im2 = ax2.imshow(Img2, cmap=colormap2) ax2.set_title(title2) cax = plt.axes([0.915, 0.3, 0.025, 0.4]) plt.colorbar(im2, cax=cax) plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) if show: plt.show()
[docs] def imagepanel( Img1, Img2, Img3, Img4, suptitle="", title1="", title2="", title3="", title4="", colormap1=plt.cm.gray, colormap2=plt.cm.gray, colormap3=plt.cm.gray, colormap4=plt.cm.gray, show=False, ): fig, axarr = plt.subplots(2, 2, figsize=(20, 10)) plt.suptitle(suptitle, fontsize=16) im1 = axarr[0, 0].imshow(Img1, cmap=colormap1) axarr[0, 0].set_title(title1) cax = plt.axes([0.4, 0.54, 0.025, 0.35]) plt.colorbar(im1, cax=cax) im2 = axarr[0, 1].imshow(Img2, cmap=colormap2) axarr[0, 1].set_title(title2) cax = plt.axes([0.90, 0.54, 0.025, 0.35]) plt.colorbar(im2, cax=cax) im3 = axarr[1, 0].imshow(Img3, cmap=colormap3) axarr[1, 0].set_title(title3) cax = plt.axes([0.4, 0.12, 0.025, 0.35]) plt.colorbar(im3, cax=cax) im4 = axarr[1, 1].imshow(Img4, cmap=colormap4) axarr[1, 1].set_title(title4) cax = plt.axes([0.9, 0.12, 0.025, 0.35]) plt.colorbar(im4, cax=cax) plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) if show: plt.show()
[docs] def plot(x, y, title="", xlabel="", ylabel="", color="black"): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) plt.plot(x, y, color=color) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.show()
[docs] def add_colorbar(mappable, position="right", size="5%"): """ Example: f, axs = plt.subplots(1, 2) im = axs[0].imshow(img1, cmap='gray') add_colorbar(im) im = axs[0].imshow(img2, cmap='gray') add_colorbar(im) """ if position == "bottom": orientation = "horizontal" else: orientation = "vertical" last_axes = plt.gca() ax = mappable.axes fig = ax.figure divider = make_axes_locatable(ax) cax = divider.append_axes(position, size="5%", pad=0.05) cbar = fig.colorbar(mappable, cax=cax, orientation=orientation) plt.sca(last_axes) return cbar
[docs] def noaxis(axs): if type(axs) is np.ndarray: for ax in axs: ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) else: axs.get_xaxis().set_visible(False) axs.get_yaxis().set_visible(False)
[docs] def string_mean_std(x, prec=3): return "{:.{p}f} +/- {:.{p}f}".format(np.mean(x), np.std(x), p=prec)
[docs] def histogram(s): count, bins, ignored = plt.hist(s, 30, density=True) plt.show()
[docs] def vid2batch(root, img_dim, start_frame, end_frame): from imutils.video import FPS import imutils import cv2 stream = cv2.VideoCapture(root) fps = FPS().start() frame_nb = 0 output_batch = torch.zeros(1, end_frame - start_frame, 1, img_dim, img_dim) while True: grabbed, frame = stream.read() if not grabbed: break frame_nb += 1 if (frame_nb >= start_frame) & (frame_nb < end_frame): frame = cv2.resize(frame, (img_dim, img_dim)) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) output_batch[0, frame_nb - start_frame, 0, :, :] = torch.Tensor( frame[:, :, 1] ) return output_batch
[docs] def pre_process_video(video, crop_patch, kernel_size): import cv2 batch_size, seq_length, c, h, w = video.shape batched_frames = video.reshape(batch_size * seq_length * c, h, w) output_batch = torch.zeros(batched_frames.shape) for i in range(batch_size * seq_length * c): img = torch2numpy(batched_frames[i, :, :]) img[crop_patch] = 0 median_frame = cv2.medianBlur(img, kernel_size) output_batch[i, :, :] = torch.Tensor(median_frame) output_batch = output_batch.reshape(batch_size, seq_length, c, h, w) return output_batch