spyrit.core.recon.LearnedPGD.forward

LearnedPGD.forward(x)[source]

Full pipeline of reconstruction network

Args:

x: ground-truth images

Shape:

x: ground-truth images with shape \((B,C,H,W)\)

output: reconstructed images with shape \((B,C,H,W)\)

Example:
>>> B, C, H, M = 10, 1, 64, 64**2
>>> Ord = torch.ones((H,H))
>>> meas = HadamSplit(M, H, Ord)
>>> noise = NoNoise(meas)
>>> prep = SplitPoisson(1.0, M, H*H)
>>> recnet = LearnedPGD(noise, prep)
>>> x = torch.FloatTensor(B,C,H,H).uniform_(-1, 1)
>>> z = recnet(x)
>>> print(z.shape)
torch.Size([10, 1, 64, 64])
>>> print(torch.linalg.norm(x - z)/torch.linalg.norm(x))
tensor(5.8912e-06)