spyrit.core.recon.PinvNet.forward

PinvNet.forward(x)

Apply the full network to the input signal.

This is done by first simulating measurements of the input signal from the stored measurement modules self.acqu_modules. The measurements are then passed to the reconstruction modules self.recon_modules to reconstruct the signal.

Args:

x (torch.tensor): input tensor. For images, it is usually shaped (b, c, h, w) where b is the batch size, c is the number of channels, and h and w are the height and width of the images.

Returns:

torch.tensor: output tensor. Its shape depends on the output of the reconstruction modules.

Example:
>>> acqu1 = nn.Linear(10,5)
>>> acqu2 = nn.Sigmoid()
>>> acqu = nn.Sequential(acqu1, acqu2)
>>> recon1 = nn.Linear(5,2)
>>> recon = nn.Sequential(recon1)
>>> net = FullNet(acqu, recon)
>>> x = torch.ones(2, 10)
>>> y = net(x)
>>> print(y.shape)
torch.Size([2, 2])
>>> print(y)
tensor([[...],
        [...]], grad_fn=<AddmmBackward0>)