spyrit.core.recon.FullNet
- class spyrit.core.recon.FullNet(acqu_modules: OrderedDict | Sequential, recon_modules: OrderedDict | Sequential, *, device: device = device(type='cpu'))[source]
Bases:
SequentialDefines an arbitrary full (measurement + reconstruction) network.
The forward pass of this network simulates measurements of a signal (or image) and reconstructs it from the measurements. To this end, it sequentially applies the measurement and reconstruction modules stored in the network under the keys acqu_modules and recon_modules, respectively.
The modules contained within the measurement and reconstruction modules can be arbitrary.
- Args:
acqu_modules (Union[OrderedDict, nn.Sequential]): Measurement modules.
recon_modules (Union[OrderedDict, nn.Sequential]): Reconstruction modules.
- Raises:
TypeError: If acqu_modules or recon_modules are not of type
OrderedDictornn.Sequential.- Attributes:
acqu_modules (nn.Sequential): Measurement modules.
recon_modules (nn.Sequential): Reconstruction modules.
- Example:
>>> import torch.nn as nn >>> acqu1 = nn.Linear(10,5) >>> acqu2 = nn.Sigmoid() >>> acqu = nn.Sequential(acqu1, acqu2) >>> recon1 = nn.Linear(2,5) >>> recon = nn.Sequential(recon1) >>> net = FullNet(acqu, recon) >>> print(net) FullNet( (acqu_modules): Sequential( (0): Linear(in_features=10, out_features=5, bias=True) (1): Sigmoid() ) (recon_modules): Sequential( (0): Linear(in_features=2, out_features=5, bias=True) ) )
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Apply the reconstruction modules to the input measurements.