spyrit.core.recon.DCNet
- class spyrit.core.recon.DCNet(acqu: HadamSplit2d, prep: Rescale | RescaleEstim, sigma: tensor, denoi=Identity(), *, device: device = device(type='cpu'))[source]
Bases:
_PrebuiltFullNetPre-built
FullNetthat uses aTikhonovMeasurementPriorDiagreconstruction operator.As a
FullNet, this network has two modules: one for the measurements and one for the reconstruction.The measurement module only contains the acquisition operator. The acquisition operator must be a
spyrit.core.meas.HadamSplit2doperator. The reconstruction module contains a preprocessing operator, a Tikhonov regularizationspyrit.core.inverse.TikhonovMeasurementPriorDiagreconstruction operator, and a denoising operator.- Args:
acqu: Acquisition operator (seeHadamSplit2d)prep: Preprocessing operator (seeprep)sigma: Measurement covariance prior (for details, see theTikhonovMeasurementPriorDiag()class)denoi(optional): Image denoising operator (seennet). DefaultIdentity- Attributes:
acqu: Acquisition operator initialized asacquacqu_modules(nn.Sequential): Measurement modules. Only contains the acquisition operator.prep: Preprocessing operator initialized aspreptikho: Tikhonv regularization operator initialized as aTikhonovMeasurementPriorDiagoperator.denoi: Image denoising operator initialized asdenoirecon_modules(nn.Sequential): Reconstruction modules. Contains the preprocessing operator, the Tikhonov regularizaiton operator, and the denoising operator.- Input / Output:
input: Ground-truth images with shape \((b,c,h,w)\), with \(b\) being the batch size, \(c\) the number of channels, and \(h\) and \(w\) the height and width of the images.output: Reconstructed images with shape \((b,c,h,w)\).- Example:
>>> from spyrit.core.meas import HadamSplit2d >>> from spyrit.core.prep import UnsplitRescale >>> from spyrit.core.recon import DCNet >>> import torch >>> acqu = HadamSplit2d(32) >>> prep = UnsplitRescale() >>> sigma = torch.eye(32*32,32*32) >>> dcnet = DCNet(acqu, prep, sigma) >>> y = torch.randn(10, 1, 2048) >>> z = dcnet.reconstruct(y) >>> print(z.shape) torch.Size([10, 1, 32, 32])
Methods
acquire(x)Apply the measurement modules to the input signal.
forward(x)Apply the full network to the input signal.
reconstruct(y)Reconstruct an image from measurements.
Reconstruct an image from measurements without denoising.