Source code for spyrit.misc.load_data

# -----------------------------------------------------------------------------
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 15 17:06:19 2020

@author: crombez
"""

import glob
import math
import os
from pathlib import Path
from typing import List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch

from spyrit.misc.matrix_tools import Sum_coll


[docs] def Files_names(Path, name_type): files = glob.glob(Path + name_type) print files.sort(key=os.path.getmtime) return [os.path.basename(x) for x in files]
[docs] def load_data_recon_3D(Path_files, list_files, Nl, Nc, Nh): Data = np.zeros((Nl, Nc, Nh)) for i in range(0, 2 * Nh, 2): Data[:, :, i // 2] = np.rot90( np.array(plt.imread(Path_files + list_files[i])) ) - np.rot90(np.array(plt.imread(Path_files + list_files[i + 1]))) return Data
# Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda # odl convention the set of data has to be arranged in such way that the positive part of the hadamard motifs comes first
[docs] def load_data_Comp_1D_old(Path_files, list_files, Nh, Nl, Nc): Data = np.zeros((Nl, Nh)) for i in range(0, 2 * Nh, 2): Data[:, i // 2] = Sum_coll( np.rot90(np.array(plt.imread(Path_files + list_files[i])), 3), Nl, Nc ) - Sum_coll( np.rot90(np.array(plt.imread(Path_files + list_files[i + 1])), 3), Nl, Nc, ) return Data
# Load the data of the hSPIM and compresse the spectrale dimensions to do the reconstruction for every lambda # new convention the set of data has to be arranged in such way that the negative part of the hadamard motifs comes first
[docs] def load_data_Comp_1D_new(Path_files, list_files, Nh, Nl, Nc): Data = np.zeros((Nl, Nh)) for i in range(0, 2 * Nh, 2): Data[:, i // 2] = Sum_coll( np.rot90(np.array(plt.imread(Path_files + list_files[i + 1])), 3), Nl, Nc, ) - Sum_coll( np.rot90(np.array(plt.imread(Path_files + list_files[i])), 3), Nl, Nc ) return Data
[docs] def download_girder( server_url: str, hex_ids: Union[str, list[str]], local_folder: str, file_names: Union[str, list[str]] = None, gc_type="file", ): """ Downloads data from a Girder server and saves it locally. This function first creates the local folder if it does not exist. Then, it connects to the Girder server and gets the file names for the files whose name are not provided. For each file, it checks if it already exists by checking if the file name is already in the local folder. If not, it downloads the file. Args: server_url (str): The URL of the Girder server. hex_id (str or list[str]): The hexadecimal id of the file(s) to download. If a list is provided, the files are downloaded in the same order and are saved in the same folder. local_folder (str): The path to the local folder where the files will be saved. If it does not exist, it will be created. file_name (str or list[str], optional): The name of the file(s) to save. If a list is provided, it must have the same length as hex_id. Each element equal to `None` will be replaced by the name of the file on the server. If None, all the names will be obtained from the server. Default is None. All names include the extension. gc_type (str, optional): The type of Girder item to download. Must be either "file" or "folder". Default is "file". Raises: ValueError: If the number of file names provided does not match the number of files to download. Returns: list[str]: The absolute paths to the downloaded files. """ # leave import in function, so that the module can be used without # girder_client import girder_client assert gc_type in ["file", "folder"], "gc_type must be 'file' or 'folder'" # check the local folder exists if not os.path.exists(local_folder): print("Local folder not found, creating it... ", end="") os.makedirs(local_folder) print("done.") # connect to the server gc = girder_client.GirderClient(apiUrl=server_url) # create lists if strings are provided if type(hex_ids) is str: hex_ids = [hex_ids] if file_names is None: file_names = [None] * len(hex_ids) elif type(file_names) is str: file_names = [file_names] if len(file_names) != len(hex_ids): raise ValueError("There must be as many file names as hex ids.") abs_paths = [] # for each file, check if it exists and download if necessary for id, name in zip(hex_ids, file_names): if name is None: # get the file name if gc_type == "file": name = gc.getFile(id)["name"] elif gc_type == "folder": name = gc.getFolder(id)["name"] # check the file exists if not os.path.exists(os.path.join(local_folder, name)): # connect to the server to download the file print(f"Downloading {name}... ", end="\r") if gc_type == "file": gc.downloadFile(id, os.path.join(local_folder, name)) elif gc_type == "folder": gc.downloadFolderRecursive(id, os.path.join(local_folder, name)) print(f"Downloading {name}... done.") else: print("File already exists at", os.path.join(local_folder, name)) abs_paths.append(os.path.abspath(os.path.join(local_folder, name))) return abs_paths[0] if len(abs_paths) == 1 else abs_paths
[docs] def read_acquisition( data_root: Path, data_folder: str, data_file_prefix: str ) -> Tuple[dict, np.ndarray]: """ Read acquisition data and metadata from experimental files. Args: data_root: Root directory of the data. data_folder: Folder containing the data. data_file_prefix: Prefix of the data files. Returns: Tuple of (acquisition_parameters, measurement_data). Raises: FileNotFoundError: If metadata or data files are not found. """ try: from spas.metadata2 import read_metadata, read_metadata_2arms except ImportError: raise ImportError( "Single-pixel acquisition software (SPAS) package is required to read metadata. Please install it (see https://github.com/openspyrit/spas)." ) data_root = Path(data_root) # Read metadata meta_path = data_root / data_folder / f"{data_file_prefix}_metadata.json" if not meta_path.exists(): raise FileNotFoundError(f"Metadata file not found: {meta_path}") try: metadata = read_metadata_2arms(meta_path) except Exception: print("Falling back to single-arm metadata reader") metadata = read_metadata(meta_path) # Read spectral data data_path = data_root / data_folder / f"{data_file_prefix}_spectraldata.npz" if not data_path.exists(): raise FileNotFoundError(f"Spectral data file not found: {data_path}") raw = np.load(data_path) meas = raw["spectral_data"] return metadata, meas
[docs] def generate_synthetic_tumors( x: torch.Tensor, tumor_params: List[dict], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Creates synthetic Gaussian tumors to a tensor of shape (batch, n_wav, *img_shape). Args: :attr:`x` (torch.Tensor): Input tensor of shape (batch, n_wav, *img_shape) :attr:`tumor_params` (List[dict]): List of tumor parameters. Each dict should contain: - :attr:`center`: (row, col) center position of the tumor - :attr:`sigma_x`: Standard deviation of the Gaussian in the x direction - :attr:`sigma_y`: Standard deviation of the Gaussian in the y direction - :attr:`amplitude`: Amplitude of the tumor - :attr:`channels`: List of channel indices to add the tumor to (if None, adds to all channels) - :attr:`angle` (optional): Rotation angle in degrees (counter-clockwise). Default is 0. Returns: Tuple of (tumors, x_with_tumors): - :attr:`tumors` (torch.Tensor): Tensor of the same shape as `x` containing only the tumor contributions. - :attr:`x_with_tumors` (torch.Tensor): Tensor of the same shape as `x` with the tumors added and values clamped to [0, 1]. """ dtype = x.dtype device = x.device _, n_wav, h, w = x.shape tumors = torch.zeros_like(x, dtype=dtype, device=device) # Create coordinate grids y_axis = torch.arange(h, dtype=dtype, device=device) x_axis = torch.arange(w, dtype=dtype, device=device) yy, xx = torch.meshgrid(y_axis, x_axis, indexing="ij") for tumor_param in tumor_params: center = tumor_param["center"] sigma_x = float(tumor_param["sigma_x"]) sigma_y = float(tumor_param["sigma_y"]) amplitude = float(tumor_param["amplitude"]) channels = tumor_param.get("channels", None) # Optional rotation angle in degrees (default 0). Positive rotates counter-clockwise. angle_deg = float(tumor_param.get("angle", 0.0)) theta = math.radians(angle_deg) # Coordinates relative to center x_rel = xx - float(center[1]) y_rel = yy - float(center[0]) # Rotate coordinates into the Gaussian's principal axes (apply R(-theta)) c = math.cos(theta) s = math.sin(theta) x_rot = c * x_rel + s * y_rel y_rot = -s * x_rel + c * y_rel # Avoid division by zero sigma_x = max(sigma_x, 1e-8) sigma_y = max(sigma_y, 1e-8) # Generate rotated ellipsoidal Gaussian gauss = amplitude * torch.exp( -(x_rot**2 / (2 * sigma_x**2) + y_rot**2 / (2 * sigma_y**2)) ) if channels is None: channels = list(range(n_wav)) tumors[:, channels, :, :] += gauss.unsqueeze(0).unsqueeze(0) return tumors, torch.clamp(x + tumors, 0.0, 1.0)