{ "cells": [ { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 1: Microlearning\n", "\n", "**Week 2, Day 3: Microlearning**\n", "\n", "**By Neuromatch Academy** \n", "\n", "__Content creators:__ Blake Richards, Roman Pogodin, Daniel Levenstein, Colin Bredenberg, Jonathan Cornford\n", "\n", "__Content reviewers:__ Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Samuele Bolotta, Patrick Mineault, Hlib Solodzhuk\n", "\n", "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "___\n", "\n", "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 2 hours*\n", "\n", "In this tutorial, you will learn about normative models of synaptic plasticity. Normative models of synaptic plasticity are learning rules for parameters in neural networks that have two important features:\n", "\n", " * They optimize global objective functions that define behavioral/perceptual goals for an agent.\n", "\n", " * Unlike learning algorithms like backpropagation, they demonstrate how learning is 'local', i.e. it uses only information that could conceivably be available to a single synapse.\n", "\n", "These two features together make such learning algorithms good candidate models for how learning could work in the brain.\n", "\n", "In this tutorial, we will:\n", "\n", "* Relate local plasticity rules to estimates of loss gradients.\n", "* Understand the impact of variance and bias in gradient estimators and how they affect the scalability, performance, and generalization capabilities of learning algorithms.\n", "* Implement 2-3 learning rules in toy tasks.\n", "* Describe issues with biological plausibility in some learning algorithms, most notably, weight transport." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/rx89q/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/rx89q/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip install vibecheck datatops matplotlib numpy torch pandas torchvision tqdm --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_neuroai\",\n", " \"user_key\": \"wb2cxze8\",\n", " },\n", " ).render()\n", "\n", "feedback_prefix = \"W2D3_T1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Import dependencies\n", "\n", "# Standard library imports\n", "import logging\n", "from datetime import datetime\n", "import pdb # we encourage you to use the debugger, rather than print statements!\n", "import time\n", "\n", "# Third-party imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision import datasets, transforms\n", "from tqdm import tqdm\n", "from IPython.display import display, HTML" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper functions\n", "\n", "# The sigmoid activation function\n", "def sigmoid(X):\n", " \"\"\"\n", " Returns the sigmoid function, i.e. 1/(1+exp(-X))\n", " \"\"\"\n", "\n", " # to avoid runtime warnings, if abs(X) is more than 500, we just cap it there\n", " Y = X.copy() # this ensures we don't overwrite entries in X - Python can be a trickster!\n", " toobig = X > 500\n", " toosmall = X < -500\n", " Y[toobig] = 500\n", " Y[toosmall] = -500\n", "\n", " return 1.0 / (1.0 + np.exp(-Y))\n", "\n", "# The ReLU activation function\n", "def ReLU(X):\n", " \"\"\"\n", " Returns the ReLU function, i.e. X if X > 0, 0 otherwise\n", " \"\"\"\n", "\n", " # to avoid runtime warnings, if abs(X) is more than 500, we just cap it there\n", " Y = X.copy() # this ensures we don't overwrite entries in X - Python can be a trickster!\n", " neg = X < 0\n", " Y[neg] = 0\n", "\n", " return Y\n", "\n", "\n", "# A helper function to add an \"always on\" unit to the inputs, let's us keep the biases in the weight matrices\n", "def add_bias(inputs):\n", " \"\"\"\n", " Append an \"always on\" bias unit to some inputs\n", " \"\"\"\n", " return np.append(inputs, np.ones((1, inputs.shape[1])), axis=0)\n", "\n", "\n", "# Creates a random set of batches, returns an array of indices, one for each batch\n", "def create_batches(rng, batch_size, num_samples):\n", " \"\"\"\n", " For a given number of samples, returns an array of indices of random batches of the specified size.\n", "\n", " If the size of the data is not divisible by the batch size some samples will not be included.\n", " \"\"\"\n", "\n", " # determine the total number of batches\n", " num_batches = int(np.floor(num_samples / batch_size))\n", "\n", " # get the batches (without replacement)\n", " return rng.choice(np.arange(num_samples), size=(num_batches, batch_size), replace=False)\n", "\n", "\n", "# Calculate the accuracy of the network on some data\n", "def calculate_accuracy(outputs, targets):\n", " \"\"\"\n", " Calculate the accuracy in categorization of some outputs given some targets.\n", " \"\"\"\n", "\n", " # binarize the outputs for an easy calculation\n", " categories = (outputs == np.tile(outputs.max(axis=0), (10, 1))).astype('float')\n", "\n", " # get the accuracy\n", " accuracy = np.sum(categories * targets) / targets.shape[1]\n", "\n", " return accuracy * 100.0\n", "\n", "\n", "def calculate_cosine_similarity(grad_1, grad_2):\n", " \"\"\"\n", " Calculate the cosine similarity between two gradients\n", " \"\"\"\n", " grad_1 = grad_1.flatten()\n", " grad_2 = grad_2.flatten()\n", " return np.dot(grad_1, grad_2) / np.sqrt(np.dot(grad_1, grad_1)) / np.sqrt(np.dot(grad_2, grad_2))\n", "\n", "\n", "def calculate_grad_snr(grad, epsilon=1e-3):\n", " \"\"\"\n", " Calculate the average SNR |mean|/std across all parameters in a gradient update\n", " \"\"\"\n", " return np.mean(np.abs(np.mean(grad, axis=0)) / (np.std(grad, axis=0) + epsilon))\n", "\n", "# The main network class\n", "# This will function as the parent class for our networks, which will implement different learning algorithms\n", "class MLP(object):\n", " \"\"\"\n", " The class for creating and training a two-layer perceptron.\n", " \"\"\"\n", "\n", " # The initialization function\n", " def __init__(self, rng, N=100, sigma=1.0, activation='sigmoid'):\n", " \"\"\"\n", " The initialization function for the MLP.\n", "\n", " - N is the number of hidden units\n", " - sigma is the SD for initializing the weights\n", " - activation is the function to use for unit activity, options are 'sigmoid' and 'ReLU'\n", " \"\"\"\n", "\n", " # store the variables for easy access\n", " self.N = N\n", " self.sigma = sigma\n", " self.activation = activation\n", "\n", " # initialize the weights\n", " self.W_h = rng.normal(scale=self.sigma, size=(self.N, 784 + 1)) # input-to-hidden weights & bias\n", " self.W_y = rng.normal(scale=self.sigma, size=(10, self.N + 1)) # hidden-to-output weights & bias\n", " self.B = rng.normal(scale=self.sigma, size=(self.N, 10)) # feedback weights\n", "\n", " # The non-linear activation function\n", " def activate(self, inputs):\n", " \"\"\"\n", " Pass some inputs through the activation function.\n", " \"\"\"\n", " if self.activation == 'sigmoid':\n", " Y = sigmoid(inputs)\n", " elif self.activation == 'ReLU':\n", " Y = ReLU(inputs)\n", " else:\n", " raise Exception(\"Unknown activation function\")\n", " return Y\n", "\n", " # The function for performing a forward pass up through the network during inference\n", " def inference(self, rng, inputs, W_h=None, W_y=None, noise=0.):\n", " \"\"\"\n", " Recognize inputs, i.e. do a forward pass up through the network. If desired, alternative weights\n", " can be provided\n", " \"\"\"\n", "\n", " # load the current network weights if no weights given\n", " if W_h is None:\n", " W_h = self.W_h\n", " if W_y is None:\n", " W_y = self.W_y\n", "\n", " # calculate the hidden activities\n", " hidden = self.activate(np.dot(W_h, add_bias(inputs)))\n", " if not (noise == 0.):\n", " hidden += rng.normal(scale=noise, size=hidden.shape)\n", "\n", " # calculate the output activities\n", " output = self.activate(np.dot(W_y, add_bias(hidden)))\n", "\n", " if not (noise == 0.):\n", " output += rng.normal(scale=noise, size=output.shape)\n", "\n", " return hidden, output\n", "\n", " # A function for calculating the derivative of the activation function\n", " def act_deriv(self, activity):\n", " \"\"\"\n", " Calculate the derivative of some activations with respect to the inputs\n", " \"\"\"\n", " if self.activation == 'sigmoid':\n", " derivative = activity * (1 - activity)\n", " elif self.activation == 'ReLU':\n", " derivative = 1.0 * (activity > 1)\n", " else:\n", " raise Exception(\"Unknown activation function\")\n", " return derivative\n", "\n", " def mse_loss_batch(self, rng, inputs, targets, W_h=None, W_y=None, output=None):\n", " \"\"\"\n", " Calculate the mean-squared error loss on the given targets (average over the batch)\n", " \"\"\"\n", "\n", " # do a forward sweep through the network\n", " if (output is None):\n", " (hidden, output) = self.inference(rng, inputs, W_h, W_y)\n", " return np.sum((targets - output) ** 2, axis=0)\n", "\n", " # The function for calculating the mean-squared error loss\n", " def mse_loss(self, rng, inputs, targets, W_h=None, W_y=None, output=None):\n", " \"\"\"\n", " Calculate the mean-squared error loss on the given targets (average over the batch)\n", " \"\"\"\n", " return np.mean(self.mse_loss_batch(rng, inputs, targets, W_h=W_h, W_y=W_y, output=output))\n", "\n", " # function for calculating perturbation updates\n", " def perturb(self, rng, inputs, targets, noise=1.0):\n", " \"\"\"\n", " Calculates the weight updates for perturbation learning, using noise with SD as given\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " def node_perturb(self, rng, inputs, targets, noise=1.0):\n", " \"\"\"\n", " Calculates the weight updates for node perturbation learning, using noise with SD as given\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " # function for calculating gradient updates\n", " def gradient(self, rng, inputs, targets):\n", " \"\"\"\n", " Calculates the weight updates for gradient descent learning\n", " \"\"\"\n", "\n", " # do a forward pass\n", " hidden, output = self.inference(rng, inputs)\n", "\n", " # calculate the gradients\n", " error = targets - output\n", " delta_W_h = np.dot(\n", " np.dot(self.W_y[:, :-1].transpose(), error * self.act_deriv(output)) * self.act_deriv(hidden), \\\n", " add_bias(inputs).transpose())\n", " delta_W_y = np.dot(error * self.act_deriv(output), add_bias(hidden).transpose())\n", "\n", " return delta_W_h, delta_W_y\n", "\n", " # function for calculating feedback alignment updates\n", " def feedback(self, rng, inputs, targets):\n", " \"\"\"\n", " Calculates the weight updates for feedback alignment learning\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " # function for calculating Kolen-Pollack updates\n", " def kolepoll(self, rng, inputs, targets, eta_back=0.01):\n", " \"\"\"\n", " Calculates the weight updates for Kolen-Polack learning\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " def return_grad(self, rng, inputs, targets, algorithm='backprop', eta=0., noise=1.0):\n", " # calculate the updates for the weights with the appropriate algorithm\n", " if algorithm == 'perturb':\n", " delta_W_h, delta_W_y = self.perturb(rng, inputs, targets, noise=noise)\n", " elif algorithm == 'node_perturb':\n", " delta_W_h, delta_W_y = self.node_perturb(rng, inputs, targets, noise=noise)\n", " elif algorithm == 'feedback':\n", " delta_W_h, delta_W_y = self.feedback(rng, inputs, targets)\n", " elif algorithm == 'kolepoll':\n", " delta_W_h, delta_W_y = self.kolepoll(rng, inputs, targets, eta_back=eta)\n", " else:\n", " delta_W_h, delta_W_y = self.gradient(rng, inputs, targets)\n", "\n", " return delta_W_h, delta_W_y\n", "\n", " # function for updating the network\n", " def update(self, rng, inputs, targets, algorithm='backprop', eta=0.01, noise=1.0):\n", " \"\"\"\n", " Updates the synaptic weights (and unit biases) using the given algorithm, options are:\n", "\n", " - 'backprop': backpropagation-of-error (default)\n", " - 'perturb' : weight perturbation (use noise with SD as given)\n", " - 'feedback': feedback alignment\n", " - 'kolepoll': Kolen-Pollack\n", " \"\"\"\n", "\n", " delta_W_h, delta_W_y = self.return_grad(rng, inputs, targets, algorithm=algorithm, eta=eta, noise=noise)\n", "\n", " # do the updates\n", " self.W_h += eta * delta_W_h\n", " self.W_y += eta * delta_W_y\n", "\n", " # train the network using the update functions\n", " def train(self, rng, images, labels, num_epochs, test_images, test_labels, learning_rate=0.01, batch_size=20, \\\n", " algorithm='backprop', noise=1.0, report=False, report_rate=10):\n", " \"\"\"\n", " Trains the network with algorithm in batches for the given number of epochs on the data provided.\n", "\n", " Uses batches with size as indicated by batch_size and given learning rate.\n", "\n", " For perturbation methods, uses SD of noise as given.\n", "\n", " Categorization accuracy on a test set is also calculated.\n", "\n", " Prints a message every report_rate epochs if requested.\n", "\n", " Returns an array of the losses achieved at each epoch (and accuracies if test data given).\n", " \"\"\"\n", "\n", " # provide an output message\n", " if report:\n", " print(\"Training starting...\")\n", "\n", " # make batches from the data\n", " batches = create_batches(rng, batch_size, images.shape[1])\n", "\n", " # create arrays to store loss and accuracy values\n", " losses = np.zeros((num_epochs * batches.shape[0],))\n", " accuracy = np.zeros((num_epochs,))\n", " cosine_similarity = np.zeros((num_epochs,))\n", "\n", " # estimate the gradient SNR on the test set\n", " grad = np.zeros((test_images.shape[1], *self.W_h.shape))\n", " for t in range(test_images.shape[1]):\n", " inputs = test_images[:, [t]]\n", " targets = test_labels[:, [t]]\n", " grad[t, ...], _ = self.return_grad(rng, inputs, targets, algorithm=algorithm, eta=0., noise=noise)\n", " snr = calculate_grad_snr(grad)\n", " # run the training for the given number of epochs\n", " update_counter = 0\n", " for epoch in range(num_epochs):\n", "\n", " # step through each batch\n", " for b in range(batches.shape[0]):\n", " # get the inputs and targets for this batch\n", " inputs = images[:, batches[b, :]]\n", " targets = labels[:, batches[b, :]]\n", "\n", " # calculate the current loss\n", " losses[update_counter] = self.mse_loss(rng, inputs, targets)\n", "\n", " # update the weights\n", " self.update(rng, inputs, targets, eta=learning_rate, algorithm=algorithm, noise=noise)\n", " update_counter += 1\n", "\n", " # calculate the current test accuracy\n", " (testhid, testout) = self.inference(rng, test_images)\n", " accuracy[epoch] = calculate_accuracy(testout, test_labels)\n", " grad_test, _ = self.return_grad(rng, test_images, test_labels, algorithm=algorithm, eta=0., noise=noise)\n", " grad_bp, _ = self.return_grad(rng, test_images, test_labels, algorithm='backprop', eta=0., noise=noise)\n", " cosine_similarity[epoch] = calculate_cosine_similarity(grad_test, grad_bp)\n", "\n", " # print an output message every 10 epochs\n", " if report and np.mod(epoch + 1, report_rate) == 0:\n", " print(\"...completed \", epoch + 1,\n", " \" epochs of training. Current loss: \", round(losses[update_counter - 1], 2), \".\")\n", "\n", " # provide an output message\n", " if report:\n", " print(\"Training complete.\")\n", "\n", " return (losses, accuracy, cosine_similarity, snr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data retrieval\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Data retrieval\n", "\n", "import contextlib\n", "import io\n", "import pandas as pd\n", "import requests\n", "import os\n", "import hashlib\n", "\n", "def download_file(fname, url, expected_md5):\n", " \"\"\"\n", " Downloads a file from the given URL and saves it locally.\n", " \"\"\"\n", " if not os.path.isfile(fname):\n", " try:\n", " r = requests.get(url)\n", " except requests.ConnectionError:\n", " print(\"!!! Failed to download data !!!\")\n", " return\n", " if r.status_code != requests.codes.ok:\n", " print(\"!!! Failed to download data !!!\")\n", " return\n", " if hashlib.md5(r.content).hexdigest() != expected_md5:\n", " print(\"!!! Data download appears corrupted !!!\")\n", " return\n", " with open(fname, \"wb\") as fid:\n", " fid.write(r.content)\n", "\n", "data_files = [\n", " {\n", " \"fname\": \"accuracy.csv\",\n", " \"url\": \"https://osf.io/aqhd3/download\",\n", " \"expected_md5\": \"bfcad2350de4c4a6eeeb1f3342371390\"\n", " },\n", " {\n", " \"fname\": \"cosine_similarity.csv\",\n", " \"url\": \"https://osf.io/w4pv7/download\",\n", " \"expected_md5\": \"97ac863216a44909bb930715855a1d9e\"\n", " },\n", " {\n", " \"fname\": \"losses.csv\",\n", " \"url\": \"https://osf.io/drfg6/download\",\n", " \"expected_md5\": \"e6a50509c676b6a934d653afae3a60c6\"\n", " },\n", " {\n", " \"fname\": \"snr.csv\",\n", " \"url\": \"https://osf.io/z5mjy/download\",\n", " \"expected_md5\": \"13b4f0e43cc8dce12a4d191ae2d31c0e\"\n", " }\n", "]\n", "\n", "for data_file in data_files:\n", " download_file(data_file[\"fname\"], data_file[\"url\"], data_file[\"expected_md5\"])\n", "\n", "accuracy_data = pd.read_csv(\"accuracy.csv\")\n", "cosine_similarity_data = pd.read_csv(\"cosine_similarity.csv\")\n", "losses_data = pd.read_csv(\"losses.csv\")\n", "snr_data = pd.read_csv(\"snr.csv\")\n", "\n", "losses_weight_perturbation_solution = losses_data[\"weight_perturbation\"]\n", "losses_node_perturbation_solution = losses_data[\"node_perturbation\"]\n", "losses_feedback_alignment_solution = losses_data[\"feedback_alignment\"]\n", "losses_kolen_pollack_solution = losses_data[\"kolen_pollack\"]\n", "losses_backpropagation_solution = losses_data[\"backpropagation\"]\n", "\n", "cosine_similarity_feedback_alignment_solution = cosine_similarity_data[\"feedback_alignment\"]\n", "cosine_similarity_kolen_pollack_solution = cosine_similarity_data[\"kolen_pollack\"]\n", "cosine_similarity_backpropagation_solution = cosine_similarity_data[\"backpropagation\"]\n", "accuracy_weight_perturbation_solution = accuracy_data[\"weight_perturbation\"]\n", "\n", "accuracy_node_perturbation_solution = accuracy_data[\"node_perturbation\"]\n", "accuracy_feedback_alignment_solution = accuracy_data[\"feedback_alignment\"]\n", "accuracy_kolen_pollack_solution = accuracy_data[\"kolen_pollack\"]\n", "accuracy_backpropagation_solution = accuracy_data[\"backpropagation\"]\n", "\n", "snr_weight_perturbation_solution = snr_data[\"weight_perturbation\"][0]\n", "snr_node_perturbation_solution = snr_data[\"node_perturbation\"][0]\n", "snr_backpropagation_solution = snr_data[\"backpropagation\"][0]\n", "\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " # Load the MNIST dataset, 50K training images, 10K validation, 10K testing\n", " train_set = datasets.MNIST('./', transform=transforms.ToTensor(), train=True, download=True)\n", " test_set = datasets.MNIST('./', transform=transforms.ToTensor(), train=False, download=True)\n", "\n", " rng_data = np.random.default_rng(seed=42)\n", " train_num = 50000\n", " shuffled_train_idx = rng_data.permutation(train_num)\n", "\n", " full_train_images = train_set.data.numpy().astype(float) / 255\n", " train_images = full_train_images[shuffled_train_idx[:train_num]].reshape((-1, 784)).T.copy()\n", " valid_images = full_train_images[shuffled_train_idx[train_num:]].reshape((-1, 784)).T.copy()\n", " test_images = (test_set.data.numpy().astype(float) / 255).reshape((-1, 784)).T\n", "\n", " full_train_labels = torch.nn.functional.one_hot(train_set.targets, num_classes=10).numpy()\n", " train_labels = full_train_labels[shuffled_train_idx[:train_num]].T.copy()\n", " valid_labels = full_train_labels[shuffled_train_idx[train_num:]].T.copy()\n", " test_labels = torch.nn.functional.one_hot(test_set.targets, num_classes=10).numpy().T\n", "\n", " full_train_images = None\n", " full_train_labels = None\n", " train_set = None\n", " test_set = None\n", "\n", "#Plot some example images to make sure everything is loaded in properly\n", "with plt.xkcd():\n", " fig, axs = plt.subplots(1,10)\n", " for c in range(10):\n", " axs[c].imshow(train_images[:,c].reshape((28,28)), cmap='gray')\n", " axs[c].axis(\"off\")\n", " fig.suptitle(\"Data download check!\", fontsize=16)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 1: Weight Perturbation\n", "\n", "In this section, we will start exploring the learning algorithms which exhibit increased variance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Weight Perturbation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Weight Perturbation\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'WOSTwEQXdlc'), ('Bilibili', 'BV1Gz42187Dr')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_video_1\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\n", "$\n", "\\newcommand{\\stim}{\\mathbf{x}}\n", "\\newcommand{\\noisew}{\\boldsymbol \\Psi}\n", "\\newcommand{\\noiser}{\\boldsymbol \\xi}\n", "\\newcommand{\\target}{y}\n", "\\newcommand{\\targetdim}{\\mathbf{y}}\n", "\\newcommand{\\identity}{\\mathbf{I}}\n", "\\newcommand{\\blackbox}{f}\n", "\\newcommand{\\weight}{\\mathbf{W}}\n", "\\newcommand{\\loss}{\\mathcal{L}}\n", "\\newcommand{\\derivative}[2]{\\frac{d#1}{d#2}}\n", "\\newcommand{\\pderivative}[2]{\\frac{\\partial#1}{\\partial#2}}\n", "\\newcommand{\\rate}{\\mathbf{r}}\n", "\\newcommand{\\T}{^{\\top}}\n", "\\newcommand{\\RR}{\\mathbb{R}}\n", "\\newcommand{\\EE}{\\mathbb{E}\\,}\n", "\\newcommand{\\brackets}[1]{\\left(#1\\right)}\n", "\\newcommand{\\sqbrackets}[1]{\\left[#1\\right]}\n", "\\newcommand{\\var}[1]{\\mathbb{V}\\mathrm{ar}\\brackets{#1}}$\n", "\n", "In this first section, we will be deriving and implementing the __Weight Perturbation__ algorithm. In the next section, we will be deriving and implementing the __Node Perturbation__ algorithm. Both of these methods of gradient estimation are very closely related to *finite differences* derivative approximation.\n", "\n", " Suppose that we have some loss function, $\\loss(\\Delta \\weight)$, which we would like to minimize by making some change in our synaptic weights, $\\Delta \\weight$. The most natural way to decrease the loss would be to perform gradient descent; however, it is not reasonable to assume that a synapse in the brain could perform analytic gradient calculations for general loss functions $\\loss(\\Delta \\weight)$, which may depend on the activity of many downstream neurons and the external environment. Biological systems could solve this problem by *approximating* the gradient of the loss, which could be accomplished in many ways. \n", " \n", " To start, we will provide the __weight perturbation__ update, and will subsequently demonstrate why it provides an estimate of the gradient. We will first add noise to our weights, using $\\weight' = \\weight + \\noisew$, where $\\noisew \\sim \\mathcal N(0, \\sigma^2)$. We take as our update:\n", "\n", "\\begin{equation}\n", " \\Delta \\weight = - \\eta \\mathbb{E}_{\\noisew} \\left [\\left (\\loss(\\noisew) - \\loss(0)\\right ) \\frac{(\\weight' - \\weight)}{\\sigma^2} \\right ].\n", "\\end{equation}\n", "First, we will clarify why this parameter update is interesting from a neuroscientific perspective. If we look at the parameter update for a *single synapse*, $\\weight_{ij}$, we have:\n", "\\begin{align}\n", " \\Delta \\weight_{ij} &= - \\eta \\mathbb{E}_{\\noisew} \\left [\\left (\\loss(\\noisew) - \\loss(0)\\right ) \\frac{(\\weight'_{ij} - \\weight_{ij})}{\\sigma^2} \\right ] \\\\\n", " & \\approx - \\eta \\frac{1}{K}\\sum_{k = 0}^K\\left [\\left (\\loss(\\noisew^{(k)}) - \\loss(0)\\right ) \\frac{(\\weight'^{(k)}_{ij} - \\weight_{ij})}{\\sigma^2} \\right ],\n", "\\end{align}\n", "\n", "where for the last approximate equality we are substituting an expectation over $\\noisew$ for an empirical approximation over $K$ samples of $\\noisew$. This update only requires information about the global loss, $\\loss(\\noisew^{(k)})$ and the local parameter values, $\\weight'^{(k)}_{ij}$: using this update, a synapse in a neural network can adapt its strength with *very little* information about what is going on in the rest of the neural circuit.\n", "\n", "Lastly, we will show why this update is an approximation of the loss gradient: this section is only to satisfy your curiosity, and is not necessary for completing the coding exercises. We first notice that by first-order Taylor expansion $\\loss(\\noisew) \\approx \\loss(0) + \\derivative{\\loss}{\\weight}\\T \\noisew$. Plugging this approximation into our update equation, we get:\n", "\\begin{align}\n", " \\Delta \\weight_{ij} &= - \\eta \\mathbb{E}_{\\noisew} \\left [\\left (\\derivative{\\loss}{\\weight}\\T \\noisew\\right ) \\frac{\\noisew_{ij}}{\\sigma^2} \\right ] \\\\\n", " &= - \\eta \\derivative{\\loss}{\\weight_{ij}},\n", "\\end{align}\n", "where this last equality follows from the fact that $\\mathbb{E}_{\\noisew} \\left[\\noisew_{ij} \\noisew_{kl} \\right] = \\sigma^2$ if and only if $i = k$ and $j = l$, and is 0 otherwise. Therefore, in expectation over many noise samples $\\noisew$, our parameter update based purely on measuring how perturbations of the weights $\\weight'$ correlate with changes in the loss function $\\loss(\\noisew)$, ends up being an unbiased approximation of gradient descent.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_weight_perturbation\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Exercise 1: Perturb the weights\n", "\n", "In this section, fill out the function 'perturb' for the WeightPerturbMLP class. This function is used to update the parameters of our MLP network using the weight perturbation algorithm, using the parameter update equations from the preceding section." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "class WeightPerturbMLP(MLP):\n", " \"\"\"\n", " A multilayer perceptron that is capable of learning through weight perturbation\n", " \"\"\"\n", "\n", " def perturb(self, rng, inputs, targets, noise=1.0):\n", " \"\"\"\n", " Calculates the weight updates for perturbation learning, using noise with SD as given\n", " \"\"\"\n", " ###################################################################\n", " ## Fill out the following then remove\n", " raise NotImplementedError(\"Student exercise: determine the sign of the updates\")\n", " ###################################################################\n", "\n", " # get the random perturbations\n", " delta_W_h = rng.normal(scale=noise, size=self.W_h.shape)\n", " delta_W_y = rng.normal(scale=noise, size=self.W_y.shape)\n", "\n", " # calculate the loss with and without the perturbations\n", " loss_now = self.mse_loss(rng, inputs, targets)\n", " loss_per = self.mse_loss(rng, inputs, targets, self.W_h + delta_W_h, self.W_y + delta_W_y)\n", "\n", " # updates\n", " delta_loss = ...\n", " W_h_update = delta_loss * delta_W_h / noise ** 2\n", " W_y_update = delta_loss * delta_W_y / noise ** 2\n", " return W_h_update, W_y_update" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W2D3_Microlearning/solutions/W2D3_Tutorial1_Solution_01f74aae.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_perturb_the_weights\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Training an MLP on MNIST with weight perturbation\n", "\n", "Having implemented the appropriate update function, we will now verify that it works by training a network. We will train a simple MLP on the MNIST dataset, using the weight perturbation algorithm to estimate the gradient. This MLP consists of:\n", "\n", "* An input layer with 784 units\n", "* A hidden layer with 500 units\n", "* An output layer with 10 units, one for each of the digits 0-9\n", "\n", "We will use the **mean-squared error loss** on the one-hot encoded labels–this is just to make later calculations more tractable, but typically one would use the cross-entropy loss. We'll train the network for 3 epochs. All in all, the weight perturbation update acts as a drop-in replacement for conventional gradient descent." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hyperparameters definition\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Hyperparameters definition\n", "\n", "numhidden = 500\n", "batchsize = 200\n", "initweight = 0.1\n", "learnrate = 0.001\n", "noise = 0.1\n", "numepochs = 3\n", "numrepeats = 1\n", "numbatches = int(train_images.shape[1] / batchsize)\n", "numupdates = numepochs * numbatches\n", "activation = 'sigmoid'\n", "report = True\n", "rep_rate = 1\n", "seed = 12345" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train WeightPerturbMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Train WeightPerturbMLP\n", "\n", "rng_wp = np.random.default_rng(seed=seed)\n", "losses_perturb = np.zeros((numupdates,))\n", "accuracy_perturb = np.zeros((numepochs,))\n", "\n", "# select 1000 random images to test the accuracy on\n", "indices = rng_wp.choice(range(test_images.shape[1]), size=(1000,), replace=False)\n", "\n", "# create a network and train it using weight perturbation\n", "netperturb = WeightPerturbMLP(rng_wp, numhidden, sigma=initweight, activation=activation)\n", "(losses_perturb[:], accuracy_perturb[:], _, snr_perturb) = \\\n", " netperturb.train(rng_wp, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \\\n", " learning_rate=learnrate, batch_size=batchsize, algorithm='perturb', noise=noise, \\\n", " report=report, report_rate=rep_rate)\n", "\n", "# save metrics for plots\n", "losses_weight_perturbation_solution = losses_perturb\n", "accuracy_weight_perturbation_solution = accuracy_perturb\n", "snr_weight_perturbation_solution = snr_perturb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Observe the performance of WeightPerturbMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Observe the performance of WeightPerturbMLP\n", "\n", "# plot performance over time\n", "with plt.xkcd():\n", " plt.plot(losses_weight_perturbation_solution, label=\"Weight Perturbation\", color='b')\n", " plt.xlabel(\"Updates\")\n", " plt.ylabel(\"MSE\")\n", " plt.legend()\n", " plt.title(\"Training loss\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_training_with_weight_perturbation\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Node Perturbation\n", "\n", "Estimated timing to here from start of tutorial: 30 minutes\n", "\n", "While we can get an unbiased derivative approximation based solely on perturbations of the weights, we will show later on that this is actually a very inefficient method, because it requires averaging out $MN$ noise sources, where $M$ is the dimension of the input $\\stim$ and $N$ is the dimension of the hidden activity $\\rate$. \n", "\n", "![Network.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W2D3_Microlearning/static/network.png?raw=true)\n", "\n", "For simplicity, consider what happens in a network with a linear hidden layer. If we add noise at the level of the hidden units $\\rate = \\weight \\stim$, we will only have to average over $N$ noise sources. To do this, we can use the following update, taking $\\rate' = \\rate + \\noiser$, where $\\noiser \\sim \\mathcal{N}(0,\\sigma^2)$:\n", "\n", "\\begin{equation}\n", " \\Delta \\weight = - \\eta \\mathbb{E}_{\\noiser} \\left [\\left(\\loss(\\noiser) - \\loss(0) \\right ) \\frac{(\\rate' - \\rate)}{\\sigma^2} \\stim\\T \\right ].\n", "\\end{equation}\n", "We will now show why this update is interesting from a neuroscience perspective (for much the same reason as for weight perturbation). For a single synapse, the approximate update using samples of $\\noiser$ is given by:\n", "\\begin{equation}\n", " \\Delta \\weight_{ij} \\approx - \\eta \\frac{1}{K} \\sum_{k=0}^{K} \\left [\\left(\\loss(\\noiser^{(k)}) - \\loss(0) \\right ) \\frac{(\\rate'^{(k)}_i - \\rate_i)}{\\sigma^2} \\stim_j \\right ].\n", "\\end{equation}\n", "Once again this update requires very little knowledge about the rest of the neural circuit in order for a synapse to compute it. It requires knowledge of the global loss, $\\loss(\\noiser^{(k)})$, postsynaptic activity $\\rate^{(k)}_i$, and presynaptic activity $\\stim_j$. This form of parameter update is often called a Reward (loss)-modulated Hebbian plasticity rule, or a 3-factor plasticity rule.\n", "\n", "Lastly, we will show why this update is an unbiased gradient estimate: again, this section is only to satisfy your curiosity, and is not needed for the coding exercises. We again employ a first-order Taylor expansion: $\\loss(\\noiser) \\approx \\loss(0) + \\derivative{\\loss}{\\rate}\\T\\noiser$, to get:\n", "\\begin{align}\n", " \\Delta \\weight_{ij} &= - \\eta \\mathbb{E}_{\\noiser} \\left [\\left(\\derivative{\\loss}{\\rate}\\T\\noiser \\right ) \\frac{\\noiser_i}{\\sigma^2} \\stim_j \\right ] \\\\\n", " &= - \\eta \\pderivative{\\loss}{\\rate_i} \\stim_j \\\\\n", " &= - \\eta \\pderivative{\\loss}{\\rate_i} \\pderivative{\\rate_i}{\\weight_{ij}}\\\\\n", " &= - \\eta \\pderivative{\\loss}{\\weight_{ij}},\n", "\\end{align}\n", "Where the second equality follows from the fact that $\\mathbb{E}_{\\noiser} \\left [ \\noiser_i \\noiser_k \\right ] = \\sigma^2$ if and only if $i = k$, and is 0 otherwise. This analysis shows that we can estimate derivatives by correlating fluctuations in either $\\weight$ *or* $\\rate$ with fluctuations in the loss function. Neither strategy requires evaluating derivatives of $\\blackbox(\\cdot)$, they only require some extrinsic measure of performance, given by $\\mathcal{L}$ and how performance varies in response to perturbations in either weights or nodes, respectively. In subsequent sections, we will investigate how these different methods compare in terms of their ability to estimate gradients in systems with large numbers of neurons. We will show that there is no free lunch--though these methods require less information, they are less *efficient* than analytic gradient calculations." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Below, we provide an implementation of the node perturbation algorithm, so that you will be able to compare it to the weight perturbation algorithm in subsequent sections. Running this code will take around 9 minutes--you can move on to subsequent sections while you wait!\n", "\n", "One important detail: there are two different notions of efficiency we could consider here: 1) sample efficiency and 2) runtime efficiency. Node perturbation is more sample efficient: in general it brings the loss lower with fewer samples than weight perturbation. However, our particular implementation of node perturbation runs a little slower than weight perturbation, so you could argue that it has worse runtime efficiency. This is just due to the fact that these algorithms were implemented by different people, and the author for node perturbation exploited python parallel computation a little less effectively." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "class NodePerturbMLP(MLP):\n", " \"\"\"\n", " A multilayer perceptron that is capable of learning through node perturbation\n", " \"\"\"\n", "\n", " def node_perturb(self, rng, inputs, targets, noise=1.0):\n", " \"\"\"\n", " Calculates the weight updates for node perturbation learning, using noise with SD as given\n", " \"\"\"\n", "\n", " # get the random perturbations\n", " hidden, output = self.inference(rng, inputs)\n", " hidden_p, output_p = self.inference(rng, inputs, noise=noise)\n", "\n", " loss_now = self.mse_loss_batch(rng, inputs, targets, output=output)\n", " loss_per = self.mse_loss_batch(rng, inputs, targets, output=output_p)\n", " delta_loss = loss_now - loss_per\n", "\n", " hidden_update = np.mean(\n", " delta_loss * (((hidden_p - hidden) / noise ** 2)[:, None, :] * add_bias(inputs)[None, :, :]), axis=2)\n", " output_update = np.mean(\n", " delta_loss * (((output_p - output) / noise ** 2)[:, None, :] * add_bias(hidden_p)[None, :, :]), axis=2)\n", "\n", " return (hidden_update, output_update)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train NodePerturbMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Train NodePerturbMLP\n", "\n", "losses_node_perturb = np.zeros((numupdates,))\n", "accuracy_node_perturb = np.zeros((numepochs,))\n", "\n", "# set the random seed\n", "rng_np = np.random.default_rng(seed=seed)\n", "\n", "# select 1000 random images to test the accuracy on\n", "indices = rng_np.choice(range(test_images.shape[1]), size=(1000,), replace=False)\n", "\n", "# create a network and train it using weight perturbation\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " netnodeperturb = NodePerturbMLP(rng_np, numhidden, sigma=initweight, activation=activation)\n", " (losses_node_perturb[:], accuracy_node_perturb[:], _, snr_node_perturb) = \\\n", " netnodeperturb.train(rng_np, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \\\n", " learning_rate=learnrate, batch_size=batchsize, algorithm='node_perturb', noise=noise, \\\n", " report=report, report_rate=rep_rate)\n", "\n", "# save metrics for plots\n", "losses_node_perturbation_solution = losses_node_perturb\n", "accuracy_node_perturbation_solution = accuracy_node_perturb\n", "snr_node_perturbation_solution = snr_node_perturb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Observe the performance of NodePerturbMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Observe the performance of NodePerturbMLP\n", "\n", "# plot performance over time\n", "with plt.xkcd():\n", " plt.plot(losses_node_perturbation_solution, label=\"Node Perturbation\", color='c') #pre-saved history of loss\n", " plt.plot(losses_weight_perturbation_solution, label=\"Weight Perturbation\", color='b') #pre-saved history of loss\n", " plt.xlabel(\"Updates\")\n", " plt.ylabel(\"MSE\")\n", " plt.legend()\n", " plt.title(\"Training loss\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_node_perturbation\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 3: Assessing the variance of learning algorithms\n", "\n", "Estimated timing to here from start of tutorial: 45 minutes\n", "\n", "In this section, we will evaluate the robustness of the introduced learning algorithms by assessing their variance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 2: Assessing Variance\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: Assessing Variance\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'C8L2n8VBQlc'), ('Bilibili', 'BV14y41187Tn')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_video_2\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "$\\newcommand{\\stim}{\\mathbf{x}}$\n", "$\\newcommand{\\noisew}{\\boldsymbol \\Psi}$\n", "$\\newcommand{\\noiser}{\\boldsymbol \\xi}$\n", "$\\newcommand{\\target}{y}$\n", "$\\newcommand{\\targetdim}{\\mathbf{y}}$\n", "$\\newcommand{\\identity}{\\mathbf{I}}$\n", "$\\newcommand{\\blackbox}{f}$\n", "$\\newcommand{\\weight}{\\mathbf{W}}$\n", "$\\newcommand{\\loss}{\\mathcal{L}}$\n", "$\\newcommand{\\derivative}[2]{\\frac{d#1}{d#2}}$\n", "$\\newcommand{\\rate}{\\mathbf{r}}$\n", "$\\newcommand{\\T}{^{\\top}}$\n", "$\\newcommand{\\RR}{\\mathbb{R}}$\n", "$\\newcommand{\\EE}{\\mathbb{E}\\,}$\n", "$\\newcommand{\\brackets}[1]{\\left(#1\\right)}$\n", "$\\newcommand{\\sqbrackets}[1]{\\left[#1\\right]}$\n", "$\\newcommand{\\var}[1]{\\mathbb{V}\\mathrm{ar}\\brackets{#1}}$\n", "\n", "The main issue of perturbation methods is noise, meaning that across many samples of input stimuli and network perturbations, the gradient estimates will be much more variable than would be the case for backpropagation. This means that many, many more perturbations/training samples will be required to obtain an accurate gradient estimate: the consequence will be either very slow or much less effective learning. \n", "\n", "Here, we will demonstrate the noisiness of these learning algorithms analytically for a simplified loss and network. This derivation is principally to satisfy your curiosity: no subsequent exercises will depend on your understanding of the mathematics here, and we will subsequently provide empirical evidence based on network simulations as well. First, we will work with a linear network so $\\widehat\\targetdim =\\weight\\stim$, where $\\widehat\\targetdim\\in\\RR^M$, $\\weight\\in\\RR^{M\\times N}$ and $\\stim\\in\\RR^N$. Second, we will assume that the target output is zero $\\targetdim=0$, so the loss becomes $\\loss(\\weight)=\\frac{1}{2}\\|\\weight\\stim\\|^2_2$. (This is equivalent to saying that $\\targetdim=\\weight^*\\stim$ and then shifting the actual weights to be $\\weight - \\weight^*$; notice that here we treat the loss as a function of $\\weight$, rather than $\\Delta \\weight$.)\n", "\n", "\n", "With these changes, we will compute the variance of weight updates for a given input $\\stim$, i.e.\n", "\\begin{equation*}\n", " \\var{\\Delta \\weight}=\\EE\\brackets{\\Delta \\weight - \\EE\\Delta\\weight}^2 = \\EE\\brackets{\\Delta \\weight}^2 - \\brackets{\\EE\\Delta\\weight}^2\\,.\n", "\\end{equation*}\n", "We already know that the $\\EE\\Delta\\weight$ is the gradient update, so\n", "\\begin{equation}\n", " \\brackets{\\EE\\Delta\\weight_{ij}}^2 = \\eta^2 \\brackets{\\derivative{\\loss}{\\weight}}_{ij}^2.\n", "\\end{equation}\n", "\n", "Therefore we only need to compute $\\EE(\\Delta\\weight)^2$ for both algorithms.\n", "\n", "**Weight perturbation** For a single weight $\\weight_{ij}$, we can use the approximate weight change:\n", "\\begin{align}\n", " \\Delta \\weight_{ij} \\,&= - \\eta \\sum_{kl} \\brackets{\\brackets{\\derivative{\\loss}{\\weight}}_{kl} \\noisew_{kl}} \\frac{\\noisew_{ij}}{\\sigma^2}\\,,\\\\\n", " \\brackets{\\Delta \\weight_{ij}}^2 \\,&= \\frac{\\eta^2}{\\sigma^4} \\brackets{\\sum_{kl}\\brackets{\\derivative{\\loss}{\\weight}}_{kl} \\noisew_{kl}}^2 \\noisew_{ij}^2\\\\\n", " &=\\frac{\\eta^2}{\\sigma^4} \\brackets{\\sum_{kldn}\\brackets{\\derivative{\\loss}{\\weight}}_{kl}\\brackets{\\derivative{\\loss}{\\weight}}_{dn} \\noisew_{kl}\\noisew_{dn}} \\noisew_{ij}^2\\,.\n", "\\end{align}\n", "\n", "Now we can take the expectation of the last line w.r.t. the noise $\\noisew$. Since all entries of the noise matrix are independent and zero-mean Gaussian, we will have non-zero terms in two case: $kl=dn\\neq ij$ and $kl=dn=ij$:\n", "\\begin{align}\n", " \\EE\\noisew_{kl}\\noisew_{dn}\\noisew_{ij}^2 = \\begin{cases}\n", " 0 & k \\neq d\\ \\mathrm{or}\\ l\\neq n\\\\\n", " \\sigma^4 & k=d, l=n, (k\\neq i\\ \\mathrm{or}\\ l\\neq j)\\\\\n", " 3\\,\\sigma^4 & k=d=i,l=n=j\n", " \\end{cases}\n", "\\end{align}\n", "\n", "Therefore,\n", "\\begin{align}\n", " \\EE_{\\noisew}\\brackets{\\brackets{\\Delta \\weight_{ij}}^2} \\,& = \\frac{\\eta^2}{\\sigma^4} \\brackets{\\derivative{\\loss}{\\weight}}_{ij}^2 \\EE \\noisew_{ij}^4 + \\frac{\\eta^2}{\\sigma^4} \\sum_{kl\\neq ij} \\brackets{\\derivative{\\loss}{\\weight}}_{kl}^2 \\EE\\brackets{\\noisew_{kl}^2 \\noisew_{ij}^2}\\\\\n", " &=3\\eta^2 \\brackets{\\derivative{\\loss}{\\weight}}_{ij}^2 + \\eta^2\\sum_{kl\\neq ij} \\brackets{\\derivative{\\loss}{\\weight}}_{kl}^2\\,,\n", "\\end{align}\n", "\n", "where we used that the 4th central of the Gaussian $\\EE \\noisew_{ij}^4=3\\sigma^4$.\n", "\n", "Using the above result, we arrive at\n", "\\begin{align}\n", " \\var{\\Delta \\weight_{ij}} = \\eta^2 \\brackets{\\derivative{\\loss}{\\weight}}_{ij}^2 + \\eta^2\\sum_{kl} \\brackets{\\derivative{\\loss}{\\weight}}_{kl}^2 = O(MN)\\,,\n", "\\end{align}\n", "where the scaling comes from having $MN$ terms in the sum.\n", "\n", "**Node perturbation** Again, for a single weight $\\weight_{ij}$, we can use the approximate weight change:\n", "\\begin{align}\n", " \\Delta \\weight_{ij} \\,&= -\\frac{\\eta}{\\sigma^2}\\brackets{\\sum_{k}\\brackets{\\derivative{\\loss}{\\rate}}_k\\noiser_k} \\noiser_i\\stim_j\\,,\\\\\n", " \\brackets{\\Delta \\weight_{ij}}^2 \\,&= \\frac{\\eta^2}{\\sigma^4}\\brackets{\\sum_{k}\\brackets{\\derivative{\\loss}{\\rate}}_k\\noiser_k}^2 \\noiser_i^2\\stim_j^2\\\\\n", " &=\\frac{\\eta^2}{\\sigma^4}\\brackets{\\sum_{k,d}\\brackets{\\derivative{\\loss}{\\rate}}_k\\brackets{\\derivative{\\loss}{\\rate}}_d\\noiser_k\\noiser_d} \\noiser_i^2\\stim_j^2\\,.\n", "\\end{align}\n", "\n", "Again, computing the expectation over the last line will make use of the independent zero-mean Gaussian noise:\n", "\\begin{align}\n", " \\EE\\noiser_k\\noiser_d\\noiser_i^2 = \\begin{cases}\n", " 0 & k \\neq d\\\\\n", " \\sigma^4 & k=d\\neq i\\\\\n", " 3\\,\\sigma^4 & k=d=i\n", " \\end{cases}\n", "\\end{align}\n", "\n", "Since only $k=d\\neq i$ and $k=d=i$ terms will remain non-zero, we obtain\n", "\\begin{align}\n", " \\EE_{\\noiser}\\brackets{\\brackets{\\Delta \\weight_{ij}}^2} \\,&= \\frac{\\eta^2}{\\sigma^4}\\brackets{\\derivative{\\loss}{\\rate}}_i^2 \\EE\\brackets{\\noiser_i^4}\\stim_j^2 + \\frac{\\eta^2}{\\sigma^4}\\brackets{\\sum_{k\\neq i}\\brackets{\\derivative{\\loss}{\\rate}}_k^2\\EE\\brackets{\\noiser_k^2 \\noiser_i^2}\\stim_j^2}\\\\\n", " &=3 \\eta^2\\brackets{\\derivative{\\loss}{\\rate}}_i^2 \\stim_j^2 + \\eta^2\\sum_{k\\neq i}\\brackets{\\derivative{\\loss}{\\rate}}_k^2\\stim_j^2\\,.\n", "\\end{align}\n", "\n", "Now since $\\brackets{\\EE_{\\noiser}\\Delta \\weight_{ij}}^2=\\eta^2\\brackets{\\derivative{\\loss}{\\rate}}_i^2 \\stim_j^2$, we have\n", "\\begin{equation}\n", " \\var{\\Delta \\weight_{ij}} = \\eta^2\\brackets{\\derivative{\\loss}{\\rate}}_i^2 \\stim_j^2 + \\eta^2\\sum_{k}\\brackets{\\derivative{\\loss}{\\rate}}_k^2\\stim_j^2 = O(M)\\,,\n", "\\end{equation}\n", "where the scaling comes from the sum over $M$ outputs. \n", "\n", "To conclude, we found that the variance of the __weight perturbation__ method scales as $O(MN)$ (variance increases if there are more inputs and/or more outputs), while the __node perturbation__ variance scales as $O(M)$ (variance increases only if there are more outputs). As such, node perturbation will scale better as the number of inputs or the number of neurons in the network increases, while weight perturbation will do worse. As we will see below, neither of these methods will scale as well as backpropagation. Becoming less effective at scale is a major problem for a learning algorithm operating in the brain, where synaptic modifications may occur in billions of neurons, and potentially trillions of synapses." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Empirical demonstration\n", "\n", "Below, we provide an empirical comparison between the gradient estimates provided by three algorithms: __Weight Perturbation__, __Node Perturbation__, and __Backpropagation__.\n", "\n", "To compare the variances of these different algorithms, we will use the __Signal-to-Noise Ratio (SNR)__, averaged across network parameters. For a random variable $X$, the SNR is defined as:\n", "\n", "\\begin{equation}\n", "\\text{SNR}(X) = \\frac{|\\text{Mean}(X)|}{\\text{Std}(X)}.\n", "\\end{equation}\n", "\n", "Here the mean and standard deviation are taken over the test samples (across the batch dimension). The SNR is smaller when X is noisier. For the purposes of comparing gradient estimates, the SNR is a superior measure compared to the variance, because it is *scale invariant*: while the variance scales quadratically if the gradient estimate is multiplied by a scalar $\\eta$ (for instance a learning rate), the SNR remains unchanged! This means that we will be able to meaningfully compare gradient updates that are of different sizes, and will prevent algorithms with smaller gradient updates from appearing spuriously low-noise." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Compare the SNRs for Weight Perturbation, Node Perturbation, and Backpropagation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Compare the SNRs for Weight Perturbation, Node Perturbation, and Backpropagation\n", "\n", "# initialize the loss and accuracy holders\n", "losses_backprop = np.zeros((numupdates,))\n", "accuracy_backprop = np.zeros((numepochs,))\n", "\n", "# First, we have to train a network with Backpropagation for comparison\n", "\n", "# set the random seed to the current time\n", "rng_bp = np.random.default_rng(seed=seed)\n", "\n", "# select 1000 random images to test the accuracy on\n", "indices = rng_bp.choice(range(test_images.shape[1]), size=(1000,), replace=False)\n", "\n", "# create a network and train it using backprop\n", "netbackprop = MLP(rng_np, numhidden, sigma=initweight, activation=activation)\n", "(losses_backprop[:], accuracy_backprop[:], _, snr_backprop) = \\\n", " netbackprop.train(rng_bp, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \\\n", " learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \\\n", " report=report, report_rate=rep_rate)\n", "\n", "# save metrics for plots\n", "losses_backpropagation_solution = losses_backprop\n", "accuracy_backpropagation_solution = accuracy_backprop\n", "snr_backpropagation_solution = snr_backprop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Compare the performance and SNRs for Weight Perturbation, Node Perturbation, and Backpropagation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Compare the performance and SNRs for Weight Perturbation, Node Perturbation, and Backpropagation\n", "\n", "# plot performance over time\n", "with plt.xkcd():\n", " plt.plot(losses_node_perturbation_solution, label=\"Node Perturbation\", color='c')\n", " plt.plot(losses_weight_perturbation_solution, label=\"Weight Perturbation\", color='b')\n", " plt.plot(losses_backpropagation_solution, label=\"Backprop\", color='r')\n", " plt.xlabel(\"Updates\")\n", " plt.ylabel(\"MSE\")\n", " plt.legend()\n", " plt.title(\"Training loss\")\n", " plt.show()\n", "\n", "# plot the SNR at initialization for the three learning algorithms\n", "with plt.xkcd():\n", " plt.figure()\n", " x = [0, 1, 2]\n", " snr_vals = [snr_weight_perturbation_solution, snr_node_perturbation_solution, snr_backpropagation_solution] #pre-saved snrs\n", " colors = ['b', 'c', 'r']\n", " labels = ['Weight Perturbation', 'Node Perturbation', 'Backprop']\n", " plt.bar(x, snr_vals, color=colors, tick_label=labels)\n", " plt.xticks(rotation=90)\n", " plt.ylabel('SNR')\n", " plt.xlabel('Algorithm')\n", " plt.title('Gradient SNR')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "As should be evident, the signal-to-noise ratio for both weight and node perturbation are much worse than for backpropagation. This is also reflected in the poor performance of both algorithms relative to backpropagation. This shows that locality of parameter updates often comes at the price of poor performance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_assessing_variance\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 4: Feedback Alignment \n", "\n", "Estimated timing to here from start of tutorial: 1 hour\n", "\n", "This section will introduce another family of learning algorithms that exhibit no variance but become biased." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 3: Feedback Alignment\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 3: Feedback Alignment\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'aTsuAveKf90'), ('Bilibili', 'BV1jD421u75F')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_video_3\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\n", "$\\newcommand{\\stim}{\\mathbf{x}}$\n", "$\\newcommand{\\h}{\\mathbf{h}}$\n", "$\\newcommand{\\noisew}{\\boldsymbol \\Psi}$\n", "$\\newcommand{\\noiser}{\\boldsymbol \\xi}$\n", "$\\newcommand{\\target}{y}$\n", "$\\newcommand{\\pred}{\\mathbf{\\hat{y}}}$\n", "$\\newcommand{\\identity}{\\mathbf{I}}$\n", "$\\newcommand{\\blackbox}{f}$\n", "$\\newcommand{\\weight}{\\mathbf{W}}$\n", "$\\newcommand{\\weightout}{\\mathbf{W}^{\\textrm{out}}}$\n", "$\\newcommand{\\loss}{\\mathcal{L}}$\n", "$\\newcommand{\\derivative}[2]{\\frac{\\partial#1}{\\partial#2}}$\n", "$\\newcommand{\\rate}{\\mathbf{r}}$\n", "$\\newcommand{\\error}{\\boldsymbol \\delta}$\n", "$\\newcommand{\\losserror}{\\mathbf{e}}$\n", "$\\newcommand{\\backweight}{\\mathbf{B}}$\n", "\n", "In this section, we describe the __Feedback Alignment__ algorithm. Unlike weight and node perturbation, feedback alignment provides a mechanism whereby individual neurons can receive *targeted* error signals. To start, we assume the following network setup:\n", "\n", "\\begin{align}\n", " \\pred = \\blackbox(\\weight \\stim) = \\weightout\\sigma(\\weight\\stim) =\\weightout \\h\n", "\\end{align}\n", "\n", "With a mean squared error loss over all of the output neurons.\n", "\\begin{equation}\n", " \\loss = \\frac{1}{2n} \\sum_{k=1}^{n}\\left (\\target_k - \\hat{y}_k \\right )^2\n", "\\end{equation}\n", "\n", "Note here we have suppressed the batch index notation, and will calculate the following gradients as averages over batch elements.\n", "\n", "Backpropagation updates parameters using the gradient of the loss scaled by the learning rate $\\eta$.\n", "\n", "\\begin{align}\n", " \\Delta \\weight_{ji} &= - \\eta \\derivative{\\loss}{\\weight}_{ji} \\\\\n", " &= - \\eta \\underbrace{\\derivative{\\loss}{\\pred}\\derivative{\\pred}{h_j}}_{\\delta_j}\\derivative{h_j}{\\weight_{ji}}\\\\\n", " &= - \\eta \\delta_j \\sigma^{\\prime}(\\weight\\stim)_j\\stim_i \\\\\n", " &= - \\eta \\delta_j h^{\\prime}_j\\stim_i\n", "\\end{align}\n", "\n", "While $h^{\\prime}_j$ and $\\stim_i$ are available locally to the neuron, calculating $\\delta_j$\n", "involves non-local information, and is therefore biologically implausible.\n", "\n", "\\begin{align}\n", " \\delta_j &= \\derivative{\\loss}{h_j} \\\\\n", " &= \\sum_{k=1}^n \\derivative{\\loss}{\\hat{y}_k}\\derivative{\\hat{y}_k}{h_j} \\\\\n", " &= \\sum_{k=1}^n \\overbrace{(y_k - \\hat{y_k})}^{e_k} \\weightout_{kj} \\\\\n", " &= e_1 {\\color{red}\\weightout_{1j}} + e_2 {\\color{green}\\weightout_{2j}} + e_3{\\color{magenta}\\weightout_{3j}}\n", "\\end{align}\n", "\n", "In order to calculate $\\delta_j$ we need to use all of of the outgoing weights from neuron $h_j$.\n", "\n", "Writing $\\error$ as a column vector (i.e. $\\derivative{\\loss}{\\h}$ in [denominator layout](https://en.wikipedia.org/wiki/Matrix_calculus#Layout_conventions)) we see that in order to calculate $\\error$ we need the transpose of the forward weights.\n", "\\begin{align}\n", " \\error &= \\weight_{out}^T \\losserror .\n", "\\end{align}\n", "\n", "\n", "\n", "*From Lillicrap et al. (2016), CC-BY*\n", "\n", "Feedback alignment replaces $\\weight_{out}^T $ with a random matrix, $\\backweight$. This resolves the 'weight transport' problem, because the feedback weights are no longer the same as the feedforward weights. However, by replacing $\\weight_{out}^T$ with $\\backweight$, we are no longer calculating an accurate gradient! Interestingly, we will see empirically in subsequent sections that this replacement still produces reasonably good gradient estimates, though it still introduces *bias*." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Exercise 2: Feedback alignment algorithm\n", "\n", "For this exercise, you are tasked with implementing the __Feedback Alignment__ algorithm in our MLP network. Fill in the proper gradient calculation steps below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "class FeedbackAlignmentMLP(MLP):\n", " \"\"\"\n", " A multilayer perceptron that is capable of learning through the Feedback Alignment algorithm\n", " \"\"\"\n", "\n", " # function for calculating feedback alignment updates\n", " def feedback(self, rng, inputs, targets):\n", " \"\"\"\n", " Calculates the weight updates for feedback alignment learning\n", " \"\"\"\n", " ###################################################################\n", " ## Fill out the following then remove\n", " raise NotImplementedError(\"Student exercise: calculate the updates\")\n", " ###################################################################\n", "\n", " # do a forward pass\n", " hidden, output = self.inference(rng, inputs)\n", "\n", " # calculate the updates\n", " error = ...\n", " delta_W_h = np.dot(np.dot(self.B, error * self.act_deriv(output)) * self.act_deriv(hidden),\n", " add_bias(inputs).transpose())\n", " delta_W_y = ...\n", "\n", " return delta_W_h, delta_W_y" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W2D3_Microlearning/solutions/W2D3_Tutorial1_Solution_d1bc17ea.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now that you have an implementation of Feedback Alignment, we can verify that it works properly by training a sample network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define hyperparameters\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Define hyperparameters\n", "numhidden = 500\n", "batchsize = 200\n", "initweight = 0.1\n", "learnrate = 0.001\n", "noise = 0.1\n", "numepochs = 3\n", "numrepeats = 1\n", "numbatches = int(train_images.shape[1] / batchsize)\n", "numupdates = numepochs * numbatches\n", "activation = 'sigmoid'\n", "report = True\n", "rep_rate = 1\n", "seed = 12345" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train FeedbackAlignmentMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Train FeedbackAlignmentMLP\n", "\n", "rng_fa = np.random.default_rng(seed=seed)\n", "\n", "losses_feedback = np.zeros((numupdates,))\n", "accuracy_feedback = np.zeros((numepochs,))\n", "cosine_sim_feedback = np.zeros((numepochs,))\n", "\n", "# select 1000 random images to test the accuracy on\n", "indices = rng_fa.choice(range(test_images.shape[1]), size=(1000,), replace=False)\n", "\n", "# create a network and train it using feedback alignment\n", "netfeedback = FeedbackAlignmentMLP(rng_fa, numhidden, sigma=initweight, activation=activation)\n", "(losses_feedback[:], accuracy_feedback[:], cosine_sim_feedback[:], _) = \\\n", " netfeedback.train(rng_fa, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \\\n", " learning_rate=learnrate, batch_size=batchsize, algorithm='feedback', noise=noise, \\\n", " report=report, report_rate=rep_rate)\n", "\n", "# save metrics for plots\n", "losses_feedback_alignment_solution = losses_feedback\n", "accuracy_feedback_alignment_solution = accuracy_feedback\n", "cosine_similarity_feedback_alignment_solution = cosine_sim_feedback\n", "\n", "# Train a network with Backpropagation for comparison\n", "\n", "# set the random seed to the current time\n", "rng_bp2 = np.random.default_rng(seed=seed)\n", "\n", "# select 1000 random images to test the accuracy on\n", "indices = rng_bp2.choice(range(test_images.shape[1]), size=(1000,), replace=False)\n", "\n", "losses_backprop = np.zeros((numupdates,))\n", "accuracy_backprop = np.zeros((numepochs,))\n", "cosine_sim_backprop = np.zeros((numepochs,))\n", "\n", "# create a network and train it using backprop\n", "netbackprop = MLP(rng_bp2, numhidden, sigma=initweight, activation=activation)\n", "(losses_backprop[:], accuracy_backprop[:], cosine_sim_backprop, _) = \\\n", " netbackprop.train(rng_bp2, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \\\n", " learning_rate=learnrate, batch_size=batchsize, algorithm='backprop', noise=noise, \\\n", " report=report, report_rate=rep_rate)\n", "\n", "\n", "# save metrics for plots\n", "losses_backpropagation_solution = losses_backprop\n", "accuracy_backpropagation_solution = accuracy_backprop\n", "cosine_similarity_backpropagation_solution = cosine_sim_backprop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Observe the performance of FeedbackAlignmentMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Observe the performance of FeedbackAlignmentMLP\n", "\n", "# plot performance over time\n", "with plt.xkcd():\n", " plt.plot(losses_feedback_alignment_solution, label=\"Feedback Alignment\", color='g')\n", " plt.plot(losses_backpropagation_solution, label=\"Backprop\", color='r')\n", " plt.xlabel(\"Updates\")\n", " plt.ylabel(\"MSE\")\n", " plt.legend()\n", " plt.title(\"Training loss\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_feedback_alignment\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 5: Kolen-Pollack\n", "\n", "Estimated timing to here from start of tutorial: 1 hour 20 minutes\n", "\n", "This section presents the last method for this day, which lies in the cohort of biased ones, Kolen-Pollack method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 4: Kolen-Pollack\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 4: Kolen-Pollack\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'T1K8lL7XYEY'), ('Bilibili', 'BV16y411B7qm')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_video_4\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\n", "\n", "$\\newcommand{\\error}{\\boldsymbol \\delta}$\n", "$\\newcommand{\\losserror}{\\mathbf{e}}$\n", "$\\newcommand{\\backweight}{\\mathbf{B}}$\n", "$\\newcommand{\\h}{\\mathbf{h}}$\n", "$\\newcommand{\\y}{\\mathbf{y}}$\n", "\n", "\n", "As we've just seen, to update a feed-forward matrix using back-propagated error, we need to simply follow the weight update equation:\n", "\\begin{align}\n", " %\\Delta \\weight_{ji} &= - \\eta \\delta_j h^{\\prime}_j\\stim_i \\\\\n", " \\Delta \\weight_{out} &= - \\eta \\losserror \\h^T \\\\\n", " \\Delta \\weight &= - \\eta (\\error \\h^{\\prime})\\stim^T,\n", "\\end{align}\n", "where\n", "\\begin{align}\n", " \\error &= \\weight_{out}^T \\losserror .\n", "\\end{align}\n", "While directly \"transporting\" the weights, $\\weight_{out}^T$, is not biologically plausible, we showed that a random feedback matrix, $\\backweight$, can align the weights to propagate an approximated error,\n", "\\begin{align}\n", " \\error &= \\backweight \\losserror .\n", "\\end{align}\n", "\n", "However, this approach fails with deeper networks and more complicated datasets. We will now show a biologically plausible approach to modifying $\\backweight$, such that over learning, $\\backweight$ and $ \\weight_{out}^T $ become equal. This approach builds off an observation by Kolen and Pollack (1994) that if two matrices are repeatedly modified by the same values with weight decay,\n", "\n", "\\begin{align}\n", " \\Delta \\weight(t) &= \\mathbf{A}(t) - \\lambda \\weight(t) \\\\\n", " \\Delta \\backweight(t) &= \\mathbf{A}(t) - \\lambda \\backweight(t) ,\n", "\\end{align}\n", "then\n", "\\begin{align}\n", " \\weight(t+1) - \\backweight(t+1) &= \\weight(t) + \\Delta \\weight(t) - \\backweight(t) - \\Delta \\backweight(t) \\\\\n", " &= \\weight(t) - \\backweight(t) - \\lambda[\\weight(t) - \\backweight(t)] \\\\\n", " &= (1-\\lambda)^{t+1} [\\weight(0) - \\backweight(0)] .\n", "\\end{align}\n", "That is, as $t \\rightarrow \\infty$, the difference between the two matrices will converge to 0.\n", "\n", "The key observation is that the corresponding elements of $\\weight_{out}^T$ and $ \\backweight $ have access to the same locally available information. We can thus pick a plausible learning rule for the backward weights:\n", "\n", "\\begin{align}\n", " \\Delta \\backweight &= - \\eta \\h \\losserror^T - \\lambda \\backweight ,\n", "\\end{align}\n", "such that the updates to $\\backweight$ correspond to a transpose of the updates to $\\weight_{out}$,\n", "\\begin{align}\n", " \\Delta \\weight_{out} &= - \\eta \\losserror \\h^T - \\lambda \\weight_{out} \\\\\n", " \\Delta \\weight_{out}^T &= - \\eta \\h \\losserror^T - \\lambda \\weight_{out}^T.\n", "\\end{align}\n", "\n", "Thus, over many weight updates, $ \\backweight $ will converge to $\\weight_{out}^T$ and can be used to propagate errors back to inform updates to $\\weight$. Note that the same reasoning can be applied to networks of many layers." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Exercise 3: Kolen-Pollack algorithm\n", "\n", "For this exercise, you will be implementing the Kolen-Pollack algorithm for our MLP network. Fill in the proper gradient calculation below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "class KolenPollackMLP(MLP):\n", " \"\"\"\n", " A multilayer perceptron that is capable of learning through the Kolen-Pollack algorithm\n", " \"\"\"\n", "\n", " def kolepoll(self, rng, inputs, targets, eta_back=0.01):\n", " \"\"\"\n", " Calculates the weight updates for Kolen-Polack learning\n", " \"\"\"\n", " ###################################################################\n", " ## Fill out the following then remove\n", " raise NotImplementedError(\"Student exercise: calculate updates.\")\n", " ###################################################################\n", "\n", " # do a forward pass\n", " (hidden, output) = self.inference(rng, inputs)\n", "\n", " # calculate the updates for the forward weights\n", " error = targets - output\n", " delta_W_h = np.dot(np.dot(self.B, error * self.act_deriv(output)) * self.act_deriv(hidden), \\\n", " add_bias(inputs).transpose())\n", " delta_err = ...\n", " delta_W_y = delta_err - 0.1 * self.W_y\n", "\n", " # calculate the updates for the backwards weights and implement them\n", " delta_B = delta_err[:, :-1].transpose() - 0.1 * self.B\n", " self.B += ...\n", " return (delta_W_h, delta_W_y)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W2D3_Microlearning/solutions/W2D3_Tutorial1_Solution_95265523.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now that you have implemented Kolen-Pollack, we will test that it works by training a sample MLP network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train KolenPollackMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Train KolenPollackMLP\n", "rng_kp = np.random.default_rng(seed=seed)\n", "\n", "losses_kolepoll = np.zeros((numupdates,))\n", "accuracy_kolepoll = np.zeros((numepochs,))\n", "cosine_sim_kolepoll = np.zeros((numepochs,))\n", "# select 1000 random images to test the accuracy on\n", "indices = rng_kp.choice(range(test_images.shape[1]), size=(1000,), replace=False)\n", "\n", "# create a network and train it using feedback alignment\n", "netkolepoll = KolenPollackMLP(rng_kp, numhidden, sigma=initweight, activation=activation)\n", "(losses_kolepoll[:], accuracy_kolepoll[:], cosine_sim_kolepoll[:], _) = \\\n", " netkolepoll.train(rng_kp, train_images, train_labels, numepochs, test_images[:, indices], test_labels[:, indices], \\\n", " learning_rate=learnrate, batch_size=batchsize, algorithm='kolepoll', noise=noise, \\\n", " report=report, report_rate=rep_rate)\n", "\n", "# save metrics for plots\n", "losses_kolen_pollack_solution = losses_kolepoll\n", "accuracy_kolen_pollack_solution = accuracy_kolepoll\n", "cosine_similarity_kolen_pollack_solution = cosine_sim_kolepoll" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Observe the performance of KolenPollackMLP\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Observe the performance of KolenPollackMLP\n", "\n", "# plot performance over time\n", "with plt.xkcd():\n", " plt.plot(losses_feedback_alignment_solution, label=\"Feedback Alignment\", color='g')\n", " plt.plot(losses_backpropagation_solution, label=\"Backprop\", color='r')\n", " plt.plot(losses_kolen_pollack_solution, label=\"Kolen-Pollack\", color='k')\n", " plt.xlabel(\"Updates\")\n", " plt.ylabel(\"MSE\")\n", " plt.legend()\n", " plt.title(\"Training loss\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_kolen_pollack\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 6: Assessing the bias of learning algorithms\n", "\n", "Estimated timing to here from start of tutorial: 1 hour 50 minutes\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Having implemented both Feedback Alignment and Kolen-Pollack, we should now compare the parameter updates obtained by these algorithms to the Backpropagation update itself, so that we can assess how accurate these methods actually are.\n", "\n", "To do this, we will measure the alignment between two gradient estimates using the *cosine similarity*, which is defined as follows, for two proposed parameter updates $\\Delta \\theta_1$ and $\\Delta \\theta_2$, which we assume are vectors of equal length:\n", "\n", "\\begin{equation}\n", "CSim(\\Delta \\theta_1, \\Delta \\theta_2) = \\frac{\\Delta \\theta_1^T \\Delta \\theta_2}{\\|\\Delta \\theta_1\\|_2 \\| \\Delta \\theta_2 \\|_2},\n", "\\end{equation}\n", "\n", "This is simply the inner product between a unit vector pointing in the direction of $\\Delta \\theta_1$ and a unit vector pointing in the direction of $\\Delta \\theta_2$. It takes value 1 if the two vectors are perfectly aligned, 0 if they are orthogonal, and -1 if they point in the exact opposite direction from one another. Like the SNR, the cosine similarity measure is also *scale invariant*, so multiplying $\\Delta \\theta_1$ by a constant (e.g. a learning rate) will not change the cosine similarity at all. We care about parameter updates pointing in the same *direction* as the gradient, but for our purposes, it does not really matter if the update is greater or lesser in magnitude, because we can always decrease or increase the learning rate. For this reason, using a scale invariant measure of alignment is advantageous, because it ignores magnitudes.\n", "\n", "Below, we show the cosine similarity between backpropagation, feedback alignment, and Kolen-Pollack, all compared to the backpropagation algorithm itself. The backpropagation-to-backpropagation cosine similarity will always be 1, because the vectors trivially point in the same direction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the gradient similarity to backpropagation over training with shaded error regions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plot the gradient similarity to backpropagation over training with shaded error regions\n", "with plt.xkcd():\n", " plt.plot(cosine_similarity_backpropagation_solution, label=\"Backprop\", color='r')\n", " plt.plot(cosine_similarity_feedback_alignment_solution, label=\"Feedback Alignment\", color='g')\n", " plt.plot(cosine_similarity_kolen_pollack_solution, label=\"Kolen-Pollack\", color='k')\n", " plt.xlabel(\"Epochs\")\n", " plt.ylabel(\"Cosine Sim\")\n", " plt.legend()\n", " plt.title(\"Cosine Similarity to Backprop\")\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Finally, let's show that these improvements in the alignment between the backprop gradient estimates and the various local estimates correlate with the accuracy of the classification on MNIST." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classification accuracy comparison\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Classification accuracy comparison\n", "with plt.xkcd():\n", " plt.plot(accuracy_weight_perturbation_solution)\n", " plt.plot(accuracy_node_perturbation_solution)\n", " plt.plot(accuracy_feedback_alignment_solution)\n", " plt.plot(accuracy_kolen_pollack_solution)\n", " plt.plot(accuracy_backpropagation_solution)\n", " plt.legend(['Weight perturbation', 'Node perturbation', 'Feedback alignment', 'Kolen-Pollack', 'Backprop'])\n", " plt.xlabel('Epochs')\n", " plt.ylabel('Accuracy (%)')\n", " plt.title('Accuracy over epochs')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_assessing_bias\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "*Estimated timing of tutorial: 2 hours*\n", "\n", "Here, we will end the tutorial with a few final remarks:\n", "* There are many learning algorithms that are capable of optimizing system-wide objective functions using only information that could be plausibly available to a synapse (e.g. weight and node perturbation, feedback alignment, and Kolen-Pollack).\n", "* Local learning algorithms typically either introduce high variance or biases into gradient estimates.\n", "* Bias and variance in gradient estimates can impede the ability of a neural system to learn effectively at scale, either on complex datasets or in large neural networks.\n", "* Therefore, good generalization (which depends on large, powerful network architectures), will come from learning algorithms that have as little variance and bias in their gradient estimates as possible.\n", "* The neuroscience community does not yet know which, if any, of the learning algorithms discussed in this tutorial map onto learning in the brain. The algorithms we have introduced are best thought of as 'candidate models' for how the brain could be learning." ] } ], "metadata": { "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "W2D3_Tutorial1", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 4 }