from typing import Union
import warnings
import torch
import numpy as np
from scipy.stats import rankdata
# from /misc/statistics.py
[docs]
def img2mask(Mat: np.ndarray, M: int):
"""Returns sampling mask from sampling matrix.
Args:
Mat (np.ndarray):
N-by-N sampling matrix, where high values indicate high significance.
M (int):
Number of measurements to be kept.
Returns:
Mask (np.ndarray):
N-by-N sampling mask, where 1 indicates the measurements to sample
and 0 that to discard.
"""
(nx, ny) = Mat.shape
Mask = np.ones((nx, ny))
ranked_data = np.reshape(rankdata(-Mat, method="ordinal"), (nx, ny))
Mask[np.absolute(ranked_data) > M] = 0
return Mask
# from /former/_model_Had_DCAN.py
[docs]
def meas2img(meas: np.ndarray, Mat: np.ndarray) -> np.ndarray:
r"""Returns measurement image from a single measurement vector or from a
batch of measurement vectors. This function is particulatly useful if the
number of measurements is less than the number of pixels in the image, i.e.
the image is undersampled.
Args:
meas : `np.ndarray` with shape :math:`(M)` or :math:`(B, M)` where
:math:`B` is the batch size and :math:`M` is the length of the
measurement vector.
Mat : `np.ndarray` with shape :math:`(N,N)`. Sampling matrix, where
high values indicate high significance. It must be tha matrix used to
generate the measurement vector.
Returns:
Img : `np.ndarray` with shape :math:`(N,N,)`. N-by-N measurement image
"""
# y = np.pad(meas, (0, Mat.size - len(meas)))
# # Perm = Permutation_Matrix(Mat)
# # Img = np.dot(np.transpose(Perm), y) #.reshape(Mat.shape)
# return Img.reshape(Mat.shape)
# y = np.pad(meas, ((0, 0), (0, Mat.size - meas.shape[0]))[2-meas.ndim:])
ndim = meas.ndim
if ndim == 1:
meas = meas.reshape(1, -1)
# meas is of shape (B, M), B is batch size
y_padded = np.zeros((meas.shape[0], Mat.size))
y_padded[:, : meas.shape[1]] = meas
Img = sort_by_significance(y_padded, Mat, axis="cols", inverse_permutation=False)
return Img.reshape((-1, *Mat.shape)[2 - ndim :])
[docs]
def meas2img2(meas: np.ndarray, Mat: np.ndarray) -> np.ndarray:
"""Return multiple measurement images from multiple measurement vectors.
It is essentially the same as `meas2img`, but the `meas` argument is
two-dimensional.
.. warning::
This function is deprecated. Use `spyrit.misc.sampling.meas2img` instead.
In meas2img, the batch dimension comes first: (B, M) instead of (M, B).
Args:
meas : `np.ndarray` with shape :math:`(M,B)`
Set of :math:`B` measurement vectors of lenth :math:`M \le N^2`.
Mat : `np.ndarray` with shape :math:`(N,N)`
Sampling matrix, where high values indicate high significance.
Returns:
Img : `np.ndarray` with shape :math:`(N,N,B)`
Set of :math:`B` images of shape :math:`(N,N)`
"""
warnings.warn(
"This function is deprecated. Use `spyrit.misc.sampling.meas2img` "
+ "instead. Beware the batch dimension has moved.",
DeprecationWarning,
)
return meas2img(np.moveaxis(meas, 0, -1), Mat)
# M, B = meas.shape
# Nx, Ny = Mat.shape
# y = np.pad(meas, ((0, Mat.size - len(meas)), (0, 0)))
# # Perm = Permutation_Matrix(Mat)
# # Img = Perm.T @ y
# Img = sort_by_significance(y, Mat, axis="rows", inverse_permutation=True)
# Img = Img.reshape((Nx, Ny, B))
# return Img
[docs]
def img2meas(Img: np.ndarray, Mat: np.ndarray) -> np.ndarray:
"""Return measurement vector from measurement image (not TESTED)
Args:
Img (np.ndarray):
N-by-N measurement image.
Mat (np.ndarray):
N-by-N sampling matrix, where high values indicate high significance.
Returns:
meas (np.ndarray):
Measurement vector of lenth M <= N**2.
"""
# Perm = Permutation_Matrix(Mat)
# meas = np.dot(Perm, np.ravel(Img))
meas = sort_by_significance(
np.ravel(Img), Mat, axis="rows", inverse_permutation=False
)
return meas
[docs]
def Permutation_Matrix(Mat: np.ndarray) -> np.ndarray:
"""
Returns permutation matrix from sampling matrix
Args:
Mat (np.ndarray):
N-by-N sampling matrix, where high values indicate high significance.
Returns:
P (np.ndarray): N^2-by-N^2 permutation matrix (boolean)
.. note::
Consider using :func:`sort_by_significance` for increased
computational performance if using :func:`Permutation_Matrix` to
reorder a matrix as follows:
``y = Permutation_Matrix(Ord) @ Mat``
"""
indices = np.argsort(-Mat.flatten(), kind="stable")
return np.eye(len(Mat.flatten()))[indices]
# (nx, ny) = Mat.shape
# Reorder = rankdata(-Mat, method="ordinal")
# Columns = np.array(range(nx * ny))
# P = np.zeros((nx * ny, nx * ny))
# P[Reorder - 1, Columns] = 1
# return P
[docs]
def sort_by_significance(
arr: np.ndarray,
sig: np.ndarray,
axis: str = "rows",
inverse_permutation: bool = False,
get_indices: bool = False,
) -> np.ndarray:
"""
Returns an array ordered by decreasing significance along the specified
dimension.
The significance values are given in the :math:`sig` array. The type of
the output is the same as the input array :math:`arr`.
This function is equivalent to (but faster) :func:`Permutation_Matrix` and
multiplying the input array by the permutation matrix. More specifically,
here are the four possible different calls and their equivalent::
h = 64
arr = np.random.randn(h, h)
sig = np.random.randn(h)
# 1
y = sort_by_significance(arr, sig, axis='rows', inverse_permutation=False)
y = Permutation_Matrix(sig) @ arr
# 2
y = sort_by_significance(arr, sig, axis='rows', inverse_permutation=True)
y = Permutation_Matrix(sig).T @ arr
# 3
y = sort_by_significance(arr, sig, axis='cols', inverse_permutation=False)
y = arr @ Permutation_Matrix(sig)
# 4
y = sort_by_significance(arr, sig, axis='cols', inverse_permutation=True)
y = arr @ Permutation_Matrix(sig).T
.. note::
:math:`arr` must have the same number of rows or columns as there are
elements in the flattened :math:`sig` array.
Args:
arr (np.ndarray or torch.tensor): Array to be ordered by rows or columns.
The output's type is the same as this parameter's type.
sig (np.ndarray or torch.tensor): Array containing the significance values.
axis (str, optional): Axis along which to order the array. Must be either 'rows' or
'cols'. Defaults to 'rows'.
inverse_permutation (bool, optional): If True, the permutation matrix is
transposed before being used. This is equivalent to using the inverse
permutation matrix. Defaults to False.
get_indices (bool, optional): If True, the function returns the indices of
the significance values in decreasing order. Defaults to False.
Shape:
- arr: :math:`(*, r, c)` or :math:`(c)`, where :math:`(*)` is any number of
dimensions, and :math:`r` and :math:`c` are the number of rows and columns respectively.
- sig: :math:`(r)` if axis is 'rows' or :math:`(c)` if axis is 'cols' (or any shape
that has the same number of elements). Not used if arr is 1D.
- Output: :math:`(*, r, c)` or :math:`(c)`
Returns:
Tuple of np.ndarray:
- **Array** :math:`arr` ordered by decreasing significance :math:`sig`
along its rows or columns.
- **Indices** :math:`indices` of the significance values in decreasing
order. This is useful if you want to reorder other arrays in the
same way.
"""
# compute indices in a stable way (otherwise quicksort messes up the order)
indices = np.argsort(-sig.flatten(), kind="stable")
if get_indices:
return (reindex(arr, indices, axis, inverse_permutation), indices)
return reindex(arr, indices, axis, inverse_permutation)
[docs]
def reindex(
values: np.ndarray,
indices: np.ndarray,
axis: str = "rows",
inverse_permutation: bool = False,
) -> np.ndarray:
"""Sorts a tensor along a specified axis using the indices tensor.
The indices tensor contains the new indices of the elements in the values
tensor. `values[0]` will be placed at the index `indices[0]`, `values[1]`
at `indices[1]`, and so on.
Args:
values (np.ndarray): Array to sort. Can be 1D, 2D, or any
multi-dimensional batch of 2D arrays.
indices (np.ndarray): Array containing the new indices
of the elements contained in `values`.
axis (str, optional): The axis to sort along. Must be either 'rows',
'cols' or None. If None, `values` is flattened before sorting,
and then reshaped to its original shape. If `values` is 1D, `axis` is
not used. Default is 'rows'.
inverse_permutation (bool, optional): Whether to apply the permutation
inverse. Default is False.
Raises:
ValueError: If `axis` is not 'rows' or 'cols'.
Returns:
np.ndarray: Array ordered by the given indices along
the specified axis. The type is the same as the input array `values`.
Example:
>>> values = np.array([[10, 20, 30], [100, 200, 300]])
>>> indices = np.array(2, 0, 1])
>>> reindex(values, indices, axis="cols")
array([[ 20, 30, 10],
[200, 300, 100]])
"""
reindices = indices.argsort()
if axis == "flatten" or values.ndim == 1:
out_shape = values.shape
values = values.flatten()
if inverse_permutation:
return values[reindices.argsort()].reshape(out_shape)
return values[reindices].reshape(out_shape)
# cols corresponds to last dimension
if axis == "cols":
if inverse_permutation:
return values[..., reindices.argsort()]
return values[..., reindices]
# rows corresponds to second-to-last dimension
# because it is equivalent to sorting along the last dimension of the
# transposed tensor, we need to transpose (inverse) the permutation
elif axis == "rows":
inverse_permutation = not inverse_permutation
if inverse_permutation:
return values[..., reindices.argsort(), :]
return values[..., reindices, :]
else:
raise ValueError("Invalid axis. Must be 'rows', 'cols' or 'flatten'.")
[docs]
def reorder(meas: np.ndarray, Perm_acq: np.ndarray, Perm_rec: np.ndarray) -> np.ndarray:
r"""Reorder measurement vectors
Args:
meas (np.ndarray):
Measurements with dimensions (:math:`M_{acq} \times K_{rep}`), where
:math:`M_{acq}` is the number of acquired patterns and
:math:`K_{rep}` is the number of acquisition repetitions
(e.g., wavelength or image batch).
Perm_acq (np.ndarray):
Permutation matrix used for acquisition
(:math:`N_{acq}^2 \times N_{acq}^2` square matrix).
Perm_rec (np.ndarray):
Permutation matrix used for reconstruction
(:math:`N_{rec} \times N_{rec}` square matrix).
Returns:
(np.ndarray):
Measurements with dimensions (:math:`M_{rec} \times K_{rep}`),
where :math:`M_{rec} = N_{rec}^2`.
.. note::
If :math:`M_{rec} < M_{acq}`, the input measurement vectors are
subsampled.
If :math:`M_{rec} > M_{acq}`, the input measurement vectors are
filled with zeros.
"""
# Dimensions (N.B: images are assumed to be square)
N_acq = int(Perm_acq.shape[0] ** 0.5)
N_rec = int(Perm_rec.shape[0] ** 0.5)
K_rep = meas.shape[1]
# Acquisition order -> natural order (fill with zeros if necessary)
if N_rec > N_acq:
# Square subsampling in the "natural" order
Ord_sub = np.zeros((N_rec, N_rec))
Ord_sub[:N_acq, :N_acq] = -np.arange(-(N_acq**2), 0).reshape(N_acq, N_acq)
Perm_sub = Permutation_Matrix(Ord_sub)
# Natural order measurements (N_acq resolution)
Perm_raw = np.zeros((2 * N_acq**2, 2 * N_acq**2))
Perm_raw[::2, ::2] = Perm_acq.T
Perm_raw[1::2, 1::2] = Perm_acq.T
meas = Perm_raw @ meas
# Zero filling (needed only when reconstruction resolution is higher
# than acquisition res)
zero_filled = np.zeros((2 * N_rec**2, K_rep))
zero_filled[: 2 * N_acq**2, :] = meas
meas = zero_filled
Perm_raw = np.zeros((2 * N_rec**2, 2 * N_rec**2))
Perm_raw[::2, ::2] = Perm_sub.T
Perm_raw[1::2, 1::2] = Perm_sub.T
meas = Perm_raw @ meas
elif N_rec == N_acq:
Perm_sub = Perm_acq[: N_rec**2, :].T
elif N_rec < N_acq:
# Square subsampling in the "natural" order
Ord_sub = np.zeros((N_acq, N_acq))
Ord_sub[:N_rec, :N_rec] = -np.arange(-(N_rec**2), 0).reshape(N_rec, N_rec)
Perm_sub = Permutation_Matrix(Ord_sub)
Perm_sub = Perm_sub[: N_rec**2, :]
Perm_sub = Perm_sub @ Perm_acq.T
# Reorder measurements when the reconstruction order is not "natural"
if N_rec <= N_acq:
# Get both positive and negative coefficients permutated
Perm = Perm_rec @ Perm_sub
Perm_raw = np.zeros((2 * N_rec**2, 2 * N_acq**2))
elif N_rec > N_acq:
Perm = Perm_rec
Perm_raw = np.zeros((2 * N_rec**2, 2 * N_rec**2))
Perm_raw[::2, ::2] = Perm
Perm_raw[1::2, 1::2] = Perm
meas = Perm_raw @ meas
return meas