spyrit.core.recon.FullNet.forward

FullNet.forward(x)[source]

Apply the full network to the input signal.

This is done by first simulating measurements of the input signal from the stored measurement modules self.acqu_modules. The measurements are then passed to the reconstruction modules self.recon_modules to reconstruct the signal.

Args:

x (torch.tensor): input tensor. For images, it is usually shaped (b, c, h, w) where b is the batch size, c is the number of channels, and h and w are the height and width of the images.

Returns:

torch.tensor: output tensor. Its shape depends on the output of the reconstruction modules.

Example:

# >>> acqu1 = lambda x: x*2 # >>> acqu2 = lambda x: x - 10 # >>> acqu = nn.Sequential(acqu1, acqu2) # >>> recon1 = lambda x: (x + 10) / 2 # >>> recon = nn.Sequential(recon1) # >>> net = FullNet(acqu, recon) # >>> x = torch.tensor(5.0) # >>> y = net(x) # >>> print(y) tensor(5.0000)