spyrit.core.recon.FullNet
- class spyrit.core.recon.FullNet(acqu_modules: OrderedDict | Sequential, recon_modules: OrderedDict | Sequential, *, device: device = device(type='cpu'))[source]
Bases:
SequentialDefines an arbitrary full (measurement + reconstruction) network.
The forward pass of this network simulates measurements of a signal (or image) and reconstructs it from the measurements. To this end, it sequentially applies the measurement and reconstruction modules stored in the network under the keys acqu_modules and recon_modules, respectively.
The modules contained within the measurement and reconstruction modules can be arbitrary.
- Args:
acqu_modules (Union[OrderedDict, nn.Sequential]): Measurement modules.
recon_modules (Union[OrderedDict, nn.Sequential]): Reconstruction modules.
- Raises:
TypeError: If acqu_modules or recon_modules are not of type
OrderedDictornn.Sequential.- Attributes:
acqu_modules (nn.Sequential): Measurement modules.
recon_modules (nn.Sequential): Reconstruction modules.
- Example:
# >>> import torch.nn as nn # >>> acqu1 = lambda x: x*2 # >>> acqu2 = lambda x: x - 10 # >>> acqu = nn.Sequential(acqu1, acqu2) # >>> recon1 = lambda x: (x + 10) / 2 # >>> recon = nn.Sequential(recon1) # >>> net = FullNet(acqu, recon)
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Apply the reconstruction modules to the input measurements.