spyrit.core.recon.FullNet

class spyrit.core.recon.FullNet(acqu_modules: OrderedDict | Sequential, recon_modules: OrderedDict | Sequential, *, device: device = device(type='cpu'))[source]

Bases: Sequential

Defines an arbitrary full (measurement + reconstruction) network.

The forward pass of this network simulates measurements of a signal (or image) and reconstructs it from the measurements. To this end, it sequentially applies the measurement and reconstruction modules stored in the network under the keys acqu_modules and recon_modules, respectively.

The modules contained within the measurement and reconstruction modules can be arbitrary.

Args:

acqu_modules (Union[OrderedDict, nn.Sequential]): Measurement modules.

recon_modules (Union[OrderedDict, nn.Sequential]): Reconstruction modules.

Raises:

TypeError: If acqu_modules or recon_modules are not of type OrderedDict or nn.Sequential.

Attributes:

acqu_modules (nn.Sequential): Measurement modules.

recon_modules (nn.Sequential): Reconstruction modules.

Example:

# >>> import torch.nn as nn # >>> acqu1 = lambda x: x*2 # >>> acqu2 = lambda x: x - 10 # >>> acqu = nn.Sequential(acqu1, acqu2) # >>> recon1 = lambda x: (x + 10) / 2 # >>> recon = nn.Sequential(recon1) # >>> net = FullNet(acqu, recon)

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Apply the reconstruction modules to the input measurements.