Open In Colab   Open in Kaggle

Tutorial 3: Meta-learning#

Week 2, Day 4: Macro-Learning

By Neuromatch Academy

Content creators: Hlib Solodzhuk, Ximeng Mao, Grace Lindsay

Content reviewers: Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk, Ximeng Mao, Samuele Bolotta, Grace Lindsay

Production editors: Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk


Tutorial Objectives#

Estimated timing of tutorial: 50 minutes

In this tutorial, you will examine how meta-learning separates the problem of continual learning into two stages.


Setup#

Install and import feedback gadget#

Hide code cell source
# @title Install and import feedback gadget

!pip install vibecheck datatops --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_neuroai",
            "user_key": "wb2cxze8",
        },
    ).render()


feedback_prefix = "W2D4_T3"

Imports#

Hide code cell source
# @title Imports

#working with data
import numpy as np
from functools import partial

#plotting
import matplotlib.pyplot as plt
import logging
from sklearn.decomposition import PCA
from matplotlib.lines import Line2D

#interactive display
import ipywidgets as widgets

#modeling
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import r2_score

#utils
from tqdm import tqdm

Figure settings#

Hide code cell source
# @title Figure settings

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

Plotting functions#

Hide code cell source
# @title Plotting functions

def plot_tasks(task_days, task_prices):
    """
    Plot the tasks' prices over time.

    Inputs:
    - task_days (list): A list of three lists, where each sub-list contains the days for a specific task.
    - task_prices (list): A list of three lists, where each sub-list contains the prices for a specific task.
    """
    sorted_first_task_days, sorted_first_task_prices = zip(*sorted(zip(task_days[0], task_prices[0]), key=lambda pair: pair[0]))
    sorted_second_task_days, sorted_second_task_prices = zip(*sorted(zip(task_days[1], task_prices[1]), key=lambda pair: pair[0]))
    sorted_third_task_days, sorted_third_task_prices = zip(*sorted(zip(task_days[2], task_prices[2]), key=lambda pair: pair[0]))

    with plt.xkcd():
      plt.plot(sorted_first_task_days, sorted_first_task_prices, label = "First Task")
      plt.plot(sorted_second_task_days, sorted_second_task_prices, label = "Second Task")
      plt.plot(sorted_third_task_days, sorted_third_task_prices, label = "Third Task")
      plt.xlabel('Week')
      plt.ylabel('Price')
      plt.legend()
      plt.show()

def plot_inner_outer_weights(pca_parameters, epoch):
    """
    Plot PCA-transformed outer weights of the model in 2D over the epochs as well as inner / outer weights for the given epoch

    Inputs:
    - pca_parameters (np.ndarray): array of model parameters (already in 2D).
    - epoch (int): given epoch.
    """
    with plt.xkcd():
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        #plot points for the given epoch
        for j in range(pca_parameters.shape[1]):
            ax1.scatter(pca_parameters[epoch - 1, j, 0], pca_parameters[epoch - 1, j, 1], color='k', s=10)

        start_point = pca_parameters[epoch - 1, 0]

        #plot arrows from start point to all other
        for j in range(1, pca_parameters.shape[1]):
            #inner
            end_point = pca_parameters[epoch - 1, j]
            arrow_color = 'g'
            #outer
            if j == pca_parameters.shape[1] - 1:
                arrow_color = 'b'
            ax1.annotate('', xy=end_point, xytext=start_point,
                    arrowprops=dict(arrowstyle='->', color=arrow_color))
        #plot arrows for previous outer
        for j in range(epoch - 1):
            ax1.scatter(pca_parameters[j, 0, 0], pca_parameters[j, 0, 1], color='k', s=10)
            start_point = pca_parameters[j, 0]
            end_point = pca_parameters[j + 1, 0]
            ax1.annotate('', xy=end_point, xytext=start_point,
                    arrowprops=dict(arrowstyle='->', color='b', alpha = 0.2))

        #plot points for the given epoch
        for j in range(pca_parameters.shape[1]):
            ax2.scatter(pca_parameters[epoch - 1, j, 0], pca_parameters[epoch - 1, j, 1], color='k', s=10)

        start_point = pca_parameters[epoch - 1, 0]

        #plot arrows from start point to all other
        for j in range(1, pca_parameters.shape[1]):
            #inner
            end_point = pca_parameters[epoch - 1, j]
            arrow_color = 'g'
            #outer
            if j == pca_parameters.shape[1] - 1:
                arrow_color = 'b'
            ax2.annotate('', xy=end_point, xytext=start_point,
                    arrowprops=dict(arrowstyle='->', color=arrow_color))

        ax1.set_title("Outer weights evolution across epochs")
        ax2.set_title(f"Inner and outer weights for epoch {epoch}")
        # Create legend handles
        inner_arrow = Line2D([0], [0], color='g', lw=2, label='Inner weights')
        outer_arrow = Line2D([0], [0], color='b', lw=2, label='Outer weights')

        # Add legend to the second subplot (ax2)
        ax2.legend(handles=[inner_arrow, outer_arrow], loc='upper left')

        fig.suptitle(f'Epoch {epoch}', fontsize=16)
        plt.show()

