spyrit.core.nnet.Unet
- class spyrit.core.nnet.Unet(in_channel: int = 1, out_channel: int = 1, upsample: bool = False, upsample_mode: str = 'nearest')[source]
Bases:
ModuleDefines a U-Net model.
This U-Net model is mostly used for image denoising, but can be used for other image-to-image tasks.
The model is composed of a descending branch and an ascending branch. The descending branch is composed of three convolutional blocks, each followed by a max pooling layer. The bottleneck is a convolutional block with two convolutional layers. The ascending branch is composed of three convolutional blocks, each followed by an upsampling layer. The final layer is a convolutional block with two convolutional layers.
The upsampling layer can be either a transposed convolution or an upsampling followed by a convolution. The upsampling method can be specified using the upsample and upsample_mode arguments.
- Args:
in_channel (int): Number of input channels.
out_channel (int): Number of output channels.
upsample (bool): If True, use an upsampling layer followed by a convolution layer in the ascending branch instead of a transposed convolution.
upsample_mode (str): The upsampling method to use. It is directly passed to the mode argument of the torch.nn.Upsample class. It can be either ‘nearest’, ‘bilinear’, or ‘bicubic’.
- Attributes:
upsample (bool): If True, use an upsampling layer followed by a convolution layer in the ascending branch instead of a transposed convolution.
upsample_mode (str): The upsampling method to use. It is directly passed to the mode argument of the torch.nn.Upsample class. It can be either ‘nearest’, ‘bilinear’, or ‘bicubic’.
conv_encode1 (torch.nn.Sequential): The first convolutional block of the descending branch.
conv_maxpool1 (torch.nn.MaxPool2d): The first max pooling layer.
conv_encode2 (torch.nn.Sequential): The second convolutional block of the descending branch.
conv_maxpool2 (torch.nn.MaxPool2d): The second max pooling layer.
conv_encode3 (torch.nn.Sequential): The third convolutional block of the descending branch.
conv_maxpool3 (torch.nn.MaxPool2d): The third max pooling layer.
bottleneck (torch.nn.Sequential): The bottleneck block.
conv_decode4 (torch.nn.Sequential): The first convolutional block of the ascending branch.
conv_decode3 (torch.nn.Sequential): The second convolutional block of the ascending branch.
conv_decode2 (torch.nn.Sequential): The third convolutional block of the ascending branch.
final_layer (torch.nn.Sequential): The final convolutional block.
- Example:
>>> model = Unet(in_channel=1, out_channel=1, upsample=True, upsample_mode='nearest') >>> x = torch.randn(1, 1, 256, 256) >>> y = model(x)
Methods
bottle_neck(in_channels[, kernel_size, padding])Defines the bottleneck block of the U-Net model.
concat(upsampled, bypass)Concatenates two tensors along the channel dimension.
contract(in_channels, out_channels[, ...])Defines a convolutional block.
expans(in_channels, mid_channel, out_channels)Defines an upsampling block.
final_block(in_channels, mid_channel, ...[, ...])Defines the final block of the U-Net model.
forward(x)Forward pass of the U-Net model.