spyrit.core.recon.Denoise_layer
- class spyrit.core.recon.Denoise_layer(std_dev_prior_or_size: tensor | int, requires_grad=True)[source]
Bases:
ModuleDefines a learnable Wiener filter that assumes additive white Gaussian noise.
The filter is pre-defined upon initialization with the standard deviation prior (if known), or with an integer representing the size of the input vector. In the second case, the standard deviation prior is initialized at random from a uniform (0,2/size) distribution.
Using the foward method (the implicit call method), the filter is fully defined:
\[\sigma_\text{prior}^2/(\sigma^2_\text{prior} + \sigma^2_\text{meas})\]where \(\sigma^2_\text{prior}\) is the variance prior defined at initialization and \(\sigma^2_\text{meas}\) is the measurement variance defined using the forward method. The value given by the equation above can then be multiplied by the measurement vector to obtain the denoised measurement vector.
Note
The weight (defined at initialization or accessible through the attribute
weight) should not be squared (as it is squared when the forward method is called).- Args:
std_dev_or_size(torch.tensor or int): 1D tensor representing the standard deviation prior or an integer defining the size of the randomly-initialized standard deviation prior. If an array is passed and it is not 1D, it is flattened. It is stored internally as ann.Parameter, whosedataattribute is accessed through thesigmaattribute, and whoserequires_gradattribute is accessed through therequires_gradattribute.- Shape for forward call:
Input: \((*, in\_features)\) measurement variance.
Output: \((*, in\_features)\) fully defined Wiener filter.
- Attributes:
weight: The learnable standard deviation prior \(\sigma_\text{prior}\) of shape \((in\_features, 1)\). The values are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = 1/in\_features\).sigma: The learnable standard deviation prior \(\sigma_\text{prior}\) of shape \((, in\_features)\). If the input is an integer, the standard deviation prior is initialized at random from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\), where \(k = 1/in\_features\).in_features: The number of input features.requires_grad: A boolean indicating whether the autograd should record operations on the standard deviation tensor. Default is True.- Example:
>>> m = Denoise_layer(30) >>> input = torch.randn(128, 30) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])
Methods
forward(sigma_meas_squared)Fully defines the Wiener filter with the measurement variance.
Resets the standard deviation prior \(\sigma_\text{prior}\).
tikho(inputs, weight)Applies a transformation to the incoming data: \(y = \sigma_\text{prior}^2/(\sigma_\text{prior}^2+x)\).