spyrit.core.nnet.Unet

class spyrit.core.nnet.Unet(in_channel: int = 1, out_channel: int = 1, upsample: bool = False, upsample_mode: str = 'nearest')[source]

Bases: Module

Defines a U-Net model.

This U-Net model is mostly used for image denoising, but can be used for other image-to-image tasks.

The model is composed of a descending branch and an ascending branch. The descending branch is composed of three convolutional blocks, each followed by a max pooling layer. The bottleneck is a convolutional block with two convolutional layers. The ascending branch is composed of three convolutional blocks, each followed by an upsampling layer. The final layer is a convolutional block with two convolutional layers.

The upsampling layer can be either a transposed convolution or an upsampling followed by a convolution. The upsampling method can be specified using the upsample and upsample_mode arguments.

Args:

in_channel (int): Number of input channels.

out_channel (int): Number of output channels.

upsample (bool): If True, use an upsampling layer followed by a convolution layer in the ascending branch instead of a transposed convolution.

upsample_mode (str): The upsampling method to use. It is directly passed to the mode argument of the torch.nn.Upsample class. It can be either ‘nearest’, ‘bilinear’, or ‘bicubic’.

Attributes:

upsample (bool): If True, use an upsampling layer followed by a convolution layer in the ascending branch instead of a transposed convolution.

upsample_mode (str): The upsampling method to use. It is directly passed to the mode argument of the torch.nn.Upsample class. It can be either ‘nearest’, ‘bilinear’, or ‘bicubic’.

conv_encode1 (torch.nn.Sequential): The first convolutional block of the descending branch.

conv_maxpool1 (torch.nn.MaxPool2d): The first max pooling layer.

conv_encode2 (torch.nn.Sequential): The second convolutional block of the descending branch.

conv_maxpool2 (torch.nn.MaxPool2d): The second max pooling layer.

conv_encode3 (torch.nn.Sequential): The third convolutional block of the descending branch.

conv_maxpool3 (torch.nn.MaxPool2d): The third max pooling layer.

bottleneck (torch.nn.Sequential): The bottleneck block.

conv_decode4 (torch.nn.Sequential): The first convolutional block of the ascending branch.

conv_decode3 (torch.nn.Sequential): The second convolutional block of the ascending branch.

conv_decode2 (torch.nn.Sequential): The third convolutional block of the ascending branch.

final_layer (torch.nn.Sequential): The final convolutional block.

Example:
>>> model = Unet(in_channel=1, out_channel=1, upsample=True, upsample_mode='nearest')
>>> x = torch.randn(1, 1, 256, 256)
>>> y = model(x)

Methods

bottle_neck(in_channels[, kernel_size, padding])

Defines the bottleneck block of the U-Net model.

concat(upsampled, bypass)

Concatenates two tensors along the channel dimension.

contract(in_channels, out_channels[, ...])

Defines a convolutional block.

expans(in_channels, mid_channel, out_channels)

Defines an upsampling block.

final_block(in_channels, mid_channel, ...[, ...])

Defines the final block of the U-Net model.

forward(x)

Forward pass of the U-Net model.