spyrit.core.nnet.Unet.forward

Unet.forward(x: tensor) tensor[source]

Forward pass of the U-Net model.

The number of channels in the input tensor must be equal to the number of channels specified at initialization.

Args:

x (torch.tensor): The input tensor. It is expected to have the shape b, in_channel, h, w, where b is the batch size, in_channel is the number of input channels, h is the height, and w is the width.

Returns:

torch.tensor: The output tensor of the U-Net model. It has shape b, out_channel, h, w, where out_channel is the number of output channels specified at initialization.