Source code for spyrit.core.nnet

"""
Neural network models for image denoising.
"""

# from __future__ import print_function, division
import torch
import torch.nn as nn
from collections import OrderedDict
import copy


# =============================================================================
[docs] class Unet(nn.Module): # ============================================================================= def __init__(self, in_channel=1, out_channel=1): super(Unet, self).__init__() # Descending branch self.conv_encode1 = self.contract(in_channels=in_channel, out_channels=16) self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode2 = self.contract(16, 32) self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode3 = self.contract(32, 64) self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2) # Bottleneck self.bottleneck = self.bottle_neck(64) # Decode branch self.conv_decode4 = self.expans(64, 64, 64) self.conv_decode3 = self.expans(128, 64, 32) self.conv_decode2 = self.expans(64, 32, 16) self.final_layer = self.final_block(32, 16, out_channel)
[docs] def contract(self, in_channels, out_channels, kernel_size=3, padding=1): block = torch.nn.Sequential( torch.nn.Conv2d( kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=padding, ), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), torch.nn.Conv2d( kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels, padding=padding, ), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block
[docs] def expans(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1): block = torch.nn.Sequential( torch.nn.Conv2d( kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=padding, ), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d( kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=padding, ), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.ConvTranspose2d( in_channels=mid_channel, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=1, ), ) return block
[docs] def concat(self, upsampled, bypass): out = torch.cat((upsampled, bypass), 1) return out
[docs] def bottle_neck(self, in_channels, kernel_size=3, padding=1): bottleneck = torch.nn.Sequential( torch.nn.Conv2d( kernel_size=kernel_size, in_channels=in_channels, out_channels=2 * in_channels, padding=padding, ), torch.nn.ReLU(), torch.nn.Conv2d( kernel_size=kernel_size, in_channels=2 * in_channels, out_channels=in_channels, padding=padding, ), torch.nn.ReLU(), ) return bottleneck
[docs] def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d( kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1, ), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d( kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1, ), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d( kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1, ), ) return block
[docs] def forward(self, x): # Encode encode_block1 = self.conv_encode1(x) x = self.conv_maxpool1(encode_block1) encode_block2 = self.conv_encode2(x) x = self.conv_maxpool2(encode_block2) encode_block3 = self.conv_encode3(x) x = self.conv_maxpool3(encode_block3) # Bottleneck x = self.bottleneck(x) # Decode x = self.conv_decode4(x) x = self.concat(x, encode_block3) x = self.conv_decode3(x) x = self.concat(x, encode_block2) x = self.conv_decode2(x) x = self.concat(x, encode_block1) x = self.final_layer(x) return x
# ===========================================================================================
[docs] class ConvNet(nn.Module): # =========================================================================================== def __init__(self): super(ConvNet, self).__init__() self.convnet = nn.Sequential( OrderedDict( [ ("conv1", nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)), ("relu1", nn.ReLU()), ("conv2", nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0)), ("relu2", nn.ReLU()), ("conv3", nn.Conv2d(32, 1, kernel_size=5, stride=1, padding=2)), ] ) )
[docs] def forward(self, x): x = self.convnet(x) return x
# ===========================================================================================
[docs] class ConvNetBN(nn.Module): # =========================================================================================== def __init__(self): super(ConvNetBN, self).__init__() self.convnet = nn.Sequential( OrderedDict( [ ("conv1", nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)), ("relu1", nn.ReLU()), ("BN1", nn.BatchNorm2d(64)), ("conv2", nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0)), ("relu2", nn.ReLU()), ("BN2", nn.BatchNorm2d(32)), ("conv3", nn.Conv2d(32, 1, kernel_size=5, stride=1, padding=2)), ] ) )
[docs] def forward(self, x): x = self.convnet(x) return x
# ===========================================================================================
[docs] class DConvNet(nn.Module): # =========================================================================================== def __init__(self): super(DConvNet, self).__init__() self.convnet = nn.Sequential( OrderedDict( [ ("conv1", nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)), ("relu1", nn.ReLU()), ("BN1", nn.BatchNorm2d(64)), ("conv2", nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)), ("relu2", nn.ReLU()), ("BN2", nn.BatchNorm2d(64)), ("conv3", nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)), ("relu3", nn.ReLU()), ("BN3", nn.BatchNorm2d(32)), ("conv4", nn.Conv2d(32, 1, kernel_size=5, stride=1, padding=2)), ] ) )
[docs] def forward(self, x): x = self.convnet(x) return x
# ===========================================================================================
[docs] class List_denoi(nn.Module): # =========================================================================================== def __init__(self, Denoi, n_denoi): super(List_denoi, self).__init__() self.n_denoi = n_denoi conv_list = [] for i in range(n_denoi): conv_list.append(copy.deepcopy(Denoi)) self.conv = nn.ModuleList(conv_list)
[docs] def forward(self, x, iterate): index = min(iterate, self.n_denoi - 1) x = self.conv[index](x) return x
# ===========================================================================================
[docs] class Identity(nn.Module): # Can be useful for ablation study # =========================================================================================== def __init__(self): super(Identity, self).__init__()
[docs] def forward(self, x): return x