# -----------------------------------------------------------------------------
# This software is distributed under the terms
# of the GNU Lesser General Public Licence (LGPL)
# See LICENSE.md for further details
# -----------------------------------------------------------------------------
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from PIL import Image
import numpy as np
from numpy import linalg as LA
import time
from scipy import signal
from scipy import misc
from scipy import sparse
import torch
import math
[docs]
def display_vid(video, fps, title="", colormap=plt.cm.gray):
"""
video is a numpy array of shape [nb_frames, 1, nx, ny]
"""
plt.ion()
(nb_frames, channels, nx, ny) = video.shape
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
for i in range(nb_frames):
current_frame = video[i, 0, :, :]
plt.imshow(current_frame, cmap=colormap)
plt.title(title)
divider = make_axes_locatable(ax)
cax = plt.axes([0.85, 0.1, 0.075, 0.8])
plt.colorbar(cax=cax)
plt.show()
plt.pause(fps)
plt.ioff()
[docs]
def display_rgb_vid(video, fps, title=""):
"""
video is a numpy array of shape [nb_frames, 3, nx, ny]
"""
plt.ion()
(nb_frames, channels, nx, ny) = video.shape
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
for i in range(nb_frames):
current_frame = video[i, :, :, :]
current_frame = np.moveaxis(current_frame, 0, -1)
plt.imshow(current_frame)
plt.title(title)
plt.show()
plt.pause(fps)
plt.ioff()
[docs]
def fitPlots(N, aspect=(16, 9)):
width = aspect[0]
height = aspect[1]
area = width * height * 1.0
factor = (N / area) ** (1 / 2.0)
cols = math.floor(width * factor)
rows = math.floor(height * factor)
rowFirst = width < height
while rows * cols < N:
if rowFirst:
rows += 1
else:
cols += 1
rowFirst = not (rowFirst)
return rows, cols
[docs]
def Multi_plots(
img_list,
title_list,
shape,
suptitle="",
colormap=plt.cm.gray,
axis_off=True,
aspect=(16, 9),
savefig="",
fontsize=14,
):
[rows, cols] = shape
plt.figure()
plt.suptitle(suptitle, fontsize=16)
if (len(img_list) < rows * cols) or (len(title_list) < rows * cols):
for k in range(max(rows * cols - len(img_list), rows * cols - len(title_list))):
img_list.append(np.zeros((64, 64)))
title_list.append("")
for k in range(rows * cols):
ax = plt.subplot(rows, cols, k + 1)
ax.imshow(img_list[k], cmap=colormap)
ax.set_title(title_list[k], fontsize=fontsize)
if axis_off:
plt.axis("off")
if savefig:
plt.savefig(savefig, bbox_inches="tight")
plt.show()
[docs]
def compare_video_frames(
vid_list,
nb_disp_frames,
title_list,
suptitle="",
colormap=plt.cm.gray,
aspect=(16, 9),
savefig="",
fontsize=14,
):
rows = len(vid_list)
cols = nb_disp_frames
plt.figure(figsize=aspect)
plt.suptitle(suptitle, fontsize=16)
for i in range(rows):
for j in range(cols):
k = (j + 1) + (i) * (cols)
i
# print(k)
ax = plt.subplot(rows, cols, k)
# print("i = {}, j = {}".format(i,j))
ax.imshow(vid_list[i][0, j, 0, :, :], cmap=colormap)
ax.set_title(title_list[i][j], fontsize=fontsize)
plt.axis("off")
if savefig:
plt.savefig(savefig, bbox_inches="tight")
plt.show()
[docs]
def torch2numpy(torch_tensor):
return torch_tensor.cpu().detach().numpy()
[docs]
def uint8(dsp):
x = (dsp - np.amin(dsp)) / (np.amax(dsp) - np.amin(dsp)) * 255
x = x.astype("uint8")
return x
[docs]
def imagesc(
Img,
title="",
colormap=plt.cm.gray,
show=True,
figsize=None,
cbar_pos=None,
title_fontsize=16,
):
"""
imagesc(IMG) Display image Img with scaled colors with greyscale
colormap and colorbar
imagesc(IMG, title=ttl) Display image Img with scaled colors with
greyscale colormap and colorbar, with the title ttl
imagesc(IMG, title=ttl, colormap=cmap) Display image Img with scaled colors
with colormap and colorbar specified by cmap (choose between 'plasma',
'jet', and 'grey'), with the title ttl
"""
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(1, 1, 1)
plt.imshow(Img, cmap=colormap)
plt.title(title, fontsize=title_fontsize)
divider = make_axes_locatable(ax)
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
if cbar_pos == "bottom":
cax = inset_axes(
ax, width="100%", height="5%", loc="lower center", borderpad=-5
)
plt.colorbar(cax=cax, orientation="horizontal")
else:
cax = plt.axes([0.85, 0.1, 0.075, 0.8])
plt.colorbar(cax=cax, orientation="vertical")
# fig.tight_layout() # it raises warnings in some cases
if show is True:
plt.show()
[docs]
def imagecomp(
Img1,
Img2,
suptitle="",
title1="",
title2="",
colormap1=plt.cm.gray,
colormap2=plt.cm.gray,
):
f, (ax1, ax2) = plt.subplots(1, 2)
im1 = ax1.imshow(Img1, cmap=colormap1)
ax1.set_title(title1)
cax = plt.axes([0.43, 0.3, 0.025, 0.4])
plt.colorbar(im1, cax=cax)
plt.suptitle(suptitle, fontsize=16)
#
im2 = ax2.imshow(Img2, cmap=colormap2)
ax2.set_title(title2)
cax = plt.axes([0.915, 0.3, 0.025, 0.4])
plt.colorbar(im2, cax=cax)
plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9)
plt.show()
[docs]
def imagepanel(
Img1,
Img2,
Img3,
Img4,
suptitle="",
title1="",
title2="",
title3="",
title4="",
colormap1=plt.cm.gray,
colormap2=plt.cm.gray,
colormap3=plt.cm.gray,
colormap4=plt.cm.gray,
):
fig, axarr = plt.subplots(2, 2, figsize=(20, 10))
plt.suptitle(suptitle, fontsize=16)
im1 = axarr[0, 0].imshow(Img1, cmap=colormap1)
axarr[0, 0].set_title(title1)
cax = plt.axes([0.4, 0.54, 0.025, 0.35])
plt.colorbar(im1, cax=cax)
im2 = axarr[0, 1].imshow(Img2, cmap=colormap2)
axarr[0, 1].set_title(title2)
cax = plt.axes([0.90, 0.54, 0.025, 0.35])
plt.colorbar(im2, cax=cax)
im3 = axarr[1, 0].imshow(Img3, cmap=colormap3)
axarr[1, 0].set_title(title3)
cax = plt.axes([0.4, 0.12, 0.025, 0.35])
plt.colorbar(im3, cax=cax)
im4 = axarr[1, 1].imshow(Img4, cmap=colormap4)
axarr[1, 1].set_title(title4)
cax = plt.axes([0.9, 0.12, 0.025, 0.35])
plt.colorbar(im4, cax=cax)
plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9)
plt.show()
[docs]
def plot(x, y, title="", xlabel="", ylabel="", color="black"):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
plt.plot(x, y, color=color)
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.show()
[docs]
def add_colorbar(mappable, position="right", size="5%"):
"""
Example:
f, axs = plt.subplots(1, 2)
im = axs[0].imshow(img1, cmap='gray')
add_colorbar(im)
im = axs[0].imshow(img2, cmap='gray')
add_colorbar(im)
"""
if position == "bottom":
orientation = "horizontal"
else:
orientation = "vertical"
last_axes = plt.gca()
ax = mappable.axes
fig = ax.figure
divider = make_axes_locatable(ax)
cax = divider.append_axes(position, size="5%", pad=0.05)
cbar = fig.colorbar(mappable, cax=cax, orientation=orientation)
plt.sca(last_axes)
return cbar
[docs]
def noaxis(axs):
if type(axs) is np.ndarray:
for ax in axs:
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
else:
axs.get_xaxis().set_visible(False)
axs.get_yaxis().set_visible(False)
[docs]
def string_mean_std(x, prec=3):
return "{:.{p}f} +/- {:.{p}f}".format(np.mean(x), np.std(x), p=prec)
[docs]
def print_mean_std(x, tag="", prec=3):
print("{} = {:.{p}f} +/- {:.{p}f}".format(tag, np.mean(x), np.std(x), p=prec))
[docs]
def histogram(s):
count, bins, ignored = plt.hist(s, 30, density=True)
plt.show()
[docs]
def vid2batch(root, img_dim, start_frame, end_frame):
from imutils.video import FPS
import imutils
import cv2
stream = cv2.VideoCapture(root)
fps = FPS().start()
frame_nb = 0
output_batch = torch.zeros(1, end_frame - start_frame, 1, img_dim, img_dim)
while True:
(grabbed, frame) = stream.read()
if not grabbed:
break
frame_nb += 1
if (frame_nb >= start_frame) & (frame_nb < end_frame):
frame = cv2.resize(frame, (img_dim, img_dim))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB)
output_batch[0, frame_nb - start_frame, 0, :, :] = torch.Tensor(
frame[:, :, 1]
)
return output_batch
[docs]
def pre_process_video(video, crop_patch, kernel_size):
import cv2
batch_size, seq_length, c, h, w = video.shape
batched_frames = video.reshape(batch_size * seq_length * c, h, w)
output_batch = torch.zeros(batched_frames.shape)
for i in range(batch_size * seq_length * c):
img = torch2numpy(batched_frames[i, :, :])
img[crop_patch] = 0
median_frame = cv2.medianBlur(img, kernel_size)
output_batch[i, :, :] = torch.Tensor(median_frame)
output_batch = output_batch.reshape(batch_size, seq_length, c, h, w)
return output_batch