spyrit.core.recon.Denoise_layer

class spyrit.core.recon.Denoise_layer(M: int)[source]

Bases: Module

Wiener filter that assumes additive white Gaussian noise.

\[y = \sigma_\text{prior}^2/(\sigma^2_\text{prior} + \sigma^2_\text{meas}) x, where :math:`\sigma^2_\text{prior}` is the variance prior and :math:`\sigma^2_\text{meas}` is the variance of the measurement, x is the input vector and y is the output vector.\]
Args:

M (int): size of incoming vector

Shape:
  • Input: \((*, M)\).

  • Output: \((*, M)\).

Attributes:

weight: The learnable standard deviation prior \(\sigma_\text{prior}\) of shape \((M, 1)\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = 1/M\).

in_features: The number of input features equal to \(M\).

Example:
>>> m = Denoise_layer(30)
>>> input = torch.randn(128, 30)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

Methods

forward(inputs)

Applies a transformation to the incoming data: \(y = A^2/(A^2+x)\).

reset_parameters()

Resets the standard deviation prior \(\sigma_\text{prior}\).

tikho(inputs, weight)

Applies a transformation to the incoming data: \(y = A^2/(A^2+x)\).