"""
Noise models for simulating measurements in imaging.
There are four classes in this module, that each simulate a different type of
noise in the measurements. The classes simulate the following types of noise:
- NoNoise: Simulates measurements with no noise
- Poisson: Simulates measurements corrupted by Poisson noise (each pixel
receives a number of photons that follows a Poisson distribution)
- PoissonApproxGauss: Simulates measurements corrupted by Poisson noise, but
approximates the Poisson distribution with a Gaussian distribution
- PoissonApproxGaussSameNoise: Simulates measurements corrupted by Poisson
noise, but all measurements in a batch are corrupted with the same noise
sample (approximated by a Gaussian distribution)
"""
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import poisson
from spyrit.core.meas import Linear, LinearSplit, HadamSplit # , LinearRowSplit
# =============================================================================
[docs]
class NoNoise(nn.Module):
# =========================================================================
r"""
Simulates measurements from images in the range [0;1] by computing
:math:`y = \frac{1}{2} H(1+x)`.
.. note::
Assumes that the incoming images :math:`x` are in the range [-1;1]
The class is constructed from a measurement operator (see the
:mod:`~spyrit.core.meas` submodule)
Args:
:attr:`meas_op` : Measurement operator (see the
:mod:`~spyrit.core.meas` submodule)
Example 1: Using a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> linear_op = Linear(H)
>>> linear_acq = NoNoise(linear_op)
Example 2: Using a :class:`~spyrit.core.meas.HadamSplit` measurement operator
>>> H = torch.rand([400,32*32])
>>> Perm = torch.rand([32*32,32*32])
>>> split_op = HadamSplit(H, Perm, 32, 32)
>>> split_acq = NoNoise(split_op)
"""
def __init__(self, meas_op: Union[Linear, LinearSplit, HadamSplit]):
super().__init__()
self.meas_op = meas_op
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""
Simulates measurements
Args:
:attr:`x`: Batch of images
Shape:
- :attr:`x`: :math:`(*, N)`
- :attr:`Output`: :math:`(*, M)`
Example 1: Using a :class:`~spyrit.core.meas.Linear` measurement operator
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = linear_acq(x)
>>> print(y.shape)
torch.Size([10, 400])
Example 2: Using a :class:`~spyrit.core.meas.HadamSplit` measurement operator
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = split_acq(x)
>>> print(y.shape)
torch.Size([10, 800])
"""
x = (x + 1) / 2
x = self.meas_op(x)
return x
[docs]
def reindex(
self, x: torch.tensor, axis: str = "rows", inverse_permutation: bool = False
) -> torch.tensor:
"""Sorts a tensor along a specified axis using the indices tensor. The
indices tensor is contained in the attribute :attr:`self.meas_op.indices`.
The indices tensor contains the new indices of the elements in the values
tensor. `values[0]` will be placed at the index `indices[0]`, `values[1]`
at `indices[1]`, and so on.
Using the inverse permutation allows to revert the permutation: in this
case, it is the element at index `indices[0]` that will be placed at the
index `0`, the element at index `indices[1]` that will be placed at the
index `1`, and so on.
.. note::
See :func:`~spyrit.core.torch.reindex()` for more details.
Args:
values (torch.tensor): The tensor to sort. Can be 1D, 2D, or any
multi-dimensional batch of 2D tensors.
axis (str, optional): The axis to sort along. Must be either 'rows' or
'cols'. If `values` is 1D, `axis` is not used. Default is 'rows'.
inverse_permutation (bool, optional): Whether to apply the permutation
inverse. Default is False.
Raises:
ValueError: If `axis` is not 'rows' or 'cols'.
Returns:
torch.tensor: The sorted tensor by the given indices along the
specified axis.
"""
return self.meas_op.reindex(x, axis, inverse_permutation)
# =============================================================================
[docs]
class Poisson(NoNoise):
# =========================================================================
r"""
Simulates measurements corrupted by Poisson noise
Assuming incoming images :math:`x` in the range [-1;1], measurements are
first simulated for images in the range [0; :math:`\alpha`]. Then, Poisson
noise is applied: :math:`y = \mathcal{P}(\frac{\alpha}{2} H(1+x))`.
.. note::
Assumes that the incoming images :math:`x` are in the range [-1;1]
The class is constructed from a measurement operator and an image
intensity :math:`\alpha` that controls the noise level.
Args:
:attr:`meas_op`: Measurement operator :math:`H` (see the :mod:`~spyrit.core.meas` submodule)
:attr:`alpha` (float): Image intensity (in photoelectrons)
Example 1: Using a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> linear_op = Linear(H)
>>> linear_acq = Poisson(linear_op, 10.0)
Example 2: Using a :class:`~spyrit.core.meas.HadamSplit` measurement operator
>>> H = torch.rand([400,32*32])
>>> Perm = torch.rand([32*32,32*32])
>>> split_op = HadamSplit(H, Perm, 32, 32)
>>> split_acq = Poisson(split_op, 200.0)
Example 3: Using a :class:`~spyrit.core.meas.LinearSplit` measurement operator
>>> H = torch.rand(24,64)
>>> split_row_op = LinearSplit(H)
>>> split_acq = Poisson(split_row_op, 50.0)
"""
def __init__(
self,
meas_op: Union[Linear, LinearSplit, HadamSplit],
alpha=50.0,
):
super().__init__(meas_op)
self.alpha = alpha
[docs]
def forward(self, x):
r"""
Simulates measurements corrupted by Poisson noise
Args:
:attr:`x`: Batch of images
Shape:
- :attr:`x`: :math:`(*, N)`
- :attr:`Output`: :math:`(*, M)`
Example 1: Two noisy measurement vectors from a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> noise_op = Poisson(meas_op, 10.0)
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 400])
Measurements in (2249.00 , 2896.00)
Measurements in (2237.00 , 2880.00)
Example 2: Two noisy measurement vectors from a :class:`~spyrit.core.meas.HadamSplit` operator
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> noise_op = Poisson(meas_op, 200.0)
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 800])
Measurements in (0.00 , 55338.00)
Measurements in (0.00 , 55077.00)
Example 3: Two noisy measurement vectors from a :class:`~spyrit.core.meas.LinearSplit` operator
>>> H = torch.rand(24,64)
>>> meas_op = LinearSplit(H)
>>> noise_op = Poisson(meas_op, 50.0)
>>> x = torch.FloatTensor(10, 64, 92).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 48, 92])
Measurements in (500.00 , 1134.00)
Measurements in (465.00 , 1140.00)
"""
# x = self.alpha*(x+1)/2
# x = self.meas_op(x)
x = super().forward(x) # NoNoise forward
x = self.alpha * x
x = F.relu(x) # troncate negative values to zero
x = poisson(x)
return x
# =============================================================================
[docs]
class PoissonApproxGauss(NoNoise):
# =========================================================================
r"""
Simulates measurements corrupted by Poisson noise. To accelerate the
computation, we consider a Gaussian approximation to the Poisson
distribution.
Assuming incoming images :math:`x` in the range [-1;1], measurements are
first simulated for images in the range [0; :math:`\alpha`]:
:math:`y = \frac{\alpha}{2} P(1+x)`. Then, Gaussian noise
is added: :math:`y + \sqrt{y} \cdot \mathcal{G}(\mu=0,\sigma^2=1)`.
The class is constructed from a measurement operator :math:`P` and
an image intensity :math:`\alpha` that controls the noise level.
.. warning::
Assumes that the incoming images :math:`x` are in the range [-1;1]
Args:
:attr:`meas_op`: Measurement operator :math:`H` (see the :mod:`~spyrit.core.meas` submodule)
:attr:`alpha` (float): Image intensity (in photoelectrons)
Example 1: Using a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> noise_op = PoissonApproxGauss(meas_op, 10.0)
Example 2: Using a :class:`~spyrit.core.meas.HadamSplit` operator
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> noise_op = PoissonApproxGauss(meas_op, 200.0)
Example 3: Using a :class:`~spyrit.core.meas.LinearSplit` operator
>>> H = torch.rand(24,64)
>>> meas_op = LinearSplit(H)
>>> noise_op = PoissonApproxGauss(meas_op, 50.0)
"""
def __init__(
self,
meas_op: Union[Linear, LinearSplit, HadamSplit],
alpha: float,
):
super().__init__(meas_op)
self.alpha = alpha
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""
Simulates measurements corrupted by Poisson noise
Args:
:attr:`x`: Batch of images
Shape:
- :attr:`x`: :math:`(*, N)`
- :attr:`Output`: :math:`(*, M)`
Example 1: Two noisy measurement vectors from a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> noise_op = PoissonApproxGauss(meas_op, 10.0)
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 400])
Measurements in (2255.57 , 2911.18)
Measurements in (2226.49 , 2934.42)
Example 2: Two noisy measurement vectors from a :class:`~spyrit.core.meas.HadamSplit` operator
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> noise_op = PoissonApproxGauss(meas_op, 200.0)
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 800])
Measurements in (0.00 , 55951.41)
Measurements in (0.00 , 56216.86)
Example 3: Two noisy measurement vectors from a :class:`~spyrit.core.meas.LinearSplit` operator
>>> H = torch.rand(24,64)
>>> meas_op = LinearSplit(H)
>>> noise_op = PoissonApproxGauss(meas_op, 50.0)
>>> x = torch.FloatTensor(10, 64, 92).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 48, 92])
Measurements in (460.43 , 1216.94)
Measurements in (441.85 , 1230.43)
"""
x = super().forward(x) # NoNoise forward
x = self.alpha * x
x = F.relu(x) # remove small negative values
x = x + torch.sqrt(x) * torch.randn_like(x)
return x
# =============================================================================
[docs]
class PoissonApproxGaussSameNoise(NoNoise):
# =========================================================================
r"""
Simulates measurements corrupted by Poisson noise. To accelerate the
computation, we consider a Gaussian approximation to the Poisson
distribution. Contrary to :class:`~spyrit.core.noise.PoissonApproxGauss`,
all measurements in a batch are corrupted with the same noise sample.
Assuming incoming images :math:`x` in the range [-1;1], measurements are
first simulated for images in the range [0; :math:`\alpha`]:
:math:`y = \frac{\alpha}{2} P(1+x)`. Then, Gaussian noise
is added: :math:`y + \sqrt{y} \cdot \mathcal{G}(\mu=0,\sigma^2=1)`.
The class is constructed from a measurement operator :math:`P` and
an image intensity :math:`\alpha` that controls the noise level.
.. warning::
Assumes that the incoming images :math:`x` are in the range [-1;1]
Args:
:attr:`meas_op`: Measurement operator :math:`H` (see the :mod:`~spyrit.core.meas` submodule)
:attr:`alpha` (float): Image intensity (in photoelectrons)
Example 1: Using a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> noise_op = PoissonApproxGaussSameNoise(meas_op, 10.0)
Example 2: Using a :class:`~spyrit.core.meas.HadamSplit` operator
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> noise_op = PoissonApproxGaussSameNoise(meas_op, 200.0)
"""
def __init__(self, meas_op: Union[Linear, LinearSplit, HadamSplit], alpha: float):
super().__init__(meas_op)
self.alpha = alpha
[docs]
def forward(self, x: torch.tensor) -> torch.tensor:
r"""
Simulates measurements corrupted by Poisson noise
Args:
:attr:`x`: Batch of images
Shape:
- :attr:`x`: :math:`(*, N)`
- :attr:`Output`: :math:`(*, M)`
Example 1: Two noisy measurement vectors from a :class:`~spyrit.core.meas.Linear` measurement operator
>>> H = torch.rand([400,32*32])
>>> meas_op = Linear(H)
>>> noise_op = PoissonApproxGaussSameNoise(meas_op, 10.0)
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 400])
Measurements in (2255.57 , 2911.18)
Measurements in (2226.49 , 2934.42)
Example 2: Two noisy measurement vectors from a :class:`~spyrit.core.meas.HadamSplit` operator
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> noise_op = PoissonApproxGaussSameNoise(meas_op, 200.0)
>>> x = torch.FloatTensor(10, 32*32).uniform_(-1, 1)
>>> y = noise_op(x)
>>> print(y.shape)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
>>> y = noise_op(x)
>>> print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})")
torch.Size([10, 800])
Measurements in (0.00 , 55951.41)
Measurements in (0.00 , 56216.86)
"""
x = super().forward(x) # NoNoise forward
x = self.alpha * x
x = F.relu(x) # remove small negative values
x = x + torch.sqrt(x) * torch.randn(1, x.shape[1])
return x