# -----------------------------------------------------------------------------
# This software is distributed under the terms
# of the GNU Lesser General Public Licence (LGPL)
# See LICENSE.md for further details
# -----------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torch.nn.functional as F
import imageio
import matplotlib.pyplot as plt
# import skimage.metrics as skm
[docs]
def batch_psnr(torch_batch, output_batch):
list_psnr = []
for i in range(torch_batch.shape[0]):
img = torch_batch[i, 0, :, :]
img_out = output_batch[i, 0, :, :]
img = img.cpu().detach().numpy()
img_out = img_out.cpu().detach().numpy()
list_psnr.append(psnr(img, img_out))
return list_psnr
[docs]
def batch_psnr_(torch_batch, output_batch, r=2):
list_psnr = []
for i in range(torch_batch.shape[0]):
img = torch_batch[i, 0, :, :]
img_out = output_batch[i, 0, :, :]
img = img.cpu().detach().numpy()
img_out = img_out.cpu().detach().numpy()
list_psnr.append(psnr_(img, img_out, r=r))
return list_psnr
[docs]
def batch_ssim(torch_batch, output_batch):
list_ssim = []
for i in range(torch_batch.shape[0]):
img = torch_batch[i, 0, :, :]
img_out = output_batch[i, 0, :, :]
img = img.cpu().detach().numpy()
img_out = img_out.cpu().detach().numpy()
list_ssim.append(ssim(img, img_out))
return list_ssim
[docs]
def dataset_meas(dataloader, model, device):
meas = []
for inputs, labels in dataloader:
inputs = inputs.to(device)
# with torch.no_grad():
b, c, h, w = inputs.shape
net_output = model.acquire(inputs, b, c, h, w)
raw = net_output[:, 0, :]
raw = raw.cpu().detach().numpy()
meas.extend(raw)
return meas
#
# def dataset_psnr_different_measures(dataloader, model, model_2, device):
# psnr = [];
# #psnr_fc = [];
# for inputs, labels in dataloader:
# inputs = inputs.to(device)
# m = model_2.normalized measure(inputs);
# net_output = model.forward_reconstruct(inputs);
# #net_output2 = model.evaluate_fcl(inputs);
#
# psnr += batch_psnr(inputs, net_output);
# #psnr_fc += batch_psnr(inputs, net_output2);
# psnr = np.array(psnr);
# #psnr_fc = np.array(psnr_fc);
# return psnr;
#
[docs]
def dataset_psnr(dataloader, model, device):
psnr = []
psnr_fc = []
for inputs, labels in dataloader:
inputs = inputs.to(device)
# with torch.no_grad():
# b,c,h,w = inputs.shape;
net_output = model.evaluate(inputs)
net_output2 = model.evaluate_fcl(inputs)
psnr += batch_psnr(inputs, net_output)
psnr_fc += batch_psnr(inputs, net_output2)
psnr = np.array(psnr)
psnr_fc = np.array(psnr_fc)
return psnr, psnr_fc
[docs]
def dataset_ssim(dataloader, model, device):
ssim = []
ssim_fc = []
for inputs, labels in dataloader:
inputs = inputs.to(device)
# evaluate full model and fully connected layer
net_output = model.evaluate(inputs)
net_output2 = model.evaluate_fcl(inputs)
# compute SSIM and concatenate
ssim += batch_ssim(inputs, net_output)
ssim_fc += batch_ssim(inputs, net_output2)
ssim = np.array(ssim)
ssim_fc = np.array(ssim_fc)
return ssim, ssim_fc
[docs]
def dataset_psnr_ssim(dataloader, model, device):
# init lists
psnr = []
ssim = []
# loop over batches
for inputs, labels in dataloader:
inputs = inputs.to(device)
# evaluate full model
net_output = model.evaluate(inputs)
# compute PSNRs and concatenate
psnr += batch_psnr(inputs, net_output)
# compute SSIMs and concatenate
ssim += batch_ssim(inputs, net_output)
# convert
psnr = np.array(psnr)
ssim = np.array(ssim)
return psnr, ssim
[docs]
def dataset_psnr_ssim_fcl(dataloader, model, device):
# init lists
psnr = []
ssim = []
# loop over batches
for inputs, labels in dataloader:
inputs = inputs.to(device)
# evaluate fully connected layer
net_output = model.evaluate_fcl(inputs)
# compute PSNRs and concatenate
psnr += batch_psnr(inputs, net_output)
# compute SSIMs and concatenate
ssim += batch_ssim(inputs, net_output)
# convert
psnr = np.array(psnr)
ssim = np.array(ssim)
return psnr, ssim
[docs]
def psnr(I1, I2):
"""
Computes the psnr between two images I1 and I2
"""
d = np.amax(I1) - np.amin(I1)
diff = np.square(I2 - I1)
MSE = diff.sum() / I1.size
Psnr = 10 * np.log(d**2 / MSE) / np.log(10)
return Psnr
[docs]
def psnr_(img1, img2, r=2):
"""
Computes the psnr between two image with values expected in a given range
Args:
img1, img2 (np.ndarray): images
r (float): image range
Returns:
Psnr (float): Peak signal-to-noise ratio
"""
MSE = np.mean((img1 - img2) ** 2)
Psnr = 10 * np.log(r**2 / MSE) / np.log(10)
return Psnr
[docs]
def psnr_torch(img_gt, img_rec, mask=None, dim=(-2, -1), img_dyn=None):
r"""
Computes the Peak Signal-to-Noise Ratio (PSNR) between two images.
.. math::
\text{PSNR} = 20 \, \log_{10} \left( \frac{\text{d}}{\sqrt{\text{MSE}}} \right), \\
\text{MSE} = \frac{1}{L}\sum_{\ell=1}^L \|I_\ell - \tilde{I}_\ell\|^2_2,
where :math:`d` is the image dynamic and :math:`\{I_\ell\}` (resp. :math:`\{\tilde{I}_\ell\}`) is the set of ground truth (resp. reconstructed) images.
Args:
:attr:`img_gt`: Tensor containing the *ground-truth* image.
:attr:`img_rec`: Tensor containing the reconstructed image.
:attr:`mask`: Mask where the squared error is computed. Defaults :attr:`None`, i.e., no mask is considered.
:attr:`dim`: Dimensions where the squared error is computed. If mask is :attr:`None`, defaults to :attr:`-1` (i.e., the last dimension). Othewise defaults to :attr:`(-2,-1)` (i.e., the last two dimensions).
:attr:`img_dyn`: Image dynamic range (e.g., 1.0 for normalized images, 255 for 8-bit images). When :attr:`img_dyn` is :attr:`None`, the dynamic range is computed from the ground-truth image.
Returns:
PSNR value.
.. note::
:attr:`psnr_torch(img_gt, img_rec)` is different from :attr:`psnr_torch(img_rec, img_gt)`. The first expression assumes :attr:`img_gt` is the ground truth while the second assumes that this is :attr:`img_rec`. This leads to different dynamic ranges.
Example 1: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise
>>> x = torch.rand(10,1,64,64)
>>> n = x + 0.05*torch.randn(x.shape)
>>> out = psnr_torch(x,n)
>>> print(out.shape)
torch.Size([10, 1])
Example 2: 10 images of size 64x64 with values in [0,1) corrupted with 5% noise
>>> psnr_torch(n,x)
tensor(...)
>>> psnr_torch(x,n)
tensor(...)
>>> psnr_torch(n,x,img_dyn=1.0)
tensor(...)
"""
if mask is not None:
dim = -1
img_gt = img_gt[mask > 0]
img_rec = img_rec[mask > 0]
print("mask")
mse = (img_gt - img_rec) ** 2
mse = torch.mean(mse, dim=dim)
if img_dyn is None:
img_dyn = torch.amax(img_gt, dim=dim) - torch.amin(img_gt, dim=dim)
return 10 * torch.log10(img_dyn**2 / mse)
[docs]
def ssim(I1, I2):
"""
Computes the ssim between two images I1 and I2
"""
L = np.amax(I1) - np.amin(I1)
mu1 = np.mean(I1)
mu2 = np.mean(I2)
s1 = np.std(I1)
s2 = np.std(I2)
s12 = np.mean(np.multiply((I1 - mu1), (I2 - mu2)))
c1 = (0.01 * L) ** 2
c2 = (0.03 * L) ** 2
result = ((2 * mu1 * mu2 + c1) * (2 * s12 + c2)) / (
(mu1**2 + mu2**2 + c1) * (s1**2 + s2**2 + c2)
)
return result
# def ssim_sk(x_gt, x, img_dyn=None):
# """
# SSIM from skimage
# Args:
# torch tensors
# Returns:
# torch tensor
# """
# if not isinstance(x, np.ndarray):
# x = x.cpu().detach().numpy().squeeze()
# x_gt = x_gt.cpu().detach().numpy().squeeze()
# ssim_val = np.zeros(x.shape[0])
# for i in range(x.shape[0]):
# ssim_val[i] = skm.structural_similarity(x_gt[i], x[i], data_range=img_dyn)
# return torch.tensor(ssim_val)
[docs]
def batch_psnr_vid(input_batch, output_batch):
list_psnr = []
batch_size, seq_length, c, h, w = input_batch.shape
input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w)
output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w)
for i in range(input_batch.shape[0]):
img = input_batch[i, 0, :, :]
img_out = output_batch[i, 0, :, :]
img = img.cpu().detach().numpy()
img_out = img_out.cpu().detach().numpy()
list_psnr.append(psnr(img, img_out))
return list_psnr
[docs]
def batch_ssim_vid(input_batch, output_batch):
list_ssim = []
batch_size, seq_length, c, h, w = input_batch.shape
input_batch = input_batch.reshape(batch_size * seq_length * c, 1, h, w)
output_batch = output_batch.reshape(batch_size * seq_length * c, 1, h, w)
for i in range(input_batch.shape[0]):
img = input_batch[i, 0, :, :]
img_out = output_batch[i, 0, :, :]
img = img.cpu().detach().numpy()
img_out = img_out.cpu().detach().numpy()
list_ssim.append(ssim(img, img_out))
return list_ssim
[docs]
def compare_video_nets_supervised(net_list, testloader, device):
psnr = [[] for i in range(len(net_list))]
ssim = [[] for i in range(len(net_list))]
for batch, (inputs, labels) in enumerate(testloader):
[batch_size, seq_length, c, h, w] = inputs.shape
print("Batch :{}/{}".format(batch + 1, len(testloader)))
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
for i in range(len(net_list)):
outputs = net_list[i].evaluate(inputs)
psnr[i] += batch_psnr_vid(labels, outputs)
ssim[i] += batch_ssim_vid(labels, outputs)
return psnr, ssim
[docs]
def compare_nets_unsupervised(net_list, testloader, device):
psnr = [[] for i in range(len(net_list))]
ssim = [[] for i in range(len(net_list))]
for batch, (inputs, labels) in enumerate(testloader):
[batch_size, seq_length, c, h, w] = inputs.shape
print("Batch :{}/{}".format(batch + 1, len(testloader)))
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
for i in range(len(net_list)):
outputs = net_list[i].evaluate(inputs)
psnr[i] += batch_psnr_vid(outputs, labels)
ssim[i] += batch_ssim_vid(outputs, labels)
return psnr, ssim
[docs]
def print_mean_std(x, tag=""):
print("{}psnr = {} +/- {}".format(tag, np.mean(x), np.std(x)))