from typing import Union
import warnings
import torch
import numpy as np
from scipy.stats import rankdata
from scipy.ndimage import label
from spyrit.core.torch import walsh_matrix_2d
import ptwt
# from /misc/statistics.py
[docs]
def img2mask(Mat: np.ndarray, M: int):
"""Returns sampling mask from sampling matrix.
Args:
Mat (np.ndarray):
N-by-N sampling matrix, where high values indicate high significance.
M (int):
Number of measurements to be kept.
Returns:
Mask (np.ndarray):
N-by-N sampling mask, where 1 indicates the measurements to sample
and 0 that to discard.
"""
nx, ny = Mat.shape
Mask = np.ones((nx, ny))
ranked_data = np.reshape(rankdata(-Mat, method="ordinal"), (nx, ny))
Mask[np.absolute(ranked_data) > M] = 0
return Mask
# from /former/_model_Had_DCAN.py
[docs]
def meas2img(meas: np.ndarray, Mat: np.ndarray) -> np.ndarray:
r"""Returns measurement image from a single measurement vector or from a
batch of measurement vectors. This function is particulatly useful if the
number of measurements is less than the number of pixels in the image, i.e.
the image is undersampled.
Args:
meas : `np.ndarray` with shape :math:`(M)` or :math:`(B, M)` where
:math:`B` is the batch size and :math:`M` is the length of the
measurement vector.
Mat : `np.ndarray` with shape :math:`(N,N)`. Sampling matrix, where
high values indicate high significance. It must be the matrix used to
generate the measurement vector.
Returns:
Img : `np.ndarray` with shape :math:`(N,N,)`. N-by-N measurement image
"""
# y = np.pad(meas, (0, Mat.size - len(meas)))
# # Perm = Permutation_Matrix(Mat)
# # Img = np.dot(np.transpose(Perm), y) #.reshape(Mat.shape)
# return Img.reshape(Mat.shape)
# y = np.pad(meas, ((0, 0), (0, Mat.size - meas.shape[0]))[2-meas.ndim:])
ndim = meas.ndim
if ndim == 1:
meas = meas.reshape(1, -1)
# meas is of shape (B, M), B is batch size
y_padded = np.zeros((meas.shape[0], Mat.size))
y_padded[:, : meas.shape[1]] = meas
Img = sort_by_significance(y_padded, Mat, axis="cols", inverse_permutation=False)
return Img.reshape((-1, *Mat.shape)[2 - ndim :])
[docs]
def meas2img2(meas: np.ndarray, Mat: np.ndarray) -> np.ndarray:
r"""Return multiple measurement images from multiple measurement vectors.
It is essentially the same as `meas2img`, but the `meas` argument is
two-dimensional.
.. warning::
This function is deprecated. Use `spyrit.misc.sampling.meas2img` instead.
In meas2img, the batch dimension comes first: (B, M) instead of (M, B).
Args:
meas : `np.ndarray` with shape :math:`(M,B)`. Set of :math:`B`
measurement vectors of length :math:`M \le N^2`.
Mat : `np.ndarray` with shape :math:`(N,N)`. Sampling matrix, where
high values indicate high significance.
Returns:
Img : `np.ndarray` with shape :math:`(N,N,B)`
Set of :math:`B` images of shape :math:`(N,N)`
"""
warnings.warn(
"This function is deprecated. Use `spyrit.misc.sampling.meas2img` "
+ "instead. Beware the batch dimension has moved.",
DeprecationWarning,
)
return meas2img(np.moveaxis(meas, 0, -1), Mat)
# M, B = meas.shape
# Nx, Ny = Mat.shape
# y = np.pad(meas, ((0, Mat.size - len(meas)), (0, 0)))
# # Perm = Permutation_Matrix(Mat)
# # Img = Perm.T @ y
# Img = sort_by_significance(y, Mat, axis="rows", inverse_permutation=True)
# Img = Img.reshape((Nx, Ny, B))
# return Img
[docs]
def img2meas(Img: np.ndarray, Mat: np.ndarray) -> np.ndarray:
"""Return measurement vector from measurement image (not TESTED)
Args:
Img (np.ndarray):
N-by-N measurement image.
Mat (np.ndarray):
N-by-N sampling matrix, where high values indicate high significance.
Returns:
meas (np.ndarray):
Measurement vector of lenth M <= N**2.
"""
# Perm = Permutation_Matrix(Mat)
# meas = np.dot(Perm, np.ravel(Img))
meas = sort_by_significance(
np.ravel(Img), Mat, axis="rows", inverse_permutation=False
)
return meas
[docs]
def Permutation_Matrix(Mat: np.ndarray) -> np.ndarray:
"""
Returns permutation matrix from sampling matrix
Args:
Mat (np.ndarray):
N-by-N sampling matrix, where high values indicate high significance.
Returns:
P (np.ndarray): N^2-by-N^2 permutation matrix (boolean)
.. note::
Consider using :func:`sort_by_significance` for increased
computational performance if using :func:`Permutation_Matrix` to
reorder a matrix as follows:
``y = Permutation_Matrix(Ord) @ Mat``
"""
indices = np.argsort(-Mat.flatten(), kind="stable")
return np.eye(len(Mat.flatten()))[indices]
# (nx, ny) = Mat.shape
# Reorder = rankdata(-Mat, method="ordinal")
# Columns = np.array(range(nx * ny))
# P = np.zeros((nx * ny, nx * ny))
# P[Reorder - 1, Columns] = 1
# return P
[docs]
def sort_by_significance(
arr: np.ndarray,
sig: np.ndarray,
axis: str = "rows",
inverse_permutation: bool = False,
get_indices: bool = False,
) -> np.ndarray:
"""
Returns an array ordered by decreasing significance along the specified
dimension.
The significance values are given in the :math:`sig` array. The type of
the output is the same as the input array :math:`arr`.
This function is equivalent to (but faster) :func:`Permutation_Matrix` and
multiplying the input array by the permutation matrix. More specifically,
here are the four possible different calls and their equivalent::
h = 64
arr = np.random.randn(h, h)
sig = np.random.randn(h)
# 1
y = sort_by_significance(arr, sig, axis='rows', inverse_permutation=False)
y = Permutation_Matrix(sig) @ arr
# 2
y = sort_by_significance(arr, sig, axis='rows', inverse_permutation=True)
y = Permutation_Matrix(sig).T @ arr
# 3
y = sort_by_significance(arr, sig, axis='cols', inverse_permutation=False)
y = arr @ Permutation_Matrix(sig)
# 4
y = sort_by_significance(arr, sig, axis='cols', inverse_permutation=True)
y = arr @ Permutation_Matrix(sig).T
.. note::
:math:`arr` must have the same number of rows or columns as there are
elements in the flattened :math:`sig` array.
Args:
arr (np.ndarray or torch.tensor): Array to be ordered by rows or columns.
The output's type is the same as this parameter's type.
sig (np.ndarray or torch.tensor): Array containing the significance values.
axis (str, optional): Axis along which to order the array. Must be either 'rows' or
'cols'. Defaults to 'rows'.
inverse_permutation (bool, optional): If True, the permutation matrix is
transposed before being used. This is equivalent to using the inverse
permutation matrix. Defaults to False.
get_indices (bool, optional): If True, the function returns the indices of
the significance values in decreasing order. Defaults to False.
Shape:
- arr: :math:`(*, r, c)` or :math:`(c)`, where :math:`(*)` is any number of
dimensions, and :math:`r` and :math:`c` are the number of rows and columns respectively.
- sig: :math:`(r)` if axis is 'rows' or :math:`(c)` if axis is 'cols' (or any shape
that has the same number of elements). Not used if arr is 1D.
- Output: :math:`(*, r, c)` or :math:`(c)`
Returns:
Tuple of np.ndarray:
- **Array** :math:`arr` ordered by decreasing significance :math:`sig`
along its rows or columns.
- **Indices** :math:`indices` of the significance values in decreasing
order. This is useful if you want to reorder other arrays in the
same way.
"""
# compute indices in a stable way (otherwise quicksort messes up the order)
indices = np.argsort(-sig.flatten(), kind="stable")
if get_indices:
return (reindex(arr, indices, axis, inverse_permutation), indices)
return reindex(arr, indices, axis, inverse_permutation)
[docs]
def reindex(
values: np.ndarray,
indices: np.ndarray,
axis: str = "rows",
inverse_permutation: bool = False,
) -> np.ndarray:
"""Sorts a tensor along a specified axis using the indices tensor.
The indices tensor contains the new indices of the elements in the values
tensor. `values[0]` will be placed at the index `indices[0]`, `values[1]`
at `indices[1]`, and so on.
Args:
values (np.ndarray): Array to sort. Can be 1D, 2D, or any
multi-dimensional batch of 2D arrays.
indices (np.ndarray): Array containing the new indices
of the elements contained in `values`.
axis (str, optional): The axis to sort along. Must be either 'rows',
'cols' or None. If None, `values` is flattened before sorting,
and then reshaped to its original shape. If `values` is 1D, `axis` is
not used. Default is 'rows'.
inverse_permutation (bool, optional): Whether to apply the permutation
inverse. Default is False.
Raises:
ValueError: If `axis` is not 'rows' or 'cols'.
Returns:
np.ndarray: Array ordered by the given indices along
the specified axis. The type is the same as the input array `values`.
Example:
>>> values = np.array([[10, 20, 30], [100, 200, 300]])
>>> indices = np.array([2, 0, 1])
>>> reindex(values, indices, axis="cols")
array([[ 20, 30, 10],
[200, 300, 100]])
"""
reindices = indices.argsort()
if axis == "flatten" or values.ndim == 1:
out_shape = values.shape
values = values.flatten()
if inverse_permutation:
return values[reindices.argsort()].reshape(out_shape)
return values[reindices].reshape(out_shape)
# cols corresponds to last dimension
if axis == "cols":
if inverse_permutation:
return values[..., reindices.argsort()]
return values[..., reindices]
# rows corresponds to second-to-last dimension
# because it is equivalent to sorting along the last dimension of the
# transposed tensor, we need to transpose (inverse) the permutation
elif axis == "rows":
inverse_permutation = not inverse_permutation
if inverse_permutation:
return values[..., reindices.argsort(), :]
return values[..., reindices, :]
else:
raise ValueError("Invalid axis. Must be 'rows', 'cols' or 'flatten'.")
[docs]
def reorder(meas: np.ndarray, Perm_acq: np.ndarray, Perm_rec: np.ndarray) -> np.ndarray:
r"""Reorder measurement vectors
Args:
meas (np.ndarray):
Measurements with dimensions (:math:`M_{acq} \times K_{rep}`), where
:math:`M_{acq}` is the number of acquired patterns and
:math:`K_{rep}` is the number of acquisition repetitions
(e.g., wavelength or image batch).
Perm_acq (np.ndarray):
Permutation matrix used for acquisition
(:math:`N_{acq}^2 \times N_{acq}^2` square matrix).
Perm_rec (np.ndarray):
Permutation matrix used for reconstruction
(:math:`N_{rec} \times N_{rec}` square matrix).
Returns:
(np.ndarray):
Measurements with dimensions (:math:`M_{rec} \times K_{rep}`),
where :math:`M_{rec} = N_{rec}^2`.
.. note::
If :math:`M_{rec} < M_{acq}`, the input measurement vectors are
subsampled.
If :math:`M_{rec} > M_{acq}`, the input measurement vectors are
filled with zeros.
"""
# Dimensions (N.B: images are assumed to be square)
N_acq = int(Perm_acq.shape[0] ** 0.5)
N_rec = int(Perm_rec.shape[0] ** 0.5)
K_rep = meas.shape[1]
# Acquisition order -> natural order (fill with zeros if necessary)
if N_rec > N_acq:
# Square subsampling in the "natural" order
Ord_sub = np.zeros((N_rec, N_rec))
Ord_sub[:N_acq, :N_acq] = -np.arange(-(N_acq**2), 0).reshape(N_acq, N_acq)
Perm_sub = Permutation_Matrix(Ord_sub)
# Natural order measurements (N_acq resolution)
Perm_raw = np.zeros((2 * N_acq**2, 2 * N_acq**2))
Perm_raw[::2, ::2] = Perm_acq.T
Perm_raw[1::2, 1::2] = Perm_acq.T
meas = Perm_raw @ meas
# Zero filling (needed only when reconstruction resolution is higher
# than acquisition res)
zero_filled = np.zeros((2 * N_rec**2, K_rep))
zero_filled[: 2 * N_acq**2, :] = meas
meas = zero_filled
Perm_raw = np.zeros((2 * N_rec**2, 2 * N_rec**2))
Perm_raw[::2, ::2] = Perm_sub.T
Perm_raw[1::2, 1::2] = Perm_sub.T
meas = Perm_raw @ meas
elif N_rec == N_acq:
Perm_sub = Perm_acq[: N_rec**2, :].T
elif N_rec < N_acq:
# Square subsampling in the "natural" order
Ord_sub = np.zeros((N_acq, N_acq))
Ord_sub[:N_rec, :N_rec] = -np.arange(-(N_rec**2), 0).reshape(N_rec, N_rec)
Perm_sub = Permutation_Matrix(Ord_sub)
Perm_sub = Perm_sub[: N_rec**2, :]
Perm_sub = Perm_sub @ Perm_acq.T
# Reorder measurements when the reconstruction order is not "natural"
if N_rec <= N_acq:
# Get both positive and negative coefficients permutated
Perm = Perm_rec @ Perm_sub
Perm_raw = np.zeros((2 * N_rec**2, 2 * N_acq**2))
elif N_rec > N_acq:
Perm = Perm_rec
Perm_raw = np.zeros((2 * N_rec**2, 2 * N_rec**2))
Perm_raw[::2, ::2] = Perm
Perm_raw[1::2, 1::2] = Perm
meas = Perm_raw @ meas
return meas
[docs]
def define_order(n: int, order: str, pdf: bool = False):
"""
Creation of a Hadamard pattern order
Parameters
----------
n : int
Dimension. (Patterns of size n by n)
order : string
Type of order.
pdf : bool, optional
If True the function returns a normalised PDF such that the sum of the values of the
output tensor is equal to one.
If False the output is the ranking associated to each pattern. The default is False.
Returns
-------
torch.tensor
tensor of size n by n containing the PDF or ranks.
"""
if not isinstance(n, int) or n <= 0 or (n & (n - 1)) != 0:
raise ValueError(f"n must be an integer power of 2, got {n}")
H = walsh_matrix_2d(n)
order_list = ["Sequency", "TV", "CC", "Variance"]
if order not in order_list:
raise ValueError(f"Order must be in {order_list}")
N = n**2
h = [H[i].reshape(n, n) for i in range(N)]
if order == "Sequency":
freq = torch.zeros(N)
for i in range(N):
freq_x = torch.sum(torch.diff(h[i], dim=1) != 0, dim=1)[
0
] # Number of sign changes per row
freq_y = torch.sum(torch.diff(h[i], dim=0) != 0, dim=0)[
0
] # Number of sign changes per column
freq[i] = freq_x + freq_y
freq_score = 1 / (
freq + 1e-8
) # Give highest score (weight) to lowest frequencies
freq_score[0] = 1 # Avoid too high value due to the division by zero
if pdf is False:
freq_rank = torch.argsort(freq_score, descending=True)
freq_order = torch.zeros(N)
for i in range(N):
freq_order[freq_rank[i]] = i
return freq_order.reshape(n, n)
else:
return (freq_score / freq_score.sum()).reshape(n, n)
elif order == "TV":
TV = []
for i in range(N):
dx = h[i][:, 1:] - h[i][:, :-1]
dy = h[i][1:, :] - h[i][:-1, :]
TV.append(torch.sqrt(dx[:-1, :] ** 2 + dy[:, :-1] ** 2).sum())
TV = torch.tensor(TV)
score_tv = 1 / (TV + 1e-8)
score_tv[0] = 1
if pdf is False:
TV_rank = torch.argsort(score_tv, descending=True)
TV_order = torch.zeros(N)
for i in range(N):
TV_order[TV_rank[i]] = i
return TV_order.reshape(n, n)
else:
return (score_tv / score_tv.sum()).reshape(n, n)
elif order == "CC":
CC_values = torch.zeros(N)
for i in range(N):
patt = np.asarray(h[i])
pos = (patt > 0).astype(int)
neg = (patt < 0).astype(int)
_, num_pos = label(pos)
_, num_neg = label(neg)
CC_values[i] = num_pos + num_neg
score_CC = 1 / (CC_values + 1e-8)
score_CC[0] = 1
if pdf is False:
CC_rank = torch.argsort(score_CC, descending=True)
CC_order = torch.zeros(N)
for i in range(N):
CC_order[CC_rank[i]] = i
return CC_order.reshape(n, n)
else:
return (score_CC / score_CC.sum()).reshape(n, n)
elif order == "Variance":
# The order matrix corresponding is obtained by computing the variance of the Hadamard coefficients of the images belonging to the ImageNet 2012 dataset.
# First, we download the covariance matrix from our warehouse. The covariance was computed from the ImageNet 2012 dataset and has a size of (64*64, 64*64).
from spyrit.misc.load_data import download_girder
# url of the warehouse
url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
dataId = "672207cbf03a54733161e95d" # for reconstruction (imageNet, 64)
data_folder = "./stat/"
cov_name = "Cov_64x64.pt"
# download the covariance matrix and get the file path
file_abs_path = download_girder(url, dataId, data_folder, cov_name)
try:
# Load covariance matrix for "variance subsampling"
Cov = torch.load(file_abs_path, weights_only=True)
print(f"Cov matrix {cov_name} loaded")
except (FileNotFoundError, OSError, RuntimeError):
# Set to the identity if not found for "naive subsampling"
Cov = torch.eye(64 * 64)
print(f"Cov matrix {cov_name} not found! Set to the identity")
from spyrit.core.torch import Cov2Var
Ord_variance = Cov2Var(Cov)
if pdf is False:
Var_rank = torch.argsort(Ord_variance.flatten(), descending=True)
Var_order = torch.zeros(N)
for i in range(N):
Var_order[Var_rank[i]] = i
return Var_order.reshape(n, n)
else:
return Ord_variance / Ord_variance.sum()
[docs]
def sampling_map_from_order(order: torch.tensor, M: int):
"""
Generate a sampling map from a given order (ranking) and number of measurements
Parameters
----------
order : torch.tensor
n by n matrix containing the rankings (order) corresponding to each Hadamard pattern.
M : int
Number of measurements.
Returns
-------
s_map : torch.tensor
n by n binary sampling map.
"""
if (torch.sum(order) - 1) < 1e-6:
raise ValueError("order must be a ranking of the patterns not a PDF.")
if M > order.shape[0] ** 2:
raise ValueError(
"The number of measurements M must be lower or equal than the number of patterns"
)
s_map = torch.zeros_like(order)
s_map[order < M] = 1
return s_map
[docs]
def sampling_map_VDS(pdf: torch.tensor, M: int, seed: int = 0):
"""
Define a VDS sampling scheme that follows a PDF.
Parameters
----------
pdf : torch.tensor
Probability distribution function.
M : int
Number of measurements.
seed : int, optional
Fixed seed for reproducibility. The default is 0.
Returns
-------
sampling_map : torch.tensor
Sampling map.
"""
if M < 1:
raise ValueError(f"M must be >= 1, got {M}")
torch.manual_seed(seed)
n = pdf.shape[0]
N = n**2
samp = torch.multinomial(pdf.reshape(N)[1:], M - 1, replacement=False) + 1
samp = torch.cat(
(torch.tensor([0]), samp)
) # Force the selection of the first pattern
sampling_map = torch.zeros(n, n)
sampling_map.reshape(N)[samp] = 1
return sampling_map
[docs]
def sampling_map_multilevel_VDS(
pdf: torch.tensor,
M: int,
levels: int,
J: int = 3,
wave: str = "sym8",
mode: str = "periodic",
seed: int = 0,
):
"""
Generation of a sampling map following a Multilevel VDS sampling scheme
Parameters
----------
pdf : torch.tensor
PDF (or order) used to discriminate the sampling levels.
M : int
Total number of measurements.
levels : int
Number of sampling levels.
J : int, optional
Number of wavelet decomposition levels. The default is 3.
wave : string, optional
Wavelet type. The default is 'sym8'.
mode : string, optional
Wavelet mode. The default is 'periodization'.
seed: int, optional
Fixed seed for reproducibility. Default is 0.
Returns
-------
sampling_map : torch.tensor
Multilevel sampling map.
"""
torch.manual_seed(seed)
n = pdf.shape[0]
N = n**2
H = walsh_matrix_2d(n)
# dwt = DWTForward(J=J, wave=wave, mode=mode)
lvl_sizes = torch.zeros(levels) # number of elements in each level
lvl_maps = torch.zeros(levels, n, n)
selected = 0 # Number of elements already selected
mu_kl = torch.zeros(
levels, J + 1
) # Local coherences per sampling and wavelet levels
sampling_map = torch.zeros(n, n)
m_k = torch.zeros(levels) # Number of measurements in each level
for k in range(levels):
lvl_sizes[k] = (n / (2 ** (levels - k - 1))) ** 2
lvl_sizes[k] -= torch.sum(lvl_sizes[:k])
mask_basis = torch.zeros(N)
mask_basis[selected : selected + int(lvl_sizes[k])] = 1
selected += int(lvl_sizes[k])
lvl_maps[k] = sort_by_significance(mask_basis, pdf).reshape(n, n)
H_k = H[
lvl_maps[k].reshape(N).int() == 1
] # Selection of the patterns in the desired level
mu_loc = torch.zeros(
int(lvl_sizes[k]), J + 1
) # Local coherences inside each level
for i in range(int(lvl_sizes[k])):
coeffs = ptwt.wavedec2(
H_k[i].reshape(n, n).unsqueeze(0).unsqueeze(0),
wavelet=wave,
mode=mode,
level=J,
)
mu_loc[i, 0] = torch.max(abs(coeffs[0]))
for j in range(J):
mu_loc[i, j + 1] = torch.max(abs(coeffs[1][2 - j]))
for l in range(J + 1):
mu_kl[k, l] = torch.max(abs(mu_loc[:, l]))
m_k[k] += mu_kl[k, l] * 2 ** (l + 1)
m = m_k / m_k.sum() * M # Normalise to have a total of M measurements
m = torch.round(m)
# Due to the rounding operation there might be slight mismatch that we must
# fix between m and M
if int(torch.sum(m)) < M:
m[0] += M - int(torch.sum(m))
if int(torch.sum(m)) > M:
m[levels - 1] -= int(torch.sum(m)) - M
selected_idx = torch.tensor([])
for k in range(levels):
if m[k] > lvl_sizes[k]:
remaining = m[k] - lvl_sizes[k]
m[k] = int(lvl_sizes[k])
m[k + 1] += int(remaining)
# Set of indices in the level
level_idx = torch.nonzero(lvl_maps[k].reshape(N), as_tuple=False)
# Draw uniformly the desired number of indices in this set
Omega_k_idx = torch.multinomial(
torch.ones(int(lvl_sizes[k])) / lvl_sizes[k], int(m[k]), replacement=False
)
# Apply the mask to select the indices
Omega_k = level_idx[Omega_k_idx.long()]
# Concatenate the list of selected indices in all levels
selected_idx = torch.cat((selected_idx, Omega_k))
sampling_map.reshape(N)[selected_idx.long()] = 1
return sampling_map
[docs]
def reorder_from_sampling_map(
meas: np.ndarray, Ord_acq: np.ndarray, s_map: np.ndarray
) -> np.ndarray:
"""
Reorder splitted measurements following a sampling map
Parameters
----------
meas : np.ndarray
Measurement array of size (2*N,C) with N the number of patterns acquired and C the number of channels.
Ord_acq : np.ndarray
(N,) Array containing the indices of the patterns corresponding to each measurement.
s_map : np.ndarray
(n,n) array containing the sampling map.
Returns
-------
meas_rec : np.ndarray
Reordered measurement vector.
"""
s_map = s_map.flatten()
C = meas.shape[1]
N_rec = int(
s_map[s_map == 1].shape[0]
) # Number of patterns used for reconstruction
# Pass from acquisition order to natural order
# If some patterns have not been acquired, their slots are filled with zeros
meas_nat = np.zeros((2 * len(s_map), C))
for i, j in enumerate(Ord_acq):
meas_nat[2 * j] = meas[2 * i]
meas_nat[2 * j + 1] = meas[2 * i + 1]
# Pass from natural order to reconstruction order
meas_rec = np.zeros((2 * N_rec, C))
j = 0
for i, val in enumerate(s_map):
if val == 1:
meas_rec[2 * j] = meas_nat[2 * i]
meas_rec[2 * j + 1] = meas_nat[2 * i + 1]
j += 1
return meas_rec