"""
Measurement operators, static and dynamic.
There are six classes contained in this module, each representing a different
type of measurement operator. Three of them are static, i.e. they are used to
simulate measurements of still images, and three are dynamic, i.e. they are used
to simulate measurements of moving objects, represented as a sequence of images.
The inheritance tree is as follows::
Linear -------> DynamicLinear
| |
V V
LinearSplit DynamicLinearSplit
| |
V V
HadamSplit2d DynamicHadamSplit2d
"""
import warnings
from typing import Any, Union
from collections.abc import Iterable
# import memory_profiler as mprof
import torch
import torch.nn as nn
from spyrit.core.warp import DeformationField
import spyrit.core.torch as spytorch
# =============================================================================
[docs]
class Linear(nn.Module):
r"""
Simulates linear measurements
.. math::
m =\mathcal{N}\left(Hx\right),
where :math:`\mathcal{N} \colon\, \mathbb{R}^M \to \mathbb{R}^M` represents a noise operator (e.g., Gaussian),
:math:`H\in\mathbb{R}^{M\times N}` is the acquisition matrix, :math:`x \in \mathbb{R}^N` is the signal of interest,
:math:`M` is the number of measurements, and :math:`N` is the dimension of the signal.
.. important::
The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array
(e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`).
Args:
:attr:`H` (:class:`torch.tensor`): measurement matrix (linear operator)
with shape :math:`(M, N)`. Only real values are supported.
:attr:`meas_shape` (tuple, optional): Shape of the underliying
multi-dimensional array :math:`X`. Must be a tuple of integers
:math:`(N_1, ... ,N_k)` such that :math:`\prod_k N_k = N`. If not, an
error is raised. Defaults to None.
:attr:`meas_dims` (tuple, optional): Dimensions of :math:`X` the
acquisition matrix applies to. Must be a tuple with the same length as
:attr:`meas_shape`. If not, an error is raised. Defaults to the last
dimensions of the multi-dimensional array :math:`X` (e.g., `(-2,-1)`
when `len(meas_shape)`).
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
Defaults to = `torch.nn.Identity()`.
Attributes:
:attr:`H` (:class:`torch.tensor`): (Learnable) measurement matrix of shape
:math:`(M, N)` initialized as :math:`H`.
:attr:`meas_shape` (tuple): Shape of the underlying
multi-dimensional array :math:`X`.
:attr:`meas_dims` (tuple): Dimensions the acquisition matrix applies to.
:attr:`meas_ndim` (int): Number of dimensions the
acquisition matrix applies to. This is `len(meas_dims)`
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
:attr:`M` (int): Number of measurements :math:`M`.
Example: (to be updated!)
Example 1:
>>> H = torch.rand([400, 1600])
>>> meas_op = Linear(H)
>>> print(meas_op)
Linear(
(noise_model): Identity()
)
"""
def __init__(
self,
H: torch.tensor,
meas_shape: Union[int, torch.Size, Iterable[int]] = None,
meas_dims: Union[int, torch.Size, Iterable[int]] = None,
*,
noise_model: nn.Module = nn.Identity(),
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__()
if meas_shape is None:
meas_shape = H.shape[-1]
if type(meas_shape) is int:
meas_shape = [meas_shape]
self.meas_shape = torch.Size(meas_shape)
if meas_dims is None:
meas_dims = list(range(-len(self.meas_shape), 0))
if type(meas_dims) is int:
meas_dims = [meas_dims]
self.meas_dims = torch.Size(meas_dims)
# don't store H if we use a HadamSplit
if not isinstance(self, HadamSplit2d):
self.H = nn.Parameter(H, requires_grad=False).to(dtype=dtype, device=device)
self.noise_model = noise_model
# additional attributes
self.M = H.shape[0]
self.meas_ndim = len(self.meas_dims)
self.N = self.meas_shape.numel()
self.last_dims = tuple(range(-self.meas_ndim, 0)) # for permutations
if len(self.meas_shape) != len(self.meas_dims):
raise ValueError("meas_shape and meas_dims must have the same length")
if H.ndim != 2:
raise ValueError("matrix must have 2 dimensions")
if H.shape[1] != self.N:
raise ValueError(
f"The number of columns in the matrix ({H.shape[1]}) does "
+ f"not match the number of measured items ({self.N}) "
+ f"in the measurement shape {self.meas_shape}."
)
# define the available matrices for reconstruction
self._available_pinv_matrices = ["H"]
self._selected_pinv_matrix = "H" # select default here (no choice)
@property
def device(self) -> torch.device:
# if we have a split object, it has a A matrix
if self.H.device == getattr(self, "A", self.H).device:
return self.H.device
else:
raise RuntimeError(
f"device undefined, H and A are on different device (found {self.H.device} and {self.A.device} respectively)"
)
@property
def dtype(self) -> torch.dtype:
# if we have a split object, it has a A matrix
if self.H.dtype == getattr(self, "A", self.H).dtype:
return self.H.dtype
else:
raise RuntimeError(
f"dtype undefined, H and A are of different dtype (found {self.H.dtype} and {self.A.dtype} respectively)"
)
@property
def matrix_to_inverse(self) -> str:
return self._selected_pinv_matrix
@property
def get_matrix_to_inverse(self) -> torch.tensor:
return getattr(self, self._selected_pinv_matrix)
[docs]
def set_matrix_to_inverse(self, matrix_name: str) -> None:
if matrix_name in self._available_pinv_matrices:
self._selected_pinv_matrix = matrix_name
else:
raise KeyError(
f"Matrix {matrix_name} not available for pinv. Available matrices: {self._available_pinv_matrices.keys()}"
)
[docs]
def measure(self, x: torch.tensor) -> torch.tensor:
r"""Simulate noiseless measurements
.. math::
m = Hx,
where :math:`H\in\mathbb{R}^{M\times N}` is the acquisition matrix,
:math:`x \in \mathbb{R}^N` is the signal of interest,
:math:`M` is the number of measurements, and
:math:`N` is the dimension of the signal.
.. note::
This method does not degrade measurement with noise. To do so, see :func:`~spyrit.core.meas.forward()`
Args:
:attr:`x` (:class:`torch.tensor`): A batch of signals :math:`x`. The
dimensions indexed by :attr:`self.meas_dims` must match the measurement
shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: A batch of measurement of shape :math:`(*, M)` where * denotes
all the dimensions of the input tensor that are not included in :attr:`self.meas_dims`.
Example:
(3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 10.
>>> H = torch.randn(10, 15)
>>> meas_op = Linear(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([3, 4, 10])
3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 10. The acquisition matrix applies to both dimensions -2 and -1.
>>> H = torch.randn(10, 60)
>>> meas_op = Linear(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([3, 10])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
x = self.vectorize(x)
x = torch.einsum("mn,...n->...m", self.H, x)
return x
[docs]
def forward(self, x: torch.tensor):
r"""Simulate noisy measurements
.. math::
m =\mathcal{N}\left(Hx\right),
where :math:`\mathcal{N} \colon\, \mathbb{R}^M \to \mathbb{R}^M` represents a noise operator (e.g., Gaussian), :math:`H\in\mathbb{R}^{M\times N}` is the acquisition matrix, :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`M` is the number of measurements, and :math:`N` is the dimension of the signal.
.. note::
This method degrades measurements with noise. To compute :math:`Hx` only, see :func:`~spyrit.core.meas.measure()`.
Args:
:attr:`x` (:class:`torch.tensor`): A batch of signals :math:`x`. The
dimensions indexed by :attr:`self.meas_dims` must match the measurement
shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: A batch of measurement of shape :math:`(*, M)` where * denotes
all the dimensions of the input tensor that are not included in :attr:`self.meas_dims`.
Example:
Example 1: (3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 10.
>>> H = torch.randn(10, 15)
>>> meas_op = Linear(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([3, 4, 10])
Example 2: 3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 10. The acquisition matrix applies to both dimensions -2 and -1.
>>> H = torch.randn(10, 60)
>>> meas_op = Linear(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([3, 10])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
x = self.measure(x)
x = self.noise_model(x)
return x
def unvectorize(self, input: torch.tensor) -> torch.tensor:
r"""Unflatten the measured dimensions.
This method first expands the last dimension into the measurement
shape (:attr:`self.meas_shape`), and then moves the expanded dimensions to
their original positions as defined by :attr:`self.meas_dims`.
Input:
input (:class:`torch.tensor`): A tensor of shape (:attr:`*, self.N`) where * denotes any batch size.
Output:
:class:`torch.tensor`: A tensor whose dimensions given by :attr:`self.meas_dims` have shape :attr:`self.meas_shape`.
See also:
For the opposite operation use :meth:`vectorize()`.
Example:
>>> import spyrit.core.meas as meas
>>> matrix = torch.randn(10, 60)
>>> meas_op = meas.Linear(matrix, meas_shape=(12, 5), meas_dims=(-1,-3))
>>> x = torch.randn(3, 7, 60)
>>> print(meas_op.unvectorize(x).shape)
torch.Size([3, 5, 7, 12])
"""
# unvectorize the last dimension
input = input.reshape(*input.shape[:-1], *self.meas_shape)
# move the measured dimensions to their original positions
if self.meas_dims != self.last_dims:
input = torch.movedim(input, self.last_dims, self.meas_dims)
return input
[docs]
def adjoint(self, m: torch.tensor, unvectorize=False):
r"""Apply adjoint of matrix H.
It computes
.. math::
x = H^Tm,
where :math:`H^T\in\mathbb{R}^{N\times M}` is the adjoint of the
acquisition matrix, :math:`m \in \mathbb{R}^M` is a measurement.
Args:
:attr:`m` (:class:`torch.tensor`): A batch of measurement
:math:`m` of shape :math:`(*, M)` where :math:`*` denotes all the
dimensions that are not included in :attr:`self.meas_dims`
:attr:`unvectorize` (:obj:`bool`, optional): Whether to unvectorize
the measurement dimensions. This calls
:meth:`~spyrit.core.meas.unvectorize()` after mutiplication by the
adjoint. Defaults to False.
Returns:
:class:`torch.tensor`: A batch of signals :math:`x`.
If :attr:`unvectorize` is :obj:`False`, :math:`x` has shape
:math:`(*, N)` where :math:`*` is the same as for :attr:`m`. If
:attr:`unvectorize` is :obj:`True`, :math:`x` is reshaped such that
the dimensions :attr:`self.meas_dims` match the measurement shape
:attr:`self.meas_shape`.
Example:
Example 2: (3, 4) measurements of length 10 produces (3, 4) signals
of length 10.
>>> H = torch.randn(10, 15)
>>> meas_op = Linear(H)
>>> m = torch.randn(3, 4, 10)
>>> x = meas_op.adjoint(m)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 2: 3 measurements of length 10 produces 3 signals of length
60
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.Linear(H, meas_shape=(15, 4))
>>> m = torch.randn(3, 10)
>>> x = meas_op.adjoint(m)
>>> print(x.shape)
torch.Size([3, 60])
Using unvectorize=True produces 3 signals of length (15, 4)
>>> x = meas_op.adjoint(m, unvectorize=True)
>>> print(x.shape)
torch.Size([3, 15, 4])
"""
m = torch.einsum("mn,...m->...n", self.H, m)
if unvectorize:
m = self.unvectorize(m)
return m
[docs]
def vectorize(self, input: torch.tensor) -> torch.tensor:
r"""Flatten the measured dimensions.
The tensor is flattened at the indicated `self.meas_dims` dimensions. The
flattened dimensions are then collapsed into one, which is the last
dimension of the output tensor.
Input:
input (:class:`torch.tensor`): A tensor whose dimensions given by :attr:`self.meas_dims` have shape :attr:`self.meas_shape`.
Output:
:class:`torch.tensor`: A tensor of shape (:attr:`*, self.meas_shape`) where * denotes all the dimensions of the input tensor not included in :attr:`self.meas_dims`.
See also:
For the opposite operation use :meth:`unvectorize()`.
Example:
>>> import spyrit.core.meas as meas
>>> matrix = torch.randn(10, 60)
>>> meas_op = meas.Linear(matrix, meas_shape=(12, 5), meas_dims=(-1,-3))
>>> x = torch.randn(3, 5, 7, 12)
>>> print(meas_op.vectorize(x).shape)
torch.Size([3, 7, 60])
"""
# move all measured dimensions to the end
if self.meas_dims != self.last_dims:
input = torch.movedim(input, self.meas_dims, self.last_dims)
# flatten the measured dimensions
input = input.reshape(*input.shape[: -self.meas_ndim], self.N)
return input
[docs]
def unvectorize(self, input: torch.tensor) -> torch.tensor:
r"""Unflatten the measured dimensions.
This method first expands the last dimension into the measurement
shape (:attr:`self.meas_shape`), and then moves the expanded dimensions to
their original positions as defined by :attr:`self.meas_dims`.
Input:
:class:`input` (:class:`torch.tensor`): A tensor of shape (:attr:`*, self.N`) where * denotes any batch size.
Output:
:class:`torch.tensor`: A tensor whose dimensions given by :attr:`self.meas_dims` have shape :attr:`self.meas_shape`.
See also:
For the opposite operation use :meth:`vectorize()`.
Example:
>>> import spyrit.core.meas as meas
>>> matrix = torch.randn(10, 60)
>>> meas_op = meas.Linear(matrix, meas_shape=(12, 5), meas_dims=(-1,-3))
>>> x = torch.randn(3, 7, 60)
>>> print(meas_op.unvectorize(x).shape)
torch.Size([3, 5, 7, 12])
"""
# unvectorize the last dimension
input = input.reshape(*input.shape[:-1], *self.meas_shape)
# move the measured dimensions to their original positions
if self.meas_dims != self.last_dims:
input = torch.movedim(input, self.last_dims, self.meas_dims)
return input
# =============================================================================
# =============================================================================
[docs]
class LinearSplit(Linear):
r"""
Simulate linear measurements by splitting an acquisition matrix
:math:`H\in \mathbb{R}^{M\times N}` that contains negative values.
In practice, only positive values can be implemented using a DMD.
Therefore, we acquire
.. math::
y =\mathcal{N}\left(Ax\right),
where :math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}` represents a noise operator (e.g., Gaussian), :math:`A \colon\, \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest., :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal.
Given a matrix :math:`H`, we define the positive DMD patterns :math:`A` from the positive and negative components :math:`H`. In practice, the even rows of :math:`A` contain the positive components of :math:`H`, while odd rows of :math:`A` contain the negative components of :math:`H`
.. math::
\begin{cases}
A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\
A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H).
\end{cases}
.. note::
:math:`H_{+}` and :math:`H_{-}` are such that :math:`H_{+} - H_{-} = H`.
.. important::
The vector :math:`x \in \mathbb{R}^N` represents a multi-dimensional array (e.g, an image :math:`X \in \mathbb{R}^{N_1 \times N_2}` with :math:`N = N_1 \times N_2`).
Args:
:attr:`H` (:class:`torch.tensor`): measurement matrix (linear operator)
with shape :math:`(M, N)`. Only real values are supported.
:attr:`meas_shape` (tuple, optional): Shape of the underliying
multi-dimensional array :math:`X`. Must be a tuple of integers
:math:`(N_1, ... ,N_k)` such that :math:`\prod_k N_k = N`. If not, an
error is raised. Defaults to None.
:attr:`meas_dims` (tuple, optional): Dimensions of :math:`X` the
acquisition matrix applies to. Must be a tuple with the same length as
:attr:`meas_shape`. If not, an error is raised. Defaults to the last
dimensions of the multi-dimensional array :math:`X` (e.g., `(-2,-1)`
when `len(meas_shape)`).
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`. Defaults to = `torch.nn.Identity`.
Attributes:
:attr:`A` (:class:`torch.tensor`): (Learnable) positive measurement
matrix of shape :math:`(2M, N)` initialized as :math:`A`.
:attr:`H` (:class:`torch.tensor`): (Learnable) measurement matrix of shape
:math:`(M, N)` initialized as :math:`H`.
:attr:`meas_shape` (tuple): Shape of the underliying
multi-dimensional array :math:`X`.
:attr:`meas_dims` (tuple): Dimensions the acquisition matrix applies to.
:attr:`meas_ndim` (int): Number of dimensions the
acquisition matrix applies to. This is `len(meas_dims)`
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
:attr:`M` (int): Number of measurements :math:`M`.
Examples:
Example 1: (3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 20.
>>> import torch
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([3, 4, 20])
Example 2: 3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 20. The acquisition matrix applies to both dimensions -2 and -1.
>>> import torch
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([3, 20])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
def __init__(
self,
H,
meas_shape=None,
meas_dims=None,
*,
noise_model=nn.Identity(),
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(
H,
meas_shape,
meas_dims,
noise_model=noise_model,
dtype=dtype,
device=device,
)
# split positive and negative components
pos, neg = nn.functional.relu(self.H), nn.functional.relu(-self.H)
A = torch.cat([pos, neg], 1).reshape(2 * self.M, self.N)
# A is built from self.H which is cast to device and dtype
self.A = nn.Parameter(A, requires_grad=False)
# define the available matrices for reconstruction
self._available_pinv_matrices = ["H", "A"]
self._selected_pinv_matrix = "H" # select default here
# HERE: device=device, dtype=dtype
[docs]
def measure(self, x: torch.tensor):
r"""Simulate noiseless measurements from matrix A.
It acquires
.. math::
y = Ax,
where :math:`A \in \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest., :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal.
Given a matrix :math:`H \in \mathbb{R}^{M\times N}`, we define the positive DMD patterns :math:`A` from the positive and negative components of :math:`H`.
.. note::
The acquisition matrix :math:`A` is given by :attr:`self.A`.
Args:
:attr:`x` (:class:`torch.tensor`): Signal :math:`x` whose
dimensions :attr:`self.meas_dims` must have shape
shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`y` of length :attr:`2\*self.M`.
Examples:
Example 1: (3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 20.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([3, 4, 20])
Example 2: 3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 20. The acquisition matrix applies to both dimensions -2 and -1.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([3, 20])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
x = self.vectorize(x)
x = torch.einsum("mn,...n->...m", self.A, x)
return x
[docs]
def measure_H(self, x: torch.tensor):
r"""Simulate noiseless measurements from matrix H.
It computes
.. math::
m = Hx,
where :math:`H \in \mathbb{R}^{M\times N}` is the acquisition matrix (that may contain negative values), :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`M` is the number of DMD patterns, and :math:`N` is the dimension of the signal.
.. note::
The acquisition matrix :math:`H` is given by :attr:`self.H`.
Args:
:attr:`x` (:class:`torch.tensor`): Signal :math:`x` whose
dimensions :attr:`self.meas_dims` must have shape
shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`self.M`.
Examples:
Example 1: (3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 10.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.measure_H(x)
>>> print(y.shape)
torch.Size([3, 4, 10])
Example 2: 3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 10. The acquisition matrix applies to both dimensions -2 and -1.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op.measure_H(x)
>>> print(y.shape)
torch.Size([3, 10])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
return super().measure(x)
[docs]
def adjoint(self, y: torch.tensor, unvectorize=False):
r"""Apply adjoint of matrix A.
It computes
.. math::
x = A^Ty,
where :math:`A \in \mathbb{R}^{2M\times N}` is the acquisition matrix (that may contain negative values) and :math:`y \in \mathbb{R}^{2M}` is a measurement vector.
.. note::
The acquisition matrix :math:`A` is given by :attr:`self.A`.
Args:
:attr:`y` (:class:`torch.tensor`): Measurement :math:`y` whose dimensions :attr:`self.meas_dims` must have shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: A batch of signals :math:`x` with shape :math:`(*, N)` where :math:`*` is the same as for :attr:`m`.
Examples:
Example 1: (3, 4) measurements of length 20 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) signals of length 15.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> y = torch.randn(3, 4, 20)
>>> x = meas_op.adjoint(y)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 2: 3 measurements of length 20 are measured with an acquisition matrix of shape (10, 60). This produces 3 signals of length 60.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> m = torch.randn(3, 20)
>>> x = meas_op.adjoint(m)
>>> print(x.shape)
torch.Size([3, 60])
"""
y = torch.einsum("mn,...m->...n", self.A, y)
if unvectorize:
y = self.unvectorize(y)
return y
[docs]
def adjoint_H(self, m: torch.tensor, unvectorize=False):
r"""Apply adjoint of matrix H.
It computes
.. math::
x = H^Tm,
where :math:`H \in \mathbb{R}^{M\times N}` is the acquisition matrix (that may contain negative values), :math:`m \in \mathbb{R}^M` is a measurement vector.
.. note::
The acquisition matrix :math:`H` is given by :attr:`self.H`.
Args:
:attr:`m` (:class:`torch.tensor`): Measurements :math:`m` whose dimensions :attr:`self.meas_dims` must have shape :attr:`self.meas_shape`.
Returns:
A batch of signals :math:`x`. If :attr:`unvectorize` is :obj:`False`, :math:`x` has shape :math:`(*, N)` where :math:`*` is the same as for :attr:`m`. If :attr:`unvectorize` is :obj:`True`, :math:`x` is reshaped such that the dimensions :attr:`self.meas_dims` have shape :attr:`self.meas_shape`.
Examples:
Example 1: (3, 4) measurements of length 10 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) signals of length 15.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> m = torch.randn(3, 4, 10)
>>> x = meas_op.adjoint_H(m)
>>> print(x.shape)
torch.Size([3, 4, 15])
Example 2: 3 measurements of length 10 are measured with an acquisition matrix of shape (10, 60). This produces 3 signals of length 60.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> m = torch.randn(3, 10)
>>> x = meas_op.adjoint_H(m)
>>> print(x.shape)
torch.Size([3, 60])
Using unvectorize=True produces 3 signals of length (15, 4)
>>> x = meas_op.adjoint_H(m, unvectorize=True)
>>> print(x.shape)
torch.Size([3, 15, 4])
"""
return super().adjoint(m, unvectorize=unvectorize)
[docs]
def forward(self, x: torch.tensor):
r"""Simulate noisy measurements from matrix A.
It computes
.. math::
y =\mathcal{N}\left(Ax\right),
where :math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}` represents a noise operator (e.g., Gaussian), where :math:`A \in \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns, :math:`x \in \mathbb{R}^N` is the signal of interest., :math:`2M` is the number of DMD patterns, and :math:`N` is the dimension of the signal.
Given a matrix :math:`H \in \mathbb{R}^{M\times N}`, we define the positive DMD patterns :math:`A` from the positive and negative components of :math:`H`.
.. note::
The acquisition matrix :math:`A` is given by :attr:`self.A`.
Args:
:attr:`x` (:class:`torch.tensor`): Signal :math:`x` whose
dimensions :attr:`self.meas_dims` must have shape
shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`y` of length :attr:`2*self.M`.
Examples:
Example 1: (3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 20.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([3, 4, 20])
Example 2: 3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 20. The acquisition matrix applies to both dimensions -2 and -1.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op(x)
>>> print(y.shape)
torch.Size([3, 20])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
return super().forward(x)
[docs]
def forward_H(self, x: torch.tensor):
r"""Simulate noisy measurements from matrix H.
It computes
.. math::
m =\mathcal{N}\left(Hx\right),
where :math:`\mathcal{N} \colon\, \mathbb{R}^M \to \mathbb{R}^M` represents a noise operator (e.g., Gaussian), :math:`H \in \mathbb{R}^{M\times N}` is the acquisition matrix (that may contain negative values), :math:`x \in \mathbb{R}^N` is the signal of interest, :math:`M` is the number of DMD patterns, and :math:`N` is the dimension of the signal.
.. note::
The acquisition matrix :math:`H` is given by :attr:`self.H`.
Args:
:attr:`x` (:class:`torch.tensor`): Signal :math:`x` whose
dimensions :attr:`self.meas_dims` must have shape
shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`self.M`.
Examples:
Example 1: (3, 4) signals of length 15 are measured with an acquisition matrix of shape (10, 15). This produces (3, 4) measurements of length 10.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 15)
>>> meas_op = meas.LinearSplit(H)
>>> x = torch.randn(3, 4, 15)
>>> y = meas_op.forward_H(x)
>>> print(y.shape)
torch.Size([3, 4, 10])
Example 2: 3 signals of length (15, 4) are measured with an acquisition matrix of shape (10, 60). This produces 3 measurements of length 10. The acquisition matrix applies to both dimensions -2 and -1.
>>> import spyrit.core.meas as meas
>>> H = torch.randn(10, 60)
>>> meas_op = meas.LinearSplit(H, meas_shape=(15, 4))
>>> x = torch.randn(3, 15, 4)
>>> y = meas_op.forward_H(x)
>>> print(y.shape)
torch.Size([3, 10])
>>> print(meas_op.meas_dims)
torch.Size([-2, -1])
"""
x = self.measure_H(x)
x = self.noise_model(x)
return x
# =============================================================================
# =============================================================================
[docs]
class HadamSplit2d(LinearSplit):
r"""Simulate 2D Hadamard split acquisitions.
Considering the acquisition of :math:`2M` square DMD patterns of size :math:`h`, it computes
.. math::
y =\mathcal{N}\left(\mathcal{S}\left(AXA^T\right)\right),
where :math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}` represents a noise operator (e.g., Gaussian), :math:`\mathcal{S} \colon\, \mathbb{R}^{2h\times 2h} \to \mathbb{R}^{2M}` is a subsampling operator, :math:`A \in \mathbb{R}_+^{2h\times h}` is the acquisition matrix that contains the positive and negative components of a Hadamard matrix, :math:`X \in \mathbb{R}^{h\times h}` is the (2D) image.
1. The matrix :math:`A` is obtained by splitting a Hadamard matrix :math:`H\in\mathbb{R}^{h\times h}` such that :math:`A[0::2, :] = H_{+}` and :math:`A[1::2, :] = H_{-}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`.
.. note::
:math:`H_{+} - H_{-} = H`.
2. The subsampling operator keeps the pixels that correspond to the :math:`M` largest values in the order matrix :math:`O\in\mathbb{R}^{h^2 \times h^2}`.
.. note::
Subsampling applies to :math:`H_{+}XH_{+}^T` and :math:`H_{-}XH_{-}^T` the same way, independently.
.. note::
The operator :math:`\mathcal{S}` returns a vector. In the case :math:`M=h^2` (no subsampling), :math:`\mathcal{S}` is the vectorization operator.
Args:
:attr:`h` (int): Image size :math:`h`. Must be a power of 2.
:attr:`order` (:class:`torch.tensor`, optional): Order matrix :math:`O` that defines the measurements to keep. The first component of :math:`y` will correspond to the index where :attr:`order` is the highest.
:attr:`fast` (bool, optional): Whether to use the fast Hadamard transform
algorithm. If False, it uses matrix-vector products. Defaults to True.
:attr:`reshape_output` (bool, optional): Whether reshape the output of adjoint and pinv methods to images. If False, output are vectors.
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
Defaults to `torch.nn.Identity()`.
:attr:`dtype` (:class:`torch.dtype`, optional): Data type of the measurement
matrix. Defaults to `torch.float32`.
:attr:`device` (:obj:`torch.device`, optional): Device of the measurement matrix.
Defaults to `torch.device("cpu")`.
.. note:
The argument :attr:`order` is particularly useful when rearranging the
measurements by decreasing variance. The variance matrix can simply be
put as `order`.
Attributes:
:attr:`H` (:class:`torch.tensor`): The 2D measurement matrix given by :math:`H\otimes H`.
:attr:`A` (:class:`torch.tensor`): The 2D acquisition matrix given by :math:`A\otimes A`.
:attr:`M` (int): Number of measurements :math:`M`.
:attr:`N` (int): Number of pixels in the image equal to :math:`h^2`.
:attr:`meas_shape` (torch.Size): Shape of the measurement patterns. Is
equal to :math:`(h, h)`.
:attr:`meas_dims` (torch.Size): Dimensions of the image the acquisition
matrix applies to. Is equal to `(-2, -1)`.
:attr:`H_static` (:class:`torch.tensor`): alias for :attr:`H`.
:attr:`H_pinv` (:class:`torch.tensor`, optional): The learnable pseudo inverse
measurement matrix :math:`H^\dagger` of shape :math:`(N, M)`.
:attr:`order` (:class:`torch.tensor`): Order matrix :math:`O`. It
is used by :func:`~spyrit.core.torch.sort_by_significance()`. Defaults to rectangular order (e.g., linear indices).
:attr:`indices` (:class:`torch.tensor`): Indices used to reorder the measurement vector. It is used by the method :meth:`reindex()`.
Example:
>>> import spyrit.core.meas as meas
>>> h = 32
>>> meas_op = meas.HadamSplit2d(h, 400)
>>> print(meas_op.H1d.shape)
torch.Size([32, 32])
>>> print(meas_op.M)
400
"""
def __init__(
self,
h: int,
M: int = None,
order: torch.tensor = None,
fast: bool = True,
reshape_output: bool = False,
*,
noise_model=nn.Identity(),
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
meas_dims = (-2, -1)
meas_shape = (h, h)
if M is None:
M = h**2
# call Linear constructor (avoid setting A)
super(LinearSplit, self).__init__(
torch.empty(h**2, h**2),
meas_shape,
meas_dims,
noise_model=noise_model,
dtype=dtype,
device=device,
)
if order is None:
order = torch.ones(h, h)
# 1D version of H
# H1d = spytorch.walsh_matrix(h).to(dtype=dtype, device=device)
# self.H1d = nn.Parameter(H1d, requires_grad=False)
self.H1d = nn.Parameter(spytorch.walsh_matrix(h), requires_grad=False).to(
dtype=dtype, device=device
)
self.M = M # supercharged self.M
self.order = order
self.indices = torch.argsort(-order.flatten(), stable=True).to(
dtype=torch.int32, device=self.device
)
self.fast = fast
self.reshape_output = reshape_output
@property
def dtype(self) -> torch.dtype:
return self.H1d.dtype
@property
def device(self) -> torch.device:
return self.H1d.device
@property
def H(self):
H = torch.kron(self.H1d, self.H1d)
H = self.reindex(H, "rows", False)
# !!!!
return H[: self.M, :]
@property
def A(self):
H = self.H
pos, neg = nn.functional.relu(H), nn.functional.relu(-H)
return torch.cat([pos, neg], 1).reshape(2 * self.M, self.N)
@property
def matrix_to_inverse(self):
return self.H
[docs]
def reindex(
self, x: torch.tensor, axis: str = "rows", inverse_permutation: bool = False
) -> torch.tensor:
"""Sorts a tensor along a specified axis using the indices tensor. The
indices tensor is contained in the attribute :attr:`self.indices`.
The indices tensor contains the new indices of the elements in the values
tensor. `values[0]` will be placed at the index `indices[0]`, `values[1]`
at `indices[1]`, and so on.
Using the inverse permutation allows to revert the permutation: in this
case, it is the element at index `indices[0]` that will be placed at the
index `0`, the element at index `indices[1]` that will be placed at the
index `1`, and so on.
.. note::
See :func:`~spyrit.core.torch.reindex()` for more details.
Args:
:attr:`values` (:class:`torch.tensor`): The tensor to sort. Can be 1D, 2D, or any
multi-dimensional batch of 2D tensors.
:attr:`axis` (str, optional): The axis to sort along. Must be either 'rows' or
'cols'. If `values` is 1D, `axis` is not used. Default is 'rows'.
:attr:`inverse_permutation` (bool, optional): Whether to apply the permutation
inverse. Default is False.
Raises:
ValueError: If `axis` is not 'rows' or 'cols'.
Returns:
:class:`torch.tensor`: The sorted tensor by the given indices along the
specified axis.
"""
return spytorch.reindex(x, self.indices.to(x.device), axis, inverse_permutation)
[docs]
def measure(self, x: torch.tensor) -> torch.tensor:
r"""Simulate noiseless measurements from matrix A.
It computes
.. math::
y =\mathcal{S}\left(AXA^T\right),
where :math:`\mathcal{S} \colon\, \mathbb{R}^{2h\times 2h} \to \mathbb{R}^{2M}` is the subsampling operator, :math:`A \colon\, \mathbb{R}_+^{2h\times h}` is the acquisition matrix that contains the positive and negative component of 2D Hadamard patterns, :math:`X \in \mathbb{R}^{h\times h}` is the (2D) image, :math:`2M` is the number of DMD patterns, and :math:`h` is the image size.
Args:
:attr:`x` (:class:`torch.tensor`): Image :math:`X` whose
dimensions :attr:`self.meas_dims` must have shape
shape :attr:`self.meas_shape`.
Returns:
Measurement vector :math:`y \in \mathbb{R}^{2M}`.
Examples:
Example 1: No subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h = 32
>>> Ord = torch.randn(h, h)
>>> meas_op = meas.HadamSplit2d(h)
>>> x = torch.empty(10, h, h).uniform_(0, 1)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([10, 2048])
Example 2: With subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h = 32
>>> Ord = torch.randn(h, h)
>>> meas_op = meas.HadamSplit2d(h, 49)
>>> x = torch.empty(8, 2, h, h).uniform_(0, 1)
>>> y = meas_op.measure_H(x)
>>> print(y.shape)
torch.Size([8, 2, 49])
"""
if self.fast:
return self.fast_measure(x)
else:
return super().measure(x)
[docs]
def measure_H(self, x: torch.tensor):
r"""Simulate noiseless measurements from matrix H.
It computes
.. math::
m =\mathcal{S}\left(HXH^T\right),
where :math:`\mathcal{S} \colon\, \mathbb{R}^{h\times h} \to \mathbb{R}^{M}` is the subsampling operator, :math:`H \colon\, \mathbb{R}^{h\times h}` is the Hadamard matrix, :math:`X \in \mathbb{R}^{h\times h}` is the (2D) image.
Args:
:attr:`x` (:class:`torch.tensor`): Image :math:`X` whose
dimensions :attr:`self.meas_dims` must have shape
shape :attr:`self.meas_shape`.
Returns:
Measurement vector :math:`m \in \mathbb{R}^{M}`.
Examples:
Example 1: No subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h = 32
>>> meas_op = meas.HadamSplit2d(h)
>>> x = torch.empty(h, h).uniform_(0, 1)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([2048])
Example 2: With subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h = 32
>>> meas_op = meas.HadamSplit2d(h, 49)
>>> x = torch.empty(8, 2, h, h).uniform_(0, 1)
>>> y = meas_op.measure(x)
>>> print(y.shape)
torch.Size([8, 2, 98])
"""
if self.fast:
return self.fast_measure_H(x)
else:
return super().measure_H(x)
[docs]
def adjoint_H(self, m: torch.tensor, unvectorize=False) -> torch.tensor:
r"""Apply the adjoint of matrix H.
Args:
:attr:`m` (:class:`torch.tensor`): Measurement :math:`m` length is :attr:`self.M`.
:attr:`unvectorize` (bool): whether to apply a :meth:`unvectorize`
operation at the end of the computation.
Returns:
Vectorized image vector :math:`x \in \mathbb{R}^{h^2}`
Examples:
Example 1: No subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h = 32
>>> meas_op = meas.HadamSplit2d(h)
>>> m = torch.empty(10, h*h).uniform_(0, 1)
>>> x = meas_op.adjoint_H(m)
>>> print(x.shape)
torch.Size([10, 1024])
Example 2: With subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h, M = 32, 49
>>> meas_op = meas.HadamSplit2d(h, M)
>>> m = torch.empty(8, 2, M).uniform_(0, 1)
>>> x = meas_op.adjoint_H(m)
>>> print(x.shape)
torch.Size([8, 2, 1024])
"""
if self.fast:
# fast_pinv takes 'vectorize' as argument
return self.fast_pinv(m, not unvectorize) * self.N
else:
return super().adjoint_H(m, unvectorize)
[docs]
def fast_measure(self, x: torch.tensor) -> torch.tensor:
r"""Simulate noiseless measurements from matrix A."""
Hx = self.fast_measure_H(x)
x_sum = Hx[..., None, 0] # indexing while keeping the original shape
y_pos, y_neg = (x_sum + Hx) / 2, (x_sum - Hx) / 2
new_shape = y_pos.shape[:-1] + (2 * self.M,)
y = torch.stack([y_pos, y_neg], -1).reshape(new_shape)
return y
[docs]
def fast_measure_H(self, x: torch.tensor) -> torch.tensor:
r"""Simulate noiseless measurements from matrix H."""
x = spytorch.mult_2d_separable(self.H1d, x)
x = self.vectorize(x)
x = x.index_select(dim=-1, index=self.indices)
# x = self.reindex(x, "rows", False)
return x[..., : self.M]
[docs]
def fast_pinv(self, m: torch.tensor, vectorize=False) -> torch.tensor:
r"""Apply the pseudo inverse of H.
Args:
:attr:`m` (:class:`torch.tensor`): Measurement :math:`m` of length :attr:`self.M`.
:attr:`vectorize` (bool): Whether to apply the :meth:`vectorize` method
after computation of the pseudo inverse.
Returns:
:class:`torch.tensor`: Vectorized image :math:`x` of length :attr:`self.N`.
.. note::
We use the separability of the 2D Hadamard transform. Only multiplications
with the "1D" Hadamard matrix (i.e., :attr:`self.H1d`) are required. If
the number of measurements is smaller than the number of pixels,
the measurement vector is zero-padded.
Examples:
Example 1: No subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h = 32
>>> meas_op = meas.HadamSplit2d(h)
>>> m = torch.empty(10, h*h).uniform_(0, 1)
>>> x = meas_op.fast_pinv(m)
>>> print(x.shape)
torch.Size([10, 32, 32])
Example 2: With subsampling
>>> import torch
>>> import spyrit.core.meas as meas
>>> h, M = 32, 49
>>> meas_op = meas.HadamSplit2d(h, M)
>>> m = torch.empty(8, 2, M).uniform_(0, 1)
>>> x = meas_op.fast_pinv(m)
>>> print(x.shape)
torch.Size([8, 2, 32, 32])
Example 3: Output are vectors, not images
>>> import torch
>>> import spyrit.core.meas as meas
>>> h, M = 32, 49
>>> meas_op = meas.HadamSplit2d(h, M)
>>> m = torch.empty(8, 2, M).uniform_(0, 1)
>>> x = meas_op.fast_pinv(m, vectorize=True)
>>> print(x.shape)
torch.Size([8, 2, 1024])
"""
if self.N != self.M:
m = torch.cat(
(m, torch.zeros(*m.shape[:-1], self.N - self.M, device=m.device)),
-1,
)
m = self.reindex(m, "cols", False)
m = self.unvectorize(m)
m = spytorch.mult_2d_separable(self.H1d, m) / self.N
if vectorize:
m = self.vectorize(m)
return m
[docs]
def fast_H_pinv(self) -> torch.tensor:
r"""Return the pseudo inverse of the matrix H"""
return self.H.T / self.N
# =============================================================================
[docs]
class DynamicLinear(Linear):
r"""Simulates linear measurements of a moving scene
.. math::
m = \mathcal{N}\left( \text{diag}(H x_{t=1, ..., M})\right),
where :math:`H\in\mathbb{R}^{M \times N}` is the acquisition matrix,
:math:`x_{t=1, ..., M} \in \mathbb{R}^{N \times M}` is the temporal signal of interest,
:math:`M` is both the number of measurements and the number of frames,
:math:`N` is the dimension of the signal within the field of view,
:math:`\text{diag}\colon\, \mathbb{R}^{M \times M} \to \mathbb{R}^M` extracts the diagonal of its input, and
:math:`\mathcal{N} \colon\, \mathbb{R}^M \to \mathbb{R}^M` represents a noise operator (e.g., Gaussian).
.. warning::
The current implementation only supports 2D spatial dimensions (i.e., images).
Consequently, meas_shape and img_shape must be tuples of two integers.
.. warning::
For each call, there must be **exactly** as many frames in :math:`x` as
there are measurements in the linear operator used to initialize the class.
Args:
:attr:`H` (:class:`torch.tensor`): measurement matrix (linear operator)
with shape :math:`(M, N)`. Only real values are supported.
:attr:`time_dim` (int): dimension index in the input tensor :math:`x` that
corresponds to time (i.e., the frames dimension).
:attr:`meas_shape` (tuple, optional): Shape of the measurement patterns.
Must be a tuple of two integers representing the height and width of the
patterns. If not specified, the shape is suppposed to be a square image.
If not, an error is raised. Defaults to None.
:attr:`meas_dims` (tuple, optional): Dimensions of :math:`x_{t=1, ..., M}` the
acquisition matrix applies to. Must be a tuple with the same length as
:attr:`meas_shape`. If not, an error is raised. Defaults to the last
dimensions of the multi-dimensional array :math:`x_{t=1, ..., M}` (e.g., `(-2,-1)`
when `len(meas_shape)=2`).
:attr:`img_shape` (tuple, optional): Shape of the image. Must be a tuple
of two integers representing the height and width of the image. If not
specified, the shape is taken as equal to `meas_shape`. Setting this
value is particularly useful when using an extended field of view [Maitre2024_2]_.
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
Defaults to `torch.nn.Identity()`.
:attr:`white_acq` (torch.tensor, optional): Eventual spatial gain resulting from
detector inhomogeneities and used for dynamic flat-field correction. It can be
determined from a "white acquisition" without any object. If None, no correction is
applied. Must have :attr:`self.meas_shape` shape.
Attributes:
:attr:`M` (int): Number of measurements.
:attr:`N` (int): Number of pixels in the field of view.
:attr:`L` (int): Number of pixels in the extended field of view.
:attr:`meas_shape` (tuple): Shape of the underlying multi-dimensional
array :math:`x` over the field of view.
:attr:`img_shape` (tuple): Shape of the underlying multi-dimensional
array :math:`x` over the extended field of view.
:attr:`H` (:class:`torch.tensor`): Static measurement matrix of shape
:math:`(M, N)` initialized as :math:`H`.
:attr:`H_dyn` (torch.tensor): Dynamic measurement matrix :math:`H_{\rm{dyn}}` of shape.
:math:`(M, L)`. Must be set using the method :meth:`build_dynamic_forward` before being accessed.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicLinear
>>>
>>> x = torch.rand([1, 400, 3, 50, 50]) # dummy RGB video with 400 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50))
>>> print(meas_op)
DynamicLinear(
(noise_model): Identity()
)
References:
[Maitre2024_2]_ Maitre, T., Bretin, E., Phan, R., Ducros, N., & Sdika, M. (2024, October).
Dynamic single-pixel imaging on an extended field of view without warping the patterns. In International
Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 275-284).
Cham: Springer Nature Switzerland. DOI: 10.1007/978-3-031-72104-5_27
[Maitre2026]_ (Submitted to TIP) Maitre, T., Bretin, E., Mahieu-Williame, L., Phan, R., Sdika, M., & Ducros, N. (2025).
Dual-arm motion-compensated single-pixel imaging. HAL Id: hal-05068181
"""
def __init__(
self,
H: torch.tensor,
time_dim: int,
meas_shape: Union[int, torch.Size, Iterable[int]] = None,
meas_dims: Union[int, torch.Size, Iterable[int]] = None,
img_shape: Union[int, torch.Size, Iterable[int]] = None,
*,
noise_model: nn.Module = nn.Identity(),
white_acq: torch.tensor = None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
super().__init__(
H,
meas_shape,
meas_dims,
noise_model=noise_model,
dtype=dtype,
device=device,
)
self.time_dim = time_dim
self.white_acq = white_acq
if self.time_dim in self.meas_dims:
raise RuntimeError(
f"The time dimension must not be in the measurement dimensions. Found {self.time_dim} in {self.meas_dims}."
)
if len(self.meas_shape) != 2:
raise NotImplementedError(
"Currently only 2D spatial dimensions are supported."
)
self.img_shape = img_shape if img_shape is not None else self.meas_shape
self.img_h, self.img_w = self.img_shape # for legacy
# self.h, self.w = self.meas_shape # for legacy
self.N = int(torch.prod(torch.tensor(self.meas_shape)))
self.L = int(torch.prod(torch.tensor(self.img_shape)))
# define the available matrices for reconstruction
self._available_pinv_matrices = ["H_dyn"]
self._selected_pinv_matrix = "H_dyn" # select default here
@property
def recon_mode(self) -> str:
"""Interpolation mode used for reconstruction."""
return self._recon_mode
@property
def H_dyn(self) -> torch.tensor:
"""Dynamic measurement matrix H_dyn."""
try:
return self._param_H_dyn.data
except AttributeError as e:
raise AttributeError(
"The dynamic measurement matrix H_dyn has not been set yet. "
+ "Please call build_dynamic_forward() before accessing the attribute H_dyn."
) from e
[docs]
def measure(self, x):
r"""Simulates noiseless measurements.
.. math::
m = \text{diag}(H x_{t=1, ..., M}),
where :math:`H \in \mathbb{R}^{M \times N}` is the acquisition matrix,
:math:`x_{t=1, ..., M} \in \mathbb{R}^{N \times M}` is the temporal signal of interest,
:math:`M` is both the number of measurements and frames,
:math:`N` is the dimension of the signal in the field of view, and
:math:`\text{diag}\colon\, \mathbb{R}^{M \times M} \to \mathbb{R}^M` extracts the diagonal of its input.
.. note::
This method does not degrade measurement with noise.
To do so, see :func:`~spyrit.core.meas.DynamicLinear.forward()`
Args:
:attr:`x` (:class:`torch.tensor`): A batch of temporal signals whose time
dimension matches :attr:`self.time_dim`, and measured dimensions matches :attr:`self.meas_dims`.
Returns:
:class:`torch.tensor`: A batch of measurement of shape :math:`(*, M)` where * denotes
all the dimensions of the input tensor that are not included in :attr:`self.meas_dims`.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicLinear
>>> from spyrit.core.noise import Poisson
>>>
>>> x = torch.rand([1, 400, 3, 50, 50]) # dummy RGB video with 400 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinear(
(noise_model): Poisson()
)
>>> m = meas_op.measure(x) # simulate noiseless dynamic measurements
>>> print(m.shape)
torch.Size([1, 3, 400])
"""
x = spytorch.center_crop(x, self.meas_shape)
# vectorize with the time dimension being the second-to-last dimension
x = self.vectorize(x)
# here index m is the number of mesurements, it is also the number of frames, ie. time dimension
x = torch.einsum("mn,...mn->...m", self.H, x)
return x
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy dynamic measurements.
.. math::
m = \mathcal{N}\left(\text{diag}(H x_{t=1, ..., M})\right),
where :math:`H \in \mathbb{R}^{M \times N}` is the acquisition matrix,
:math:`x_{t=1, ..., M} \in \mathbb{R}^{N \times M}` is the temporal signal of interest,
:math:`M` is both the number of measurements and frames,
:math:`N` is the dimension of the signal in the field of view, and
:math:`\text{diag}\colon\, \mathbb{R}^{M \times M} \to \mathbb{R}^M` extracts the diagonal of its input.
.. note::
This method degrades measurements with noise.
To compute :math:`Hx_{t=1, ..., M}` only, see :func:`~spyrit.core.meas.DynamicLinear.measure()`.
Args:
:attr:`x` (:class:`torch.tensor`): A batch of temporal signals whose time
dimension matches :attr:`self.time_dim`, and measured dimensions matches :attr:`self.meas_dims`.
Returns:
:class:`torch.tensor`: A batch of measurement of shape :math:`(*, M)` where * denotes
all the dimensions of the input tensor that are not included in :attr:`self.meas_dims`.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicLinear
>>> from spyrit.core.noise import Poisson
>>>
>>> x = torch.rand([1, 400, 3, 50, 50]) # dummy RGB video with 400 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinear(
(noise_model): Poisson()
)
>>>
>>> m = meas_op(x) # simulate noisy dynamic measurements
>>> print(m.shape)
torch.Size([1, 3, 400])
"""
x = self.measure(x)
x = self.noise_model(x)
return x
[docs]
def build_dynamic_forward(
self,
motion: DeformationField,
mode: str = "bilinear",
warping: str = "image",
verbose: bool = False,
) -> None:
r"""Builds the dynamic forward operator :math:`H_{\rm{dyn}}`.
.. math::
\text{diag}(H x_{t=1, ..., M}) = H_{\rm{dyn}} x,
where
:math:`x_{t=1, ..., M} \in \mathbb{R}^{N \times M}` is the temporal signal of interest,
:math:`H \in \mathbb{R}^{M \times N}` is the static acquisition matrix,
:math:`x \in \mathbb{R}^L` is the reference frame defined over an extended field-of-view, and
:math:`H_{\rm{dyn}} \in \mathbb{R}^{M \times L}` is the dynamic forward operator that compensates the motion.
The dynamic measurement matrix :math:`H_{\rm{dyn}}` is obtained by **motion-compensation**
to a reference time, leveraging known deformation field.
The output is stored in the attribute :attr:`self.H_dyn`.
.. important::
There are two ways of building the dynamic matrix, namely :attr:`warping='pattern'` or :attr:`warping='image'`.
When :attr:`warping='pattern'`, the input deformation field :attr:`motion` needs to be respectively the *inverse*
deformation field that compensates the motion.
When :attr:`warping='image'`, the input deformation field :attr:`motion` needs to be the *direct*
deformation field that induces the motion.
**Reminder**: When looking at the images vectors as continuous functions from :math:`\mathbb{R}^2` to :math:`\mathbb{R}`,
we define the **direct** deformation as the function :math:`u \colon \mathbb{Z}^3 \mapsto \mathbb{R}^2` such that,
for :math:`k \in \{1, ..., M\}` and :math:`(i, j) \in \mathbb{Z}^2`,
.. math::
x_{t=k}(i, j) = x_{t=1}(u(t=k, i, j))
The **inverse** deformation field is defined as :math:`v=u^{-1}`.
.. note::
Warping sharp patterns introduces a bias in the model due to interpolation artifacts.
We recommend to exploit the image regularity by setting :attr:`warping='image'`.
Args:
:attr:`motion` (DeformationField): Deformation field representing the
scene motion. Need to pass the direct deformation field when
:attr:`warping` is set to 'image', and the inverse deformation field when
:attr:`warping` is set to 'pattern'.
:attr:`mode` (str): Interpolation mode for constructing the dynamic matrix. Defaults to 'bilinear'.
:attr:`warping` (str): Choose between 'image' or 'pattern'. This parameter decides whether to warp
the patterns or the (unknown) image to recover when building the dynamic measurement matrix.
Defaults to 'image'.
Returns:
None. The dynamic measurement matrix is stored in the attribute :attr:`self.H_dyn`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinear
>>>
>>> def_field = DeformationField(torch.rand([400, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinear(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> print(meas_op.H_dyn.shape)
torch.Size([400, 2500])
References:
[Maitre2024_2]_ Maitre, T., Bretin, E., Phan, R., Ducros, N., & Sdika, M. (2024, October).
Dynamic single-pixel imaging on an extended field of view without warping the patterns. In International
Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 275-284).
Cham: Springer Nature Switzerland. DOI: 10.1007/978-3-031-72104-5_27
[Maitre2026]_ (Submitted to TIP) Maitre, T., Bretin, E., Mahieu-Williame, L., Phan, R., Sdika, M., & Ducros, N. (2025).
Dual-arm motion-compensated single-pixel imaging. HAL Id: hal-05068181
"""
# Deprecate boolean 'warping' values: accept only 'image' or 'pattern' going forward.
if isinstance(warping, bool):
warnings.warn(
"Passing a boolean for 'warping' is deprecated and will be removed in a future release. "
"Please pass the string 'image' or 'pattern' instead. "
f"Interpreting {warping!r} as {'pattern' if warping else 'image'}.",
DeprecationWarning,
)
warping = "pattern" if warping else "image"
if not isinstance(warping, str) or warping not in ("image", "pattern"):
raise ValueError("warping must be either 'image' or 'pattern'")
if self.img_shape != motion.img_shape:
raise RuntimeError(
"The measurement operator img_shape must be the same as the motion field."
)
if self.device != motion.device:
raise RuntimeError(
"The device of the motion and the measurement operator must be the same."
)
# store the method and mode in attribute
self._recon_mode = mode
self._recon_warping = warping
try:
del self._param_H_dyn
del self._param_H_dyn_pinv
warnings.warn(
"The dynamic measurement matrix pseudo-inverse H_pinv has "
+ "been deleted. Please call self.build_dynamic_forward_pinv() to "
+ "recompute it.",
UserWarning,
)
except AttributeError:
pass
n_frames = motion.n_frames
# get deformation field from motion
# scale from [-1;1] x [-1;1] to [0;width-1] x [0;height-1]
scale_factor = (torch.tensor(self.img_shape) - 1).to(self.device)
def_field = (motion.field + 1) / 2 * scale_factor
if isinstance(self, DynamicLinearSplit):
meas_pattern = self.A
else:
meas_pattern = self.H
if self.white_acq is not None:
# for eventual spatial gain
meas_pattern *= self.white_acq.ravel().unsqueeze(0)
if warping == "image":
# drawings of the kernels for bilinear and bicubic 'interpolation'
# 00 point 01
# +------+--------+
# | | |
# | | |
# +------+--------+ point
# | | |
# +------+--------+
# 10 11
# 00 01 point 02 03
# +-----------+-----+-----+-----------+
# | | | |
# | | | | |
# | 11 | | 12 |
# 10 +-----------+-----+-----+-----------+ 13
# | | | | |
# + - - - - - + - - + - - + - - - - - + point
# | | | | |
# 20 +-----------+-----+-----+-----------+ 23
# | 21 | | | 22 |
# | | | |
# | | | | |
# +-----------+-----+-----+-----------+
# 30 31 32 33
kernel_size = self._spline(torch.tensor([0]), mode).shape[1]
kernel_width = kernel_size - 1
kernel_n_pts = kernel_size**2
# Memory optimization: Clear CUDA cache before large allocations
if self.device.type == "cuda":
torch.cuda.empty_cache()
# PART 1: SEPARATE THE INTEGER AND DECIMAL PARTS OF THE FIELD
# _________________________________________________________________
# crop def_field to keep only measured area
# moveaxis because crop expects (h,w) as last dimensions
if verbose:
print("Part 1: separating integer and decimal parts of the field")
def_field = spytorch.center_crop(
def_field.moveaxis(-1, 0), self.meas_shape
).moveaxis(
0, -1
) # shape (n_frames, meas_h, meas_w, 2)
# coordinate of top-left closest corner
def_field_floor = def_field.floor().to(torch.int64)
# shape (n_frames, meas_h, meas_w, 2)
# compute decimal part in x y direction
dx, dy = torch.split((def_field - def_field_floor), [1, 1], dim=-1)
dx, dy = dx.squeeze(-1), dy.squeeze(-1)
# dx.shape = dy.shape = (n_frames, meas_h, meas_w)
del def_field
if self.device.type == "cuda":
torch.cuda.empty_cache()
# evaluate the spline at the decimal part
dxy = torch.einsum(
"iajk,ibjk->iabjk", self._spline(dy, mode), self._spline(dx, mode)
).reshape(n_frames, kernel_n_pts, self.N)
# shape (n_frames, kernel_n_pts, meas_h*meas_w)
# Memory optimization: explicitly delete large intermediate tensors
del dx, dy
if self.device.type == "cuda":
torch.cuda.empty_cache()
# PART 2: FLATTEN THE INDICES
# _________________________________________________________________
# we consider an expanded grid (img_h+k)x(img_w+k), where k is
# (kernel_width). This allows each part of the (kernel_size^2)-
# point grid to contribute to the interpolation.
# get coordinate of point _00
if verbose:
print("Part 2: flattening the indices")
def_field_00 = def_field_floor - (kernel_size // 2 - 1)
del def_field_floor
# shift the grid for phantom rows/columns
def_field_00 += kernel_width
# create a mask indicating if either of the 2 indices is out of bounds
# (w,h) because the def_field is in (x,y) coordinates
maxs = torch.tensor(
[self.img_w + kernel_width, self.img_h + kernel_width],
device=self.device,
)
mask = torch.logical_or(
(def_field_00 < 0).any(dim=-1), (def_field_00 >= maxs).any(dim=-1)
) # shape (n_frames, meas_h, meas_w)
# trash index receives all the out-of-bounds indices
trash = (maxs[0] * maxs[1]).to(torch.int64).to(self.device)
# if the indices are out of bounds, we put the trash index
# otherwise we put the flattened index (y*w + x)
flattened_indices = torch.where(
mask,
trash,
def_field_00[..., 0]
+ def_field_00[..., 1] * (self.img_w + kernel_width),
).reshape(n_frames, self.N)
del def_field_00, mask
if self.device.type == "cuda":
torch.cuda.empty_cache()
# PART 3: WARP H MATRIX WITH FLATTENED INDICES
# _________________________________________________________________
# Build 4 submatrices with 4 weights for bilinear interpolation
if verbose:
print("Part 3: building H_dyn matrix with flattened indices")
meas_dxy = meas_pattern.reshape(n_frames, 1, self.N).to(dxy.dtype) * dxy
del dxy, meas_pattern
# Memory optimization: Check if we need chunked processing
sparse_size = (self.img_h + kernel_width) * (self.img_w + kernel_width) + 1
max_memory_per_tensor = 2e8 # ~200MB limit per tensor
expected_size = (
n_frames * kernel_n_pts * sparse_size * meas_dxy.element_size()
)
if expected_size > max_memory_per_tensor:
if verbose:
print(
f"Using chunked processing to avoid OOM (tensor is expected to be {expected_size/1e9:.2f} GB)"
)
# Process in smaller chunks
chunk_size = max(
1,
int(
max_memory_per_tensor
/ (kernel_n_pts * sparse_size * meas_dxy.element_size())
),
)
if verbose:
print(f"Processing {n_frames} frames in chunks of {chunk_size}")
H_dyn_chunks = []
for i in range(0, n_frames, chunk_size):
end_idx = min(i + chunk_size, n_frames)
chunk_frames = end_idx - i
if verbose:
print(
f"Processing chunk {i//chunk_size + 1}/{(n_frames + chunk_size - 1)//chunk_size}: frames {i} to {end_idx-1}"
)
# Create smaller tensor for this chunk
meas_dxy_sorted_chunk = torch.zeros(
(chunk_frames, kernel_n_pts, sparse_size),
dtype=meas_dxy.dtype,
device=self.device,
)
# add at flattened_indices the values of meas_dxy for this chunk
meas_dxy_sorted_chunk.scatter_add_(
2,
flattened_indices[i:end_idx]
.unsqueeze(1)
.expand_as(meas_dxy[i:end_idx]),
meas_dxy[i:end_idx],
)
# drop last column (trash)
meas_dxy_sorted_chunk = meas_dxy_sorted_chunk[:, :, :-1]
# FOLD THE MATRIX for this chunk
fold = nn.Fold(
output_size=self.img_shape,
kernel_size=(kernel_size, kernel_size),
padding=kernel_width,
)
H_dyn_chunk = fold(meas_dxy_sorted_chunk).reshape(
chunk_frames, self.L
)
H_dyn_chunks.append(
H_dyn_chunk.clone()
) # Clone to ensure memory is copied
# Clean up chunk memory
del meas_dxy_sorted_chunk, H_dyn_chunk
if self.device.type == "cuda":
torch.cuda.empty_cache()
# Concatenate all chunks
H_dyn = torch.cat(H_dyn_chunks, dim=0)
del H_dyn_chunks
if verbose:
print("Chunked processing completed successfully")
else:
if verbose:
print("Using standard processing (tensor fits in memory)")
# Create a larger H_dyn that will be folded
meas_dxy_sorted = torch.zeros(
(n_frames, kernel_n_pts, sparse_size),
dtype=meas_dxy.dtype,
device=self.device,
)
# add at flattened_indices the values of meas_dxy
meas_dxy_sorted.scatter_add_(
2, flattened_indices.unsqueeze(1).expand_as(meas_dxy), meas_dxy
)
# drop last column (trash)
meas_dxy_sorted = meas_dxy_sorted[:, :, :-1]
# PART 4: FOLD THE MATRIX
# _________________________________________________________________
# define operator
fold = nn.Fold(
output_size=self.img_shape,
kernel_size=(kernel_size, kernel_size),
padding=kernel_width,
)
H_dyn = fold(meas_dxy_sorted).reshape(n_frames, self.L)
# Memory optimization: Clean up after folding
del meas_dxy_sorted
# Clean up remaining variables
del flattened_indices, meas_dxy
if self.device.type == "cuda":
torch.cuda.empty_cache()
elif warping == "pattern":
print(
"Be careful to use the inverse deformation field when warping patterns."
)
# Memory optimization: Clear cache before warping operations
if self.device.type == "cuda":
torch.cuda.empty_cache()
det = self._calc_det(def_field)
det = det.reshape((det.shape[0], -1))
meas_pattern = meas_pattern.reshape(
meas_pattern.shape[0], 1, self.meas_shape[0], self.meas_shape[1]
)
# Memory optimization: Use in-place operations when possible
meas_pattern_ext = torch.zeros(
(meas_pattern.shape[0], 1, self.img_shape[0], self.img_shape[1]),
dtype=motion.field.dtype, # Use correct dtype from start
device=self.device,
)
amp_max_h = (self.img_shape[0] - self.meas_shape[0]) // 2
amp_max_w = (self.img_shape[1] - self.meas_shape[1]) // 2
meas_pattern_ext[
:,
:,
amp_max_h : self.meas_shape[0] + amp_max_h,
amp_max_w : self.meas_shape[1] + amp_max_w,
] = meas_pattern
del meas_pattern
H_dyn = nn.functional.grid_sample(
meas_pattern_ext,
motion.field,
mode=mode,
padding_mode="zeros",
align_corners=True,
)
# Memory optimization: Clean up before final computation
del meas_pattern_ext
if self.device.type == "cuda":
torch.cuda.empty_cache()
H_dyn = H_dyn.reshape((H_dyn.shape[0], -1)) * det
del det
self._param_H_dyn = nn.Parameter(H_dyn, requires_grad=False).to(
self.device
) # store in _param_H_dyn
# Memory optimization: Clean up H_dyn variable (data is now in parameter)
del H_dyn
if self.device.type == "cuda":
torch.cuda.empty_cache()
if verbose:
print(
f"Final memory after storing H_dyn: {torch.cuda.memory_allocated()/1024**3:.2f} GB"
)
def _calc_det(self, def_field):
r"""Computes the determinant of a deformation field.
It is used for building the dynamic matrix with pattern warping.
Args:
:attr:`def_field` (:class:`torch.tensor`): a tensor of shape
(t, h, w, 2), where t, h, w can be any dimensions.
Returns:
:class:`torch.tensor`: The determinant for each frame of the
deformation field, it has shape (t, h, w).
"""
# Memory optimization: Clear cache before computation
if self.device.type == "cuda":
torch.cuda.empty_cache()
# def_field of shape (n_frames, img_shape[0], img_shape[1], 2) in range [0, h-1] x [0, w-1]
v1, v2 = def_field[:, :, :, 0], def_field[:, :, :, 1]
n_frames = def_field.shape[0]
# Memory-efficient gradient computation
# Compute gradients for v1
diff_v1_dim1 = torch.diff(v1, dim=1)
# ones_v1_dim1 = torch.ones(n_frames, 1, v1.shape[2], device=v1.device, dtype=v1.dtype)
last_v1_dim1 = diff_v1_dim1[:, -1:, :].clone() # replicate last difference
dy_v1 = torch.cat([diff_v1_dim1, last_v1_dim1], dim=1)
del diff_v1_dim1, last_v1_dim1
diff_v1_dim2 = torch.diff(v1, dim=2)
# ones_v1_dim2 = torch.ones(n_frames, v1.shape[1], 1, device=v1.device, dtype=v1.dtype)
last_v1_dim2 = diff_v1_dim2[:, :, -1:].clone() # replicate last difference
dx_v1 = torch.cat([diff_v1_dim2, last_v1_dim2], dim=2)
del diff_v1_dim2, last_v1_dim2, v1
# Compute gradients for v2
diff_v2_dim1 = torch.diff(v2, dim=1)
# ones_v2_dim1 = torch.ones(n_frames, 1, v2.shape[2], device=v2.device, dtype=v2.dtype)
last_v2_dim1 = diff_v2_dim1[:, -1:, :].clone() # replicate last difference
dy_v2 = torch.cat([diff_v2_dim1, last_v2_dim1], dim=1)
del diff_v2_dim1, last_v2_dim1
diff_v2_dim2 = torch.diff(v2, dim=2)
# ones_v2_dim2 = torch.ones(n_frames, v2.shape[1], 1, device=v2.device, dtype=v2.dtype)
last_v2_dim2 = diff_v2_dim2[:, :, -1:].clone() # replicate last difference
dx_v2 = torch.cat([diff_v2_dim2, last_v2_dim2], dim=2)
del diff_v2_dim2, last_v2_dim2, v2
# Compute determinant
det = dx_v1 * dy_v2 - dx_v2 * dy_v1
# Clean up
del dx_v1, dy_v1, dx_v2, dy_v2
if self.device.type == "cuda":
torch.cuda.empty_cache()
return det
[docs]
def measure_H_dyn(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless dynamic measurements with the dynamic matrix
.. math::
m = H_{\rm{dyn}} x
where :math:`H_{\rm{dyn}} \in \mathbb{R}^{M \times L}` is the dynamic acquisition matrix,
:math:`x \in \mathbb{R}^L` is the reference signal of interest,
:math:`M` is the number of measurements, and
:math:`L` is the dimension of the signal (with extended FOV).
.. warning::
This supposes the dynamic measurement matrix :math:`H_{\rm{dyn}}` has been set using the
:meth:`build_dynamic_forward()` method. An error will be raised otherwise.
Args:
:attr:`x` (torch.tensor): Batch of reference (static) signals. The
dimensions indexed by :attr:`self.meas_dims` must match the measurement
shape :attr:`self.img_shape`.
Returns:
torch.tensor: Measurement of the input signal. It has shape (..., M).
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinear
>>>
>>> def_field = DeformationField(torch.rand([400, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinear(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> m = meas_op.measure_H_dyn(x) # simulate noiseless dynamic measurements from dynamic matrix
>>> print(m.shape)
torch.Size([1, 3, 400])
"""
x = self.vectorize(x) # don't need to crop because H_dyn has extended FOV
x = torch.einsum("mn,...n->...m", self.H_dyn, x)
return x
[docs]
def forward_H_dyn(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy dynamic measurements with the dynamic matrix
.. math::
m = \mathcal{N}\left(H_{\rm{dyn}} x \right)
where :math:`H_{\rm{dyn}} \in \mathbb{R}^{M \times L}` is the dynamic acquisition matrix,
:math:`x \in \mathbb{R}^L` is the reference signal of interest,
:math:`M` is the number of measurements, and
:math:`L` is the dimension of the signal (with extended FOV).
.. warning::
This supposes the dynamic measurement matrix :math:`H_{\rm{dyn}}` has been set using the
:meth:`build_dynamic_forward()` method. An error will be raised otherwise.
Args:
:attr:`x` (torch.tensor): Batch of reference (static) signals. The
dimensions indexed by :attr:`self.meas_dims` must match the measurement
shape :attr:`self.img_shape`.
Returns:
torch.tensor: Measurement of the input signal. It has shape :math:`(*, M)` where :math:`*`
denotes all the dimensions that are not included in :attr:`self.meas_dims`
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinear
>>>
>>> def_field = DeformationField(torch.rand([400, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinear(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> m = meas_op.forward_H_dyn(x) # simulate noisy dynamic measurements from dynamic matrix
>>> print(m.shape)
torch.Size([1, 3, 400])
"""
x = self.vectorize(x) # don't need to crop because H_dyn has extended FOV
x = torch.einsum("mn,...n->...m", self.H_dyn, x)
x = self.noise_model(x)
return x
[docs]
def adjoint(self, m: torch.tensor, unvectorize=False) -> torch.tensor:
r"""Apply adjoint of matrix :math:`H_{\rm{dyn}}`.
It computes
.. math::
x = H_{\rm{dyn}}^\top m,
where :math:`H_{\rm{dyn}}^\top \in\mathbb{R}^{L \times M}` is the adjoint of the
dynamic acquisition matrix, :math:`m \in \mathbb{R}^M` is the measurement vector.
.. warning::
This supposes the dynamic measurement matrix :math:`H_{\rm{dyn}}` has been
set using the :meth:`build_dynamic_forward()` method. An error will be raised otherwise.
Args:
:attr:`m` (:class:`torch.tensor`): A batch of measurement
:math:`m` of shape :math:`(*, M)` where :math:`*` denotes all the
dimensions that are not included in :attr:`self.meas_dims`
Returns:
:class:`torch.tensor`: A batch of signals :math:`x`.
If :attr:`unvectorize` is :obj:`False`, :math:`x` has shape
:math:`(*, N)` where :math:`*` is the same as for :attr:`m`. If
:attr:`unvectorize` is :obj:`True`, :math:`x` is reshaped such that
the dimensions :attr:`self.meas_dims` match the measurement shape
:attr:`self.meas_shape`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinear
>>>
>>> def_field = DeformationField(torch.rand([400, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinear(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinear(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> m = meas_op(x_motion) # simulate noisy dynamic measurements
>>> H_dyn_adj_x = meas_op.adjoint(m) # apply adjoint of dynamic measurement matrix
>>> print(H_dyn_adj_x.shape)
torch.Size([1, 3, 2500])
"""
m = torch.einsum("mn,...m->...n", self.H_dyn, m)
if unvectorize:
m = self.unvectorize(m)
return m
[docs]
def vectorize(self, input: torch.tensor) -> torch.tensor:
r"""Flatten the measured dimensions.
The tensor is flattened at the indicated `self.meas_dims` dimensions. The
collapsed dimensions are then moved to the last dimension of the output tensor.
The time dimension is moved to the second-to-last position.
Input:
input (:class:`torch.tensor`): A tensor whose dimensions given by :attr:`self.meas_dims`
have shape :attr:`self.meas_shape`.
Output:
:class:`torch.tensor`: A tensor of shape (:attr:`*, self.M, self.meas_shape`) where * denotes
all the dimensions of the input tensor not included in :attr:`self.meas_dims`.
See also:
For the opposite operation use :meth:`unvectorize()`.
"""
# concatenate time and measurement dimensions
time_and_meas_dims = torch.Size([self.time_dim, *self.meas_dims])
time_and_last_dims = torch.Size(list(range(-len(self.meas_shape) - 1, 0)))
# move only if necessary
if time_and_meas_dims != time_and_last_dims:
input = torch.movedim(input, time_and_meas_dims, time_and_last_dims)
# flatten the last measured dimensions
# input = input.reshape(*input.shape[: -self.meas_ndim], self.N)
input = input.reshape(
*input.shape[: -self.meas_ndim], -1
) # this way it works even for img_shape and meas_shape
return input
[docs]
def unvectorize(self, input: torch.tensor) -> torch.tensor:
r"""Unflatten the measured dimensions.
This method expands the last dimension into the measurement or image
shape (:attr:`self.meas_shape` or :attr:`self.img_shape`), and then moves the
expanded dimensions to their original positions as defined by :attr:`self.meas_dims`.
Input:
:class:`input` (:class:`torch.tensor`): A tensor of shape (:attr:`*, self.N`)
or (:attr:`*, self.L`) where * denotes any batch size.
Output:
:class:`torch.tensor`: A tensor whose dimensions given by :attr:`self.meas_dims`
have shape :attr:`self.meas_shape` or :attr:`self.img_shape`.
Raises:
ValueError: If the last dimension of input is different from :attr:`self.N` or
:attr:`self.L`
See also:
For the opposite operation use :meth:`vectorize()`.
"""
if input.shape[-1] == self.N:
unflattened_shape = self.meas_shape
elif input.shape[-1] == self.L:
unflattened_shape = self.img_shape
else:
raise ValueError("Input of unvectorize has unexpected size in its last dim")
# unflatten the last dimension
input = input.reshape(*input.shape[:-1], *unflattened_shape)
# compare the dimensions
time_and_meas_dims = torch.Size([self.time_dim, *self.meas_dims])
time_and_last_dims = torch.Size(list(range(-len(unflattened_shape) - 1, 0)))
# move dimensions if necessary
if time_and_meas_dims != time_and_last_dims:
input = torch.movedim(input, time_and_last_dims, time_and_meas_dims)
return input
def _spline(self, dx, mode):
"""
Returns a 2D row-like tensor containing the values of dx evaluated at
each B-spline (2 values for bilinear, 4 for bicubic).
dx must be between 0 and 1.
Shapes
dx: (n_frames, meas_h, meas_w)
out: (n_frames, {2,4}, meas_h, meas_w)
"""
if mode == "bilinear":
ans = torch.stack((1 - dx, dx), dim=1)
elif mode == "bicubic":
ans = torch.stack(
(
(1 - dx) ** 3 / 6,
2 / 3 - dx**2 * (2 - dx) / 2,
2 / 3 - (1 - dx) ** 2 * (1 + dx) / 2,
dx**3 / 6,
),
dim=1,
)
elif mode == "schaum":
ans = torch.stack(
(
dx / 6 * (dx - 1) * (2 - dx),
(1 - dx / 2) * (1 - dx**2),
(1 + (dx - 1) / 2) * (1 - (dx - 1) ** 2),
1 / 6 * (dx + 1) * dx * (dx - 1),
),
dim=1,
)
else:
raise NotImplementedError(
f"The mode {mode} is invalid, please choose bilinear, "
+ "bicubic or schaum."
)
return ans.to(self.device)
# =============================================================================
[docs]
class DynamicLinearSplit(DynamicLinear):
# =========================================================================
r"""
Simulates linear measurements of a moving scene by splitting an acquisition matrix
:math:`H \in \mathbb{R}^{M \times N}` that contains negative values.
In practice, only positive values can be implemented using a DMD.
Therefore, we acquire
.. math::
y = \mathcal{N}\left(\text{diag}(A x_{t=1,..., 2M})\right),
where :math:`A \colon\, \mathbb{R}_+^{2M\times N}` is the acquisition
matrix that contains positive DMD patterns,
:math:`x_{t=1,..., 2M} \in \mathbb{R}^{N \times 2M}` is the temporal signal of interest,
:math:`2M` is both the number of DMD patterns (positives and negatives)
and the number of frames,
:math:`N` is the dimension of the signal within the field of view,
:math:`\text{diag}\colon\, \mathbb{R}^{2M \times 2M} \to \mathbb{R}^{2M}`
extracts the diagonal of its input, and
:math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}`
represents a noise operator (e.g., Gaussian).
Given a matrix :math:`H`, we define the positive DMD patterns :math:`A`
from the positive and negative components :math:`H`.
In practice, the even rows of :math:`A` contain the positive components of :math:`H`,
while odd rows of :math:`A` contain the negative components of :math:`H`.
.. math::
\begin{cases}
A[0::2, :] = H_{+}, \text{ with } H_{+} = \max(0,H),\\
A[1::2, :] = H_{-}, \text{ with } H_{-} = \max(0,-H).
\end{cases}
.. note::
:math:`H_{+}` and :math:`H_{-}` are such that :math:`H_{+} - H_{-} = H`.
.. warning::
For each call, there must be **exactly** twice as many images in :math:`x` as
there are measurements in the linear operator :math:`H`.
Args:
:attr:`H` (:class:`torch.tensor`): measurement matrix (linear operator)
with shape :math:`(M, N)`. Only real values are supported.
:attr:`time_dim` (int): dimension index in the input tensor :math:`x` that corresponds
to time (i.e., the frames dimension).
:attr:`meas_shape` (tuple, optional): Shape of the measurement patterns.
Must be a tuple of two integers representing the height and width of the
patterns. If not specified, the shape is suppposed to be a square image.
If not, an error is raised. Defaults to None.
:attr:`meas_dims` (tuple, optional): Dimensions of :math:`x_{t=1, ..., M}` the
acquisition matrix applies to. Must be a tuple with the same length as
:attr:`meas_shape`. If not, an error is raised. Defaults to the last
dimensions of the multi-dimensional array :math:`x_{t=1, ..., M}` (e.g., `(-2,-1)`
when `len(meas_shape)=2`).
:attr:`img_shape` (tuple, optional): Shape of the image. Must be a tuple
of two integers representing the height and width of the image. If not
specified, the shape is taken as equal to `meas_shape`. Setting this
value is particularly useful when using an extended field of view [Maitre2024_2]_.
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
Defaults to `torch.nn.Identity()`.
:attr:`white_acq` (torch.tensor, optional): Eventual spatial gain resulting from
detector inhomogeneities and used for dynamic flat-field correction. It can be
determined from a "white acquisition" without any object. If None, no correction is
applied. Must have :attr:`self.meas_shape` shape.
:attr:`dtype` (:class:`torch.dtype`, optional): Data type of the measurement
matrix. Defaults to `torch.float32`.
:attr:`device` (:obj:`torch.device`, optional): Device of the measurement matrix.
Defaults to `torch.device("cpu")`.
Attributes:
:attr:`M` (int): Number of (pos, neg) measurements.
:attr:`N` (int): Number of pixels in the field of view.
:attr:`L` (int): Number of pixels in the extended field of view.
:attr:`meas_shape` (tuple): Shape of the underlying multi-dimensional
array :math:`x` over the field of view.
:attr:`img_shape` (tuple): Shape of the underlying multi-dimensional
array :math:`x` over the extended field of view.
:attr:`H` (:class:`torch.tensor`): Static measurement matrix of shape
:math:`(M, N)` initialized as :math:`H`.
:attr:`A` (:class:`torch.tensor`): Splitted static measurement matrix of shape
:math:`(2M, N)` initialized as :math:`A`.
:attr:`H_dyn` (torch.tensor): Differential dynamic measurement matrix :math:`H_{\rm{dyn}}` of shape.
:math:`(M, L)`. Must be set using the :meth:`build_dynamic_forward` method before being accessed.
:attr:`A_dyn` (torch.tensor): Splitted dynamic measurement matrix :math:`A_{\rm{dyn}}` of shape.
:math:`(2M, L)`. Must be set using the :meth:`build_dynamic_forward` method before being accessed.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> x = torch.rand([1, 2*400, 3, 50, 50]) # dummy RGB video with 800 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50))
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Identity()
)
References:
[Maitre2024_2]_ Maitre, T., Bretin, E., Phan, R., Ducros, N., & Sdika, M. (2024, October).
Dynamic single-pixel imaging on an extended field of view without warping the patterns. In International
Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 275-284).
Cham: Springer Nature Switzerland. DOI: 10.1007/978-3-031-72104-5_27
[Maitre2026]_ (Submitted to TIP) Maitre, T., Bretin, E., Mahieu-Williame, L., Phan, R., Sdika, M., & Ducros, N. (2025).
Dual-arm motion-compensated single-pixel imaging. HAL Id: hal-05068181
"""
def __init__(
self,
H: torch.tensor,
time_dim: int,
meas_shape: Union[int, torch.Size, Iterable[int]] = None,
meas_dims: Union[int, torch.Size, Iterable[int]] = None,
img_shape: Union[int, torch.Size, Iterable[int]] = None,
*,
noise_model: nn.Module = nn.Identity(),
white_acq: torch.tensor = None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
# call constructor of DynamicLinear
super().__init__(
H,
time_dim,
meas_shape,
meas_dims,
img_shape,
noise_model=noise_model,
white_acq=white_acq,
dtype=dtype,
device=device,
)
# split positive and negative components
pos, neg = nn.functional.relu(self.H), nn.functional.relu(-self.H)
A = torch.cat([pos, neg], 1).reshape(2 * self.M, self.N)
# A is built from self.H which is cast to device and dtype
self.A = nn.Parameter(A, requires_grad=False)
# define the available matrices for reconstruction
self._available_pinv_matrices = ["H_dyn", "A_dyn"]
self._selected_pinv_matrix = "H_dyn" # select default here
@property
def A_dyn(self) -> torch.tensor:
"""Splitted dynamic measurement matrix computed with the call to build_dynamic_forward"""
try:
return self._param_H_dyn.data
except AttributeError as e:
raise AttributeError(
"The dynamic measurement matrix H_dyn has not been set yet. "
+ "Please call build_dynamic_forward() before accessing the attribute H_dyn_diff."
) from e
@property
def H_dyn(self) -> torch.tensor:
"""Dynamic measurement matrix H_dyn_diff that adopts the differential
measurement strategy as described in [ref_journal],
i.e., `H_dyn[0] - H_dyn[1]`, `H_dyn[2] - H_dyn[3]`, etc."""
try:
return self._param_H_dyn.data[::2] - self._param_H_dyn.data[1::2]
except AttributeError as e:
raise AttributeError(
"The dynamic measurement matrix H_dyn has not been set yet. "
+ "Please call build_dynamic_forward() before accessing the attribute H_dyn_diff."
) from e
[docs]
def measure(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless dynamic measurements from matrix A.
It acquires
.. math::
y = \text{diag}(A x_{t=1, ..., 2M}),
where :math:`A \in \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns,
:math:`x \in \mathbb{R}^{N \times 2M}` is the temporal signal of interest,
:math:`2M` is the number of DMD patterns and the number of frames,
:math:`N` is the dimension of the signal, and
:math:`\text{diag}\colon\, \mathbb{R}^{2M \times 2M} \to \mathbb{R}^{2M}`
extracts the diagonal of its input.
Given a matrix :math:`H \in \mathbb{R}^{M\times N}`,
we define the positive DMD patterns :math:`A` from the positive and negative components of :math:`H`.
.. note::
The acquisition matrix :math:`A` is given by :attr:`self.A`.
Args:
:attr:`x` (:class:`torch.tensor`): Batch of temporal signals :math:`x` whose
time dimensions :matches :attr:`self.time_dim` and measured dimensions matches
:attr:`self.meas_dims`
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`2\*self.M`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> x = torch.rand([1, 2*400, 3, 50, 50]) # dummy RGB video with 800 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> y = meas_op.measure(x) # simulate noiseless dynamic measurements
>>> print(y.shape)
torch.Size([1, 3, 800])
"""
x = spytorch.center_crop(x, self.meas_shape)
# vectorize with the time dimension being the second-to-last dimension
x = self.vectorize(x)
# here index m is the number of mesurements and the number of frames
x = torch.einsum("mn,...mn->...m", self.A, x)
return x
[docs]
def measure_H(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless dynamic measurements from matrix H.
It acquires
.. math::
m = \text{diag}(H x_{t=1, ..., M}),
where :math:`H \in \mathbb{R}^{M\times N}` is the measurement matrix (that may contain negative values),
:math:`x_{t=1, ..., M} \in \mathbb{R}^{N \times M}` is the temporal signal obtained from
averaging the positive and negative frames of :math:`x_{t=1, ..., 2M}`,
:math:`M` is the number of DMD patterns,
:math:`N` is the dimension of the signal, and
:math:`\text{diag}\colon\, \mathbb{R}^{M \times M} \to \mathbb{R}^{M}`
extracts the diagonal of its input.
.. note::
The acquisition matrix :math:`H` is given by :attr:`self.H`.
.. note::
Here the number of frames is 2M and the number of measurements is M.
Args:
:attr:`x` (:class:`torch.tensor`): Batch of temporal signals :math:`x` whose
time dimensions :matches :attr:`self.time_dim` and measured dimensions matches
:attr:`self.meas_dims`
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`self.M`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> x = torch.rand([1, 2*400, 3, 50, 50]) # dummy RGB video with 800 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> m = meas_op.measure_H(x) # simulate noiseless dynamic measurements from matrix H
>>> print(m.shape)
torch.Size([1, 3, 400])
"""
x = x.movedim(self.time_dim, 0)
x = (x[::2] + x[1::2]) / 2
x = x.movedim(0, self.time_dim)
return super().measure(x)
[docs]
def build_dynamic_forward(
self,
motion: DeformationField,
mode: str = "bilinear",
warping: str = "image",
verbose: bool = False,
) -> None:
r"""Builds the dynamic forward operator :math:`A_{\rm{dyn}}`.
.. math::
\text{diag}(A x_{t=1, ..., 2M}) = A_{\rm{dyn}} x,
where
:math:`x_{t=1, ..., 2M} \in \mathbb{R}^{N \times 2M}` is the temporal signal of interest,
:math:`A \in \mathbb{R}^{2M \times N}` is the splitted static acquisition matrix,
:math:`x \in \mathbb{R}^L` is the reference frame defined over an extended field-of-view, and
:math:`A_{\rm{dyn}} \in \mathbb{R}^{M \times L}` is the splitted dynamic forward operator that compensates the motion.
The dynamic measurement matrix :math:`A_{\rm{dyn}}` is obtained by **motion-compensation**
to a reference time, leveraging known deformation field.
The output is stored in the attribute :attr:`self.A_dyn`.
.. important::
There are two ways of building the dynamic matrix, namely :attr:`warping='pattern'` or :attr:`warping='image'`.
When :attr:`warping='pattern'`, the input deformation field :attr:`motion` needs to be respectively the *inverse*
deformation field that compensates the motion.
When :attr:`warping='image'`, the input deformation field :attr:`motion` needs to be the *direct*
deformation field that induces the motion.
**Reminder**: When looking at the images vectors as continuous functions from :math:`\mathbb{R}^2` to :math:`\mathbb{R}`,
we define the **direct** deformation as the function :math:`u \colon \mathbb{Z}^3 \mapsto \mathbb{R}^2` such that,
for :math:`k \in \{1, ..., 2M\}` and :math:`(i, j) \in \mathbb{Z}^2`,
.. math::
x_{t=k}(i, j) = x_{t=1}(u(t=k, i, j))
The *inverse* deformation field is defined as :math:`v=u^{-1}`.
.. note::
Warping sharp patterns introduces a bias in the model due to interpolation artifacts.
We recommend to exploit the image regularity by setting :attr:`warping='image'`.
.. note::
When working with splitted measurements, it is common practice to exploit the problem's linearity by using
a differential measurement strategy. This allows to eliminate ambient light and dark current offsets.
The attribute :attr:`H_dyn` applies the differential strategy **after** motion compensation to avoid
an additional error term [ref journal].
Args:
:attr:`motion` (DeformationField): Deformation field representing the
scene motion. Need to pass the deformation field when
:attr:`warping` is set to 'image', and the inverse deformation field when
:attr:`warping` is set to 'pattern'.
:attr:`mode` (str): Interpolation mode for constructing the dynamic matrix. Defaults to 'bilinear'.
:attr:`warping` (str): Choose between 'image' or 'pattern'. This parameter decides whether to warp
the patterns or the (unknown) image to recover when building the dynamic measurement matrix.
Defaults to 'image'.
Returns:
None. The dynamic measurement matrix is stored in the attribute :attr:`self.A_dyn`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> def_field = DeformationField(torch.rand([800, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> print(meas_op.A_dyn.shape)
torch.Size([800, 2500])
>>> print(meas_op.H_dyn.shape)
torch.Size([400, 2500])
References:
[Maitre2024_2]_ Maitre, T., Bretin, E., Phan, R., Ducros, N., & Sdika, M. (2024, October).
Dynamic single-pixel imaging on an extended field of view without warping the patterns. In International
Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 275-284).
Cham: Springer Nature Switzerland. DOI: 10.1007/978-3-031-72104-5_27
[Maitre2026]_ (Submitted to TIP) Maitre, T., Bretin, E., Mahieu-Williame, L., Phan, R., Sdika, M., & Ducros, N. (2025).
Dual-arm motion-compensated single-pixel imaging. HAL Id: hal-05068181
"""
# redefine to update doc for splitted measurements
super().build_dynamic_forward(motion, mode, warping, verbose)
[docs]
def adjoint(self, y: torch.tensor, unvectorize=False):
r"""Apply adjoint of matrix :math:`A_{\rm{dyn}}`.
It computes
.. math::
x = A_{\rm{dyn}}^\top y,
where :math:`A_{\rm{dyn}} \in \mathbb{R}^{2M\times L}` is the
dynamic acquisition matrix (that may contain negative values due to warping)
and :math:`y \in \mathbb{R}^{2M}` is a measurement vector.
.. warning::
This supposes the dynamic measurement matrix :math:`A_{\rm{dyn}}` has been
set using the :meth:`build_dynamic_forward()` method. An error will be raised otherwise.
.. note::
The acquisition matrix :math:`A_{\rm{dyn}}` is given by :attr:`self.A_dyn`.
It may contains negative values due to warping.
Args:
:attr:`y` (:class:`torch.tensor`): Measurement :math:`y` whose dimensions
:attr:`self.meas_dims` must have shape :attr:`self.meas_shape`.
Returns:
:class:`torch.tensor`: A batch of signals :math:`x` with shape :math:`(*, N)`
where :math:`*` is the same as for :attr:`m`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> def_field = DeformationField(torch.rand([800, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> y = meas_op(x_motion) # simulate noisy dynamic measurements
>>> A_dyn_adj_x = meas_op.adjoint(y) # apply adjoint of dynamic measurement matrix
>>> print(A_dyn_adj_x.shape)
torch.Size([1, 3, 2500])
"""
y = torch.einsum("mn,...m->...n", self.A_dyn, y)
if unvectorize:
y = self.unvectorize(y)
return y
[docs]
def adjoint_H_dyn(self, m: torch.tensor, unvectorize=False):
r"""Apply adjoint of matrix :math:`H_{\rm{dyn}}`.
It computes
.. math::
x = H_{\rm{dyn}}^\top m,
where :math:`H_{\rm{dyn}} \in \mathbb{R}^{M \times L}` is the
dynamic acquisition matrix (that may contain negative values),
:math:`m \in \mathbb{R}^M` is a measurement vector.
.. warning::
This supposes the dynamic measurement matrix :math:`H_{\rm{dyn}}` has been
set using the :meth:`build_dynamic_forward()` method. An error will be raised otherwise.
.. note::
The acquisition matrix :math:`H_{\rm{dyn}}` is given by :attr:`self.H_dyn`.
Args:
:attr:`m` (:class:`torch.tensor`): Measurements :math:`m` whose dimensions
:attr:`self.meas_dims` must have shape :attr:`self.meas_shape`.
Returns:
A batch of signals :math:`x`. If :attr:`unvectorize` is :obj:`False`, :math:`x` has
shape :math:`(*, L)` where :math:`*` is the same as for :attr:`m`. If :attr:`unvectorize`
is :obj:`True`, :math:`x` is reshaped such that the dimensions :attr:`self.meas_dims` have
shape :attr:`self.img_shape`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> def_field = DeformationField(torch.rand([800, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> m = meas_op.measure_H(x_motion) # simulate noiseless dynamic measurements from matrix H
>>> H_dyn_adj_x = meas_op.adjoint_H_dyn(m) # apply adjoint of dynamic measurement matrix (with differential strategy)
>>> print(H_dyn_adj_x.shape)
torch.Size([1, 3, 2500])
"""
return super().adjoint(m, unvectorize=unvectorize)
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy dynamic measurements from matrix A.
It acquires
.. math::
m = \mathcal{N}\left(\text{diag}(A x_{t=1, ..., 2M})\right),
where :math:`A \in \mathbb{R}_+^{2M\times N}` is the acquisition matrix that contains positive DMD patterns,
:math:`x \in \mathbb{R}^{N \times 2M}` is the temporal signal of interest,
:math:`2M` is the number of DMD patterns and the number of frames,
:math:`N` is the dimension of the signal,
:math:`\text{diag}\colon\, \mathbb{R}^{2M \times 2M} \to \mathbb{R}^{2M}`
extracts the diagonal of its input, and
:math:`\mathcal{N} \colon\, \mathbb{R}^{2M} \to \mathbb{R}^{2M}`
represents a noise operator (e.g., Gaussian).
Given a matrix :math:`H \in \mathbb{R}^{M\times N}`,
we define the positive DMD patterns :math:`A` from the positive and negative components of :math:`H`.
.. note::
The acquisition matrix :math:`A` is given by :attr:`self.A`.
Args:
:attr:`x` (:class:`torch.tensor`): Video signal :math:`x` whose
dimensions :attr:`self.meas_dims` must be of shape :attr:`self.meas_shape`
and dimension :attr:`self.time_dim` must be of size :attr:`2 * self.M`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`2\*self.M`.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicLinearSplit
>>> from spyrit.core.noise import Poisson
>>>
>>> x = torch.rand([1, 800, 3, 50, 50]) # dummy RGB video with 400 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> y = meas_op(x) # simulate noisy dynamic measurements
>>> print(y.shape)
torch.Size([1, 3, 800])
"""
# it is ok to use super().forward, because measure method has been redefined
return super().forward(x)
[docs]
def forward_H(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy dynamic measurements from matrix H.
It acquires
.. math::
m = \mathcal{N}\left(\text{diag}(H x_{t=1, ..., M})\right),
where :math:`H \in \mathbb{R}^{M\times N}` is the measurement matrix (that may contain negative values),
:math:`x_{t=1, ..., M} \in \mathbb{R}^{N \times M}` is the temporal signal obtained from
averaging the positive and negative frames of :math:`x_{t=1, ..., 2M}`,
:math:`M` is the number of DMD patterns,
:math:`N` is the dimension of the signal,
:math:`\text{diag}\colon\, \mathbb{R}^{M \times M} \to \mathbb{R}^{M}`
extracts the diagonal of its input, and
:math:`\mathcal{N} \colon\, \mathbb{R}^{M} \to \mathbb{R}^{M}`
represents a noise operator (e.g., Gaussian).
.. note::
The acquisition matrix :math:`H` is given by :attr:`self.H`.
.. note::
Here the number of frames is 2M and the number of measurements is M.
Args:
:attr:`x` (:class:`torch.tensor`): Video signal :math:`x` whose
dimensions :attr:`self.meas_dims` must be of shape :attr:`self.meas_shape`
and dimension :attr:`self.time_dim` must be of size :attr:`2 * self.M`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`self.M`.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicLinearSplit
>>> from spyrit.core.noise import Poisson
>>>
>>> x = torch.rand([1, 800, 3, 50, 50]) # dummy RGB video with 400 frames of size 50x50
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> m = meas_op.forward_H(x) # simulate noisy dynamic measurements
>>> print(m.shape)
torch.Size([1, 3, 400])
"""
x = self.measure_H(x)
x = self.noise_model(x)
return x
[docs]
def measure_A_dyn(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless dynamic measurements with the splitted dynamic matrix
.. math::
y = A_{\rm{dyn}} x
where :math:`A_{\rm{dyn}} \in \mathbb{R}^{2 M \times L}` is the dynamic acquisition matrix,
:math:`x \in \mathbb{R}^L` is the reference signal of interest,
:math:`M` is the number of measurements, and
:math:`L` is the dimension of the signal (with extended FOV).
.. warning::
This supposes the dynamic measurement matrix :math:`A_{\rm{dyn}}` has been set using the
:meth:`build_dynamic_forward()` method. An error will be raised otherwise.
Args:
:attr:`x` (torch.tensor): Batch of reference (static) signals. The
dimensions indexed by :attr:`self.meas_dims` must match the measurement
shape :attr:`self.img_shape`.
Returns:
torch.tensor: Measurement of the input signal. It has shape :math:`(*, M)` where :math:`*`
denotes all the dimensions that are not included in :attr:`self.meas_dims`
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> def_field = DeformationField(torch.rand([800, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> y = meas_op.measure_A_dyn(x) # simulate noiseless dynamic measurements from splitted dynamic matrix A_dyn
>>> print(y.shape)
torch.Size([1, 3, 800])
"""
x = self.vectorize(x) # don't need to crop because A_dyn has extended FOV
x = torch.einsum("mn,...n->...m", self.A_dyn, x)
return x
[docs]
def forward_A_dyn(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy dynamic measurements with the splitted dynamic matrix
.. math::
y = \mathcal{N}\left(A_{\rm{dyn}} x \right)
where :math:`A_{\rm{dyn}} \in \mathbb{R}^{2 M \times L}` is the dynamic acquisition matrix,
:math:`x \in \mathbb{R}^L` is the reference signal of interest,
:math:`M` is the number of measurements, and
:math:`L` is the dimension of the signal (with extended FOV).
.. warning::
This supposes the dynamic measurement matrix :math:`A_{\rm{dyn}}` has been set using the
:meth:`build_dynamic_forward()` method. An error will be raised otherwise.
Args:
:attr:`x` (torch.tensor): Batch of reference (static) signals. The
dimensions indexed by :attr:`self.meas_dims` must match the measurement
shape :attr:`self.img_shape`.
Returns:
torch.tensor: Measurement of the input signal. It has shape :math:`(*, M)` where :math:`*`
denotes all the dimensions that are not included in :attr:`self.meas_dims`
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.warp import DeformationField
>>> from spyrit.core.meas import DynamicLinearSplit
>>>
>>> def_field = DeformationField(torch.rand([800, 50, 50, 2]) * 2 - 1) # dummy deformation field with 400 frames
>>> x = torch.rand([1, 3, 50, 50]) # dummy RGB reference image of size 50x50
>>> x_motion = def_field(x) # dummy video obtained by warping x with def_field
>>> H = torch.rand([400, 40*40]) # dummy static measurement matrix
>>>
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicLinearSplit(H, time_dim=1, meas_shape=(40, 40), img_shape=(50, 50), noise_model=noise_op)
>>> print(meas_op)
DynamicLinearSplit(
(noise_model): Poisson()
)
>>>
>>> meas_op.build_dynamic_forward(def_field)
>>> y = meas_op.forward_A_dyn(x) # simulate noisy dynamic measurements from splitted dynamic matrix A_dyn
>>> print(y.shape)
torch.Size([1, 3, 800])
"""
x = self.vectorize(x) # don't need to crop because A_dyn has extended FOV
x = torch.einsum("mn,...n->...m", self.A_dyn, x)
x = self.noise_model(x)
return x
# =============================================================================
[docs]
class DynamicHadamSplit2d(DynamicLinearSplit):
# =========================================================================
r""" Simulate 2D Hadamard split acquisitions of a moving scene.
We perform the acquisition of :math:`2M` square DMD patterns of size :math:`h` by exploiting the Kronecker structure of the 2D Hadamard matrix:
Each measurement is acquired as, for :math:`k \in \{1, ..., 2M\}`:
.. math::
y_k = \mathcal{N}\left( \sum_{i, j} A_{1d}[r_k, i] x_{t=k}[i, j] A_{1d}[j, c_k] \right),
where
:math:`A_{1d} \in \mathbb{R}_+^{2h\times h}` contains the positive and negative components of a 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of the 1d Hadamard matrix used to generate the 2d Hadamard pattern used at time :math:`t=k`,
:math:`\mathcal{N} \colon\, \mathbb{R} \to \mathbb{R}` represents a noise operator (e.g., Gaussian).
.. important::
Only the forward methods benefit from the fast Hadamard transform algorithm, the adjoint methods do not because
the dynamic forward operator (:math:`H_{\rm{dyn}}` or :math:`A_{\rm{dyn}}`) does not have Kronecker structure.
.. note::
The splitting of the :math:`k^{\rm{th}}` 2D pattern into its positive and negative parts is given by splitting 1D patterns as:
.. math::
H[k, :]^{+} = H_{1d}^{+}[r_k, :] \otimes H_{1d}^{+}[:, c_k] + H_{1d}^{-}[r_k, :] \otimes H_{1d}^{-}[:, c_k] \\
H[k, :]^{-} = H_{1d}^{+}[r_k, :] \otimes H_{1d}^{-}[:, c_k] + H_{1d}^{-}[r_k, :] \otimes H_{1d}^{+}[:, c_k]
Args:
:attr:`time_dim` (int): dimension index in the input tensor :math:`x` that corresponds
to time (i.e., the frames dimension).
:attr:`h` (int): Image size :math:`h`. Must be a power of 2.
:attr:`M` (int): Number of (pos, neg) measurements. If None, it is set to :math:`h^2` (no subsampling).
:attr:`order` (:class:`torch.tensor`, optional): Order matrix :math:`O` that defines the measurements to keep. The first component of :math:`y` will correspond to the index where :attr:`order` is the highest.
:attr:`fast` (bool, optional): Whether to use the fast Hadamard transform
algorithm (i.e. exploit the kronecker structure). If False, it uses matrix-vector products. Defaults to True.
:attr:`reshape_output` (bool, optional): Whether reshape the output of adjoint and pinv methods to images. If False, output are vectors.
:attr:`img_shape` (tuple): Shape of the underlying multi-dimensional
array :math:`x` over the extended field of view. If None, is set to :math:`(h, h)`.
:attr:`noise_model` (see :mod:`spyrit.core.noise`): Noise model :math:`\mathcal{N}`.
Defaults to `torch.nn.Identity()`.
:attr:`white_acq` (torch.tensor, optional): Eventual spatial gain resulting from
detector inhomogeneities and used for dynamic flat-field correction. It can be
determined from a "white acquisition" without any object. If None, no correction is
applied. Must have :attr:`self.meas_shape` shape.
:attr:`dtype` (:class:`torch.dtype`, optional): Data type of the measurement
matrix. Defaults to `torch.float32`.
:attr:`device` (:obj:`torch.device`, optional): Device of the measurement matrix.
Defaults to `torch.device("cpu")`.
.. note::
The argument :attr:`order` is particularly useful when rearranging the
measurements by decreasing variance. The variance matrix can simply be
put as `order`.
Attributes:
:attr:`M` (int): Number of (pos, neg) measurements.
:attr:`N` (int): Number of pixels in the field of view.
:attr:`L` (int): Number of pixels in the extended field of view.
:attr:`meas_shape` (tuple): Shape of the measurements patterns. It is equal to :math:`(h, h)`.
:attr:`meas_dims` (torch.Size): Dimensions of the image the acquisition
matrix applies to. Is equal to `(-2, -1)`.
:attr:`img_shape` (tuple): Shape of the underlying multi-dimensional
array :math:`x` over the extended field of view.
:attr:`H1d` (:class:`torch.tensor`): Static 1D Hadamard matrix of shape
:math:`(h, h)`.
:attr:`H` (:class:`torch.tensor`): Static 2D Hadamard matrix of shape
:math:`(M, N)` given by :math:`H_{1d} \otimes H_{1d}`.
:attr:`A` (:class:`torch.tensor`): Splitted static 2d Hadamard matrix of shape
:math:`(2M, N)` given by :math:`A_{1d} \otimes A_{1d}`.
:attr:`H_dyn` (torch.tensor): Differential dynamic Hadamard matrix :math:`H_{\rm{dyn}}` of shape.
:math:`(M, L)`. Must be set using the :meth:`build_dynamic_forward` method before being accessed.
:attr:`A_dyn` (torch.tensor): Splitted dynamic Hadamard matrix :math:`A_{\rm{dyn}}` of shape.
:math:`(2M, L)`. Must be set using the :meth:`build_dynamic_forward` method before being accessed.
:attr:`order` (:class:`torch.tensor`): Order matrix :math:`O`. It
is used by :func:`~spyrit.core.torch.sort_by_significance()`. Defaults to rectangular order (e.g., linear indices).
:attr:`indices` (:class:`torch.tensor`): Indices used to reorder the measurement vector. It is used by the method :meth:`reindex()`.
Example:
>>> import torch
>>> from spyrit.core.meas import DynamicHadamSplit2d
>>>
>>> order = torch.rand([32,32])
>>> # acquisition with 2 * 32 ** 2 splitted Hadamard patterns of size 32x32.
>>> meas_op = DynamicHadamSplit2d(time_dim=1, h=32, M=32**2, order=order, img_shape=(40, 40))
>>> print(meas_op)
DynamicHadamSplit2d(
(noise_model): Identity()
)
Reference:
[Maitre2024_2]_ Maitre, T., Bretin, E., Phan, R., Ducros, N., & Sdika, M. (2024, October).
Dynamic single-pixel imaging on an extended field of view without warping the patterns. In International
Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 275-284).
Cham: Springer Nature Switzerland. DOI: 10.1007/978-3-031-72104-5_27
[Maitre2026]_ (Submitted to TIP) Maitre, T., Bretin, E., Mahieu-Williame, L., Phan, R., Sdika, M., & Ducros, N. (2025).
Dual-arm motion-compensated single-pixel imaging. HAL Id: hal-05068181
"""
def __init__(
self,
time_dim: int,
h: int,
M: int = None,
order: torch.tensor = None,
fast: bool = True,
reshape_output: bool = False,
img_shape: Union[int, torch.Size, Iterable[int]] = None,
*,
noise_model: nn.Module = nn.Identity(),
white_acq: torch.tensor = None,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
meas_dims = (-2, -1)
meas_shape = (h, h)
if M is None:
M = h**2
self.h = h
# call DynamicLinearSplit constructor (avoid setting A)
super(DynamicLinearSplit, self).__init__(
torch.empty(h**2, h**2, dtype=dtype, device=device), # dummy H
time_dim,
meas_shape,
meas_dims,
img_shape,
noise_model=noise_model,
white_acq=white_acq,
dtype=dtype,
device=device,
)
if order is None:
order = torch.ones(h, h)
# 1D version of H
# H1d = spytorch.walsh_matrix(h).to(dtype=dtype, device=device)
# self.H1d = nn.Parameter(H1d, requires_grad=False)
self.H1d = nn.Parameter(spytorch.walsh_matrix(h), requires_grad=False).to(
dtype=dtype, device=device
)
self.M = M # supercharged self.M
self.order = order
self.indices = torch.argsort(-order.flatten(), stable=True).to(
dtype=torch.int32, device=self.device
)
self.fast = fast
self.reshape_output = reshape_output
@property
def dtype(self) -> torch.dtype:
return self.H1d.dtype
@property
def device(self) -> torch.device:
return self.H1d.device
@property
def H(self):
H = torch.kron(self.H1d, self.H1d)
H = self.reindex(H, "rows", False)
# !!!!
return H[: self.M, :]
@property
def A(self):
H = self.H
pos, neg = nn.functional.relu(H), nn.functional.relu(-H)
return torch.cat([pos, neg], 1).reshape(2 * self.M, self.N)
@property
def matrix_to_inverse(self):
return self.H
[docs]
def reindex(
self, x: torch.tensor, axis: str = "rows", inverse_permutation: bool = False
) -> torch.tensor:
"""Sorts a tensor along a specified axis using the indices tensor. The
indices tensor is contained in the attribute :attr:`self.indices`.
The indices tensor contains the new indices of the elements in the values
tensor. `values[0]` will be placed at the index `indices[0]`, `values[1]`
at `indices[1]`, and so on.
Using the inverse permutation allows to revert the permutation: in this
case, it is the element at index `indices[0]` that will be placed at the
index `0`, the element at index `indices[1]` that will be placed at the
index `1`, and so on.
.. note::
See :func:`~spyrit.core.torch.reindex()` for more details.
Args:
values (:class:`torch.tensor`): The tensor to sort. Can be 1D, 2D, or any
multi-dimensional batch of 2D tensors.
axis (str, optional): The axis to sort along. Must be either 'rows' or
'cols'. If `values` is 1D, `axis` is not used. Default is 'rows'.
inverse_permutation (bool, optional): Whether to apply the permutation
inverse. Default is False.
Raises:
ValueError: If `axis` is not 'rows' or 'cols'.
Returns:
:class:`torch.tensor`: The sorted tensor by the given indices along the
specified axis.
"""
return spytorch.reindex(x, self.indices.to(x.device), axis, inverse_permutation)
[docs]
def measure(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless measurements leveraging the Kronecker structure of the 2d splitted Hadamard transform A.
Each measurement is acquired as, for :math:`k \in \{1, ..., 2M\}`:
.. math::
y_k = \sum_{i, j} A_{1d}[r_k, i] x_{t=k}[i, j] A_{1d}[j, c_k],
where
:math:`A_{1d} \in \mathbb{R}_+^{2h\times h}` contains the positive and negative components of a 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of the 1d Hadamard matrix
used to generate the 2d Hadamard pattern used at time :math:`t=k`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.meas import DynamicHadamSplit2d
>>>
>>> x = torch.rand([1, 2 * 32**2, 3, 40, 40]) # dummy RGB video with 2 * 32**2 frames of size 40x40
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicHadamSplit2d(time_dim=1, h=32, M=32**2, img_shape=(40, 40), \
... noise_model=noise_op) # acquisition with 2*M splitted Hadamard patterns of size hxh.
>>> print(meas_op)
DynamicHadamSplit2d(
(noise_model): Poisson()
)
>>>
>>> y = meas_op.measure(x) # simulate noiseless dynamic measurements
>>> print(y.shape)
torch.Size([1, 3, 2048])
"""
if self.fast:
x = spytorch.center_crop(x, self.meas_shape)
time_and_meas_dims = torch.Size([self.time_dim, *self.meas_dims])
time_and_last_dims = torch.Size([1, -2, -1])
if time_and_meas_dims != time_and_last_dims:
x = torch.movedim(x, time_and_meas_dims, time_and_last_dims)
return self._fast_measure(x)
else:
return super().measure(x)
[docs]
def measure_H(self, x: torch.tensor):
r"""Simulates noiseless measurements leveraging the Kronecker structure of the 2d Hadamard transform H.
Each measurement is acquired as, for :math:`k \in \{1, ..., M\}`:
.. math::
m_k = \sum_{i, j} H_{1d}[r_k, i] x_{t=k}[i, j] H_{1d}[j, c_k],
where
:math:`H_{1d} \in \mathbb{R}^{h\times h}` is the 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of
the 1d Hadamard matrix used to generate the 2d Hadamard pattern used at time :math:`t=k`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.meas import DynamicHadamSplit2d
>>>
>>> x = torch.rand([1, 2 * 32**2, 3, 40, 40]) # dummy RGB video with 2 * 32**2 frames of size 40x40
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicHadamSplit2d(time_dim=1, h=32, M=32**2, img_shape=(40, 40), noise_model=noise_op) # acquisition with 2*M splitted Hadamard patterns of size hxh.
>>> print(meas_op)
DynamicHadamSplit2d(
(noise_model): Poisson()
)
>>>
>>> m = meas_op.measure_H(x) # simulate noiseless dynamic measurements from matrix H
>>> print(m.shape)
torch.Size([1, 3, 1024])
"""
if self.fast:
x = x.movedim(self.time_dim, 0)
x = (x[::2] + x[1::2]) / 2
x = x.movedim(0, self.time_dim)
x = spytorch.center_crop(x, self.meas_shape)
time_and_meas_dims = torch.Size([self.time_dim, *self.meas_dims])
time_and_last_dims = torch.Size([1, -2, -1])
if time_and_meas_dims != time_and_last_dims:
x = torch.movedim(x, time_and_meas_dims, time_and_last_dims)
return self._fast_measure_H(x)
else:
return super().measure_H(x)
def _fast_measure(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless measurements leveraging the Kronecker structure of the 2d splitted Hadamard transform A.
Each measurement is acquired as, for :math:`k \in \{1, ..., 2M\}`:
.. math::
y_k = \sum_{i, j} A_{1d}[r_k, i] x_{t=k}[i, j] A_{1d}[j, c_k],
where
:math:`A_{1d} \in \mathbb{R}_+^{2h\times h}` contains the positive and negative components of a 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of the 1d Hadamard matrix used to generate the 2d Hadamard pattern used at time :math:`t=k`.
"""
pattern_indices = self.indices[: self.M]
# Find indices to 2D coordinates in the Hadamard sampling map (for separable transform)
row_indices = pattern_indices // self.h
col_indices = pattern_indices % self.h
# Extract all required rows and columns from H1d
H1d_rows = self.H1d[row_indices, :] # shape (M, h)
H1d_cols = self.H1d[:, col_indices] # shape (h, M)
# Split the 1D patterns into positive and negative parts
H1d_rows_pos = nn.functional.relu(H1d_rows) # shape (M, h)
H1d_rows_neg = nn.functional.relu(-H1d_rows) # shape (M, h)
H1d_cols_pos = nn.functional.relu(H1d_cols) # shape (h, M)
H1d_cols_neg = nn.functional.relu(-H1d_cols) # shape (h, M)
# For split 2D Hadamard: H_pos = H1d_row_pos \otimes H1d_col_pos + H1d_row_neg \otimes H1d_col_neg
# H_neg = H1d_row_pos \otimes H1d_col_neg + H1d_row_neg \otimes H1d_col_pos
x_pos, x_neg = x[:, ::2], x[:, 1::2]
# Compute the four separable components
m_pp = torch.einsum("th,btchw,wt->bct", H1d_rows_pos, x_pos, H1d_cols_pos)
m_nn = torch.einsum("th,btchw,wt->bct", H1d_rows_neg, x_pos, H1d_cols_neg)
m_pn = torch.einsum("th,btchw,wt->bct", H1d_rows_pos, x_neg, H1d_cols_neg)
m_np = torch.einsum("th,btchw,wt->bct", H1d_rows_neg, x_neg, H1d_cols_pos)
# Combine to get positive and negative measurements
y_pos = m_pp + m_nn
y_neg = m_pn + m_np
# Interleave positive and negative measurements: [pos0, neg0, pos1, neg1, ...]
y = torch.stack([y_pos, y_neg], dim=-1) # shape (b, c, M, 2)
y = y.reshape(*y.shape[:-2], 2 * self.M) # shape (b, c, 2*M)
return y
def _fast_measure_H(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noiseless measurements leveraging the Kronecker structure of the 2d splitted Hadamard transform H.
Each measurement is acquired as, for :math:`k \in \{1, ..., M\}`:
.. math::
m_k = \sum_{i, j} H_{1d}[r_k, i] x_{t=k}[i, j] H_{1d}[j, c_k],
where
:math:`H_{1d} \in \mathbb{R}^{h\times h}` is the 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of the 1d Hadamard matrix used to generate the 2d Hadamard pattern used at time :math:`t=k`.
"""
pattern_indices = self.indices[: self.M]
# Find indices to 2D coordinates in the Hadamard sampling map (for separable transform)
row_indices = pattern_indices // self.h
col_indices = pattern_indices % self.h
# Extract all required rows and columns from H1d
H1d_rows = self.H1d[row_indices, :] # shape (M, h)
H1d_cols = self.H1d[:, col_indices] # shape (h, M)
# Vectorized separable 2D transform using the kronecker structure
# x shape: (b, t, c, h, w) -> we want (b, c, t)
m = torch.einsum("th,btchw,wt->bct", H1d_rows, x, H1d_cols)
return m
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy measurements leveraging the Kronecker structure of the 2d splitted Hadamard transform A.
Each measurement is acquired as, for :math:`k \in \{1, ..., 2M\}`:
.. math::
y_k = \mathcal{N}\left( \sum_{i, j} A_{1d}[r_k, i] x_{t=k}[i, j] A_{1d}[j, c_k] \right),
where
:math:`A_{1d} \in \mathbb{R}_+^{2h\times h}` contains the positive and negative components of a 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of the 1d Hadamard matrix used to generate the 2d Hadamard pattern used at time :math:`t=k`.
Args:
:attr:`x` (:class:`torch.tensor`): Video signal :math:`x` whose
dimensions :attr:`self.meas_dims` must be of shape :attr:`self.meas_shape`
and dimension :attr:`self.time_dim` must be of size :attr:`2 * self.M`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`2\*self.M`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.meas import DynamicHadamSplit2d
>>>
>>> x = torch.rand([1, 2 * 32**2, 3, 40, 40]) # dummy RGB video with 2 * 32**2 frames of size 40x40
>>> alpha = 5 # noise level
>>> noise_op = Poisson(alpha=alpha, g=1/alpha)
>>> meas_op = DynamicHadamSplit2d(time_dim=1, h=32, M=32**2, img_shape=(40, 40), noise_model=noise_op) # acquisition with 2*M splitted Hadamard patterns of size hxh.
>>> print(meas_op)
DynamicHadamSplit2d(
(noise_model): Poisson()
)
>>>
>>> y = meas_op(x) # simulate noisy dynamic measurements
>>> print(y.shape)
torch.Size([1, 3, 2048])
"""
# it is ok to use super().forward, because measure method has been redefined
return super().forward(x)
[docs]
def forward_H(self, x: torch.tensor) -> torch.tensor:
r"""Simulates noisy measurements leveraging the Kronecker structure of the 2d Hadamard transform H.
Each measurement is acquired as, for :math:`k \in \{1, ..., M\}`:
.. math::
m_k = \mathcal{N}\left( \sum_{i, j} H_{1d}[r_k, i] x_{t=k}[i, j] H_{1d}[j, c_k] \right),
where
:math:`H_{1d} \in \mathbb{R}^{h\times h}` is the 1d Hadamard matrix,
:math:`x_{t=k} \in \mathbb{R}^{h \times h}` is :math:`k^{\rm{th}}` frame of the video,
:math:`(r_k, c_k) = (\left \lfloor k / h \right\rfloor, k \bmod h)` are the row and column indices of the 1d Hadamard matrix used to generate the 2d Hadamard pattern used at time :math:`t=k`.
Args:
:attr:`x` (:class:`torch.tensor`): Video signal :math:`x` whose
dimensions :attr:`self.meas_dims` must be of shape :attr:`self.meas_shape`
and dimension :attr:`self.time_dim` must be of size :attr:`2 * self.M`.
Returns:
:class:`torch.tensor`: Measurement vector :math:`m` of length :attr:`self.M`.
Example:
>>> import torch
>>> from spyrit.core.noise import Poisson
>>> from spyrit.core.meas import DynamicHadamSplit2d
>>>
>>> x = torch.rand([1, 2 * 32**2, 3, 40, 40]) # dummy RGB video with 2 * 32**2 frames of size 40x40
>>> noise_op = torch.nn.Identity() # We can't use Poisson here because measurements from H can be negative
>>> meas_op = DynamicHadamSplit2d(time_dim=1, h=32, M=32**2, img_shape=(40, 40), noise_model=noise_op) # acquisition with 2*M splitted Hadamard patterns of size hxh.
>>> print(meas_op)
DynamicHadamSplit2d(
(noise_model): Identity()
)
>>>
>>> m = meas_op.forward_H(x) # simulate noisy dynamic measurements from matrix H
>>> print(m.shape)
torch.Size([1, 3, 1024])
"""
# it is ok to use super().forward_H, because measure_H method has been redefined
return super().forward_H(x)