Source code for spyrit.core.prep

"""
Preprocessing operators applying affine transformations to the measurements.

There are two classes in this module: :class:`DirectPoisson` and
:class:`SplitPoisson`. The first one is used for direct measurements (i.e.
without splitting the measurement matrix in its positive and negative parts),
while the second one is used for split measurements.
"""

from typing import Union, Tuple

import torch
import torch.nn as nn

from spyrit.core.meas import LinearSplit, HadamSplit  # , Linear


# ==============================================================================
[docs] class DirectPoisson(nn.Module): # ============================================================================== r""" Preprocess the raw data acquired with a direct measurement operator assuming Poisson noise. It also compensates for the affine transformation applied to the images to get positive intensities. It computes :math:`m = \frac{2}{\alpha}y - H1` and the variance :math:`\sigma^2 = 4\frac{y}{\alpha^{2}}`, where :math:`y = Hx` are obtained using a direct linear measurement operator (see :mod:`spyrit.core.Linear`), :math:`\alpha` is the image intensity, and 1 is the all-ones vector. Args: :attr:`alpha`: maximun image intensity :math:`\alpha` (in counts) :attr:`meas_op`: measurement operator (see :mod:`~spyrit.core.meas`) Example: >>> H = torch.rand([400,32*32]) >>> meas_op = Linear(H) >>> prep_op = DirectPoisson(1.0, meas_op) """ def __init__(self, alpha: float, meas_op): super().__init__() self.alpha = alpha self.N = meas_op.N self.M = meas_op.M self.max = nn.MaxPool1d(self.N) self.register_buffer("H_ones", meas_op(torch.ones((1, self.N))))
[docs] def forward(self, x: torch.tensor) -> torch.tensor: r""" Preprocess measurements to compensate for the affine image normalization It computes :math:`\frac{2}{\alpha}x - H1`, where H1 represents the all-ones vector. Args: :attr:`x`: batch of measurement vectors Shape: x: :math:`(B, M)` where :math:`B` is the batch dimension meas_op: the number of measurements :attr:`meas_op.M` should match :math:`M`. Output: :math:`(B, M)` Example: >>> x = torch.rand([10,400], dtype=torch.float) >>> H = torch.rand([400,32*32]) >>> meas_op = Linear(H) >>> prep_op = DirectPoisson(1.0, meas_op) >>> m = prep_op(x) >>> print(m.shape) torch.Size([10, 400]) """ # normalize H_ones = self.H_ones.expand(x.shape[0], self.M) x = 2 * x / self.alpha - H_ones return x
[docs] def sigma(self, x: torch.tensor) -> torch.tensor: r"""Estimates the variance of the preprocessed measurements The variance is estimated as :math:`\frac{4}{\alpha^2} x` Args: :attr:`x`: batch of measurement vectors Shape: :attr:`x`: :math:`(B,M)` where :math:`B` is the batch dimension Output: :math:`(B, M)` Example: >>> x = torch.rand([10,400], dtype=torch.float) >>> v = prep_op.sigma(x) >>> print(v.shape) torch.Size([10, 400]) """ x = 4 * x / (self.alpha**2) # Cov is in [-1,1] so *4 return x
[docs] def denormalize_expe(self, x, beta, h, w): r""" Denormalize images from the range [-1;1] to the range [0; :math:`\beta`] It computes :math:`m = \frac{\beta}{2}(x+1)`, where :math:`\beta` is the normalization factor. Args: :attr:`x`: Batch of images :attr:`beta`: Normalizarion factor :attr:`h`: Image height :attr:`w`: Image width Shape: :attr:`x`: :math:`(*, 1, h, w)` :attr:`beta`: :math:`(*)` or :math:`(*, 1)` :attr:`h`: int :attr:`w`: int :attr:`Output`: :math:`(*, 1, h, w)` Example: >>> x = torch.rand([10, 1, 32,32], dtype=torch.float) >>> beta = 9*torch.rand([10]) >>> y = prep_op.denormalize_expe(x, beta, 32, 32) >>> print(y.shape) torch.Size([10, 1, 32, 32]) """ bc = x.shape[0] # Denormalization beta = beta.reshape(bc, 1, 1, 1) beta = beta.expand(bc, 1, h, w) x = (x + 1) / 2 * beta return x
# ==============================================================================
[docs] class SplitPoisson(nn.Module): # ============================================================================== r""" Preprocess the raw data acquired with a split measurement operator assuming Poisson noise. It also compensates for the affine transformation applied to the images to get positive intensities. It computes :math:`m = \frac{y_{+}-y_{-}}{\alpha} - H1` and the variance :math:`var = \frac{2(y_{+} + y_{-})}{\alpha^{2}}`, where :math:`y_{+} = H_{+}x` and :math:`y_{-} = H_{-}x` are obtained using a split measurement operator (see :mod:`spyrit.core.LinearSplit`), :math:`\alpha` is the image intensity, and 1 is the all-ones vector. Args: alpha (float): maximun image intensity :math:`\alpha` (in counts) :attr:`meas_op`: measurement operator (see :mod:`~spyrit.core.meas`) Example: >>> H = torch.rand([400,32*32]) >>> meas_op = LinearSplit(H) >>> split_op = SplitPoisson(10, meas_op) Example 2: >>> Perm = torch.rand([32,32]) >>> meas_op = HadamSplit(400, 32, Perm) >>> split_op = SplitPoisson(10, meas_op) """ def __init__(self, alpha: float, meas_op): super().__init__() self.alpha = alpha self.N = meas_op.N self.M = meas_op.M self.even_index = range(0, 2 * self.M, 2) self.odd_index = range(1, 2 * self.M, 2) self.max = nn.MaxPool1d(self.N) self.register_buffer( "H_ones", meas_op.forward_H(torch.ones((1, self.N))), # torch.ones(1, self.N) @ meas_op.H.T, # "H_ones", meas_op.H(torch.ones((1, self.N))) )
[docs] def forward(self, x: torch.tensor) -> torch.tensor: r""" Preprocess to compensates for image normalization and splitting of the measurement operator. It computes :math:`\frac{x[0::2]-x[1::2]}{\alpha} - H1` Args: :attr:`x`: batch of measurement vectors Shape: x: :math:`(*, 2M)` where :math:`*` indicates one or more dimensions meas_op: the number of measurements :attr:`meas_op.M` should match :math:`M`. Output: :math:`(*, M)` Example: >>> x = torch.rand([10,2*400], dtype=torch.float) >>> H = torch.rand([400,32*32]) >>> meas_op = LinearSplit(H) >>> split_op = SplitPoisson(10, meas_op) >>> m = split_op(x) >>> print(m.shape) torch.Size([10, 400]) Example 2: >>> x = torch.rand([10,2*400], dtype=torch.float) >>> Perm = torch.rand([32,32]) >>> meas_op = HadamSplit(400, 32, Perm) >>> split_op = SplitPoisson(10, meas_op) >>> m = split_op(x) >>> print(m.shape) torch.Size([10, 400]) """ s = x.shape[:-1] + torch.Size([self.M]) # torch.Size([*,M]) H_ones = self.H_ones.expand(s) # unsplit x = x[..., self.even_index] - x[..., self.odd_index] # normalize x = 2 * x / self.alpha - H_ones return x
[docs] def forward_expe( self, x: torch.tensor, meas_op: Union[LinearSplit, HadamSplit] ) -> Tuple[torch.tensor, torch.tensor]: r""" Preprocess to compensate for image normalization and splitting of the measurement operator. It computes :math:`m = \frac{x[0::2]-x[1::2]}{\alpha}`, where :math:`\alpha = \max H^\dagger (x[0::2]-x[1::2])`. Contrary to :meth:`~forward`, the image intensity :math:`\alpha` is estimated from the pseudoinverse of the unsplit measurements. This method is typically called for the reconstruction of experimental measurements, while :meth:`~forward` is called in simulations. The method returns a tuple containing both :math:`m` and :math:`\alpha` Args: :attr:`x`: batch of measurement vectors :attr:`meas_op`: measurement operator (required to estimate :math:`\alpha`) Output (:math:`m`, :math:`\alpha`): preprocess measurement and estimated intensities. Shape: x: :math:`(B, 2M)` where :math:`B` is the batch dimension meas_op: the number of measurements :attr:`meas_op.M` should match :math:`M`. :math:`m`: :math:`(B, M)` :math:`\alpha`: :math:`(B)` Example: >>> x = torch.rand([10,2*400], dtype=torch.float) >>> Perm = torch.rand([32,32]) >>> meas_op = HadamSplit(400, 32, Perm) >>> split_op = SplitPoisson(10, meas_op) >>> m, alpha = split_op.forward_expe(x, meas_op) >>> print(m.shape) >>> print(alpha.shape) torch.Size([10, 400]) torch.Size([10]) """ bc = x.shape[0] # unsplit x = x[:, self.even_index] - x[:, self.odd_index] # estimate alpha x_pinv = meas_op.pinv(x) alpha = self.max(x_pinv) alpha = alpha.expand(bc, self.M) # shape is (b*c, M) # normalize H_ones = self.H_ones.expand(bc, self.M) x = torch.div(x, alpha) x = 2 * x - H_ones alpha = alpha[:, 0] # shape is (b*c,) return x, alpha
[docs] def sigma(self, x: torch.tensor) -> torch.tensor: r"""Estimates the variance of the preprocessed measurements The variance is estimated as :math:`\frac{4}{\alpha^2} H(x[0::2]+x[1::2])` Args: :attr:`x`: batch of images in the Hadamard domain Shape: - Input: :math:`(*,2*M)` :math:`*` indicates one or more dimensions - Output: :math:`(*, M)` Example: >>> x = torch.rand([10,2*400], dtype=torch.float) >>> v = split_op.sigma(x) >>> print(v.shape) torch.Size([10, 400]) """ x = x[..., self.even_index] + x[..., self.odd_index] x = 4 * x / (self.alpha**2) # Cov is in [-1,1] so *4 return x
[docs] def set_expe(self, gain=1.0, mudark=0.0, sigdark=0.0, nbin=1.0): r""" Sets experimental parameters of the sensor Args: - :attr:`gain` (float): gain (in count/electron) - :attr:`mudark` (float): average dark current (in counts) - :attr:`sigdark` (float): standard deviation or dark current (in counts) - :attr:`nbin` (float): number of raw bin in each spectral channel (if input x results from the sommation/binning of the raw data) Example: >>> split_op.set_expe(gain=1.6) >>> print(split_op.gain) 1.6 """ self.gain = gain self.mudark = mudark self.sigdark = sigdark self.nbin = nbin
[docs] def sigma_expe(self, x: torch.tensor) -> torch.tensor: r""" Estimates the variance of the measurements that are compensated for splitting but **NOT** for image normalization Args: :attr:`x`: Batch of images in the Hadamard domain. Shape: Input: :math:`(B,2*M)` where :math:`B` is the batch dimension Output: :math:`(B, M)` Example: >>> x = torch.rand([10,2*32*32], dtype=torch.float) >>> split_op.set_expe(gain=1.6) >>> v = split_op.sigma_expe(x) >>> print(v.shape) torch.Size([10, 400]) """ # Input shape (b*c, 2*M) # output shape (b*c, M) x = x[:, self.even_index] + x[:, self.odd_index] x = ( self.gain * (x - 2 * self.nbin * self.mudark) + 2 * self.nbin * self.sigdark**2 ) x = 4 * x # to get the cov of an image in [-1,1], not in [0,1] return x
[docs] def sigma_from_image( self, x: torch.tensor, meas_op: Union[LinearSplit, HadamSplit] ) -> torch.tensor: r""" Estimates the variance of the preprocessed measurements corresponding to images through a measurement operator The variance is estimated as :math:`\frac{4}{\alpha} \{(Px)[0::2] + (Px)[1::2]\}` Args: :attr:`x`: Batch of images :attr:`meas_op`: Measurement operator Shape: :attr:`x`: :math:`(*,N)` :attr:`meas_op`: An operator such that :attr:`meas_op.N` :math:`=N` and :attr:`meas_op.M` :math:`=M` Output: :math:`(*, M)` Example: >>> x = torch.rand([10,2*400], dtype=torch.float) >>> Perm = torch.rand([32,32]) >>> meas_op = HadamSplit(400, 32, Perm) >>> split_op = SplitPoisson(10, meas_op) >>> v = split_op.sigma_from_image(x, meas_op) >>> print(v.shape) torch.Size([10, 400]) """ x = meas_op(x) x = x[:, self.even_index] + x[:, self.odd_index] x = 4 * x / self.alpha # here alpha should not be squared return x
[docs] def denormalize_expe(self, x, beta, h, w): r""" Denormalize images from the range [-1;1] to the range [0; :math:`\beta`] It computes :math:`m = \frac{\beta}{2}(x+1)`, where :math:`\beta` is the normalization factor. Args: - :attr:`x`: Batch of images - :attr:`beta`: Normalizarion factor - :attr:`h`: Image height - :attr:`w`: Image width Shape: - :attr:`x`: :math:`(*, 1, h, w)` - :attr:`beta`: :math:`(*)` or :math:`(*, 1)` - :attr:`h`: int - :attr:`w`: int - Output: :math:`(*, 1, h, w)` Example: >>> x = torch.rand([10, 1, 32,32], dtype=torch.float) >>> beta = 9*torch.rand([10]) >>> y = split_op.denormalize_expe(x, beta, 32, 32) >>> print(y.shape) torch.Size([10, 1, 32, 32]) """ bc = x.shape[0] # Denormalization beta = beta.reshape(bc, 1, 1, 1) beta = beta.expand(bc, 1, h, w) x = (x + 1) / 2 * beta return x
# ============================================================================== # class SplitRowPoisson(nn.Module): # # ============================================================================== # r""" # Preprocess raw data acquired with a split measurement operator # It computes :math:`m = \frac{y_{+}-y_{-}}{\alpha}` and the variance # :math:`\sigma^2 = \frac{2(y_{+} + y_{-})}{\alpha^{2}}`, where # :math:`y_{+} = H_{+}x` and :math:`y_{-} = H_{-}x` are obtained using # a split measurement operator such as :class:`spyrit.core.LinearSplit`. # Args: # - :math:`\alpha` (float): maximun image intensity (in counts) # - :math:`M` (int): number of measurements # - :math:`h` (int): number of rows in the image, i.e., image height # Example: # >>> split_op = SplitRawPoisson(2.0, 24, 64) # """ # def __init__(self, alpha: float, M: int, h: int): # super().__init__() # self.alpha = alpha # self.M = M # self.h = h # self.even_index = range(0, 2 * M, 2) # self.odd_index = range(1, 2 * M, 2) # # self.max = nn.MaxPool1d(h) # def forward( # self, # x: torch.tensor, # meas_op: LinearSplit, # ) -> torch.tensor: # """ # Args: # x: batch of images that are Hadamard transformed across rows # meas_op: measurement operator # Shape: # x: :math:`(b*c, 2M, w)` with :math:`b` the batch size, :math:`c` the # number of channels, :math:`2M` is twice the number of patterns (as # it includes both positive and negative components), and :math:`w` # is the image width. # meas_op: The number of measurement `meas_op.M` should match `M`, # while the length of the measurements :math:`meas_op.N` should match # image height :math:`h`. # Output: :math:`(b*c,M)` # Example: # >>> x = torch.rand([10,48,64], dtype=torch.float) # >>> H_pos = torch.rand([24,64]) # >>> H_neg = torch.rand([24,64]) # >>> meas_op = LinearSplit(H_pos, H_neg) # >>> m = split_op(x, meas_op) # >>> print(m.shape) # torch.Size([10, 24, 64]) # """ # # unsplit # x = x[:, self.even_index] - x[:, self.odd_index] # # normalize # e = torch.ones([x.shape[0], meas_op.N, self.h], device=x.device) # print("shape of e:", e.shape) # print("shape of x:", x.shape) # print("shape of fwd:", meas_op.forward_H(e).shape) # x = 2 * x / self.alpha - meas_op.forward_H(e) # return x