spyrit.core.recon.FullNet

class spyrit.core.recon.FullNet(acqu_modules: OrderedDict | Sequential, recon_modules: OrderedDict | Sequential, *, device: device = device(type='cpu'))[source]

Bases: Sequential

Defines an arbitrary full (measurement + reconstruction) network.

The forward pass of this network simulates measurements of a signal (or image) and reconstructs it from the measurements. To this end, it sequentially applies the measurement and reconstruction modules stored in the network under the keys acqu_modules and recon_modules, respectively.

The modules contained within the measurement and reconstruction modules can be arbitrary.

Args:

acqu_modules (Union[OrderedDict, nn.Sequential]): Measurement modules.

recon_modules (Union[OrderedDict, nn.Sequential]): Reconstruction modules.

Raises:

TypeError: If acqu_modules or recon_modules are not of type OrderedDict or nn.Sequential.

Attributes:

acqu_modules (nn.Sequential): Measurement modules.

recon_modules (nn.Sequential): Reconstruction modules.

Example:
>>> import torch.nn as nn
>>> acqu1 = nn.Linear(10,5)
>>> acqu2 = nn.Sigmoid()
>>> acqu = nn.Sequential(acqu1, acqu2)
>>> recon1 = nn.Linear(2,5)
>>> recon = nn.Sequential(recon1)
>>> net = FullNet(acqu, recon)
>>> print(net)
FullNet(
  (acqu_modules): Sequential(
    (0): Linear(in_features=10, out_features=5, bias=True)
    (1): Sigmoid()
  )
  (recon_modules): Sequential(
    (0): Linear(in_features=2, out_features=5, bias=True)
  )
)

Methods

acquire(x)

Apply the measurement modules to the input signal.

forward(x)

Apply the full network to the input signal.

reconstruct(y)

Apply the reconstruction modules to the input measurements.