Source code for spyrit.misc.pattern_choice

# -----------------------------------------------------------------------------
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from PIL import Image
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy
from abc import ABC, abstractmethod
import pywt

########################################################################
# 1. Define Abstract Pattern Class
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# All Pattern Classes (Basis, Custom and optimized will)
# inherit this abstract class. The only real abstract
# method being the set_desired_patterns, and add_
# desired pattern.


[docs] class Patterns(ABC): def __init__(self, img_size, method="split", binarized=False, par=2, lvl=1, dyn=8): super().__init__() self.n = img_size self.method = method self.binarized = binarized self.par = par self.lvl = lvl self.dyn = dyn self.Q = nn.Conv2d(1, img_size, kernel_size=img_size, stride=1, padding=0) self.T = np.zeros((img_size, 2 * img_size)) self.P = nn.Conv2d(1, 2 * img_size, kernel_size=img_size, stride=1, padding=0) self.steps = [0]
[docs] def get_desired_pattern(self): return self.Q[self.start, 1, :, :]
[docs] def get_all_desired_pattern(self): return self.Q
[docs] def get_measurement_matrix(self): return self.P, self.T
[docs] def set_measurement_matrix(self): next_P, next_T = eval(self.method + "(self.Q, self.dyn)") self.P = next_P self.T = next_T
[docs] def save_measurement_matrix(self, root): K = self.P.bias.shape[0] np.save(root + "T.npy", self.T) for i in range(K): pat = self.P.weight[i, 0, :, :] pattern = pat.cpu().detach().numpy() im = Image.fromarray(pattern) im.save(root + "pat_{}x{}".format(i) + ".png")
[docs] @abstractmethod def set_desired_pattern(self, def_matrix): pass
[docs] @abstractmethod def add_desired_patterns(self, def_matrix): pass
######################################################################## # 2. Define children of that abstract class # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # All Pattern Classes (Basis, Custom and optimized will) # inherit the abstract class Patterns and implement the two main # abstract methods #
[docs] class Basis_patterns(Patterns): def __init__(self, img_size, basis, method="split", binarized=False): super(Basis_patterns, self).__init__( self, img_size, method="split", binarized=False ) self.basis = basis self.img_size = img_size self.indexes = np.zeros((img_size, img_size)) self.cumulated_indexes = np.zeros((img_size, img_size))
[docs] def add_desired_pattern(self, def_matrix): temp = np.ones((self.img_size, self.img_size)) self.indexes += def_matrix self.cumulated += def_matrix + temp(self.cumulated > 0) I_old = self.Q.bias.shape[0] Q = eval(self.basis + "(def_matrix, self.par, self.lvl)") I_add = int(np.sum(def_matrix)) I_new = I_old + I_add next_Q = nn.Conv2d(1, I_new, kernel_size=img_size, stride=1, padding=0) next_Q.bias = torch.zeros(I_new) next_Q.weight[:I_old, :, :, :] = self.Q.weight next_Q.weight[I_old:, :, :, :] = Q.weight next_Q.bias.requires_grad = False next_Q.weight.requires_grad = False self.Q = next_Q self.steps.append(I_old)
[docs] def set_desired_pattern(self, def_matrix): self.indexes = def_matrix self.cumulated = def_matrix next_Q = eval(basis + "(def_matrix, self.par, self.lvl)") self.Q = next_Q self.steps = [0]
[docs] class Custom_patterns(Patterns): def __init__(self, img_size, Q, method="split", binarized=False): super(Basis_patterns, self).__init__( self, img_size, method="split", binarized=False ) self.img_size = img_size self.Q = Q
[docs] def add_desired_pattern(self, Q): I_old = self.Q.bias.shape[0] I_add = Q.bias.shape[0] I_new = I_old + I_add next_Q = nn.Conv2d(1, I_new, kernel_size=img_size, stride=1, padding=0) next_Q.bias = torch.zeros(I_new) next_Q.weight[:I_old, :, :, :] = self.Q.weight next_Q.weight[I_old:, :, :, :] = Q.weight next_Q.bias.requires_grad = False next_Q.weight.requires_grad = False self.Q = next_Q self.steps.append(I_old)
[docs] def set_desired_pattern(self, Q): self.Q = Q self.steps = [0]
[docs] class Optimized_patterns(Patterns): def __init__(self, img_size, basis, M, method="split", binarized=False): super(Basis_patterns, self).__init__( self, img_size, method="split", binarized=False ) self.basis = basis self.img_size = img_size self.M = M
[docs] def add_desired_pattern(self, M_prim): I_old = self.Q.bias.shape[0] Q = eval(self.basis + "_opt(M_prim, self.par, self.lvl)") I_add = M_prim I_new = I_old + I_add next_Q = nn.Conv2d(1, I_new, kernel_size=img_size, stride=1, padding=0) next_Q.bias = torch.zeros(I_new) next_Q.weight[:I_old, :, :, :] = self.Q.weight next_Q.weight[I_old:, :, :, :] = Q.weight next_Q.bias.requires_grad = False next_Q.weight.requires_grad = False self.Q = next_Q self.steps.append(I_old)
[docs] def set_desired_pattern(self, M): next_Q = eval(basis + "_opt(M, self.par, self.lvl)") self.Q = next_Q self.steps = [0]
######################################################################## # 3. Define functions for basis pattern # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Return a convolution filter that contains all the basis # functions of a given transform. #
[docs] def Fourier(def_matrix, par=0, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape Q = nn.Conv2d(1, 2 * I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(2 * I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 pat = cv2.dft(np.float32(Z), flags=cv2.DFT_COMPLEX_OUTPUT) pat_real = pat[:, :, 0] pat_img = pat[:, :, 1] Z[i, j] = 0 Q.weight.data[2 * index, 0, :, :] = pat_real Q.weight.data[2 * index + 1, 0, :, :] = pat_img Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
[docs] def Hadamard(def_matrix, par=0, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape Q = nn.Conv2d(1, I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 # pat = torch.from_numpy(fht2(Z)); Z[i, j] = 0 Q.weight.data[index, 0, :, :] = pat Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
[docs] def Haar(def_matrix, par=0, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape wave = "Haar" md = "periodization" Q = nn.Conv2d(1, I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 pat_coefs = pywt.wavedec2(Z, wave, mode=md, level=lvl) [temp, arr] = pywt.coeffs_to_array(pat_coefs) pat_temp = pywt.array_to_coeffs(Z, arr, output_format="wavedec2") pat = pywt.waverec2(pat_temp, wave) pat = torch.from_numpy(pat) Z[i, j] = 0 Q.weight.data[index, 0, :, :] = pat Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
[docs] def Daubechies(def_matrix, par=2, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape wave = "db" + str(par) md = "periodization" Q = nn.Conv2d(1, I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 pat_coefs = pywt.wavedec2(Z, wave, mode=md, level=lvl) [temp, arr] = pywt.coeffs_to_array(pat_coefs) pat_temp = pywt.array_to_coeffs(Z, arr, output_format="wavedec2") pat = pywt.waverec2(pat_temp, wave) pat = torch.from_numpy(pat) Z[i, j] = 0 Q.weight.data[index, 0, :, :] = pat Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
######################################################################## # 3. Define functions for optimized basis pattern # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Return a convolution filter that contains all the basis # funtions obtained through a training phase. #
[docs] def Fourier_opt(M, par=0, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape Q = nn.Conv2d(1, 2 * I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(2 * I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 pat = cv2.dft(np.float32(Z), flags=cv2.DFT_COMPLEX_OUTPUT) pat_real = pat[:, :, 0] pat_img = pat[:, :, 1] Z[i, j] = 0 Q.weight.data[2 * index, 0, :, :] = pat_real Q.weight.data[2 * index + 1, 0, :, :] = pat_img Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
[docs] def Hadamard_opt(M, par=0, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape Q = nn.Conv2d(1, I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 # pat = torch.from_numpy(fht2(Z)); Z[i, j] = 0 Q.weight.data[index, 0, :, :] = pat Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
[docs] def Haar_opt(M, par=0, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape wave = "Haar" md = "periodization" Q = nn.Conv2d(1, I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 pat_coefs = pywt.wavedec2(Z, wave, mode=md, level=lvl) [temp, arr] = pywt.coeffs_to_array(pat_coefs) pat_temp = pywt.array_to_coeffs(Z, arr, output_format="wavedec2") pat = pywt.waverec2(pat_temp, wave) pat = torch.from_numpy(pat) Z[i, j] = 0 Q.weight.data[index, 0, :, :] = pat Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
[docs] def Daubechies_opt(M, par=2, lvl=1): I = int(np.sum(def_matrix)) nx, ny = def_matrix.shape wave = "db" + str(par) md = "periodization" Q = nn.Conv2d(1, I, kernel_size=nx, stride=1, padding=0) Q.bias.data = torch.zeros(I) ind = np.nonzero(def_matrix) Z = np.zeros((nx, ny)) for index in range(I): i = ind[0][index] j = ind[1][index] Z[i, j] = 1 pat_coefs = pywt.wavedec2(Z, wave, mode=md, level=lvl) [temp, arr] = pywt.coeffs_to_array(pat_coefs) pat_temp = pywt.array_to_coeffs(Z, arr, output_format="wavedec2") pat = pywt.waverec2(pat_temp, wave) pat = torch.from_numpy(pat) Z[i, j] = 0 Q.weight.data[index, 0, :, :] = pat Q.bias.requires_grad = False Q.weight.requires_grad = False return Q
######################################################################## # 4. Define functions for pattern splitting and shifting # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # Implements most used splitting and shifting models #
[docs] def split(Q, dyn): I = Q.bias.shape[0] img_size = Q.weight.shape[-1] P = nn.Conv2d(1, 2 * I, kernel_size=img_size, stride=1, padding=0) T_matrix = torch.zeros((I, 2 * I)) T = nn.Linear(2 * I, I, bias=False) for i in range(I): pat = Q.weight.data[i, 0, :, :] pat = pat.cpu().detach().numpy() pat_pos = np.zeros((img_size, img_size)) pat_neg = np.zeros((img_size, img_size)) pat_pos[pat > 0] = pat[pat > 0] pat_neg[pat < 0] = -pat[pat < 0] max_pos = np.max(pat_pos) max_neg = np.max(pat_neg) pat_pos = (2**dyn - 1) * pat_pos / max_pos if max_pos != 0 else pat_pos pat_neg = (2**dyn - 1) * pat_neg / max_neg if max_neg != 0 else pat_neg T_matrix[i, 2 * i] = max_pos / (2**dyn - 1) T_matrix[i, 2 * i + 1] = -max_neg / (2**dyn - 1) P.weight.data[2 * i, 0, :, :] = torch.from_numpy(pat_pos) P.weight.data[2 * i + 1, 0, :, :] = torch.from_numpy(pat_neg) P.bias.requires_grad = False P.weight.requires_grad = False T.weight.data = T_matrix T.weight.data = T.weight.data.float() T.weight.requires_grad = False return P, T
[docs] def shift(Q, dyn): I = Q.bias.shape[0] img_size = Q.weight.shape[-1] P = nn.Conv2d(1, I + 1, kernel_size=img_size, stride=1, padding=0) T = np.zeros((I, I + 1)) for i in range(I): pat = Q.weight.data[i, 0, :, :] pat = pat.cpu().detach().numpy() dc_val = np.min(pat) pat_ac = pat - dc_val max_ac = np.max(pat_ac) pat_ac = (2**dyn - 1) * pat_ac / max_ac T[i, i] = max_ac / (2**dyn - 1) T[i, I + 1] = -max_ac / (2**dyn - 1) P.weight.data[i, 0, :, :] = torch.from_numpy(pat_ac) P.weight.data[I, 0, :, :] = (2**dyn - 1) * torch.ones(img_size, img_size) P.bias.requires_grad = False P.weight.requires_grad = False return P, T
[docs] def matrix2conv(Matrix): """ Returns Convulution filter che each kernel correponds to a line of Matrix, that has been reshaped """ M, N = Matrix.shape img_size = int(round(np.sqrt(N))) P = nn.Conv2d(1, M, kernel_size=img_size, stride=1, padding=0) P.bias.data = torch.zeros(M) for i in range(M): pattern = np.reshape(Matrix[i, :], (img_size, img_size)) P.weight.data[i, 0, :, :] = torch.from_numpy(pattern) P.bias.requires_grad = False P.weight.requires_grad = False return P