Tutorial 3: Meta-learning#
Week 2, Day 4: Macro-Learning
By Neuromatch Academy
Content creators: Hlib Solodzhuk, Ximeng Mao, Grace Lindsay
Content reviewers: Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk, Ximeng Mao, Samuele Bolotta, Grace Lindsay
Production editors: Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk
Tutorial Objectives#
Estimated timing of tutorial: 50 minutes
In this tutorial, you will examine how meta-learning separates the problem of continual learning into two stages.
Setup#
Install and import feedback gadget#
Show code cell source
# @title Install and import feedback gadget
!pip install vibecheck datatops --quiet
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_neuroai",
"user_key": "wb2cxze8",
},
).render()
feedback_prefix = "W2D4_T3"
Imports#
Show code cell source
# @title Imports
#working with data
import numpy as np
from functools import partial
#plotting
import matplotlib.pyplot as plt
import logging
from sklearn.decomposition import PCA
from matplotlib.lines import Line2D
#interactive display
import ipywidgets as widgets
#modeling
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import r2_score
#utils
from tqdm import tqdm
Figure settings#
Show code cell source
# @title Figure settings
logging.getLogger('matplotlib.font_manager').disabled = True
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")
Plotting functions#
Show code cell source
# @title Plotting functions
def plot_tasks(task_days, task_prices):
"""
Plot the tasks' prices over time.
Inputs:
- task_days (list): A list of three lists, where each sub-list contains the days for a specific task.
- task_prices (list): A list of three lists, where each sub-list contains the prices for a specific task.
"""
sorted_first_task_days, sorted_first_task_prices = zip(*sorted(zip(task_days[0], task_prices[0]), key=lambda pair: pair[0]))
sorted_second_task_days, sorted_second_task_prices = zip(*sorted(zip(task_days[1], task_prices[1]), key=lambda pair: pair[0]))
sorted_third_task_days, sorted_third_task_prices = zip(*sorted(zip(task_days[2], task_prices[2]), key=lambda pair: pair[0]))
with plt.xkcd():
plt.plot(sorted_first_task_days, sorted_first_task_prices, label = "First Task")
plt.plot(sorted_second_task_days, sorted_second_task_prices, label = "Second Task")
plt.plot(sorted_third_task_days, sorted_third_task_prices, label = "Third Task")
plt.xlabel('Week')
plt.ylabel('Price')
plt.legend()
plt.show()
def plot_inner_outer_weights(pca_parameters, epoch):
"""
Plot PCA-transformed outer weights of the model in 2D over the epochs as well as inner / outer weights for the given epoch
Inputs:
- pca_parameters (np.ndarray): array of model parameters (already in 2D).
- epoch (int): given epoch.
"""
with plt.xkcd():
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
#plot points for the given epoch
for j in range(pca_parameters.shape[1]):
ax1.scatter(pca_parameters[epoch - 1, j, 0], pca_parameters[epoch - 1, j, 1], color='k', s=10)
start_point = pca_parameters[epoch - 1, 0]
#plot arrows from start point to all other
for j in range(1, pca_parameters.shape[1]):
#inner
end_point = pca_parameters[epoch - 1, j]
arrow_color = 'g'
#outer
if j == pca_parameters.shape[1] - 1:
arrow_color = 'b'
ax1.annotate('', xy=end_point, xytext=start_point,
arrowprops=dict(arrowstyle='->', color=arrow_color))
#plot arrows for previous outer
for j in range(epoch - 1):
ax1.scatter(pca_parameters[j, 0, 0], pca_parameters[j, 0, 1], color='k', s=10)
start_point = pca_parameters[j, 0]
end_point = pca_parameters[j + 1, 0]
ax1.annotate('', xy=end_point, xytext=start_point,
arrowprops=dict(arrowstyle='->', color='b', alpha = 0.2))
#plot points for the given epoch
for j in range(pca_parameters.shape[1]):
ax2.scatter(pca_parameters[epoch - 1, j, 0], pca_parameters[epoch - 1, j, 1], color='k', s=10)
start_point = pca_parameters[epoch - 1, 0]
#plot arrows from start point to all other
for j in range(1, pca_parameters.shape[1]):
#inner
end_point = pca_parameters[epoch - 1, j]
arrow_color = 'g'
#outer
if j == pca_parameters.shape[1] - 1:
arrow_color = 'b'
ax2.annotate('', xy=end_point, xytext=start_point,
arrowprops=dict(arrowstyle='->', color=arrow_color))
ax1.set_title("Outer weights evolution across epochs")
ax2.set_title(f"Inner and outer weights for epoch {epoch}")
# Create legend handles
inner_arrow = Line2D([0], [0], color='g', lw=2, label='Inner weights')
outer_arrow = Line2D([0], [0], color='b', lw=2, label='Outer weights')
# Add legend to the second subplot (ax2)
ax2.legend(handles=[inner_arrow, outer_arrow], loc='upper left')
fig.suptitle(f'Epoch {epoch}', fontsize=16)
plt.show()
def value_to_saturation(value, vmin, vmax):
"""
Return saturation of the point based on the min/max values in the array.
Inputs:
- value (float): value of point.
- vmin (float): min value in all points.
- vmax (float): max value in all points.
"""
norm_value = (value - vmin) / (vmax - vmin)
saturation = 0.2 + 0.8 * norm_value
return saturation
def plot_sensitivity_r_squared(name, list_gradient_steps, list_num_samples_finetune, fix_scale = False):
"""Performs fine-tuning for a couple of tasks for different hyperparameter values and plots 3D sensitivity plot.
Inputs:
- name (str): name of the model's file.
- gradient_steps (np.ndarray): list of number of steps to perform gradient descent.
- num_samples_finetune (np.ndarray) list of number of samples.
- fix_scale (bool, default = False): whether to fix the same values of R-squared metric for both plots.
"""
model_path = name + '.pt'
meta_model = MetaLearningModel(model = model_path, mean = days_mean, std = days_std)
dataset = FruitSupplyDataset(also_sample_outer = False)
tasks = [[0.005, 0.1, 0.0, 1.0], [-0.005, 0.1, 0.0, 4.0]]
cmap = plt.colormaps.get_cmap('Reds')
with plt.xkcd():
legend_num_samples_finetune = []
legend_gradient_steps = []
legend_r_squared_score = []
prices = tasks[0][0] * days ** 2 + tasks[0][1] * np.sin(np.pi * days + tasks[0][2]) + tasks[0][3]
for num_samples_finetune in list_num_samples_finetune:
x_finetune, y_finetune = dataset.sample_particular_task(*tasks[0], num_samples_finetune)
for gradient_steps in list_gradient_steps:
if gradient_steps:
prediction = finetune(meta_model, torch.tensor(x_finetune).type(torch.float32), torch.tensor(y_finetune).type(torch.float32), gradient_steps)(torch.tensor((np.expand_dims(days, 1) - days_mean) / days_std).type(torch.float32)).detach().numpy()
legend_num_samples_finetune.append(num_samples_finetune)
legend_gradient_steps.append(gradient_steps)
legend_r_squared_score.append(r2_score(prices, prediction))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})
vmin = np.min(legend_r_squared_score)
vmax = np.max(legend_r_squared_score)
colors = [cmap(value_to_saturation(value, vmin, vmax)) for value in legend_r_squared_score]
ax1.scatter(legend_num_samples_finetune, legend_gradient_steps, legend_r_squared_score, c=colors, marker='o')
ax1.set_xlabel('Number of samples')
ax1.set_ylabel('Number of gradient steps')
ax1.set_zlabel('R-squared score')
ax1.set_title('Positive Squared Term Task')
legend_num_samples_finetune = []
legend_gradient_steps = []
legend_r_squared_score = []
prices = tasks[1][0] * days ** 2 + tasks[1][1] * np.sin(np.pi * days + tasks[1][2]) + tasks[1][3]
for num_samples_finetune in list_num_samples_finetune:
x_finetune, y_finetune = dataset.sample_particular_task(*tasks[1], num_samples_finetune)
for gradient_steps in list_gradient_steps:
if gradient_steps:
prediction = finetune(meta_model, torch.tensor(x_finetune).type(torch.float32), torch.tensor(y_finetune).type(torch.float32), gradient_steps)(torch.tensor((np.expand_dims(days, 1) - days_mean) / days_std).type(torch.float32)).detach().numpy()
legend_num_samples_finetune.append(num_samples_finetune)
legend_gradient_steps.append(gradient_steps)
legend_r_squared_score.append(r2_score(prices, prediction))
vmin = np.min(legend_r_squared_score)
vmax = np.max(legend_r_squared_score)
colors = [cmap(value_to_saturation(value, vmin, vmax)) for value in legend_r_squared_score]
ax2.scatter(legend_num_samples_finetune, legend_gradient_steps, legend_r_squared_score, c=colors, marker='o')
ax2.set_xlabel('Number of samples')
ax2.set_ylabel('Number of gradient steps')
ax2.set_zlabel('R-squared score')
ax2.set_title('Negative Squared Term Task')
if fix_scale:
ax1.set_zlim(0.65, 1)
ax2.set_zlim(0.65, 1)
plt.show()
Helper functions#
Show code cell source
# @title Helper functions
class UtilModel(nn.Module):
def __init__(self, model, mean = 0, std = 1, outer_learning_rate=0.001, inner_learning_rate=0.01):
"""Super class for model; hide utility code.
"""
super(UtilModel, self).__init__()
self.model = self.__load_model_from_context(model)
self.outer_learning_rate = outer_learning_rate
self.inner_learning_rate = inner_learning_rate
self.mean = mean
self.std = std
self.loss_fn = nn.MSELoss()
def __load_model_from_context(self, model):
"""Load weights of the model from file or as defined architecture.
"""
if isinstance(model, str):
return torch.load(model)
return model
def deep_clone_model(self, model):
"""Create clone of the model.
"""
clone = type(model)()
clone.load_state_dict(model.state_dict())
return clone
def save_parameters(self, path):
"""Save the parameters as a state dictionary.
"""
torch.save(self.model, path)
def inference(self, x):
"""Implement forward pass for inference.
"""
#apply normalization on days
x = (x - self.mean) / self.std
return self.model(x)
def manual_output(self, weights, x):
"""Calculate the result of forward pass on the external values of the model parameters (weights).
"""
for j in range(len(weights) // 2):
kernel, bias = weights[2 * j], weights[2 * j + 1]
if j == len(weights) // 2 - 1:
#last layer doesn't possess ReLU activation
return F.linear(x, kernel, bias = bias)
else:
x = F.relu(F.linear(x, kernel, bias = bias))
days = np.arange(-26, 26 + 1/7, 1/7, dtype = np.float32)
class FruitSupplyDatasetComplete(Dataset):
def __init__(self, num_epochs = 1, num_tasks = 1, num_samples = 1, days = days, also_sample_outer = True):
"""Initialize particular instance of `FruitSupplyDataset` dataset.
Inputs:
- num_epochs (int): Number of epochs the model is going to be trained on.
- num_tasks (int): Number of tasks to sample for each epoch (the loss and improvement is going to be represented as sum over considered tasks).
- num_samples (int): Number of days to sample for each task.
- days (np.ndarray): Summer and autumn days to sample from.
- also_sample_outer (bool): `True` if we want to sample inner and outer data (necessary for training).
Raises:
- ValueError: If the number of sampled days `num_samples` exceeds number of days to sample from.
"""
if also_sample_outer:
if num_samples > days.shape[0] // 2:
raise ValueError("Number of sampled days for one task should be less or equal to the total amount of days divided by two as we sample inner and outer data.")
else:
if num_samples > days.shape[0]:
raise ValueError("Number of sampled days for one task should be less or equal to the total amount of days.")
#total amount of data is (2/4 x num_epochs x num_tasks x num_samples) (2/4 because -> x_inner, x_outer, y_inner, y_outer; outer is optional)
self.num_epochs = num_epochs
self.num_tasks = num_tasks
self.num_samples = num_samples
self.also_sample_outer = also_sample_outer
self.days = days
def __len__(self):
"""Calculate the length of the dataset. It is obligatory for PyTorch to know in advance how many samples to expect (before training),
thus we enforced to icnlude number of epochs and tasks per epoch in `FruitSupplyDataset` parameters."""
return self.num_epochs * self.num_tasks
def __getitem__(self, idx):
"""Generate particular instance of task with prefined number of samples `num_samples`."""
A = np.random.uniform(min_A, max_A, size = 1)
B = np.random.uniform(min_B, max_B, size = 1)
phi = np.random.uniform(min_phi, max_phi, size = 1)
C = np.random.uniform(min_C, max_C, size = 1)
#`replace = False` is important flag here as we don't want repeated data
inner_sampled_days = np.expand_dims(np.random.choice(self.days, size = self.num_samples, replace = False), 1)
if self.also_sample_outer:
#we don't want inner and outer data to overlap
outer_sampled_days = np.expand_dims(np.random.choice(np.setdiff1d(self.days, inner_sampled_days), size = self.num_samples, replace = False), 1)
return inner_sampled_days, A * inner_sampled_days ** 2 + B * np.sin(np.pi * inner_sampled_days + phi) + C, outer_sampled_days, A * outer_sampled_days ** 2 + B * np.sin(np.pi * outer_sampled_days + phi) + C
return inner_sampled_days, A * inner_sampled_days ** 2 + B * np.sin(np.pi * inner_sampled_days + phi) + C
def sample_particular_task(self, A, B, phi, C, num_samples):
"""Samples for the particular instance of the task defined by the tuple of parameters (A, B, phi, C) and `num_samples`."""
sampled_days = np.expand_dims(np.random.choice(self