def value_to_saturation(value, vmin, vmax):
    """
    Return saturation of the point based on the min/max values in the array.

    Inputs:
    - value (float): value of point.
    - vmin (float): min value in all points.
    - vmax (float): max value in all points.
    """
    norm_value = (value - vmin) / (vmax - vmin)
    saturation = 0.2 + 0.8 * norm_value
    return saturation

def plot_sensitivity_r_squared(name, list_gradient_steps, list_num_samples_finetune, fix_scale = False):
    """Performs fine-tuning for a couple of tasks for different hyperparameter values and plots 3D sensitivity plot.

    Inputs:
    - name (str): name of the model's file.
    - gradient_steps (np.ndarray): list of number of steps to perform gradient descent.
    - num_samples_finetune (np.ndarray) list of number of samples.
    - fix_scale (bool, default = False): whether to fix the same values of R-squared metric for both plots.
    """
    model_path = name + '.pt'
    meta_model = MetaLearningModel(model = model_path, mean = days_mean, std = days_std)
    dataset = FruitSupplyDataset(also_sample_outer = False)

    tasks = [[0.005, 0.1, 0.0, 1.0], [-0.005, 0.1, 0.0, 4.0]]

    cmap = plt.colormaps.get_cmap('Reds')

    with plt.xkcd():
        legend_num_samples_finetune = []
        legend_gradient_steps = []
        legend_r_squared_score = []
        prices = tasks[0][0] * days ** 2 + tasks[0][1] * np.sin(np.pi * days + tasks[0][2]) + tasks[0][3]
        for num_samples_finetune in list_num_samples_finetune:
            x_finetune, y_finetune = dataset.sample_particular_task(*tasks[0], num_samples_finetune)
            for gradient_steps in list_gradient_steps:
                if gradient_steps:
                    prediction = finetune(meta_model, torch.tensor(x_finetune).type(torch.float32),  torch.tensor(y_finetune).type(torch.float32), gradient_steps)(torch.tensor((np.expand_dims(days, 1) - days_mean) / days_std).type(torch.float32)).detach().numpy()
                    legend_num_samples_finetune.append(num_samples_finetune)
                    legend_gradient_steps.append(gradient_steps)
                    legend_r_squared_score.append(r2_score(prices, prediction))


        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), subplot_kw={'projection': '3d'})

        vmin = np.min(legend_r_squared_score)
        vmax = np.max(legend_r_squared_score)
        colors = [cmap(value_to_saturation(value, vmin, vmax)) for value in legend_r_squared_score]

        ax1.scatter(legend_num_samples_finetune, legend_gradient_steps, legend_r_squared_score, c=colors, marker='o')
        ax1.set_xlabel('Number of samples')
        ax1.set_ylabel('Number of gradient steps')
        ax1.set_zlabel('R-squared score')
        ax1.set_title('Positive Squared Term Task')

        legend_num_samples_finetune = []
        legend_gradient_steps = []
        legend_r_squared_score = []
        prices = tasks[1][0] * days ** 2 + tasks[1][1] * np.sin(np.pi * days + tasks[1][2]) + tasks[1][3]
        for num_samples_finetune in list_num_samples_finetune:
            x_finetune, y_finetune = dataset.sample_particular_task(*tasks[1], num_samples_finetune)
            for gradient_steps in list_gradient_steps:
                if gradient_steps:
                    prediction = finetune(meta_model, torch.tensor(x_finetune).type(torch.float32),  torch.tensor(y_finetune).type(torch.float32), gradient_steps)(torch.tensor((np.expand_dims(days, 1) - days_mean) / days_std).type(torch.float32)).detach().numpy()
                    legend_num_samples_finetune.append(num_samples_finetune)
                    legend_gradient_steps.append(gradient_steps)
                    legend_r_squared_score.append(r2_score(prices, prediction))

        vmin = np.min(legend_r_squared_score)
        vmax = np.max(legend_r_squared_score)
        colors = [cmap(value_to_saturation(value, vmin, vmax)) for value in legend_r_squared_score]

        ax2.scatter(legend_num_samples_finetune, legend_gradient_steps, legend_r_squared_score, c=colors, marker='o')
        ax2.set_xlabel('Number of samples')
        ax2.set_ylabel('Number of gradient steps')
        ax2.set_zlabel('R-squared score')
        ax2.set_title('Negative Squared Term Task')

        if fix_scale:
            ax1.set_zlim(0.65, 1)
            ax2.set_zlim(0.65, 1)

        plt.show()

