"""
Reconstruction methods and networks.
"""
import warnings
from typing import Union
import math
import torch
import torch.nn as nn
import spyrit.core.meas as meas
from spyrit.core.noise import NoNoise
from spyrit.core.prep import DirectPoisson, SplitPoisson
warnings.filterwarnings("ignore", ".*Sparse CSR tensor support is in beta state.*")
# =============================================================================
[docs]
class PositiveParameters(nn.Module):
def __init__(self, params, requires_grad=True):
super(PositiveParameters, self).__init__()
self.params = torch.tensor(params, requires_grad=requires_grad)
[docs]
def forward(self):
return torch.abs(self.params)
# =============================================================================
[docs]
class PseudoInverse(nn.Module):
r"""Moore-Penrose pseudoinverse.
Considering linear measurements :math:`y = Hx`, where :math:`H` is the
measurement matrix and :math:`x` is a vectorized image, it estimates
:math:`x` from :math:`y` by computing :math:`\hat{x} = H^\dagger y`, where
:math:`H` is the Moore-Penrose pseudo inverse of :math:`H`.
Example:
>>> H = torch.rand([400,32*32])
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> y = torch.rand([85,400], dtype=torch.float)
>>> pinv_op = PseudoInverse()
>>> x = pinv_op(y, meas_op)
>>> print(x.shape)
torch.Size([85, 1024])
"""
def __init__(self):
super().__init__()
[docs]
def forward(
self,
x: torch.tensor,
meas_op: Union[meas.Linear, meas.DynamicLinear],
**kwargs,
) -> torch.tensor:
r"""Computes pseudo-inverse of measurements.
Args:
:attr:`x`: Batch of measurement vectors.
:attr:`meas_op`: Measurement operator. Any class that
implements a :meth:`pinv` method can be used, e.g.,
:class:`~spyrit.core.meas.HadamSplit`.
:attr:`kwargs`: Additional keyword arguments that are passed to
the :meth:`pinv` method of the measurement operator. Can be used
to specify a regularization parameter.
Shape:
:attr:`x`: :math:`(*, M)`
:attr:`meas_op`: not applicable
:attr:`output`: :math:`(*, N)`
Example:
>>> H = torch.rand([400,32*32])
>>> Perm = torch.rand([32*32,32*32])
>>> meas_op = HadamSplit(H, Perm, 32, 32)
>>> y = torch.rand([85,400], dtype=torch.float)
>>> pinv_op = PseudoInverse()
>>> x = pinv_op(y, meas_op)
>>> print(x.shape)
torch.Size([85, 1024])
"""
return meas_op.pinv(x, **kwargs)
# =============================================================================
[docs]
class TikhonovMeasurementPriorDiag(nn.Module):
r"""
Tikhonov regularisation with prior in the measurement domain.
Considering linear measurements :math:`m = Hx \in\mathbb{R}^M`, where
:math:`H = GF` is the measurement matrix and :math:`x\in\mathbb{R}^N` is a
vectorized image, it estimates :math:`x` from :math:`m` by approximately
minimizing
.. math::
\|m - GFx \|^2_{\Sigma^{-1}_\alpha} + \|F(x - x_0)\|^2_{\Sigma^{-1}}
where :math:`x_0\in\mathbb{R}^N` is a mean image prior,
:math:`\Sigma\in\mathbb{R}^{N\times N}` is a covariance prior, and
:math:`\Sigma_\alpha\in\mathbb{R}^{M\times M}` is the measurement noise
covariance. The matrix :math:`G\in\mathbb{R}^{M\times N}` is a
subsampling matrix.
.. note::
The class is instantiated from :math:`\Sigma`, which represents the
covariance of :math:`Fx`.
Args:
- :attr:`sigma`: covariance prior with shape :math:`N` x :math:`N`
- :attr:`M`: number of measurements :math:`M`
Attributes:
:attr:`comp`: The learnable completion layer initialized as
:math:`\Sigma_1 \Sigma_{21}^{-1}`. This layer is a :class:`nn.Linear`
:attr:`denoi`: The learnable denoising layer initialized from
:math:`\Sigma_1`.
Example:
>>> sigma = torch.rand([32*32, 32*32])
>>> recon_op = TikhonovMeasurementPriorDiag(sigma, 400)
"""
def __init__(self, sigma: torch.tensor, M: int):
super().__init__()
# N = sigma.shape[0]
var_prior = sigma.diag()[:M]
# self.denoi = Denoise_layer(M)
# self.denoi.weight.data = torch.sqrt(var_prior)
# self.denoi.weight.data = self.denoi.weight.data.float()
# self.denoi.weight.requires_grad = False
self.denoise_weights = nn.Parameter(torch.sqrt(var_prior), requires_grad=False)
Sigma1 = sigma[:M, :M]
Sigma21 = sigma[M:, :M]
# W = Sigma21 @ torch.linalg.inv(Sigma1)
W = torch.linalg.solve(Sigma1.T, Sigma21.T).T
self.comp = nn.Parameter(W, requires_grad=False)
# self.comp = nn.Linear(M, N - M, False)
# self.comp.weight.data = W
# self.comp.weight.data = self.comp.weight.data.float()
# self.comp.weight.requires_grad = False
[docs]
def wiener_denoise(self, x: torch.tensor, var: torch.tensor) -> torch.tensor:
"""Returns a denoised version of the input tensor using the variance prior.
This uses the attribute self.denoise_weights, which is a learnable
parameter.
Inputs:
x (torch.tensor): The input tensor to be denoised.
var (torch.tensor): The variance prior.
Returns:
torch.tensor: The denoised tensor.
"""
weights_squared = self.denoise_weights**2
return torch.mul((weights_squared / (weights_squared + var)), x)
[docs]
def forward(
self,
x: torch.tensor,
x_0: torch.tensor,
var: torch.tensor,
meas_op: meas.HadamSplit,
) -> torch.tensor:
r"""
Computes the Tikhonov regularization with prior in the measurement domain.
We approximate the solution as:
.. math::
\hat{x} = x_0 + F^{-1} \begin{bmatrix} m_1 \\ m_2\end{bmatrix}
with :math:`m_1 = D_1(D_1 + \Sigma_\alpha)^{-1} (m - GF x_0)` and
:math:`m_2 = \Sigma_1 \Sigma_{21}^{-1} m_1`, where
:math:`\Sigma = \begin{bmatrix} \Sigma_1 & \Sigma_{21}^\top \\ \Sigma_{21} & \Sigma_2\end{bmatrix}`
and :math:`D_1 =\textrm{Diag}(\Sigma_1)`. Assuming the noise
covariance :math:`\Sigma_\alpha` is diagonal, the matrix inversion
involved in the computation of :math:`m_1` is straightforward.
This is an approximation to the exact solution
.. math::
\hat{x} &= x_0 + F^{-1}\begin{bmatrix}\Sigma_1 \\ \Sigma_{21} \end{bmatrix}
[\Sigma_1 + \Sigma_\alpha]^{-1} (m - GF x_0)
See Lemma B.0.5 of the PhD dissertation of A. Lorente Mur (2021):
https://theses.hal.science/tel-03670825v1/file/these.pdf
Args:
- :attr:`x`: A batch of measurement vectors :math:`m`
- :attr:`x_0`: A batch of prior images :math:`x_0`
- :attr:`var`: A batch of measurement noise variances :math:`\Sigma_\alpha`
- :attr:`meas_op`: A measurement operator that provides :math:`GF` and :math:`F^{-1}`
Shape:
- :attr:`x`: :math:`(*, M)`
- :attr:`x_0`: :math:`(*, N)`
- :attr:`var` :math:`(*, M)`
- Output: :math:`(*, N)`
Example:
>>> B, H, M = 85, 32, 512
>>> sigma = torch.rand([H**2, H**2])
>>> recon_op = TikhonovMeasurementPriorDiag(sigma, M)
>>> Ord = torch.ones((H,H))
>> meas = HadamSplit(M, H, Ord)
>>> y = torch.rand([B,M], dtype=torch.float)
>>> x_0 = torch.zeros((B, H**2), dtype=torch.float)
>>> var = torch.zeros((B, M), dtype=torch.float)
>>> x = recon_op(y, x_0, var, meas)
torch.Size([85, 1024])
"""
x = x - meas_op.forward_H(x_0)
y1 = self.wiener_denoise(x, var)
y2 = y1 @ self.comp.T
y = torch.cat((y1, y2), -1)
x = x_0 + meas_op.inverse(y)
return x
# =============================================================================
[docs]
class Tikhonov(nn.Module):
r"""Tikhonov regularization (aka as ridge regression).
It estimates the signal :math:`x\in\mathbb{R}^{N}` the from linear
measurements :math:`y = Ax\in\mathbb{R}^{M}` corrupted by noise by solving
.. math::
\| y - Ax \|^2_{\Gamma{-1}} + \|x\|^2_{\Sigma^{-1}},
where :math:`\Gamma` is covariance of the noise, and :math:`\Sigma` is the
signal covariance. In the case :math:`M\le N`, the solution can be computed
as
.. math::
\hat{x} = \Sigma A^\top (A \Sigma A^\top + \Gamma)^{-1} y,
where we assume that both covariance matrices are positive definite. The
class is constructed from :math:`A` and :math:`\Sigma`, while
:math:`\Gamma` is passed as an argument to :meth:`forward()`. Passing
:math:`\Gamma` to :meth:`forward()` is useful in the presence of signal-
dependent noise.
.. note::
* :math:`x` can be a 1d signal or a vectorized image/volume. This can
be specified by setting the :attr:`meas_shape` attribute of the
measurement operator.
* The above formulation assumes that the signal :math:`x` has zero mean.
Args:
- :attr:`meas_op` : Measurement operator (see :class:`~spyrit.core.meas`).
Its measurement operator has shape :math:`(M, N)`, with :math:`M` the
number of measurements and :math:`N` the number of pixels in the image.
- :attr:`sigma` : Signal covariance prior, of shape :math:`(N, N)`.
- :attr:`diagonal_approximation` : A boolean indicating whether to set
the non-diagonal elements of :math:`A \Sigma A^T` to zero. Default is
False. If True, this speeds up the computation of the inverse
:math:`(A \Sigma A^T + \Sigma_\alpha)^{-1}`.
Attributes:
- :attr:`meas_op` : Measurement operator initialized as :attr:`meas_op`.
- :attr:`diagonal_approximation` : Indicates if the diagonal approximation
is used.
- :attr:`img_shape` : Shape of the image, initialized as :attr:`meas_op.img_shape`.
- :attr:`sigma_meas` : Measurement covariance prior initialized as
:math:`A \Sigma A^T`. If :attr:`diagonal_approximation` is True, the
non-diagonal elements are set to zero.
- :attr:`sigma_A_T` : Covariance of the missing measurements initialized
as :math:`\Sigma A^T`.
- :attr:`noise_scale` : Hidden parameter to use to scale the noise
regularization. It is used in the computation of the inverse:
:math:`(A \Sigma A^T + noisescale \times \Sigma_\alpha)^{-1}`. Default is 1.
Example:
>>> B, H, M, N = 85, 17, 32, 64
>>> sigma = torch.rand(N, N)
>>> gamma = torch.rand(M, M)
>>> A = torch.rand([M,N])
>>> meas = Linear(A, meas_shape=(1,N))
>>> recon = Tikhonov(meas, sigma)
>>> y = torch.rand(B,H,M)
>>> x = recon(y, gamma)
>>> print(y.shape)
>>> print(x.shape)
torch.Size([85, 17, 32])
torch.Size([85, 17, 64])
"""
def __init__(self, meas_op, sigma: torch.tensor, approx=False):
super().__init__()
dtype = sigma.dtype
A = meas_op.H # .to(sigma.dtype)
sigma = sigma.to(A.dtype)
# if isinstance(meas_op, meas.DynamicLinearSplit):
# the measurement covariance prior is assumed to be for
# Hadamard-matrix measurements (i.e. no splitting)
# A = A[::2, :] - A[1::2, :]
if approx:
# if we use the diagonal approximation, then we assume that the
# A @ sigma @ A.T is diagonal
sigma_meas = torch.diag(A @ sigma @ A.T).to(dtype)
else:
sigma_meas = (A @ sigma @ A.T).to(dtype)
# estimation of the missing measurements
sigma_A_T = torch.mm(sigma, A.mT).to(dtype)
self.register_buffer("sigma_meas", sigma_meas)
self.register_buffer("sigma_A_T", sigma_A_T)
self.meas_op = meas_op
self.img_shape = meas_op.img_shape
self.approx = approx
# hidden parameter to use as a hyperparameter for dynamic reconstructions
self.noise_scale = 1
[docs]
def divide(self, y: torch.tensor, gamma: torch.tensor) -> torch.tensor:
"""Computes the division :math:`y \cdot (\sigma_\alpha \times noisescale + (A \Sigma A^T))^{-1}`.
Measurements `y` are divided by the sum of the measurement covariance.
If :attr:`self.approx` is True, the inverse is approximated as
a diagonal matrix, speeding up the computation. Otherwise, the
inverse is computed with the whole matrix.
Args:
y (torch.tensor): Input measurement tensor. Shape :math:`(*, M)`.
gamma (torch.tensor): Noise covariance tensor. Shape :math:`(*, M, M)`.
Returns:
torch.tensor: The divided tensor. Shape :math:`(*, M)`.
"""
if self.approx:
return y / (
self.sigma_meas
+ self.noise_scale * torch.diagonal(gamma, dim1=-2, dim2=-1)
)
else:
# we need to expand the matrices for the solve/ matmul
batch_shape = y.shape[:-1]
expand_shape = batch_shape + (self.sigma_meas.shape)
y = y.unsqueeze(-1) # add a dimension to y for batch matrix multiplications
y = torch.linalg.solve(
(self.sigma_meas + self.noise_scale * gamma).expand(expand_shape), y
)
return y.squeeze(-1)
[docs]
def forward(
self, y: torch.tensor, gamma: torch.tensor # x_0: torch.tensor,
) -> torch.tensor:
r"""Reconstructs the signal from measurements and noise covariance.
The Tikhonov solution is computed as
.. math::
\hat{x} = B^\top (C + \Gamma)^{-1} y
with :math:`B = \Sigma A^\top` and :math:`C = A \Sigma A^\top`. When
:attr:`self.approx` is True, it is approximated as
.. math::
\hat{x} = B^\top \frac{y}{\text{diag}(C + \Gamma)}
Args:
:attr:`y` (torch.tensor): A batch of measurement vectors :math:`y`
:attr:`gamma` (torch.tensor): A batch of noise covariance :math:`\Gamma`
Shape:
:attr:`y` (torch.tensor): :math:`(*, M)`
:attr:`gamma` (torch.tensor): :math:`(*, M, M)`
Output (torch.tensor): :math:`(*, N)`
"""
y = self.divide(y, gamma)
y = torch.matmul(self.sigma_A_T, y.unsqueeze(-1)).squeeze(-1)
# y = y.reshape(*y.shape[:-1], *self.img_shape)
return y
# =============================================================================
[docs]
class Denoise_layer(nn.Module):
r"""Defines a learnable Wiener filter that assumes additive white Gaussian noise.
The filter is pre-defined upon initialization with the standard deviation prior
(if known), or with an integer representing the size of the input vector.
In the second case, the standard deviation prior is initialized at random
from a uniform (0,2/size) distribution.
Using the foward method (the implicit call method), the filter is fully
defined:
.. math::
\sigma_\text{prior}^2/(\sigma^2_\text{prior} + \sigma^2_\text{meas})
where :math:`\sigma^2_\text{prior}` is the variance prior defined at
initialization and :math:`\sigma^2_\text{meas}` is the measurement variance
defined using the forward method. The value given by the equation above
can then be multiplied by the measurement vector to obtain the denoised
measurement vector.
.. note::
The weight (defined at initialization or accessible through the
attribute :attr:`weight`) should not be squared (as it is squared when
the forward method is called).
Args:
:attr:`std_dev_or_size` (torch.tensor or int): 1D tensor representing
the standard deviation prior or an integer defining the size of the
randomly-initialized standard deviation prior. If an array is passed
and it is not 1D, it is flattened. It is stored internally as a
:class:`nn.Parameter`, whose :attr:`data` attribute is accessed through
the :attr:`sigma` attribute, and whose :attr:`requires_grad` attribute
is accessed through the :attr:`requires_grad` attribute.
Shape for forward call:
- Input: :math:`(*, in\_features)` measurement variance.
- Output: :math:`(*, in\_features)` fully defined Wiener filter.
Attributes:
:attr:`weight`:
The learnable standard deviation prior :math:`\sigma_\text{prior}` of
shape :math:`(in\_features, 1)`. The values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = 1/in\_features`.
:attr:`sigma`:
The learnable standard deviation prior :math:`\sigma_\text{prior}` of shape
:math:`(, in\_features)`. If the input is an integer, the standard deviation prior
is initialized at random from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`,
where :math:`k = 1/in\_features`.
:attr:`in_features`:
The number of input features.
:attr:`requires_grad`:
A boolean indicating whether the autograd should record operations on
the standard deviation tensor. Default is True.
Example:
>>> m = Denoise_layer(30)
>>> input = torch.randn(128, 30)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
def __init__(
self, std_dev_prior_or_size: Union[torch.tensor, int], requires_grad=True
):
super(Denoise_layer, self).__init__()
warnings.warn(
"This class is deprecated and will be removed in a future release. "
"Please use the `TikhonovMeasurementPriorDiag` class instead.",
DeprecationWarning,
)
if isinstance(std_dev_prior_or_size, int):
self.weight = nn.Parameter(
torch.Tensor(std_dev_prior_or_size), requires_grad=requires_grad
)
self.reset_parameters()
else:
if not isinstance(std_dev_prior_or_size, torch.Tensor):
raise TypeError(
"std_dev_or_size should be an integer or a torch.Tensor"
)
self.weight = nn.Parameter(
std_dev_prior_or_size.reshape(-1), requires_grad=requires_grad
)
@property
def in_features(self):
return self.weight.data.numel()
[docs]
def reset_parameters(self):
r"""
Resets the standard deviation prior :math:`\sigma_\text{prior}`.
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`,
where :math:`k = 1/in\_features`. They are stored in the :attr:`weight`
attribute.
"""
nn.init.uniform_(self.weight, 0, 2 / math.sqrt(self.in_features))
[docs]
def forward(self, sigma_meas_squared: torch.tensor) -> torch.tensor:
r"""
Fully defines the Wiener filter with the measurement variance.
This outputs :math:`\sigma_\text{prior}^2/(\sigma_\text{prior}^2 + \sigma^2_\text{meas})`,
where :math:`\sigma^2_\text{meas}` is the measurement variance (see :attr:`sigma_meas_squared`) and
:math:`\sigma_\text{prior}` is the standard deviation prior defined
upon construction of the class (see :attr:`self.weight`).
.. note::
The measurement variance should be squared before being passed to
this method, unlike the standard deviation prior (defined at construction).
Args:
:attr:`sigma_meas_squared` (torch.tensor): input tensor :math:`\sigma^2_\text{meas}`
of shape :math:`(*, in\_features)`
Returns:
torch.tensor: The multiplicative filter of shape
:math:`(*, in\_features)`
Shape:
- Input: :math:`(*, in\_features)`
- Output: :math:`(*, in\_features)`
"""
if sigma_meas_squared.shape[-1] != self.in_features:
raise ValueError(
"The last dimension of the input tensor "
+ f"({sigma_meas_squared.shape[-1]})should be equal to the number of "
+ f"input features ({self.in_features})."
)
return self.tikho(sigma_meas_squared, self.weight)
def extra_repr(self):
return "in_features={}".format(self.in_features)
[docs]
@staticmethod
def tikho(inputs: torch.tensor, weight: torch.tensor) -> torch.tensor:
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Applies a transformation to the incoming data: :math:`y = \sigma_\text{prior}^2/(\sigma_\text{prior}^2+x)`.
:math:`x` is the input tensor (see :attr:`inputs`) and :math:`\sigma_\text{prior}` is the
standard deviation prior (see :attr:`weight`).
Args:
:attr:`inputs` (torch.tensor): input tensor :math:`x` of shape
:math:`(N, *, in\_features)`
:attr:`weight` (torch.tensor): standard deviation prior :math:`\sigma_\text{prior}` of
shape :math:`(in\_features)`
Returns:
torch.tensor: The transformed data :math:`y` of shape
:math:`(N, in\_features)`
Shape:
- :attr:`inputs`: :math:`(N, *, in\_features)` where `*` means any number of
additional dimensions - Variance of measurements
- :attr:`weight`: :math:`(in\_features)` - corresponds to the standard deviation
of our prior.
- :attr:`output`: :math:`(N, in\_features)`
"""
a = weight**2 # prefer to square it, because when learnt, it can go to the
# negative, which we do not want to happen.
# TO BE Potentially done : square inputs.
b = a + inputs
return a / b
# -----------------------------------------------------------------------------
# | RECONSTRUCTION NETWORKS |
# -----------------------------------------------------------------------------
# =============================================================================
[docs]
class PinvNet(nn.Module):
r"""Pseudo inverse reconstruction network.
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`denoi` (optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
Input / Output:
:attr:`input`: Ground-truth images with shape :math:`(B,C,H,W)`
corresponding to the batch size, number of channels, height, and width.
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
corresponding to the batch size, number of channels, height, and width.
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
:attr:`prep`: Preprocessing operator initialized as :attr:`prep`
:attr:`pinv`: Analytical reconstruction operator initialized as
:class:`~spyrit.core.recon.PseudoInverse()`
:attr:`Denoi`: Image denoising operator initialized as :attr:`denoi`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
"""
def __init__(self, noise, prep, denoi=nn.Identity()):
super().__init__()
self.acqu = noise
self.prep = prep
self.pinv = PseudoInverse()
self.denoi = denoi
@property
def device(self):
return self.acqu.device
[docs]
def forward(self, x):
r"""Full pipeline of reconstrcution network
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: reconstructed images with shape :math:`(B,C,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
"""
x = self.acquire(x)
x = self.reconstruct(x)
return x
[docs]
def acquire(self, x):
r"""Simulates data acquisition
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: measurement vectors with shape :math:`(BC,2M)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet.acquire(x)
>>> print(z.shape)
torch.Size([10, 8192])
"""
# b, c, _, _ = x.shape
# Acquisition
# x = x.reshape(b * c, self.acqu.meas_op.N) # shape x = [b*c,h*w] = [b*c,N]
return self.acqu(x) # shape x = [b*c, 2*M]
[docs]
def meas2img(self, y):
"""Returns images from raw measurement vectors
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(*,2M)`
:attr:`output`: :math:`(*,H,W)`
Example:
>>> B, C, H, M = 10, 3, 64, 64**2
>>> Ord = torch.ones(H,H)
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H**2)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.rand((B,C,2*M), dtype=torch.float32)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
torch.Size([10, 3, 64, 64])
"""
m = self.prep(y)
m = torch.nn.functional.pad(m, (0, self.acqu.meas_op.N - self.acqu.meas_op.M))
# reindex the measurements
z = self.acqu.meas_op.reindex(m, "cols", False)
return z.reshape(*z.shape[:-1], self.acqu.meas_op.h, self.acqu.meas_op.w)
[docs]
def reconstruct(self, x):
r"""Preprocesses, reconstructs, and denoises raw measurement vectors.
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H**2)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
return self.denoi(self.reconstruct_pinv(x))
[docs]
def reconstruct_pinv(self, x):
r"""Preprocesses and reconstructs raw measurement vectors.
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H**2)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct_pinv(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
x = self.prep(x)
x = self.pinv(x, self.acqu.meas_op)
return x
[docs]
def reconstruct_expe(self, x):
r"""Reconstruction step of a reconstruction network
Same as :meth:`reconstruct` reconstruct except that:
1. The preprocessing step estimates the image intensity for normalization
2. The output images are "denormalized", i.e., have units of photon counts
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
"""
# x of shape [b*c, 2M]
bc, _ = x.shape
# Preprocessing
x, N0_est = self.prep.forward_expe(x, self.acqu.meas_op) # shape x = [b*c, M]
# print(N0_est)
# measurements to image domain processing
x = self.pinv(x, self.acqu.meas_op) # shape x = [b*c,N]
# Image domain denoising
x = x.reshape(
bc, 1, self.acqu.meas_op.h, self.acqu.meas_op.w
) # shape x = [b*c,1,h,w]
x = self.denoi(x) # shape x = [b*c,1,h,w]
# print(x.max())
# Denormalization
x = self.prep.denormalize_expe(
x, N0_est, self.acqu.meas_op.h, self.acqu.meas_op.w
)
# return x
return x, N0_est
# =============================================================================
[docs]
class Pinv1Net(nn.Module):
# =============================================================================
r"""1D pseudo inverse reconstruction network.
Considering linear measurements :math:`Y = HX`, where
:math:`H\in\mathbb{R}^{k\times h}` is the
measurement matrix and :math:`X \in\mathbb{R}^{h\times w}` is an image, it estimates
:math:`X` from :math:`Y` by computing
.. math:: \hat{X} = \mathcal{G}_\theta(H^\dagger Y),
where :math:`H` is the Moore-Penrose pseudo inverse of :math:`H`, and
:math:`\mathcal{G}_\theta` is a neural network.
The pseudo-inverse is computed along the last dimension, while (learnable)
denoising applies to the last two dimensions.
Args:
:attr:`noise`: Acquisition operator that compute (noisy) measurements :math:`Y = HX` (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`denoi` (optional): Image denoising operator
:math:`\mathcal{G}_\theta` (see :class:`~spyrit.core.nnet`). Defaults
to :class:`~spyrit.core.nnet.Identity`.
Input / Output:
:attr:`input`: Ground-truth images :math:`X` with shape :math:`(b,c,h,w)`.
:attr:`output`: Reconstructed images :math:`\hat{X}` with shape :math:`(b,c,h,w)`.
Attributes:
:attr:`acqu`: Acquisition operator initialized as :attr:`noise`.
:attr:`prep`: Preprocessing operator initialized as :attr:`prep`.
:attr:`pinv`: Analytical reconstruction operator initialized as
:class:`~spyrit.core.recon.PseudoInverse()`.
:attr:`denoi`: Image denoising operator initialized as :attr:`denoi`.
Example:
>>> b,c,h,w = 10,1,48,64
>>> H = torch.rand(15,w)
>>> meas = Linear(H, meas_shape=(1,w))
>>> noise = NoNoise(meas)
>>> prep = DirectPoisson(1.0, meas)
>>> recnet = Pinv1Net(noise, prep)
>>> x = torch.FloatTensor(b,c,h,n).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
.. note::
The measurement operator applies to the last dimension of the input
tensor, contrary :class:`~spyrit.core.recon.PinvNet` where it applies
to the last two dimensions. In both cases, the denoising operator
applies to the last two dimensions.
"""
def __init__(self, noise, prep, denoi=nn.Identity()):
super().__init__()
self.acqu = noise
self.prep = prep
self.pinv = PseudoInverse()
self.denoi = denoi
[docs]
def forward(self, x):
r"""Full pipeline (image-to-image mapping)
Args:
:attr:`x` (torch.tensor): Ground-truth images with shape :math:`(b,c,h,w)`.
Output:
torch.tensor: Reconstructed images with shape :math:`(b,c,h,w)`.
Example:
>>> b,c,h,w = 10,1,48,64
>>> H = torch.rand(15,w)
>>> meas = Linear(H, meas_shape=(1,w))
>>> noise = NoNoise(meas)
>>> prep = DirectPoisson(1.0, meas)
>>> recnet = Pinv1Net(noise, prep)
>>> x = torch.FloatTensor(b,c,h,n).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
"""
# Acquisition
x = self.acqu(x)
# Reconstruction
x = self.reconstruct(x)
return x
[docs]
def reconstruct(self, x):
r"""Reconstruction (measurement-to-image mapping)
Args:
:attr:`x` (torch.tensor): Raw measurement vectors with shape :math:`(b,c,h,k)`.
Output:
torch.tensor: Reconstructed images with shape :math:`(b,c,h,w)`
"""
# Preprocessing in the measurement domain
x = self.prep(x)
# measurements to image-domain processing
x = self.pinv(x, self.acqu.meas_op)
x = x.squeeze(-2) # shape x = [*,1,N] -> x = [*,N]
# Image-domain denoising
x = self.denoi(x)
return x
[docs]
def reconstruct_expe(self, x):
r"""Reconstruction (measurement-to-image mapping) for experimental data.
Args:
:attr:`x`: Raw measurement vectors with shape :math:`(b,c,h,k)`.
Output:
Reconstructed images with shape :math:`(b,c,h,w)`
"""
# Preprocessing
# x, norm = self.prep.forward_expe(x, self.acqu.meas_op, (-2,-1)) # shape: [*, M]
# Alternative where the mean is computed on each row
x, norm = self.prep.forward_expe(x, self.acqu.meas_op) # shape: [*, M]
# measurements to image domain processing
x = self.pinv(x, self.acqu.meas_op) # shape: [*,N]
x = x.squeeze(-2) # shape: [*,1,N] -> [*,N]
# Image-domain denoising
x = self.denoi(x) # shape: [*,h,w]
# Denormalization
x = self.prep.denormalize_expe(x, norm, x.shape[-2], x.shape[-1])
return x, norm
# =============================================================================
[docs]
class DCNet(nn.Module):
r"""Denoised completion reconstruction network.
This is a four step reconstruction method:
#. Denoising in the measurement domain.
#. Estimation of the missing measurements from the denoised ones.
#. Image-domain mapping.
#. (Learned) Denoising in the image domain.
The first three steps corresponds to Tikhonov regularisation. Typically, only the last step involves learnable parameters.
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`sigma`: Covariance prior (for details, see the
:class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag()` class)
:attr:`denoi` (optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
Input / Output:
:attr:`input`: Ground-truth images with shape :math:`(*,H,W)`, with
:math:`*` being any batch size.
:attr:`output`: Reconstructed images with shape :math:`(*,H,W)`, with
:math:`*` being any batch size.
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
:attr:`PreP`: Preprocessing operator initialized as :attr:`prep`
:attr:`DC_Layer`: Data consistency layer initialized as :attr:`tikho`
:attr:`Denoi`: Image denoising operator initialized as :attr:`denoi`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = torch.rand([H**2, H**2])
>>> recnet = DCNet(noise,prep,sigma)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
def __init__(
self,
noise: NoNoise,
prep: Union[DirectPoisson, SplitPoisson],
sigma: torch.tensor,
denoi=nn.Identity(),
):
super().__init__()
self.Acq = noise
self.prep = prep
self.denoi = denoi
sigma = sigma.to(torch.float32)
sigma = noise.reindex(sigma, "rows", False)
sigma = noise.reindex(sigma, "cols", True)
sigma_perm = sigma
# save in tikho
self.tikho = TikhonovMeasurementPriorDiag(sigma_perm, noise.meas_op.M)
@property
def device(self):
return self.Acq.device
[docs]
def forward(self, x):
r"""Full pipeline of the reconstruction network
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: reconstructed images with shape :math:`(B,C,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = torch.rand([H**2, H**2])
>>> recnet = DCNet(noise,prep,sigma)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
x = self.acquire(x)
x = self.reconstruct(x)
return x
[docs]
def acquire(self, x):
r"""Simulate data acquisition
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: measurement vectors with shape :math:`(BC,2M)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = torch.rand([H**2, H**2])
>>> recnet = DCNet(noise,prep,sigma)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet.acquire(x)
>>> print(z.shape)
torch.Size([10, 8192])
"""
return self.Acq(x)
[docs]
def reconstruct(self, x):
r"""Reconstruction step of a reconstruction network
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: raw measurement vectors with shape :math:`(BC,2M)`
:attr:`output`: reconstructed images with shape :math:`(BC,1,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> sigma = torch.rand([H**2, H**2])
>>> recnet = DCNet(noise,prep,sigma)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
var_noi = self.prep.sigma(x)
x = self.prep(x)
# x.shape = (*, M), make x_0 (*, h, w)
x_0 = torch.zeros(
(*x.shape[:-1], *self.Acq.meas_op.meas_shape), device=x.device
)
x = self.tikho(x, x_0, var_noi, self.Acq.meas_op)
# Image domain denoising
return self.denoi(x)
[docs]
def reconstruct_expe(self, x):
r"""Reconstruction step of a reconstruction network
Same as :meth:`reconstruct` reconstruct except that:
1. The preprocessing step estimates the image intensity. The
estimated intensity is used for both normalizing the raw
data and computing the variance of the normalized data.
2. The output images are "denormalized", i.e., have units of photon
counts
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
"""
# x of shape [b*c, 2M]
bc, _ = x.shape
# Preprocessing expe
var_noi = self.prep.sigma_expe(x)
x, N0_est = self.prep.forward_expe(x, self.Acq.meas_op) # x <- x/N0_est
x = x / self.prep.gain
norm = self.prep.gain * N0_est
# variance of preprocessed measurements
var_noi = torch.div(
var_noi, (norm.reshape(-1, 1).expand(bc, self.Acq.meas_op.M)) ** 2
)
# measurements to image domain processing
x_0 = torch.zeros((bc, self.Acq.meas_op.N), device=x.device)
x = self.tikho(x, x_0, var_noi, self.Acq.meas_op)
x = x.reshape(
bc, 1, self.Acq.meas_op.h, self.Acq.meas_op.w
) # shape x = [b*c,1,h,w]
# Image domain denoising
x = self.denoi(x) # shape x = [b*c,1,h,w]
# Denormalization
x = self.prep.denormalize_expe(x, norm, self.Acq.meas_op.h, self.Acq.meas_op.w)
return x
# =============================================================================
[docs]
class TikhoNet(nn.Module):
r"""Tikhonov reconstruction network.
This is a two-step reconstruction method. Typically, only the last step involves learnable parameters.
#. Tikhonov regularisation.
#. (Learned) Denoising in the image domain.
Args:
:attr:`noise` (spyrit.core.noise): Acquisition operator (see :mod:`~spyrit.core.noise`)
:attr:`prep` (spyrit.core.prep): Preprocessing operator (see :mod:`~spyrit.core.prep`)
:attr:`sigma` (torch.tensor): Image-domain covariance prior (for details, see the
:class:`~spyrit.core.recon.Tikhonov()` class)
:attr:`denoi` (torch.nn.Module, optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
Input / Output:
:attr:`input` (torch.tensor): Ground-truth images with shape :math:`(B,C,H,W)`.
:attr:`output` (torch.tensor): Reconstructed images with shape :math:`(B,C,H,W)`.
Attributes:
:attr:`acqu`: Acquisition operator initialized as :attr:`noise`
:attr:`prep`: Preprocessing operator initialized as :attr:`prep`
:attr:`tikho`: Data consistency layer initialized as :attr:`Tikhonov(noise.meas_op, sigma)`
:attr:`denoi`: Image denoising operator initialized as :attr:`denoi`
Example:
>>> B, H, M, N = 85, 17, 32, 64
>>> sigma = torch.rand(N, N)
>>> gamma = torch.rand(M, M)
>>> A = torch.rand([M,N])
>>> meas = Linear(A, meas_shape=(1,N))
>>> noise = NoNoise(meas)
>>> prep = DirectPoisson(1, meas)
>>> recon = TikhoNet(noise, prep, sigma)
>>> y = torch.rand(B,H,M)
>>> x = recon(y, gamma)
>>> print(y.shape)
>>> print(x.shape)
torch.Size([85, 17, 32])
torch.Size([85, 17, 1, 64])
"""
def __init__(self, noise, prep, sigma: torch.tensor, denoi=nn.Identity()):
super().__init__()
self.acqu = noise
self.prep = prep
# device = noise.meas_op.H.device # spyrit 2.3.
# sigma = torch.as_tensor(sigma, dtype=torch.float32, device=device)
self.tikho = Tikhonov(noise.meas_op, sigma)
self.denoi = denoi
[docs]
def forward(self, x):
"""Full pipeline (image-to-image mapping)
Args:
x (torch.tensor): Ground-truth images with shape :math:`(B,C,H,W)`.
Returns:
torch.tensor: Reconstruction images with shape :math:`(B,C,H,W)`.
"""
# Acquisition
x = self.acqu(x) # shape x = [b*c, 2*M]
# Reconstruction
x = self.reconstruct(x) # shape x = [bc, 1, h,w]
return x
[docs]
def reconstruct(self, x):
r"""Reconstruction (measurement-to-image mapping)
Args:
:attr:`x` (torch.tensor): Raw measurement vectors with shape :math:`(B,C,M)`.
Output:
torch.tensor: Reconstructed images with shape :math:`(B,C,H,W)`
"""
# Preprocessing
cov_meas = self.prep.sigma(x)
x = self.prep(x)
# covariance of measurements
cov_meas = torch.diag_embed(cov_meas) #
# print(cov_meas)
# measurements to image domain processing
x = self.tikho(x, cov_meas)
x = x.reshape(*x.shape[:-1], self.acqu.meas_op.h, self.acqu.meas_op.w)
# Image domain denoising
x = self.denoi(x)
return x
[docs]
def reconstruct_expe(self, x):
r"""Reconstruction (measurement-to-image mapping) for experimental data.
Args:
:attr:`x` (torch.tensor): Raw measurement vectors with shape :math:`(B,C,M)`.
Output:
torch.tensor: Reconstructed images with shape :math:`(B,C,H,W)`
"""
# Preprocessing
cov_meas = self.prep.sigma_expe(x)
# print(cov_meas)
# print(self.prep.nbin, self.prep.mudark)
# x, norm = self.prep.forward_expe(x, self.acqu.meas_op, (-2,-1)) # shape: [*, M]
# Alternative where the mean is computed on each row
x, norm = self.prep.forward_expe(x, self.acqu.meas_op) # shape: [*, M]
# covariance of measurements
cov_meas = cov_meas / norm**2
cov_meas = torch.diag_embed(cov_meas)
# measurements to image domain processing
x = self.tikho(x, cov_meas)
x = x.reshape(*x.shape[:-1], self.acqu.meas_op.h, self.acqu.meas_op.w)
# Image domain denoising
x = self.denoi(x)
# Denormalization
x = self.prep.denormalize_expe(x, norm, x.shape[-2], x.shape[-1])
return x, norm
# =============================================================================
[docs]
class LearnedPGD(nn.Module):
r"""Learned Proximal Gradient Descent reconstruction network.
Iterative algorithm that alternates between a gradient step and a proximal step,
where the proximal operator is learned denoiser. The update rule is given by:
:math:`x_{k+1} = prox(\hat{x_k} - step * H^T (Hx_k - y))=
denoi(\hat{x_k} - step * H^T (Hx_k - y))`
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
:attr:`prep`: Preprocessing operator (see :class:`~spyrit.core.prep`)
:attr:`denoi` (optional): Image denoising operator
(see :class:`~spyrit.core.nnet`).
Default :class:`~spyrit.core.nnet.Identity`
:attr:`iter_stop` (int): Number of iterations of the LPGD algorithm
(commonly 3 to 10, trade-off between accuracy and speed).
Default 3 (for speed and with higher accuracy than post-processing denoising)
:attr:`step` (float): Step size of the LPGD algorithm. Default is None,
and it is estimated as the inverse of the Lipschitz constant of the gradient of the
data fidelity term.
- If :math:`meas_op.N` is available, the step size is estimated as
:math:`step=1/L=1/\text{meas_op.N}`, true for Hadamard operators.
- If not, the step size is estimated from by computing
the Lipschitz constant as the largest singular value of the
Hessians, :math:`L=\lambda_{\max}(H^TH)`. If this fails,
the step size is set to 1e-4.
:attr:`step_estimation` (bool): Default False. See :attr:`step` for details.
:attr:`step_grad` (bool): Default False. If True, the step size is learned
as a parameter of the network. Not tested yet.
:attr:`wls` (bool): Default False. If True, the data fidelity term is
modified to be the weighted least squares (WLS) term, which approximates
the Poisson likelihood. In this case, the data fidelity term is
:math:`\|Hx-y\|^2_{C^{-1}}`, where :math:`C` is the covariance matrix.
We assume that :math:`C` is diagonal, and the diagonal elements are
the measurement noise variances, estimated from :class:`~spyrit.core.prep.sigma`.
:attr:`gt` (torch.tensor): Ground-truth images. If available, the mean
squared error (MSE) is computed and logged. Default None.
:attr:`log_fidelity` (bool): Default False. If True, the data fidelity term
is logged for each iteration of the LPGD algorithm.
Input / Output:
:attr:`input`: Ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
:attr:`prep`: Preprocessing operator initialized as :attr:`prep`
:attr:`pinv`: Analytical reconstruction operator initialized as
:class:`~spyrit.core.recon.PseudoInverse()`
:attr:`Denoi`: Image denoising operator initialized as :attr:`denoi`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = LearnedPGD(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
tensor(5.8912e-06)
"""
def __init__(
self,
noise,
prep,
denoi=nn.Identity(),
iter_stop=3,
x0=0,
step=None,
step_estimation=False,
step_grad=False,
step_decay=1,
wls=False,
gt=None,
log_fidelity=False,
res_learn=False,
):
super().__init__()
# nn.module
self.acqu = noise
self.prep = prep
self.denoi = denoi
self.pinv = PseudoInverse()
# LPGD algo
self.x0 = x0
self.iter_stop = iter_stop
self.step = step
self.step_estimation = step_estimation
self.step_grad = step_grad
self.step_decay = step_decay
self.res_learn = res_learn
# Init step size (estimate)
self.set_stepsize(step)
# WLS
self.wls = wls
# Log fidelity
self.log_fidelity = log_fidelity
# Log MSE (Ground truth available)
if gt is not None:
self.x_gt = nn.Parameter(
torch.tensor(gt.reshape(gt.shape[0], -1)), requires_grad=False
)
else:
self.x_gt = None
[docs]
def step_schedule(self, step):
if self.step_decay != 1:
step = [step * self.step_decay**i for i in range(self.iter_stop)]
elif self.iter_stop > 1:
step = [step for i in range(self.iter_stop)]
else:
step = [step]
return step
[docs]
def set_stepsize(self, step):
if step is None:
# Stimate stepsize from Lipschitz constant
if hasattr(self.acqu.meas_op, "N"):
step = 1 / self.acqu.meas_op.N
else:
# Estimate step size as 1/sv_max(H^TH); if failed, set to 1e-4
self.step_estimation = True
step = 1e-4
step = self.step_schedule(step)
# step = nn.Parameter(torch.tensor(step), requires_grad=self.step_grad)
step = PositiveParameters(step, requires_grad=self.step_grad)
self.step = step
[docs]
def forward(self, x):
r"""Full pipeline of reconstruction network
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: reconstructed images with shape :math:`(B,C,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = LearnedPGD(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
tensor(5.8912e-06)
"""
x = self.acquire(x)
x = self.reconstruct(x)
return x
[docs]
def acquire(self, x):
r"""Simulate data acquisition
Args:
:attr:`x`: ground-truth images
Shape:
:attr:`x`: ground-truth images with shape :math:`(B,C,H,W)`
:attr:`output`: measurement vectors with shape :math:`(BC,2M)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet.acquire(x)
>>> print(z.shape)
torch.Size([10, 8192])
"""
return self.acqu(x)
[docs]
def hessian_sv(self):
H = self.acqu.meas_op.H
if self.wls:
std_mat = 1 / torch.sqrt(self.meas_variance)
std_mat = torch.diag(std_mat.reshape(-1))
H = torch.matmul(std_mat, H)
try:
s = torch.linalg.svdvals(torch.matmul(H.t(), H))
except:
print("svdvals(H^T*H) failed, trying svdvals(H) instead")
s = torch.linalg.svdvals(H) ** 2
return s
[docs]
def stepsize_gd(self):
s = self.hessian_sv()
self.step = 2 / (s.min() + s.max()) # Kressner, EPFL, GD #1/(2*s.max()**2)
[docs]
def cost_fun(self, x, y):
proj = self.acqu.meas_op.forward_H(x)
res = proj - y
if self.wls:
res = res / torch.sqrt(self.meas_variance)
return torch.linalg.norm(res) ** 2
[docs]
def mse_fun(self, x, x_gt):
return torch.linalg.norm(x - x_gt)
[docs]
def reconstruct(self, x):
r"""Reconstruction step of a reconstruction network
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H**2)
>>> recnet = PinvNet(noise, prep)
>>> x = torch.rand((B*C,2*M), dtype=torch.float)
>>> z = recnet.reconstruct(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
"""
# Compute the stepsize from the Lipschitz constant
if self.step_estimation:
self.stepsize_gd()
step = self.step
if not isinstance(step, torch.Tensor):
step = step.params
# Preprocessing in the measurement domain
m = self.prep(x)
if self.wls:
# Get variance of the measurements
if hasattr(self.prep, "sigma"):
meas_variance = self.prep.sigma(x)
self.meas_variance = meas_variance
else:
print(
"WLS requires the variance of the measurements to be known!. Estimating var==m"
)
meas_variance = m
# Normalize the stepsize to account for the variance
meas_variance_img_min, _ = torch.min(meas_variance, 1) # 128
step = step.reshape(self.iter_stop, 1).to(x.device)
# Multiply meas_variance_img_min and step
step = meas_variance_img_min * step
# If pinv method is defined
if self.x0 != 0:
if hasattr(self.acqu.meas_op, "pinv"):
x = self.acqu.meas_op.pinv(m)
# proximal step (prior)
if isinstance(self.denoi, nn.ModuleList):
x = self.denoi[0](x)
else:
x = self.denoi(x)
if self.res_learn:
z0 = x.detach().clone()
else:
# zero init
x = torch.zeros(
(*x.shape[:-1], *self.acqu.meas_op.meas_shape), device=x.device
)
print("x shape:", x.shape)
if self.log_fidelity:
self.cost = []
with torch.no_grad():
# data_fidelity.append(self.data_fidelity(torch.zeros_like(x), m).cpu().numpy().tolist())
self.cost.append(self.cost_fun(x, m).cpu().numpy().tolist())
if self.x_gt is not None:
self.mse = []
with torch.no_grad():
self.mse.append(self.mse_fun(x, self.x_gt).cpu().numpy().tolist())
u = None
for i in range(self.iter_stop):
# gradient step (data fidelity)
res = self.acqu.meas_op.forward_H(x) - m
if self.wls:
res = res / meas_variance
upd = step[i].reshape(-1, 1) * self.acqu.meas_op.adjoint(res)
else:
upd = step[i] * self.acqu.meas_op.adjoint(res)
x = x - upd
if i == 0 and self.res_learn and self.x0 == 0:
# if x0 does not exist
z0 = x.detach().clone()
# proximal step (prior)
if isinstance(self.denoi, nn.ModuleList):
x = self.denoi[i](x)
else:
x = self.denoi(x)
if self.log_fidelity:
with torch.no_grad():
self.cost.append(self.cost_fun(x, m).cpu().numpy().tolist())
# Compute mse if ground truth is field
if self.x_gt is not None:
with torch.no_grad():
self.mse.append(self.mse_fun(x, self.x_gt).cpu().numpy().tolist())
if self.log_fidelity:
print(f"Data fidelity: {(self.cost)}. Stepsize: {self.step}")
if self.x_gt is not None:
print(f"|x - x_gt| = {self.mse}")
if self.res_learn:
# z=x-step*grad(L), x = P(z), x_end = z0 + P(z)
x = x + z0
return x
[docs]
def reconstruct_expe(self, x):
r"""Reconstruction step of a reconstruction network
Same as :meth:`reconstruct` reconstruct except that:
1. The preprocessing step estimates the image intensity for normalization
2. The output images are "denormalized", i.e., have units of photon counts
Args:
:attr:`x`: raw measurement vectors
Shape:
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
"""
# Preprocessing
x, N0_est = self.prep.forward_expe(x, self.acqu.meas_op) # shape x = [b*c, M]
# print(N0_est)
# measurements to image domain processing
x = self.pinv(x, self.acqu.meas_op)
# Denoising
x = self.denoi(x)
# print(x.max())
# Denormalization
x = self.prep.denormalize_expe(
x, N0_est, self.acqu.meas_op.h, self.acqu.meas_op.w
)
return x