Source code for spyrit.core.train

# -----------------------------------------------------------------------------
#   This software is distributed under the terms
#   of the GNU Lesser General  Public Licence (LGPL)
#   See LICENSE.md for further details
# -----------------------------------------------------------------------------

"""
Training functions for deep learning models.
"""

from __future__ import print_function, division
import sys
import os
import time
import datetime
import copy
import pickle
from collections import OrderedDict
import re

# from pathlib import Path
# import statistics

import torch
import torch.nn as nn
import numpy as np
import torchvision
import matplotlib.pyplot as plt

# import torch.optim as optim
# from torchvision import datasets, models, transforms


######################################################################
# 1. Visualize a few images from the training set
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's visualize a few training images so as to understand the data
# augmentations.


[docs] def imshow(img, title=""): """ """ plt.ion() # interactive mode img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() if title is not None: plt.title(title) plt.pause(0.0001)
######################################################################## # 2. Train the network # ^^^^^^^^^^^^^^^^^^^^ # # We loop over our data iterator, feed the inputs to the # network and optimize.
[docs] def count_trainable_param(model): n_param = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of trainable parameters: {n_param}") return n_param
[docs] def count_param(model): n_param = sum(p.numel() for p in model.parameters()) print(f"Total number of parameters: {n_param}") return n_param
[docs] def count_memory(model): mem_params = sum([p.nelement() * p.element_size() for p in model.parameters()]) mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) mem = mem_params + mem_bufs print(f"Memory requirement: {mem} bytes") return mem
[docs] def images_norm(images): return 0.5 * (images + 1)
[docs] def tb_writer_init(tb_path, samples=None): """Tensorboard log for torch""" from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(tb_path) # Add model graph for visualization # with torch.no_grad(): # writer.add_graph(model, samples) # Add sample image if samples is not None: print("samples") img_grid = torchvision.utils.make_grid(images_norm(samples)) writer.add_image("data_sample", img_grid) writer.close() return writer
[docs] def tb_writer_add_scalar(writer, name_metric, val_metric, step): """Tensorboard writer: Add a scalar (loss)""" writer.add_scalar(name_metric, val_metric, step)
[docs] def tb_writer_add_image(writer, name_metric, images, step): """Tensorboard writer: Add an image)""" # Prediction img_grid = torchvision.utils.make_grid(images_norm(images)) writer.add_image(name_metric, img_grid, global_step=step)
[docs] def tb_profiler( path_prof, model, criterion, optimizer, dataloader, device, wait=1, warmup=1, active=3, repeat=2, ): """Tensorboard profiler: Profile code execution""" prof = torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), on_trace_ready=torch.profiler.tensorboard_trace_handler(path_prof), record_shapes=True, with_stack=True, ) prof.start() for step, (inputs, labels) in enumerate(dataloader): if step >= (wait + warmup + active) * repeat: break inputs = inputs.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(inputs, outputs, model) loss.backward() optimizer.step() prof.step() prof.stop()
[docs] def train_model( model, criterion, optimizer, scheduler, dataloaders, device, root, num_epochs=25, disp=False, do_checkpoint=0, tb_path=False, tb_prof=False, tb_freq=20, ): """Trains the pytorch model""" count_trainable_param(model) count_param(model) count_memory(model) since = time.time() best_loss = float("inf") best_model_wts = copy.deepcopy(model.state_dict()) dataset_sizes = {x: len(dataloaders[x]) for x in ["train", "val"]} train_info = {} train_info["train"] = [] train_info["val"] = [] # check that the folder `root` exists if do_checkpoint > 0 if (do_checkpoint > 0) and (not os.path.exists(root)): raise ValueError(f"Folder {root} not found") # Set tensorboard writer if tb_path: samples, _ = next(iter(dataloaders["val"])) samples = samples[:3, :, :, :].to(device) writer = tb_writer_init(tb_path, samples) for epoch in range(num_epochs): prev_time = time.time() # if disp : # print('Epoch {}/{}'.format(epoch, num_epochs - 1)) # print('-' * 10) # # Each epoch has a training and validation phase for phase in ["train", "val"]: if phase == "train": model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for batch_i, (inputs, labels) in enumerate(dataloaders[phase]): inputs = inputs.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == "train"): outputs = model(inputs) loss = criterion(inputs, outputs, model) # backward + optimize only if in training phase if phase == "train": loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) if torch.cuda.is_available(): torch.cuda.empty_cache() if disp: # print('{} Loss: {:.4f} '.format(phase, epoch_loss)) batches_done = epoch * len(dataloaders[phase]) + batch_i batches_left = num_epochs * len(dataloaders[phase]) - batches_done time_left = datetime.timedelta( seconds=batches_left * (time.time() - prev_time) ) prev_time = time.time() sys.stdout.write( "\r[%s] [Epoch %d/%d] [Batch %d/%d] [Loss: %f] ETA: %s " % ( phase, epoch + 1, num_epochs, batch_i + 1, len(dataloaders[phase]), loss.item(), time_left, ) ) if tb_path: if batch_i % tb_freq == 0: # Loss tb_writer_add_scalar( writer, name_metric=f"{phase}_loss", val_metric=loss.item() * inputs.size(0), step=epoch * dataset_sizes[phase] + batch_i, ) # Prediction with torch.no_grad(): samples_pred = model(samples) tb_writer_add_image( writer, name_metric="model_preds", images=samples_pred, step=epoch * dataset_sizes[phase] + batch_i, ) del outputs epoch_loss = running_loss / dataset_sizes[phase] train_info[phase].append(epoch_loss) if phase == "train": scheduler.step() if disp: print("") print("{} Loss: {:.4f} ".format(phase, epoch_loss)) # deep copy the model if phase == "val" and epoch_loss < best_loss: best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) if do_checkpoint > 0: if epoch % do_checkpoint == 0: checkpoint(root, epoch, model) time_elapsed = time.time() - since if disp: print( "Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60 ) ) print("Best val Loss: {:4f}".format(best_loss)) # Tensorboard profiler if tb_prof: tb_profiler( tb_path, model, criterion, optimizer, dataloaders["train"], device, wait=1, warmup=1, active=3, repeat=2, ) # load best model weights model.load_state_dict(best_model_wts) return model, train_info
[docs] def train_model_supervised( model, criterion, optimizer, scheduler, dataloaders, device, root, num_epochs=25, disp=False, do_checkpoint=0, ): """Trains the pytorch model in a supervised way""" since = time.time() best_loss = float("inf") best_model_wts = copy.deepcopy(model.state_dict()) dataset_sizes = {x: len(dataloaders[x]) for x in ["train", "val"]} train_info = {} train_info["train"] = [] train_info["val"] = [] for epoch in range(num_epochs): prev_time = time.time() # if disp : # print('Epoch {}/{}'.format(epoch, num_epochs - 1)) # print('-' * 10) # # Each epoch has a training and validation phase for phase in ["train", "val"]: if phase == "train": scheduler.step() model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 # Iterate over data. for batch_i, (inputs, labels) in enumerate(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == "train"): outputs = model(inputs) loss = criterion(labels, outputs, model) # backward + optimize only if in training phase if phase == "train": loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) if torch.cuda.is_available(): torch.cuda.empty_cache() if disp: # print('{} Loss: {:.4f} '.format(phase, epoch_loss)) batches_done = epoch * len(dataloaders[phase]) + batch_i batches_left = num_epochs * len(dataloaders[phase]) - batches_done time_left = datetime.timedelta( seconds=batches_left * (time.time() - prev_time) ) prev_time = time.time() sys.stdout.write( "\r[%s] [Epoch %d/%d] [Batch %d/%d] [Loss: %f] ETA: %s " % ( phase, epoch + 1, num_epochs, batch_i + 1, len(dataloaders[phase]), loss.item(), time_left, ) ) epoch_loss = running_loss / dataset_sizes[phase] train_info[phase].append(epoch_loss) if disp: print("") print("{} Loss: {:.4f} ".format(phase, epoch_loss)) # deep copy the model if phase == "val" and epoch_loss < best_loss: best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) if do_checkpoint > 0: if epoch % do_checkpoint == 0: checkpoint(root, epoch, model) time_elapsed = time.time() - since if disp: print( "Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60 ) ) print("Best val Loss: {:4f}".format(best_loss)) # load best model weights model.load_state_dict(best_model_wts) return model, train_info
[docs] class Weight_Decay_Loss(nn.Module): def __init__(self, loss): super(Weight_Decay_Loss, self).__init__() self.loss = loss
[docs] def forward(self, x, y, net): mse = self.loss(x, y) return mse
[docs] class Train_par: def __init__(self, batch_size, learning_rate, img_size, reg=0): self.batch_size = batch_size self.learning_rate = learning_rate self.img_size = img_size self.reg = reg self.train_loss = [] self.val_loss = [] self.minimum = float("inf")
[docs] def set_loss(self, train_info): self.train_loss = train_info["train"] self.val_loss = train_info["val"] self.minimum = min(self.val_loss)
def __str__(self): string1 = "Parameters:\nBatch Size : \t {} \nLearning : \t {} \n".format( self.batch_size, self.learning_rate ) string2 = "Image Size : \t {} \nRegularisation : \t {}".format( self.img_size, self.reg ) string = string1 + string2 return string
[docs] def get_loss(self): train_info = {} train_info["train"] = self.train_loss train_info["val"] = self.val_loss return train_info
[docs] def title(self): string1 = "Batch_Size_{}_Learning_{}".format( self.batch_size, self.learning_rate ) string2 = "_size_{}_Regularisation_{}".format(self.img_size, self.reg) title = string1 + string2 return title
[docs] def plot(self, start=0): plt.ion() string1 = "Batch Size : \t {} \n Learning : \t {} \n".format( self.batch_size, self.learning_rate ) string2 = "size : \t {} \nRegularisation : \t {}".format( self.img_size, self.reg ) title = string1 + string2 plt.figure(1, figsize=(20, 10)) plt.suptitle = title Epochs = [i + 1 for i in range(start, len(self.train_loss))] plt.subplot(2, 1, 1) plt.plot(Epochs, self.train_loss[start:], "o-") plt.title("Train") plt.xlabel("Epoch") plt.ylabel("Loss") plt.subplot(2, 1, 2) plt.plot(Epochs, self.val_loss[start:], ".-") plt.title("Validation") plt.xlabel("Epoch") plt.ylabel("Loss") plt.show()
[docs] def plot_log(self, start=0): plt.ion() string1 = " Learning : \t {} \n".format(self.batch_size, self.learning_rate) string2 = "size : \t {} \nRegularisation : \t {}".format( self.img_size, self.reg ) title = string1 + string2 plt.figure(1, figsize=(20, 10)) plt.suptitle = title Epochs = [i + 1 for i in range(start, len(self.train_loss))] train_loss = np.log(self.train_loss) val_loss = np.log(self.val_loss) plt.subplot(2, 1, 1) plt.plot(Epochs, train_loss[start:], "o-") plt.title("Train") plt.xlabel("Epoch") plt.ylabel("Loss") plt.subplot(2, 1, 2) plt.plot(Epochs, val_loss[start:], ".-") plt.title("Validation") plt.xlabel("Epoch") plt.ylabel("Loss") plt.show()
[docs] def multiplot(train_info1, train_info2, train_info3, start=0): plt.ion() string1 = "Learning : \t {} \n".format(train_info1.learning_rate) string2 = "size : \t {} \nRegularisation : \t {}".format( train_info1.img_size, train_info1.reg ) title = string1 + string2 plt.figure(1, figsize=(20, 10)) plt.suptitle = title Epochs = [i + 1 for i in range(start, len(train_info1.train_loss))] plt.subplot(2, 1, 1) plt.plot(Epochs, train_info1.train_loss[start:], "o-") plt.plot(Epochs, train_info2.train_loss[start:], "o-") plt.plot(Epochs, train_info3.train_loss[start:], "o-") plt.legend(["ConvNet", "U_net", "DC Model"]) plt.title("Train") plt.xlabel("Epoch") plt.ylabel("Loss") plt.subplot(2, 1, 2) plt.plot(Epochs, train_info1.val_loss[start:], ".-") plt.plot(Epochs, train_info2.val_loss[start:], ".-") plt.plot(Epochs, train_info3.val_loss[start:], ".-") plt.legend(["ConvNet", "U_net", "DC Model"]) plt.title("Validation") plt.xlabel("Epoch") plt.ylabel("Loss") plt.show()
[docs] def read_param(path): with open(path, "rb") as param_file: params = pickle.load(param_file) return params
###################################################################### # 3. Visualizing the model predictions # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ## # Function to Display reconstruction for a few images #
[docs] def boxplot(model1, model2, model3, criterion, dataloaders, device): mse = [[], [], []] model = [model1, model2, model3] model1.eval() model2.eval() model3.eval() for batch_i, (inputs, labels) in enumerate(dataloaders["val"]): if torch.cuda.is_available(): torch.cuda.empty_cache() phase = "eval" inputs = inputs.to(device) b, c, h, w = inputs.shape inputs1 = model1.forward_acquire(inputs, b, c, h, w) with torch.set_grad_enabled(phase == "train"): for i in range(3): outputs = model[i].forward_reconstruct(inputs1, b, c, h, w) for j in range(int(inputs.shape[0])): Loss = criterion(inputs[j, :, :, :], outputs[j, :, :, :], model[i]) mse[i] += [Loss.tolist()] # torch.cuda.empty_cache() fig1, ax1 = plt.subplots() ax1.set_title( "Reconstruction error (MSE)" + " with N0 ={} and M = {}".format(model1.N0, model1.M) ) ax1.boxplot( mse, labels=["ConvNet", "U-net", "Data-consistent Model"], showmeans=True, showfliers=False, ) plt.show()
[docs] def boxplotconsist(model1, model2, model3, criterion, dataloaders, device): mse = [[], [], []] model = [model1, model2, model3] model1.eval() model2.eval() model3.eval() for batch_i, (inputs, labels) in enumerate(dataloaders["val"]): if torch.cuda.is_available(): torch.cuda.empty_cache() phase = "eval" inputs = inputs.to(device) b, c, h, w = inputs.shape inputs1 = model1.forward_acquire(inputs, b, c, h, w) with torch.set_grad_enabled(phase == "train"): for i in range(3): outputs = model[i].forward_reconstruct(inputs1, b, c, h, w) reconmeasurements = model1.Pconv(outputs) measurements = model1.Pconv(inputs) # print(measurements.shape) # print(reconmeasurements.shape) print("1=", torch.max(torch.sqrt(measurements))) print(torch.min(torch.sqrt(measurements))) print(torch.max(torch.sqrt(reconmeasurements))) print(torch.min(torch.sqrt(reconmeasurements))) # normeasurements = measurements/torch.sqrt(measurements) # normreconmeasurements = reconmeasurements/torch.sqrt(measurements) # print(torch.max(torch.sqrt(normeasurements))) # print(normreconmeasurements.shape) # print(normreconmeasurements) for j in range(int(inputs.shape[0])): Loss = criterion( measurements[j, :, :, :], reconmeasurements[j, :, :, :], model[i], ) mse[i] += [Loss.tolist()] # torch.cuda.empty_cache() # torch.cuda.empty_cache() # mse=torch.tensor((1/M),device=device)*mse # mse1=[[],[],[]] # for i in range(3): # mse1[i] = [j for j in mse[i] if j<1000] # print(len(mse1[i])) fig1, ax1 = plt.subplots() ax1.set_title( "Reconstruction error over measures (MSE)" + " with N0 ={} and M = {}".format(model1.N0, model1.M) ) ax1.boxplot( mse, labels=["ConvNet", "U-net", "Data-consistent Model"], showmeans=True, showfliers=False, ) plt.show()
[docs] def visualize_model(model, dataloaders, device, suptitle="", colormap=plt.cm.gray): """ Takes 8 images from the dataloader and shows side by side the input image and the reconstructed image """ plt.ion() # interactive mode inputs, classes = next(iter(dataloaders["train"])) while inputs.shape[0] < 8: next_input, classes = next(iter(dataloaders["train"])) inputs = torch.cat((inputs, next_input), 0) inputs = inputs.to(device) with torch.no_grad(): outputs = model(inputs) inputs = inputs.cpu().detach().numpy() outputs = outputs.cpu().detach().numpy() fig, axarr = plt.subplots(4, 4, figsize=(20, 20)) # plt.suptitle(suptitle, fontsize = 16) for i in range(4): for j in range(2): im1 = axarr[i, 2 * j].imshow(inputs[2 * i + j, 0, :, :], cmap=colormap) axarr[i, 2 * j].set_title("Ground Truth") im2 = axarr[i, 2 * j + 1].imshow(outputs[2 * i + j, 0, :, :], cmap=colormap) axarr[i, 2 * j + 1].set_title("Reconstructed") plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) plt.show()
[docs] def compare_model( model1, model2, model3, dataloaders, device, suptitle="", colormap=plt.cm.gray ): """ Compare three models """ plt.ion() # interactive mode inputs, classes = next(iter(dataloaders["train"])) while inputs.shape[0] < 4: next_input, classes = next(iter(dataloaders["train"])) inputs = torch.cat((inputs, next_input), 0) inputs = inputs[:4, :, :, :] inputs = inputs.to(device) model1 = model1.to(device) model2 = model2.to(device) model3 = model3.to(device) with torch.no_grad(): outputs1 = model1(inputs) outputs2 = model2(inputs) outputs3 = model3(inputs) inputs = inputs.cpu().detach().numpy() outputs1 = outputs1.cpu().detach().numpy() outputs3 = outputs3.cpu().detach().numpy() outputs2 = outputs2.cpu().detach().numpy() fig, axarr = plt.subplots(4, 4, figsize=(20, 20)) # plt.suptitle(suptitle, fontsize = 16) for i in range(4): j = 0 im1 = axarr[i, 2 * j].imshow(inputs[i, 0, :, :], cmap=colormap) axarr[i, 2 * j].set_title("Ground Truth") im2 = axarr[i, 2 * j + 1].imshow(outputs1[i, 0, :, :], cmap=colormap) axarr[i, 2 * j + 1].set_title("Reconstructed with ConvNet") j = 1 im1 = axarr[i, 2 * j].imshow(outputs2[i, 0, :, :], cmap=colormap) axarr[i, 2 * j].set_title("Reconstructed with Unet") im2 = axarr[i, 2 * j + 1].imshow(outputs3[i, 0, :, :], cmap=colormap) axarr[i, 2 * j + 1].set_title("Reconstructed with DC model") plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) plt.show()
[docs] def visualize_conv_layers(conv_layer, suptitle="", colormap=plt.cm.gray): """Displays the 8 first filters of the convolution layer conv_layer""" params = list(conv_layer.parameters()) conv_filters = params[0] conv_filters = conv_filters.cpu().detach().numpy() plt.ion() (nb_filters, entry_channels, s_x, s_y) = conv_filters.shape fig, axarr = plt.subplots(2, 4, figsize=(20, 10)) # plt.suptitle(suptitle, fontsize=16); for i in range(2): for j in range(4): nb_pat = 4 * i + j if nb_filters > nb_pat: Img = conv_filters[nb_pat, 0, :, :] else: Img = np.zeros((s_x, s_y)) im = axarr[i, j].imshow(Img, cmap=colormap) cax = plt.axes([0.02 + (j + 1) * 0.225, 0.6 - i * 0.43, 0.005, 0.25]) plt.colorbar(im, cax=cax) plt.subplots_adjust(left=0.08, wspace=0.5, top=0.9, right=0.9) plt.show()
###################################################################### # 4. Saving and loading the model so that it can later be utilized # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #
[docs] def checkpoint(root, epoch, model): """Saves the dictionaries of a given pytorch model for the right epoch """ model_out_path = "model_epoch_{}.pth".format(epoch) model_out_path = root / model_out_path torch.save(model.state_dict(), model_out_path) print("Checkpoint saved to {}".format(model_out_path))
[docs] def save_net(title, model): """Saves dictionaries of a given pytorch model in the place defined by title """ model_out_path = title # "{}.pth".format(title) print(model_out_path) torch.save(model.state_dict(), model_out_path) print("Model Saved")
[docs] def load_net(path, model, device=None, strict=True): """Loads network defined by path into model. The network is loaded in-place Args: :attr:`path` (str): full path to the model, must contain file extension :attr:`model` (torch.nn.Module): model to load the weights into. The model must have the same architecture as the model that was saved. :attr:`device` (str): device to load the model on. If None, the model is loaded on the cpu. :attr:`strict` (bool): this argument is passed to the `load_state_dict` of the `nn.Module`. If True, the keys of the state_dict and the model must match exactly. If there is a mismatch, an exception is raised. Returns: `None` """ # if title.endswith(".pth"): # model_out_path = "{}".format(title) # else: model_out_path = path try: if device is None: model.load_state_dict( torch.load(model_out_path, weights_only=True), strict=strict ) else: model.load_state_dict( torch.load( model_out_path, weights_only=True, map_location=torch.device(device) ), strict=strict, ) print("Model Loaded: {}".format(path)) except: if os.path.isfile(model_out_path): print("Model not loaded at {}".format(model_out_path)) else: print("Model not found at {}".format(model_out_path))
[docs] def rename_model_attributes(source, old_name, new_name, target=None): """ Rename the name of the attributes of a saved model (nn.module) Parameters ---------- source : str Path to the saved model. old_name : str source pattern for the attributes of the model to be renamed. new_name : str destination pattern for the attributes of the model to be renamed. target : str, optional Path to model with remaned attributes. The default is source. Returns ------- None. Example ------- Rename the key `Denoi.layer.0.weight` and `Denoi.layer.0.weight` as `denoi.layer.0.weight` and `Denoi.layer.0.weight` and save the resulting model as `target.pth` Adapted from https://gist.github.com/the-bass/0bf8aaa302f9ba0d26798b11e4dd73e3 """ if target is None: target = source state_dict = torch.load(source) new_state_dict = OrderedDict() for key, value in state_dict.items(): new_key = attr_transformation(key, old_name, new_name) new_state_dict[new_key] = value print(f"{key} -> {new_key} ") torch.save(new_state_dict, target)
[docs] def remove_model_attributes(source, old_name, target=None): """ Remove some attributes of a saved model (nn.module) Parameters ---------- source : str Path to the saved model. old_name : str source pattern for the attributes of the model to be removed. target : str, optional Path to model with remaned attributes. The default is source. Returns ------- None. Example ------- Remove the attribute `Denoi` of the model saved as `source`. The resulting model is saved as `target.pth` """ if target is None: target = source state_dict = torch.load(source) new_state_dict = OrderedDict() for key, value in state_dict.items(): m = attr_removal(key, old_name) if m: print(f"{key} has been removed") else: new_state_dict[key] = value print(f"{key} -> {key}") torch.save(new_state_dict, target)
[docs] def attr_transformation(old_key, old_name, new_name): new_key = re.sub(old_name, new_name, old_key) return new_key
[docs] def attr_removal(old_key, old_name): new_key = re.match(old_name, old_key) return new_key