Helper functions#

Hide code cell source
# @title Helper functions

class UtilModel(nn.Module):
    def __init__(self, model, mean = 0, std = 1, outer_learning_rate=0.001, inner_learning_rate=0.01):
        """Super class for model; hide utility code.
        """
        super(UtilModel, self).__init__()

        self.model = self.__load_model_from_context(model)

        self.outer_learning_rate = outer_learning_rate
        self.inner_learning_rate = inner_learning_rate

        self.mean = mean
        self.std = std

        self.loss_fn = nn.MSELoss()

    def __load_model_from_context(self, model):
        """Load weights of the model from file or as defined architecture.
        """
        if isinstance(model, str):
            return torch.load(model)
        return model

    def deep_clone_model(self, model):
        """Create clone of the model.
        """
        clone = type(model)()
        clone.load_state_dict(model.state_dict())
        return clone

    def save_parameters(self, path):
        """Save the parameters as a state dictionary.
        """
        torch.save(self.model, path)

    def inference(self, x):
        """Implement forward pass for inference.
        """
        #apply normalization on days
        x = (x - self.mean) / self.std
        return self.model(x)

    def manual_output(self, weights, x):
        """Calculate the result of forward pass on the external values of the model parameters (weights).
        """
        for j in range(len(weights) // 2):
            kernel, bias = weights[2 * j], weights[2 * j + 1]
            if j == len(weights) // 2 - 1:
                #last layer doesn't possess ReLU activation
                return F.linear(x, kernel, bias = bias)
            else:
                x = F.relu(F.linear(x, kernel, bias = bias))

days = np.arange(-26, 26 + 1/7, 1/7, dtype = np.float32)

class FruitSupplyDatasetComplete(Dataset):
    def __init__(self, num_epochs = 1, num_tasks = 1, num_samples = 1, days = days, also_sample_outer = True):
        """Initialize particular instance of `FruitSupplyDataset` dataset.

        Inputs:
        - num_epochs (int): Number of epochs the model is going to be trained on.
        - num_tasks (int): Number of tasks to sample for each epoch (the loss and improvement is going to be represented as sum over considered tasks).
        - num_samples (int): Number of days to sample for each task.
        - days (np.ndarray): Summer and autumn days to sample from.
        - also_sample_outer (bool): `True` if we want to sample inner and outer data (necessary for training).

        Raises:
        - ValueError: If the number of sampled days `num_samples` exceeds number of days to sample from.
        """

        if also_sample_outer:
            if num_samples > days.shape[0] // 2:
                raise ValueError("Number of sampled days for one task should be less or equal to the total amount of days divided by two as we sample inner and outer data.")
        else:
            if num_samples > days.shape[0]:
                raise ValueError("Number of sampled days for one task should be less or equal to the total amount of days.")

        #total amount of data is (2/4 x num_epochs x num_tasks x num_samples) (2/4 because -> x_inner, x_outer, y_inner, y_outer; outer is optional)
        self.num_epochs = num_epochs
        self.num_tasks = num_tasks
        self.num_samples = num_samples
        self.also_sample_outer = also_sample_outer
        self.days = days

    def __len__(self):
        """Calculate the length of the dataset. It is obligatory for PyTorch to know in advance how many samples to expect (before training),
        thus we enforced to icnlude number of epochs and tasks per epoch in `FruitSupplyDataset` parameters."""

        return self.num_epochs * self.num_tasks

    def __getitem__(self, idx):
        """Generate particular instance of task with prefined number of samples `num_samples`."""

        A = np.random.uniform(min_A, max_A, size = 1)
        B = np.random.uniform(min_B, max_B, size = 1)
        phi = np.random.uniform(min_phi, max_phi, size = 1)
        C = np.random.uniform(min_C, max_C, size = 1)

        #`replace = False` is important flag here as we don't want repeated data
        inner_sampled_days = np.expand_dims(np.random.choice(self.days, size = self.num_samples, replace = False), 1)

        if self.also_sample_outer:

            #we don't want inner and outer data to overlap
            outer_sampled_days = np.expand_dims(np.random.choice(np.setdiff1d(self.days, inner_sampled_days), size = self.num_samples, replace = False), 1)

            return inner_sampled_days, A * inner_sampled_days ** 2 + B * np.sin(np.pi * inner_sampled_days + phi) + C, outer_sampled_days, A * outer_sampled_days ** 2 + B * np.sin(np.pi * outer_sampled_days + phi) + C

        return inner_sampled_days, A * inner_sampled_days ** 2 + B * np.sin(np.pi * inner_sampled_days + phi) + C

    def sample_particular_task(self, A, B, phi, C, num_samples):
        """Samples for the particular instance of the task defined by the tuple of parameters (A, B, phi, C) and `num_samples`."""

        sampled_days = np.expand_dims(np.random.choice(self