"""
Preprocessing operators applying affine transformations to the measurements.
There are two classes in this module: :class:`DirectPoisson` and
:class:`SplitPoisson`. The first one is used for direct measurements (i.e.
without splitting the measurement matrix in its positive and negative parts),
while the second one is used for split measurements.
"""
from typing import Union, Tuple
import torch
import torch.nn as nn
from spyrit.core.meas import LinearSplit, HadamSplit # , Linear
# =============================================================================
[docs]
class DirectPoisson(nn.Module):
r"""
Preprocess the raw data acquired with a direct measurement operator assuming
Poisson noise. It also compensates for the affine transformation applied
to the images to get positive intensities.
It computes :math:`m = \frac{2}{\alpha}y - H1` and the variance
:math:`\sigma^2 = 4\frac{y}{\alpha^{2}}`, where :math:`y = Hx` are obtained
using a direct linear measurement operator (see :mod:`spyrit.core.Linear`),
:math:`\alpha` is the image intensity, and 1 is the all-ones vector.
Args:
:attr:`alpha`: maximun image intensity :math:`\alpha` (in counts)
:attr:`meas_op`: measurement operator (see :mod:`~spyrit.core.meas`)
Example:
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> prep_op = DirectPoisson(1.0, meas_op)
"""
def __init__(self, alpha: float, meas_op):
super().__init__()
self.alpha = alpha
self.meas_op = meas_op
self.M = meas_op.M
self.N = meas_op.N
self.h = meas_op.h
self.w = meas_op.w
self.max = nn.MaxPool2d((self.h, self.w))
# self.register_buffer("H_ones", meas_op(torch.ones((1, self.N))))
# generate H_ones on the fly as it is memmory intensive and easy to compute
# ?? Why does it returns float64 ??
@property
def H_ones(self):
return self.meas_op.H.sum(dim=-1).to(self.device)
@property
def device(self):
return self.meas_op.device
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""
Preprocess measurements to compensate for the affine image normalization
It computes :math:`\frac{2}{\alpha}x - H1`, where H1 represents the
all-ones vector.
Args:
:attr:`x`: batch of measurement vectors
Shape:
x: :math:`(B, M)` where :math:`B` is the batch dimension
meas_op: the number of measurements :attr:`meas_op.M` should match
:math:`M`.
Output: :math:`(B, M)`
Example:
>>> x = torch.rand([10,400], dtype=torch.float)
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> prep_op = DirectPoisson(1.0, meas_op)
>>> m = prep_op(x)
>>> print(m.shape)
torch.Size([10, 400])
"""
# normalize
# H_ones = self.H_ones.expand(x.shape[0], self.M)
x = 2 * x / self.alpha - self.H_ones.to(x.dtype).expand(x.shape)
return x
[docs]
def sigma(self, x: torch.tensor) -> torch.tensor:
r"""Estimates the variance of raw measurements
The variance is estimated as :math:`\frac{4}{\alpha^2} x`
Args:
:attr:`x`: batch of measurement vectors
Shape:
:attr:`x`: :math:`(B,M)` where :math:`B` is the batch dimension
Output: :math:`(B, M)`
Example:
>>> x = torch.rand([10,400], dtype=torch.float)
>>> v = prep_op.sigma(x)
>>> print(v.shape)
torch.Size([10, 400])
"""
# *4 to account for the image normalized [-1,1] -> [0,1]
return 4 * x / (self.alpha**2)
[docs]
def denormalize_expe(
self, x: torch.tensor, beta: torch.tensor, h: int = None, w: int = None
) -> torch.tensor:
r"""Denormalize images from the range [-1;1] to the range [0; :math:`\beta`]
It computes :math:`m = \frac{\beta}{2}(x+1)`, where
:math:`\beta` is the normalization factor, that can be different for each
image in the batch.
Args:
- :attr:`x` (torch.tensor): Batch of images
- :attr:`beta` (torch.tensor): Normalization factor. It should have
the same shape as the batch.
- :attr:`h` (int, optional): Image height. If None, it is
deduced from the shape of :attr:`x`. Defaults to None.
- :attr:`w` (int): Image width. If None, it is deduced from the
shape of :attr:`x`. Defaults to None.
Shape:
- :attr:`x`: :math:`(*, h, w)` where :math:`*` indicates any batch
dimensions
- :attr:`beta`: :math:`(*)` or :math:`(1)` if the same for all
images
- :attr:`h`: int
- :attr:`w`: int
- Output: :math:`(*, h, w)`
Example:
>>> x = torch.rand([10, 1, 32,32], dtype=torch.float)
>>> beta = 9*torch.rand([10])
>>> y = split_op.denormalize_expe(x, beta, 32, 32)
>>> print(y.shape)
torch.Size([10, 1, 32, 32])
"""
if h is None:
h = x.shape[-2]
if w is None:
w = x.shape[-1]
if beta.numel() == 1:
beta = beta.expand(x.shape)
else:
# Denormalization
beta = beta.reshape(*beta.shape, 1, 1)
beta = beta.expand((*beta.shape[:-2], h, w))
return (x + 1) / 2 * beta
[docs]
def unsplit(self, x: torch.tensor, mode: str = "diff") -> torch.tensor:
"""Unsplits measurements by combining odd and even indices.
The parameter `mode` can be either 'diff' or 'sum'. The first one
computes the difference between the even and odd indices, while the
second one computes the sum.
Args:
x (torch.tensor): Measurements, can have any shape.
mode (str): 'diff' or 'sum'. If 'diff', the difference between the
even and odd indices is computed. If 'sum', the sum is computed.
Defaults to 'diff'.
Returns:
torch.tensor: The input tensor with the even and odd indices
of the last dimension combined (either by difference or sum).
"""
if mode == "diff":
return x[..., 0::2] - x[..., 1::2]
elif mode == "sum":
return x[..., 0::2] + x[..., 1::2]
else:
raise ValueError("mode should be either 'diff' or 'sum'")
# =============================================================================
[docs]
class SplitPoisson(DirectPoisson):
r"""
Preprocess the raw data acquired with a split measurement operator assuming
Poisson noise. It also compensates for the affine transformation applied
to the images to get positive intensities.
It computes
.. math::
m = \frac{y_{+}-y_{-}}{\alpha} - H1
and the variance
.. math::
\sigma^2 = \frac{2(y_{+} + y_{-})}{\alpha^{2}},
where :math:`y_{+} = H_{+}x` and :math:`y_{-} = H_{-}x` are obtained using a
split measurement operator (see :mod:`spyrit.core.LinearSplit`),
:math:`\alpha` is the image intensity, and 1 is the all-ones vector.
Args:
alpha (float): maximun image intensity :math:`\alpha` (in counts)
:attr:`meas_op`: measurement operator (see :mod:`~spyrit.core.meas`)
Example:
>>> H = torch.rand([400,32*32])
>>> meas_op = LinearSplit(H)
>>> split_op = SplitPoisson(10, meas_op)
Example 2:
>>> Perm = torch.rand([32,32])
>>> meas_op = HadamSplit(400, 32, Perm)
>>> split_op = SplitPoisson(10, meas_op)
"""
def __init__(self, alpha: float, meas_op):
super().__init__(alpha, meas_op)
@property
def even_index(self):
return range(0, 2 * self.M, 2)
@property
def odd_index(self):
return range(1, 2 * self.M, 2)
# @property
# def H_ones(self):
# return self.unsplit(super().H_ones, mode="diff")
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""
Preprocess to compensates for image normalization and splitting of the
measurement operator.
It computes :math:`\frac{x[0::2]-x[1::2]}{\alpha} - H1`
Args:
:attr:`x`: batch of measurement vectors
Shape:
x: :math:`(*, 2M)` where :math:`*` indicates one or more dimensions
meas_op: the number of measurements :attr:`meas_op.M` should match
:math:`M`.
Output: :math:`(*, M)`
Example:
>>> x = torch.rand([10,2*400], dtype=torch.float)
>>> H = torch.rand([400,32*32])
>>> meas_op = LinearSplit(H)
>>> split_op = SplitPoisson(10, meas_op)
>>> m = split_op(x)
>>> print(m.shape)
torch.Size([10, 400])
Example 2:
>>> x = torch.rand([10,2*400], dtype=torch.float)
>>> Perm = torch.rand([32,32])
>>> meas_op = HadamSplit(400, 32, Perm)
>>> split_op = SplitPoisson(10, meas_op)
>>> m = split_op(x)
>>> print(m.shape)
torch.Size([10, 400])
"""
# s = x.shape[:-1] + torch.Size([self.M]) # torch.Size([*,M])
# H_ones = self.H_ones.expand(s)
return super().forward(self.unsplit(x, mode="diff"))
[docs]
def forward_expe(
self, x: torch.tensor, meas_op: Union[LinearSplit, HadamSplit]
) -> Tuple[torch.tensor, torch.tensor]:
r"""Preprocess to compensate for image normalization and splitting of the
measurement operator.
It computes
.. math:: m = \frac{x[0::2]-x[1::2]}{\alpha},
where :math:`\alpha = \max H^\dagger (x[0::2]-x[1::2])`.
.. note::
Contrary to :meth:`~forward`, the image intensity :math:`\alpha`
is estimated from the pseudoinverse of the unsplit measurements. This
method is typically called for the reconstruction of experimental
measurements, while :meth:`~forward` is called in simulations.
The method returns a tuple containing both :math:`m` and :math:`\alpha`
Args:
:attr:`x`: batch of measurement vectors
:attr:`meas_op`: measurement operator (required to estimate
:math:`\alpha`)
Output (:math:`m`, :math:`\alpha`): preprocess measurement and estimated
intensities.
Shape:
x: :math:`(B, 2M)` where :math:`B` is the batch dimension
meas_op: the number of measurements :attr:`meas_op.M` should match
:math:`M`.
:math:`m`: :math:`(B, M)`
:math:`\alpha`: :math:`(B)`
Example:
>>> x = torch.rand([10,2*400], dtype=torch.float)
>>> Perm = torch.rand([32,32])
>>> meas_op = HadamSplit(400, 32, Perm)
>>> split_op = SplitPoisson(10, meas_op)
>>> m, alpha = split_op.forward_expe(x, meas_op)
>>> print(m.shape)
>>> print(alpha.shape)
torch.Size([10, 400])
torch.Size([10])
"""
x = self.unsplit(x, mode="diff")
# estimate alpha
x_pinv = meas_op.pinv(x)
alpha = self.max(x_pinv).squeeze(-1) # shape is now (b, c, 1)
# normalize
alpha = alpha.expand(x.shape)
x = torch.div(x, alpha)
x = 2 * x - self.H_ones.expand(x.shape)
alpha = alpha[..., 0] # shape is (b, c)
return x, alpha
[docs]
def sigma(self, x: torch.tensor) -> torch.tensor:
r"""Estimates the variance of raw measurements
The variance is estimated as :math:`\frac{4}{\alpha^2} (x[0::2]+x[1::2])`
Args:
:attr:`x`: batch of images in the Hadamard domain
Shape:
- Input: :math:`(*,2*M)` :math:`*` indicates one or more dimensions
- Output: :math:`(*, M)`
Example:
>>> x = torch.rand([10,2*400], dtype=torch.float)
>>> v = split_op.sigma(x)
>>> print(v.shape)
torch.Size([10, 400])
"""
return super().sigma(self.unsplit(x, mode="sum"))
[docs]
def set_expe(self, gain=1.0, mudark=0.0, sigdark=0.0, nbin=1.0):
r"""
Sets experimental parameters of the sensor
Args:
- :attr:`gain` (float): gain (in count/electron)
- :attr:`mudark` (float): average dark current (in counts)
- :attr:`sigdark` (float): standard deviation or dark current (in counts)
- :attr:`nbin` (float): number of raw bin in each spectral channel (if input x results from the sommation/binning of the raw data)
Example:
>>> split_op.set_expe(gain=1.6)
>>> print(split_op.gain)
1.6
"""
self.gain = gain
self.mudark = mudark
self.sigdark = sigdark
self.nbin = nbin
[docs]
def sigma_expe(self, x: torch.tensor) -> torch.tensor:
r"""
Estimates the variance of the measurements that are compensated for
splitting but **NOT** for image normalization
Args:
:attr:`x`: Batch of images in the Hadamard domain.
Shape:
Input: :math:`(B,2*M)` where :math:`B` is the batch dimension
Output: :math:`(B, M)`
Example:
>>> x = torch.rand([10,2*32*32], dtype=torch.float)
>>> split_op.set_expe(gain=1.6)
>>> v = split_op.sigma_expe(x)
>>> print(v.shape)
torch.Size([10, 400])
"""
x = self.unsplit(x, mode="sum")
x = (
self.gain * (x - 2 * self.nbin * self.mudark)
+ 2 * self.nbin * self.sigdark**2
)
x = 4 * x # to get the cov of an image in [-1,1], not in [0,1]
return x
[docs]
def sigma_from_image(
self, x: torch.tensor, meas_op: Union[LinearSplit, HadamSplit]
) -> torch.tensor:
r"""
Estimates the variance of the preprocessed measurements corresponding
to images through a measurement operator
The variance is estimated as
:math:`\frac{4}{\alpha} \{(Px)[0::2] + (Px)[1::2]\}`
Args:
:attr:`x`: Batch of images
:attr:`meas_op`: Measurement operator
Shape:
:attr:`x`: :math:`(*,N)`
:attr:`meas_op`: An operator such that :attr:`meas_op.N` :math:`=N`
and :attr:`meas_op.M` :math:`=M`
Output: :math:`(*, M)`
Example:
>>> x = torch.rand([10,2*400], dtype=torch.float)
>>> Perm = torch.rand([32,32])
>>> meas_op = HadamSplit(400, 32, Perm)
>>> split_op = SplitPoisson(10, meas_op)
>>> v = split_op.sigma_from_image(x, meas_op)
>>> print(v.shape)
torch.Size([10, 400])
"""
x = meas_op(x)
x = self.unsplit(x, mode="sum")
x = 4 * x / self.alpha # here alpha should not be squared
return x
# ==============================================================================
[docs]
class SplitPoissonRaw(SplitPoisson):
# ==============================================================================
r"""
Preprocess the raw data acquired with a split measurement operator assuming
Poisson noise. It also compensates for the affine transformation applied
to the images to get positive intensities.
It computes the differential measurements
.. math::
m = \frac{y_{+}-y_{-}}{\alpha} - H1
and the corresponding variance
.. math::
\sigma^2 = \frac{2(y_{+} + y_{-})}{\alpha^{2}},
where :math:`y_{+} = H_{+}x` and :math:`y_{-} = H_{-}x` are obtained using a
split measurement operator (see :mod:`spyrit.core.LinearSplit`),
:math:`\alpha` is a normalisation factor, and 1 is the all-ones vector. This
class also estimates the normalisation factor :math:`\alpha`.
.. note::
Contrary to :class:`SplitPoisson`, the estimation of the normalisation
factor is based on the mean of the raw measurement, **not** on the
pseudo inverse of the differential mesurements.
Args:
alpha (float): maximun image intensity :math:`\alpha` (in counts)
:attr:`meas_op`: measurement operator (see :mod:`~spyrit.core.meas`)
Example:
>>> H = torch.rand([400,32*32])
>>> meas_op = LinearSplit(H)
>>> split_op = SplitPoissonRaw(10, meas_op)
"""
def __init__(self, alpha: float, meas_op):
super().__init__(alpha, meas_op)
[docs]
def forward_expe(
self, x: torch.tensor, meas_op: Union[LinearSplit, HadamSplit], dim=-1
) -> Tuple[torch.tensor, torch.tensor]:
r"""Preprocess to compensate for image normalization and splitting of the
measurement operator.
.. note::
Contrary to :meth:`~forward`, the image intensity :math:`\alpha`
is estimated from the raw measurements. This
method is typically called for the reconstruction of experimental
measurements, while :meth:`~forward` is called in simulations.
Args:
:attr:`x`: batch of measurement vectors with shape :math:`(*, 2M)`
:attr:`meas_op`: measurement operator (see :mod:`~spyrit.core.meas`).
The number of measurements :attr:`meas_op.M` should be equal to
:math:`M`.
:attr:`dim`: dimensions where the max of the pseudo inverse is
computed. Defaults to -1 (i.e., last dimension).
Output:
Preprocessed measurements :math:`m` with shape :math:`(*, M)`.
Estimated intensities :math:`\alpha` with shape :math:`(*)`.
Example:
>>> H = torch.rand([400,32*32])
>>> meas = LinearSplit(H)
>>> split = SplitPoissonRaw(10, meas_op)
>>> x = torch.rand([10,2*400], dtype=torch.float)
>>> split.set_expe()
>>> m, alpha = split.forward_expe(x, meas)
>>> print(m.shape)
>>> print(alpha.shape)
torch.Size([10, 400])
torch.Size([1])
"""
# estimate intensity (in counts)
z = x[..., self.even_index] + x[..., self.odd_index]
mu = torch.mean(z, dim, keepdim=True)
alpha = (2 / self.N) * (mu - 2 * self.mudark) / self.gain
# alternative based on the variance
# var = torch.var(z, dim, keepdim=True)
# alpha_2 = (2/self.N)*(var - 2*self.sigdark**2)/self.gain**2
# gain = (var - 2*self.sigdark**2)/(mu - 2*self.mudark)
# Alternative where all rows of an image have the same normalization
alpha = torch.amax(alpha, -2, keepdim=True)
# intensity x gain (in counts)
norm = alpha * self.gain
# unsplit
x = x[..., self.even_index] - x[..., self.odd_index]
# normalize
x = x / norm
x = 2 * x - self.H_ones.to(x.dtype)
return x, norm # or alpha? Double check
[docs]
def sigma_expe(self, x: torch.tensor) -> torch.tensor:
r"""
Estimates the variance of the measurements that are compensated for
splitting but **NOT** for image normalization
Args:
:attr:`x`: Raw measurements with shape :math:`(*, 2M)`.
Output:
Variance with shape :math:`(*, M)`.
Example:
>>> x = torch.rand([10,2*32*32], dtype=torch.float)
>>> split_op.set_expe(gain=1.6)
>>> v = split_op.sigma_expe(x)
>>> print(v.shape)
torch.Size([10, 400])
"""
# Input shape (b*c, 2*M)
# output shape (b*c, M)
x = x[..., self.even_index] + x[..., self.odd_index]
x = (
self.gain * (x - 2 * self.nbin * self.mudark)
+ 2 * self.nbin * self.sigdark**2
)
x = 4 * x # to get the cov of an image in [-1,1], not in [0,1]
return x
[docs]
def denormalize_expe(
self, x: torch.tensor, beta: torch.tensor, h: int = None, w: int = None
) -> torch.tensor:
return (x + 1) / 2 * beta