# -*- coding: utf-8 -*-
"""
Created on Mon Dec 2 20:53:59 2024
@author: ducros
"""
import numpy as np
import warnings
from typing import Tuple
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import colorsys
from pathlib import Path
from matplotlib.colors import ListedColormap
# %%
[docs]
def wavelength_to_rgb(
wavelength: float, gamma: float = 0.8
) -> Tuple[float, float, float]:
"""Converts wavelength to RGB.
Based on https://gist.github.com/friendly/67a7df339aa999e2bcfcfec88311abfc.
Itself based on code by Dan Bruton:
http://www.physics.sfasu.edu/astro/color/spectra.html
Args:
wavelength (float):
Single wavelength to be converted to RGB.
gamma (float, optional):
Gamma correction. Defaults to 0.8.
Returns:
Tuple[float, float, float]:
RGB value.
"""
if np.min(wavelength) < 380 or np.max(wavelength) > 750:
warnings.warn("Some wavelengths are not in the visible range [380-750] nm")
if wavelength >= 380 and wavelength <= 440:
attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380)
R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma
G = 0.0
B = (1.0 * attenuation) ** gamma
elif wavelength >= 440 and wavelength <= 490:
R = 0.0
G = ((wavelength - 440) / (490 - 440)) ** gamma
B = 1.0
elif wavelength >= 490 and wavelength <= 510:
R = 0.0
G = 1.0
B = (-(wavelength - 510) / (510 - 490)) ** gamma
elif wavelength >= 510 and wavelength <= 580:
R = ((wavelength - 510) / (580 - 510)) ** gamma
G = 1.0
B = 0.0
elif wavelength >= 580 and wavelength <= 645:
R = 1.0
G = (-(wavelength - 645) / (645 - 580)) ** gamma
B = 0.0
elif wavelength >= 645 and wavelength <= 750:
attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645)
R = (1.0 * attenuation) ** gamma
G = 0.0
B = 0.0
else:
R = 0.0
G = 0.0
B = 0.0
return R, G, B
[docs]
def wavelength_to_rgb_mat(wav_range, gamma=1):
rgb_mat = np.zeros((len(wav_range), 3))
for i, wav in enumerate(wav_range):
rgb_mat[i, :] = wavelength_to_rgb(wav, gamma)
return rgb_mat
[docs]
def spectral_colorization(M_gray, wav, axis=None):
"""
Colorize the last dimension of an array
Args:
M_gray (np.ndarray): Grayscale array where the last dimension is the
spectral dimension. This is an A-by-C array, where A can indicate multiple
dimensions (e.g., 4-by-3-by-7) and C is the number of spectral channels.
wav (np.ndarray): Wavelenth. This is a 1D array of size C.
axis (None or int or tuple of ints, optional): Axis or axes along which
the grayscale input is normalized. By default, global normalization
across all axes is considered.
Returns:
M_color (np.ndarray): Color array with an extra dimension. This is an A-by-C-by-3 array.
"""
# Normalize to adjust contrast
M_gray_min = M_gray.min(keepdims=True, axis=axis)
M_gray_max = M_gray.max(keepdims=True, axis=axis)
M_gray = (M_gray - M_gray_min) / (M_gray_max - M_gray_min)
#
rgb_mat = wavelength_to_rgb_mat(wav, gamma=1)
M_red = M_gray @ np.diag(rgb_mat[:, 0])
M_green = M_gray @ np.diag(rgb_mat[:, 1])
M_blue = M_gray @ np.diag(rgb_mat[:, 2])
M_color = np.stack((M_red, M_green, M_blue), axis=-1)
return M_color
[docs]
def colorize(im, color, clip_percentile=0.1):
"""
Helper function to create an RGB image from a single-channel image using a
specific color.
"""
# Check that we just have a 2D image
if im.ndim > 2 and im.shape[2] != 1:
raise ValueError("This function expects a single-channel image!")
# Rescale the image according to how we want to display it
im_scaled = im.astype(np.float32) - np.percentile(im, clip_percentile)
im_scaled = im_scaled / np.percentile(im_scaled, 100 - clip_percentile)
print(
f"Norm: min={np.percentile(im, clip_percentile)}, max={np.percentile(im_scaled, 100 - clip_percentile)}"
)
print(f"New: min={im_scaled.min()}, max={im_scaled.max()}")
im_scaled = np.clip(im_scaled, 0, 1)
# Need to make sure we have a channels dimension for the multiplication to work
im_scaled = np.atleast_3d(im_scaled)
# Reshape the color (here, we assume channels last)
color = np.asarray(color).reshape((1, 1, -1))
return im_scaled * color
[docs]
def wavelength_to_colormap(wav, gamma=0.6):
"""
Creates a linear Matplotlib colormap that transitions from black to a specific
color corresponding to a given electromagnetic wavelength.
Args:
wav (float): The wavelength in nanometers (nm) to determine the target
color. Typically, this would be in the visible spectrum range (~380
to 780 nm).
gamma (float, optional): The gamma correction factor applied when
calculating the RGB color from the wavelength. Defaults to 0.6.
Returns:
matplotlib.colors.LinearSegmentedColormap: A custom colormap object
named 'DarkToColor' that spans from black at the low end (0.0) to the
calculated wavelength-based color at the high end (1.0).
Example:
>>> cmap = wavelength_to_colormap(550, gamma=0.8) # Green color at 550nm
>>> print(cmap)
<matplotlib.colors.LinearSegmentedColormap object at ...>
"""
# 'dark_color' is the color at 0.0 (start)
dark_color = (
"black" # You can use any dark color (e.g., '#000033', 'black', 'darkred')
)
# 'target_color' is the color at 1.0 (end)
target_color = wavelength_to_rgb(wav, gamma)
# 2. Create the list of color nodes (tuples of position and color)
# The colormap will transition linearly between these nodes.
color_list = [(0.0, dark_color), (1.0, target_color)]
custom_cmap = LinearSegmentedColormap.from_list("DarkToColor", color_list)
return custom_cmap
[docs]
def generate_colormap(
wavelength: float, img_size: int = 256, gamma: float = 0.8
) -> np.ndarray:
"""Generates colormap for a wavelength.
Args:
wavelength (float):
Single wavelength used for colormap generation.
img_size (int):
Reconstructed image size.
gamma (float):
Gamma correction.
Returns:
np.ndarray:
Array with dimensions (img_size,4). Each column corresponds to the
RGBA values. A stands for alpha or transparency and is currently
set to 1.
"""
saturation = np.arange(0, 1, 1 / img_size)
r, g, b = wavelength_to_rgb(wavelength, gamma)
h, s, v = colorsys.rgb_to_hsv(r, g, b)
# Creating colormap RGBA (A stands for alpha or transparency)
colormap = np.ones((img_size, 4))
for i in range(img_size):
r, g, b = colorsys.hsv_to_rgb(h, v, saturation[i])
colormap[i, 0] = r
colormap[i, 1] = g
colormap[i, 2] = b
return colormap
[docs]
def plot_hs(
strategy,
img,
wav,
suptitle=None,
save_fig=False,
results_root=None,
data_folder=None,
colorbar_format=None,
):
r"""Plot hyperspectral data with wavelength-aware colormaps.
Creates a grid of subplots showing each spectral band with a colormap that
corresponds to the wavelength color. Each band is displayed with a custom
colormap generated from the actual wavelength values.
Args:
:attr:`strategy` (str): Strategy type, either 'slice' or 'bin'. Used for labeling.
:attr:`img` (np.ndarray): 3D numpy array with shape (height, width, n_wav) containing the hyperspectral data.
:attr:`wav` (array-like): Array of wavelength values in nanometers, length n_wav.
:attr:`suptitle` (str, optional): Super title for the entire figure. Defaults to None.
:attr:`save_fig` (bool, optional): Whether to save the figure as PDF. Defaults to False.
:attr:`results_root` (Path or str, optional): Root directory for saving figures. Required if save_fig is True. Defaults to None.
:attr:`data_folder` (Path or str, optional): Data folder name for organizing saved figures. Required if save_fig is True. Defaults to None.
:attr:`colorbar_format` (str, optional): printf-style format string used by matplotlib colorbar to format tick labels (e.g. '%.1f'). Defaults to '%.1f'.
Raises:
ValueError: If save_fig is True but results_root or data_folder is None.
Returns:
None: Displays the plot and optionally saves it.
"""
# Validate save parameters
if save_fig and (results_root is None or data_folder is None):
raise ValueError(
"results_root and data_folder must be provided when save_fig=True"
)
height, width, n_wav = img.shape
n_rows, n_cols = n_wav // 4, 4
ratio = height / width
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(3 * n_cols, 3 * n_rows * ratio),
gridspec_kw={"wspace": 0.3, "hspace": 0.05},
)
# Handle single row case
if n_rows == 1:
axes = axes.reshape(1, -1)
for i in range(n_wav):
ax = axes[i // n_cols, i % n_cols]
# Generate spectral-aware colormap for this wavelength
wavelength_nm = float(wav[i]) # Convert from tensor to float
spectral_cmap_data = generate_colormap(wavelength_nm, img_size=height * width)
spectral_cmap = ListedColormap(spectral_cmap_data)
# Display the grayscale image with spectral colormap
im = ax.imshow(img[:, :, i], cmap=spectral_cmap)
ax.set_title(f"{wavelength_nm:.0f} nm")
ax.axis("off")
if save_fig:
path_fig = Path(results_root) / data_folder / Path(f"{n_wav}_slices")
Path(path_fig).mkdir(parents=True, exist_ok=True)
plt.imsave(
path_fig / f"{strategy}_lambda_{int(wavelength_nm)}nm.pdf",
img[:, :, i],
cmap=spectral_cmap,
)
# Add colorbar with spectral colormap, closer to the axis
cax = fig.add_axes(
[
ax.get_position().x1 + 0.005,
ax.get_position().y0,
0.01,
ax.get_position().height,
]
)
if colorbar_format is None:
plt.colorbar(im, cax=cax)
else:
plt.colorbar(im, cax=cax, format=colorbar_format)
# Hide unused subplots
for i in range(n_wav, n_rows * n_cols):
axes[i // n_cols, i % n_cols].axis("off")
if save_fig:
plt.savefig(path_fig / f"hs_{strategy}_{suptitle}.pdf", bbox_inches="tight")
plt.suptitle(suptitle, fontsize=16) if suptitle else None
plt.show()