"""
Neural network models for image denoising.
"""
# from __future__ import print_function, division
import torch
import torch.nn as nn
from collections import OrderedDict
import copy
# =============================================================================
[docs]
class Unet(nn.Module):
# =============================================================================
def __init__(self, in_channel=1, out_channel=1):
super(Unet, self).__init__()
# Descending branch
self.conv_encode1 = self.contract(in_channels=in_channel, out_channels=16)
self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode2 = self.contract(16, 32)
self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
self.conv_encode3 = self.contract(32, 64)
self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
# Bottleneck
self.bottleneck = self.bottle_neck(64)
# Decode branch
self.conv_decode4 = self.expans(64, 64, 64)
self.conv_decode3 = self.expans(128, 64, 32)
self.conv_decode2 = self.expans(64, 32, 16)
self.final_layer = self.final_block(32, 16, out_channel)
[docs]
def contract(self, in_channels, out_channels, kernel_size=3, padding=1):
block = torch.nn.Sequential(
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=in_channels,
out_channels=out_channels,
padding=padding,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=out_channels,
out_channels=out_channels,
padding=padding,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(out_channels),
)
return block
[docs]
def expans(self, in_channels, mid_channel, out_channels, kernel_size=3, padding=1):
block = torch.nn.Sequential(
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=in_channels,
out_channels=mid_channel,
padding=padding,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=mid_channel,
out_channels=mid_channel,
padding=padding,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.ConvTranspose2d(
in_channels=mid_channel,
out_channels=out_channels,
kernel_size=kernel_size,
stride=2,
padding=padding,
output_padding=1,
),
)
return block
[docs]
def concat(self, upsampled, bypass):
out = torch.cat((upsampled, bypass), 1)
return out
[docs]
def bottle_neck(self, in_channels, kernel_size=3, padding=1):
bottleneck = torch.nn.Sequential(
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=in_channels,
out_channels=2 * in_channels,
padding=padding,
),
torch.nn.ReLU(),
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=2 * in_channels,
out_channels=in_channels,
padding=padding,
),
torch.nn.ReLU(),
)
return bottleneck
[docs]
def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
block = torch.nn.Sequential(
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=in_channels,
out_channels=mid_channel,
padding=1,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=mid_channel,
out_channels=mid_channel,
padding=1,
),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(mid_channel),
torch.nn.Conv2d(
kernel_size=kernel_size,
in_channels=mid_channel,
out_channels=out_channels,
padding=1,
),
)
return block
[docs]
def forward(self, x):
# Encode
encode_block1 = self.conv_encode1(x)
x = self.conv_maxpool1(encode_block1)
encode_block2 = self.conv_encode2(x)
x = self.conv_maxpool2(encode_block2)
encode_block3 = self.conv_encode3(x)
x = self.conv_maxpool3(encode_block3)
# Bottleneck
x = self.bottleneck(x)
# Decode
x = self.conv_decode4(x)
x = self.concat(x, encode_block3)
x = self.conv_decode3(x)
x = self.concat(x, encode_block2)
x = self.conv_decode2(x)
x = self.concat(x, encode_block1)
x = self.final_layer(x)
return x
# ===========================================================================================
[docs]
class ConvNet(nn.Module):
# ===========================================================================================
def __init__(self):
super(ConvNet, self).__init__()
self.convnet = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)),
("relu1", nn.ReLU()),
("conv2", nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0)),
("relu2", nn.ReLU()),
("conv3", nn.Conv2d(32, 1, kernel_size=5, stride=1, padding=2)),
]
)
)
[docs]
def forward(self, x):
x = self.convnet(x)
return x
# ===========================================================================================
[docs]
class ConvNetBN(nn.Module):
# ===========================================================================================
def __init__(self):
super(ConvNetBN, self).__init__()
self.convnet = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)),
("relu1", nn.ReLU()),
("BN1", nn.BatchNorm2d(64)),
("conv2", nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0)),
("relu2", nn.ReLU()),
("BN2", nn.BatchNorm2d(32)),
("conv3", nn.Conv2d(32, 1, kernel_size=5, stride=1, padding=2)),
]
)
)
[docs]
def forward(self, x):
x = self.convnet(x)
return x
# ===========================================================================================
[docs]
class DConvNet(nn.Module):
# ===========================================================================================
def __init__(self):
super(DConvNet, self).__init__()
self.convnet = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 64, kernel_size=9, stride=1, padding=4)),
("relu1", nn.ReLU()),
("BN1", nn.BatchNorm2d(64)),
("conv2", nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0)),
("relu2", nn.ReLU()),
("BN2", nn.BatchNorm2d(64)),
("conv3", nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)),
("relu3", nn.ReLU()),
("BN3", nn.BatchNorm2d(32)),
("conv4", nn.Conv2d(32, 1, kernel_size=5, stride=1, padding=2)),
]
)
)
[docs]
def forward(self, x):
x = self.convnet(x)
return x
# ===========================================================================================
[docs]
class List_denoi(nn.Module):
# ===========================================================================================
def __init__(self, Denoi, n_denoi):
super(List_denoi, self).__init__()
self.n_denoi = n_denoi
conv_list = []
for i in range(n_denoi):
conv_list.append(copy.deepcopy(Denoi))
self.conv = nn.ModuleList(conv_list)
[docs]
def forward(self, x, iterate):
index = min(iterate, self.n_denoi - 1)
x = self.conv[index](x)
return x
# ===========================================================================================
[docs]
class Identity(nn.Module): # Can be useful for ablation study
# ===========================================================================================
def __init__(self):
super(Identity, self).__init__()
[docs]
def forward(self, x):
return x