{ "cells": [ { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 2: Normalization\n", "\n", "**Week 1, Day 5: Microcircuits**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Alish Dipani, Xaq Pitkow\n", "\n", "__Content reviewers:__ Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk, Patrick Mineault\n", "\n", "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "___\n", "\n", "\n", "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 50 minutes*\n", "\n", "In this tutorial, you will learn about the microcircuit element of normalization, which is a prominent computation in brains and machines. You will see different types of normalization, how to implement them, and observe some of its benefits for generalization.\n", "\n", "**Tutorial Learning Objectives**\n", "* Understand how nonlinearities may be universal function approximators, but not all functions are simple to learn.\n", "* Implement a family of normalization mechanisms.\n", "* Demonstrate how normalization helps in learning and information transmission." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/eckvr/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/eckvr/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "\n", "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip install vibecheck datatops --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_neuroai\",\n", " \"user_key\": \"wb2cxze8\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W1D5_T2\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Imports\n", "\n", "#working with data\n", "import random\n", "import numpy as np\n", "from collections import OrderedDict\n", "\n", "\n", "#plotting\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patheffects as path_effects\n", "import seaborn as sns\n", "import ipywidgets as widgets\n", "import logging\n", "\n", "#utils\n", "from tqdm import tqdm\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "#modeling\n", "import scipy\n", "from sklearn.metrics import ConfusionMatrixDisplay\n", "import scipy\n", "import torch\n", "import torchvision\n", "import torch.nn.functional as F\n", "from torchvision import transforms\n", "from torch import nn\n", "from torch.utils.data import Dataset, DataLoader, random_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper functions\n", "\n", "# Section 1\n", "\n", "# Validation\n", "VAL_X_LOW = 0\n", "VAL_X_HIGH = 1\n", "\n", "def store_grads(model, grads_mat):\n", " \"\"\"\n", " Store the gradients of a PyTorch model's layers in a dictionary.\n", "\n", " Inputs:\n", " - model (torch.nn.Module): The PyTorch model whose gradients will be stored.\n", " - grads_mat (dict): A dictionary that will store the gradients: the keys correspond to the names of the model's layers, and the values are PyTorch tensors containing the gradients.\n", "\n", " Outputs:\n", " - grads_mat (dict): The input dictionary `grads_mat` with the updated gradient values.\n", " \"\"\"\n", " for grad_layer_name, grad_layer in model.grad_layers.items():\n", " if grad_layer_name in grads_mat.keys():\n", " grads_mat[grad_layer_name] = torch.vstack((\n", " grads_mat[grad_layer_name],\n", " grad_layer.grad.detach().cpu()\n", " ))\n", " else:\n", " grads_mat[grad_layer_name] = grad_layer.grad.detach().cpu()\n", " return grads_mat\n", "\n", "def train_sec1(model, train_dataloader, learning_rate, n_epochs, VAL_X_LOW, VAL_X_HIGH, \\\n", " track_grads=False):\n", " \"\"\"\n", " Train a model using the given dataloader, learning rate, and number of epochs.\n", "\n", " Inputs:\n", " - model (torch.nn.Module): The PyTorch model to be trained.\n", " - train_dataloader (torch.utils.data.DataLoader): Dataloader for training data.\n", " - learning_rate (float): The learning rate for the optimizer.\n", " - n_epochs (int): Number of epochs for training.\n", " - VAL_X_LOW (float): Lower bound of validation x values.\n", " - VAL_X_HIGH (float): Upper bound of validation x values.\n", " - track_grads (bool): Whether to track and store gradients during training.\n", "\n", " Outputs:\n", " - losses_iter (list of float): List of loss values for each iteration.\n", " - losses_epoch (list of float): List of loss values for each epoch.\n", " - training_dynamics_mat (torch.Tensor): Tensor containing the validation errors over training.\n", " - input_thresholds_tensor (torch.Tensor): Tensor containing the input threshold weights over training.\n", " - output_weights_tensor (torch.Tensor): Tensor containing the output layer weights over training.\n", " - gradients_mat (dict, optional): Dictionary containing the gradients if track_grads is True.\n", " \"\"\"\n", " # Training settings\n", " loss_fn = nn.MSELoss()\n", " optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", " # Train the model\n", " model.train()\n", " losses_iter = []\n", " losses_epoch = []\n", " # training loss dynamics\n", " val_x = torch.arange(VAL_X_LOW+1e-2, VAL_X_HIGH, 0.01).unsqueeze(1).to(DEVICE)\n", " val_y = (1/val_x).to(DEVICE)\n", " training_dynamics_mat = None\n", " gradients_mat = {}\n", " # training weight dynamics\n", " input_thresholds_tensor = model.input_threshold_weights.cpu().clone().detach()\n", " output_weights_tensor = model.output_layer.weight.data[0].cpu().clone().detach()\n", " for epoch in range(n_epochs):\n", " epoch_loss = 0\n", " for X_batch, y_batch in train_dataloader:\n", " X_batch = X_batch.to(DEVICE)\n", " y_batch = y_batch.to(DEVICE)\n", " optimizer.zero_grad()\n", " y_pred = model(X_batch)\n", " loss = loss_fn(y_pred, y_batch)\n", " loss.backward()\n", " optimizer.step()\n", " with torch.no_grad():\n", " losses_iter.append(loss.cpu().item())\n", " epoch_loss += loss.cpu().item()\n", " if track_grads:\n", " gradients_mat = store_grads(model, gradients_mat)\n", " losses_epoch.append(epoch_loss)\n", " # Store training loss dynamics\n", " with torch.no_grad():\n", " val_errors = torch.log(nn.functional.mse_loss(model.predict(val_x), val_y, reduction='none'))\n", " val_errors = (val_errors.T).cpu()\n", " if training_dynamics_mat is None:\n", " training_dynamics_mat = val_errors\n", " else:\n", " training_dynamics_mat = torch.vstack((val_errors, training_dynamics_mat))\n", " # Store training weight dynamics\n", " epoch_ths = model.input_threshold_weights.cpu().clone().detach()\n", " epoch_wts = model.output_layer.weight.data[0].cpu().clone().detach()\n", " input_thresholds_tensor = torch.vstack((input_thresholds_tensor, epoch_ths))\n", " output_weights_tensor = torch.vstack((output_weights_tensor, epoch_wts))\n", " if track_grads:\n", " return losses_iter, losses_epoch, training_dynamics_mat, \\\n", " input_thresholds_tensor, output_weights_tensor, gradients_mat\n", " else:\n", " return losses_iter, losses_epoch, training_dynamics_mat, \\\n", " input_thresholds_tensor, output_weights_tensor\n", "\n", "def evaluate_sec1(model, test_dataloader):\n", " \"\"\"\n", " Evaluate a model using the given dataloader and return the test loss, input values, true values, and predicted values.\n", "\n", " Inputs:\n", " - model (torch.nn.Module): The PyTorch model to be evaluated.\n", " - test_dataloader (torch.utils.data.DataLoader): Dataloader for test data.\n", "\n", " Outputs:\n", " - test_loss (float): The loss on the test dataset.\n", " - x_all (torch.Tensor): Tensor containing all input values from the test dataset.\n", " - y_all (torch.Tensor): Tensor containing all true values from the test dataset.\n", " - y_pred_all (torch.Tensor): Tensor containing all predicted values from the test dataset.\n", " \"\"\"\n", " # evaluate MSE after training\n", " model.eval()\n", " test_loss = 0\n", " x_all = None\n", " y_all = None\n", " y_pred_all = None\n", " with torch.no_grad():\n", " for X_batch, y_batch in test_dataloader:\n", " X_batch = X_batch.to(DEVICE)\n", " y_batch = y_batch.to(DEVICE)\n", " y_pred = model.predict(X_batch)\n", " # Store\n", " if x_all is None:\n", " x_all = X_batch.flatten().cpu().detach()\n", " else:\n", " x_all = torch.concat((x_all, X_batch.flatten().detach()))\n", "\n", " if y_all is None:\n", " y_all = y_batch.flatten().cpu().detach()\n", " else:\n", " y_all = torch.concat((y_all, y_batch.flatten().detach()))\n", "\n", " if y_pred_all is None:\n", " y_pred_all = y_pred.flatten().cpu().detach()\n", " else:\n", " y_pred_all = torch.concat((y_pred_all, y_pred.flatten().detach()))\n", "\n", " # Calculate loss\n", " loss = nn.functional.mse_loss(y_pred, y_batch)\n", " test_loss += loss.cpu()\n", " return test_loss, x_all, y_all, y_pred_all\n", "\n", "# Section 2.2\n", "\n", "def visualize_images_sec22(data, data_titles, viz_p):\n", " \"\"\"\n", " Visualize pixels in given data arrays.\n", "\n", " Inputs:\n", " - data (list of numpy.array): List of numpy arrays to visualize.\n", " - data_titles (list of str): List of titles for each data array.\n", " - viz_p (int): Number of samples to visualize.\n", "\n", " Outputs:\n", " - None: Displays a plot of the data arrays.\n", " \"\"\"\n", " with plt.xkcd():\n", " vmin = np.min([np.min(arr[: viz_p, :]) for arr in data])\n", " vmax = np.min([np.max(arr[: viz_p, :]) for arr in data])\n", " cmap = 'gray_r'\n", " # Plot\n", " height_ratios = [10, 2, 10] if len(data)==3 else [10, 2, 10, 10]\n", " figsize = (8, 6) if len(data)==3 else (8, 8)\n", " fig, axs = plt.subplots(len(data), 1, figsize=figsize, \\\n", " gridspec_kw={'height_ratios': height_ratios})\n", " cbar_ax = fig.add_axes([0.98, 0.125, 0.02, 0.75]) # Define position for colorbar\n", " for i, data_ in enumerate(data):\n", " sns.heatmap(data_[:viz_p, :].T, cmap=cmap, annot=False, cbar=(i==0), fmt=\".2f\", annot_kws={\"size\": 10},\n", " ax=axs[i], linewidths=0, linecolor='black', square=True,\n", " cbar_ax=None if i else cbar_ax)\n", " # Use single colorbar for first heatmap\n", " # Show axes on all sides\n", " axs[i].spines['top'].set_visible(True)\n", " axs[i].spines['right'].set_visible(True)\n", " axs[i].spines['bottom'].set_visible(True)\n", " axs[i].spines['left'].set_visible(True)\n", " axs[i].set_xticks([])\n", " axs[i].set_yticks([])\n", " axs[i].set_title(data_titles[i])\n", " axs[i].set_xlabel('')\n", " axs[i].set_ylabel('')\n", " # Set common x and y labels\n", " fig.text(0.5, 0, 'Samples', ha='center', fontsize=15)\n", " fig.text(0, 0.5, 'Pixel Intensity', va='center', rotation='vertical', fontsize=15)\n", " plt.subplots_adjust(wspace=5) # Adjust the horizontal spacing between subplots\n", " plt.show()\n", "\n", "def subsets(arr, k):\n", " \"\"\"\n", " Generate all possible subsets from arr with length k.\n", "\n", " Inputs:\n", " - arr (numpy.array): 1-D numpy array or list.\n", " - k (int): Length of the subsets.\n", "\n", " Outputs:\n", " - subsets (numpy.array): Array of all possible subsets with length k.\n", " \"\"\"\n", " from itertools import combinations\n", " return np.array(list(combinations(arr, k)))\n", "\n", "# Section 2.3\n", "\n", "def train_cnns(model, train_dataloader, learning_rate, momentum, n_epochs):\n", " \"\"\"\n", " Train a CNN model using the given dataloader, learning rate, momentum, and number of epochs.\n", "\n", " Inputs:\n", " - model (torch.nn.Module): The CNN model to be trained.\n", " - train_dataloader (torch.utils.data.DataLoader): Dataloader for training data.\n", " - learning_rate (float): The learning rate for the optimizer.\n", " - momentum (float): The momentum for the optimizer.\n", " - n_epochs (int): Number of epochs for training.\n", "\n", " Outputs:\n", " - losses_iter (list of float): List of loss values for each iteration.\n", " - losses_epoch (list of float): List of loss values for each epoch.\n", " \"\"\"\n", " # Training settings\n", " loss_fn = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, \\\n", " momentum=momentum)\n", "\n", " # Train the model\n", " model.train()\n", " losses_iter = []\n", " losses_epoch = []\n", " # training loss dynamics\n", " gradients_mat = {}\n", " # training weight dynamics\n", " for epoch in tqdm(range(n_epochs)):\n", " epoch_loss = 0\n", " for X_batch, y_batch in train_dataloader:\n", " X_batch = X_batch.to(DEVICE)\n", " y_batch = y_batch.to(DEVICE)\n", " optimizer.zero_grad()\n", " y_pred = model(X_batch)\n", " loss = loss_fn(y_pred, y_batch)\n", " loss.backward()\n", " optimizer.step()\n", " with torch.no_grad():\n", " losses_iter.append(loss.cpu().item())\n", " epoch_loss += loss.cpu().item()\n", " losses_epoch.append(epoch_loss)\n", " return losses_iter, losses_epoch\n", "\n", "def evaluate_cnns(model, test_dataloader):\n", " \"\"\"\n", " Evaluate a CNN model using the given dataloader and return the test loss and accuracy.\n", "\n", " Inputs:\n", " - model (torch.nn.Module): The CNN model to be evaluated.\n", " - test_dataloader (torch.utils.data.DataLoader): Dataloader for test data.\n", "\n", " Outputs:\n", " - test_loss (float): The loss on the test dataset.\n", " - accuracy (float): The accuracy on the test dataset.\n", " \"\"\"\n", " model.eval()\n", " test_loss = 0\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for X_batch, y_batch in test_dataloader:\n", " X_batch = X_batch.to(DEVICE)\n", " y_batch = y_batch.to(DEVICE)\n", " y_pred = model(X_batch).data\n", " # Calculate loss\n", " loss = nn.functional.cross_entropy(y_pred, y_batch)\n", " test_loss += loss.cpu().item()\n", " # Calculate accuracy\n", " # the class with the highest energy is what we choose as prediction\n", " _, predicted = torch.max(y_pred, 1)\n", " total += y_pred.size(0)\n", " correct += (predicted == y_batch).sum().item()\n", " accuracy = 100 * correct / total\n", " return test_loss, accuracy\n", "\n", "def normalize_implemented(x, sigma, p, g):\n", " \"\"\"\n", " Inputs:\n", " - x(np.ndarray): Input array (n_samples * n_dim)\n", " - sigma(float): Smoothing factor\n", " - p(int): p-norm\n", " - g(int): scaling factor\n", "\n", " Outputs:\n", " - xnorm (np.ndarray): normalized values.\n", " \"\"\"\n", " # Raise the absolute value of x to the power p\n", " xp = np.power(np.abs(x), p)\n", " # Sum the x over the dimensions (n_dim) axis\n", " xp_sum = np.sum(np.power(np.abs(x), p), axis=1)\n", " # Correct the dimensions of xp_sum, and taking the average reduces the dimensions\n", " # Making xp_sum a row vector of shape (1, n_dim)\n", " xp_sum = np.expand_dims(xp_sum, axis=1)\n", " # Raise the sum to the power 1/p and add the smoothing factor (sigma)\n", " denominator = sigma + np.power(xp_sum, 1/p)\n", " # Scale the input data with a factor of g\n", " numerator = x*g\n", " # Calculate normalized x\n", " xnorm = numerator/denominator\n", " return xnorm\n", "\n", "# Exercise solutions for correct plot output\n", "\n", "class ReLUNet(nn.Module):\n", " \"\"\"\n", " ReLUNet architecture\n", " Structure is as follows:\n", " y = Σi(ai * ReLU(θi - x))\n", " \"\"\"\n", " # Define the structure of your network\n", " def __init__(self, n_units):\n", " \"\"\"\n", " Args:\n", " n_units (int): Number of hidden units\n", "\n", " Returns:\n", " Nothing\n", " \"\"\"\n", " super(ReLUNet, self).__init__()\n", " # Create input thresholds\n", " self.input_threshold_weights = nn.Parameter(torch.abs(torch.randn(n_units)))\n", " self.non_linearity = nn.ReLU()\n", " self.output_layer = nn.Linear(n_units, 1)\n", " nn.init.xavier_normal_(self.output_layer.weight)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " Args:\n", " x: torch.Tensor\n", " Input tensor of size ([1])\n", " \"\"\"\n", " op = self.input_threshold_weights - x #prepare the input to be passed through ReLU\n", " op = self.non_linearity(op) #apply ReLU\n", " op = self.output_layer(op) #run through output layer\n", " return op\n", "\n", " # Choose the most likely label predicted by the network\n", " def predict(self, x):\n", " \"\"\"\n", " Args:\n", " x: torch.Tensor\n", " Input tensor of size ([1])\n", " \"\"\"\n", " output = self.forward(x)\n", " return output\n", "\n", "non_linearities = {\n", " 'ReLU': nn.ReLU(),\n", " 'ReLU6': nn.ReLU6(),\n", " 'SoftPlus': nn.Softplus(),\n", " 'Sigmoid': nn.Sigmoid(),\n", " 'Tanh': nn.Tanh()\n", "}\n", "\n", "def HardTanh(x):\n", " \"\"\"\n", " Calculates `tanh` output for the given input data.\n", "\n", " Inputs:\n", " - x (np.ndarray): input data.\n", "\n", " Outputs:\n", " - output (np.ndarray): `tanh(x)`.\n", " \"\"\"\n", " min_val = -1\n", " max_val = 1\n", " output = np.copy(x)\n", " output[output>max_val] = max_val\n", " output[output= 1+leak_slope, \\\n", " (ycopy - np.sign(ycopy))/leak_slope, \\\n", " ycopy/(1+leak_slope)\n", " )\n", " return output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set random seed\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set random seed\n", "\n", "def set_seed(seed=None, seed_torch=True):\n", " if seed is None:\n", " seed = np.random.choice(2 ** 32)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " if seed_torch:\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True\n", "\n", "\n", "# In case that `DataLoader` is used\n", "def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)\n", "\n", "set_seed(seed=42, seed_torch=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set device (GPU or CPU)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set device (GPU or CPU)\n", "\n", "def set_device():\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " print(\"GPU is not enabled in this notebook. \\n\"\n", " \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n", " \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n", " else:\n", " print(\"GPU is enabled in this notebook. \\n\"\n", " \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n", " \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n", "\n", " return device\n", "\n", "DEVICE = set_device()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 1: Can ReLUs implement normalization?\n", "\n", "In this section we will explore how feasible it is to estimate a normalization-like function." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Introduction\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Introduction\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', '4CyTyu9KZKA'), ('Bilibili', 'BV1es421u7KR')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_introduction\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "The general form of normalization is:\n", "\n", "$$\\hat{x} = \\frac{x}{f(||x||)}$$\n", "\n", "There are indeed many options for the specific form of the denominator here; still, what we want to highlight is the essential divisive nature of the normalization.\n", "\n", "Evidence suggests that normalization provides a useful inductive bias in artificial and natural systems. However, do we need a dedicated computation that implements normalization?\n", "\n", "Let's explore if ReLUs can estimate a normalization-like function. Specifically, we will see if a fully-connected one-layer network can estimate $y=\\frac{1}{x+\\epsilon}$ function.\n", "\n", "In the cell below, we visualize train and test data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate $y=\\frac{1}{x+\\epsilon}$ train and test dataloaders\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " $\\epsilon = 0.01$\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Generate $y=\\frac{1}{x+\\epsilon}$ train and test dataloaders\n", "# Target function y = 1/x+ε\n", "\n", "# @markdown $\\epsilon = 0.01$\n", "\n", "N_SAMPLES = 5000\n", "X_LOW = 0\n", "X_HIGH = 10\n", "TRAIN_RATIO = 0.7\n", "EPSILON = 1e-2\n", "\n", "range01_ratio = 0.20 # % of samples in the range 0-1\n", "X1 = torch.distributions.uniform.Uniform(X_LOW, 1).rsample(sample_shape=torch.Size([int(N_SAMPLES*range01_ratio), 1]))\n", "X2 = torch.distributions.uniform.Uniform(1, X_HIGH).rsample(sample_shape=torch.Size([int(N_SAMPLES*(1-range01_ratio)), 1]))\n", "X_sec1 = torch.concatenate((X1, X2)) + EPSILON\n", "y_sec1 = 1/X_sec1\n", "\n", "class ReLUDataset(Dataset):\n", " def __init__(self, X, y):\n", " self.X = X\n", " self.y = y\n", "\n", " def __len__(self):\n", " return len(self.y)\n", "\n", " def __getitem__(self, idx):\n", " X = self.X[idx]\n", " y = self.y[idx]\n", " return X, y\n", "\n", "dataset_sec1 = ReLUDataset(X_sec1, y_sec1)\n", "\n", "# Define the sizes for training and testing sets\n", "train_size = int(TRAIN_RATIO * len(dataset_sec1))\n", "test_size = len(dataset_sec1) - train_size\n", "\n", "# Split the dataset into training and testing sets\n", "train_dataset_sec1, test_dataset_sec1 = random_split(dataset_sec1, [train_size, test_size])\n", "\n", "# Dataloaders\n", "# Create DataLoader for the training set\n", "train_dataloader_sec1 = DataLoader(train_dataset_sec1, batch_size=len(train_dataset_sec1), shuffle=True)\n", "\n", "# Create DataLoader for the testing set\n", "test_dataloader_sec1 = DataLoader(test_dataset_sec1, batch_size=len(test_dataset_sec1), shuffle=False)\n", "\n", "train_data = torch.column_stack((train_dataset_sec1.dataset.X[train_dataset_sec1.indices], train_dataset_sec1.dataset.y[train_dataset_sec1.indices]))\n", "sorted_indices = torch.argsort(train_data[:, 0])\n", "train_data_sorted = torch.index_select(train_data, 0, sorted_indices)\n", "\n", "with plt.xkcd():\n", " plt.plot(train_data_sorted[:, 0], train_data_sorted[:, 1], 's-y', label='train')\n", " plt.plot(test_dataset_sec1.dataset.X[test_dataset_sec1.indices], test_dataset_sec1.dataset.y[test_dataset_sec1.indices], 'Dk', label='test', alpha=0.5, markersize=1.5)\n", " plt.xlabel('Input (x)')\n", " plt.ylabel('Output (y)')\n", " plt.title(r'$y=\\frac{1}{x+\\epsilon}$')\n", " plt.legend(prop={'size': 15})\n", " ax = plt.gca()\n", " for line in ax.get_lines():\n", " line.set_path_effects([path_effects.Normal()])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 1: ReLUNet\n", "Let's define a simple model having one layer with the equation:\n", "\n", "$$\\hat{y} = \\sum_{i}w_{i} \\text{ReLU}(\\theta_{i} - x)$$\n", "\n", "Here $\\theta_{i}$ is the threshold, and $w_{i}$ is the slope of neuron $i$. $\\theta_{i}$ & $w_{i}$ are learned parameters. Our network has a total of 100 neurons. Complete the forward pass of the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {} }, "outputs": [], "source": [ "###################################################################\n", "## Fill out the following then remove\n", "raise NotImplementedError(\"Student exercise: complete forward pass.\")\n", "###################################################################\n", "\n", "class ReLUNet(nn.Module):\n", " \"\"\"\n", " ReLUNet architecture\n", " The structure is the following:\n", " y = Σi(wi * ReLU(θi - x))\n", " \"\"\"\n", " # Define the structure of your network\n", " def __init__(self, n_units):\n", " \"\"\"\n", " Args:\n", " n_units (int): Number of hidden units\n", "\n", " Returns:\n", " Nothing\n", " \"\"\"\n", " super(ReLUNet, self).__init__()\n", " # Create input thresholds\n", " self.input_threshold_weights = nn.Parameter(torch.abs(torch.randn(n_units)))\n", " self.non_linearity = nn.ReLU()\n", " self.output_layer = nn.Linear(n_units, 1)\n", " nn.init.xavier_normal_(self.output_layer.weight)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " Args:\n", " x: torch.Tensor\n", " Input tensor of size ([1])\n", " \"\"\"\n", " op = ... - ... #prepare the input to be passed through ReLU\n", " op = self.non_linearity(...) #apply ReLU\n", " op = ... #run through output layer\n", " return op\n", "\n", " # Choose the most likely label predicted by the network\n", " def predict(self, x):\n", " \"\"\"\n", " Args:\n", " x: torch.Tensor\n", " Input tensor of size ([1])\n", " \"\"\"\n", " output = self.forward(x)\n", " return output" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D5_Microcircuits/solutions/W1D5_Tutorial2_Solution_b46035c9.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, let's train the model and evaluate it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training & Evaluating model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Training & Evaluating model\n", "\n", "# Variables\n", "# Model\n", "# Training\n", "n_epochs = 200\n", "learning_rate = 5e-2\n", "\n", "# Create a new ReLUNet and transfer it to the device\n", "model = ReLUNet(100).to(DEVICE)\n", "\n", "# Train ReLUNet\n", "losses_iter, losses_epoch, training_dynamics_mat, \\\n", " input_thresholds_tensor, output_weights_tensor = train_sec1(model, \\\n", " train_dataloader_sec1, learning_rate, n_epochs, VAL_X_LOW, VAL_X_HIGH)\n", "\n", "# Evaluate ReLUNet\n", "test_loss, x_all, y_all, y_pred_all = evaluate_sec1(model, test_dataloader_sec1)\n", "\n", "with plt.xkcd():\n", " # Plot training and evaluation performance\n", " fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", " # Plot training loss per epoch\n", " ax[0].plot(range(1, len(losses_epoch)+1), losses_epoch, '-k')\n", " # plot settings\n", " ax[0].set_xlabel('Epochs')\n", " ax[0].set_ylabel('MSE Loss')\n", " ax[0].set_title('Training Loss per Epoch')\n", "\n", " # Plotting evaluation performance\n", " # plot errors\n", " y_errs = nn.functional.mse_loss(y_pred_all, y_all, reduction='none')\n", " ax[1].bar(x_all, y_errs, width=0.1, color='red', alpha=0.5, \\\n", " label='error = $(y - \\^y)^{2}$')\n", " # plot predicted values\n", " eval_plot_data = torch.column_stack((x_all, y_all, y_pred_all)) # Sort data for plotting\n", " sorted_indices = torch.argsort(eval_plot_data[:, 0])\n", " eval_plot_data_sorted = torch.index_select(eval_plot_data, 0, sorted_indices)\n", " ax[1].plot(eval_plot_data_sorted[:, 0], eval_plot_data_sorted[:, 2], 'db', label=r'$\\^y$', markersize=1.5)\n", " # plot ground truth\n", " x_values = np.linspace(X_LOW+1e-2, X_HIGH+1e-2, 1000)\n", " y_values = 1 / x_values\n", " # plot settings\n", " ax[1].plot(x_values, y_values, '-k', alpha=0.5, label=r'$y=\\frac{1}{x+\\epsilon}$')\n", " ax[1].set_title(f'Predictions, Test Loss={test_loss:.3f}')\n", " ax[1].set_ylim((-0.5, 10))\n", " ax[1].set_xlabel('Input (x)')\n", " ax[1].set_ylabel('Output (y)')\n", " ax[1].legend()\n", " plt.tight_layout()\n", " ax = plt.gca()\n", " for line in ax.get_lines():\n", " line.set_path_effects([path_effects.Normal()])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "While the model learns, we see that it does not fit well with the testing data. Let's see what are the places where the model struggles during training.\n", "\n", "Here, we plot the log mean-squared errors for values of $x$ between 0 and 1 and their progression with epochs. These are log errors (clipped at $e^{5}$ represented with blue color)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot Training Loss Dynamics\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plot Training Loss Dynamics\n", "\n", "MAX_CLIP = 5\n", "\n", "with plt.xkcd():\n", " plt.figure(figsize=(10, 5))\n", "\n", " # Create a custom colormap for clipping\n", " light_pal = sns.light_palette(\"darkred\", as_cmap=True)\n", " clipping_color = [0., 0.75, 1., 1.] # RGBA\n", " new_colors = np.vstack( (light_pal(np.arange(light_pal.N)), np.array([clipping_color])) )\n", " custom_cmap = sns.blend_palette(new_colors, as_cmap=True)\n", "\n", " ax = sns.heatmap(training_dynamics_mat.numpy(), vmax=MAX_CLIP, vmin = 0, cmap=custom_cmap)\n", "\n", " xptslen = training_dynamics_mat.shape[1]\n", " xticklabels = np.round(np.arange(VAL_X_LOW, VAL_X_HIGH + 0.05, 0.2), decimals=1)\n", " ax.set_xticks(np.linspace(0, xptslen, len(xticklabels)), labels=xticklabels)\n", " ax.set_yticks(np.arange(0, n_epochs+.1, 20), labels=np.arange(n_epochs, -0.1, -20, dtype=int))\n", " ax.set_xlabel('Input (x)')\n", " ax.set_ylabel('Epochs')\n", " plt.title('Log Train MSE Loss')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We can see that the model has higher errors for lower values of $x$, and as the training progresses, the errors for lower $x$ values start to decrease. Note that the losses are huge for very small values of $x$ ($> e^5$).\n", "\n", "Does it mean the model employs more resources to learn the function between 0 and 1?\n", "\n", "To check it, let's visualize the ReLU thresholds." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Vizualize ReLUs\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Vizualize ReLUs\n", "# Get model weights\n", "l1_thresholds = model.input_threshold_weights.data.cpu()\n", "l2_slopes = model.output_layer.weight.data[0].cpu()\n", "l2_bias = model.output_layer.bias.item()\n", "\n", "# Visualizing\n", "# X points\n", "# 1 * n_samples\n", "xpoints = torch.arange(-5, X_HIGH, 0.04).unsqueeze(1)\n", "# zi = thetai - x\n", "# n_samples * n_units\n", "thetai = l1_thresholds.repeat(len(xpoints), 1)\n", "# n_samples * n_units\n", "zi = thetai - xpoints\n", "# n_samples * n_units\n", "hi = torch.maximum(zi, torch.tensor(0, dtype=torch.float32))\n", "# n_samples * n_units\n", "ahi = hi * l2_slopes\n", "# y = Σi(ahi)\n", "y_pred = torch.sum(ahi, axis=1) + l2_bias\n", "\n", "with plt.xkcd():\n", " # Visualizing\n", " plt.title(f'Visualize ReLUs')\n", "\n", " # y =1/x\n", " # Generate x values in the range [X_LOW, X_HIGH]\n", " x_values = np.linspace(X_LOW+1e-2, X_HIGH+1e-2, 1000)\n", " # Calculate y values for the function y = 1/x\n", " y_values = 1 / x_values\n", " plt.plot(x_values, y_values, '-k', alpha=1, label=r'$y=\\frac{1}{x+\\epsilon}$')\n", "\n", " # x = 0\n", " # plt.axvline(x=0, c='k',label='x=0')\n", "\n", " # y_hat\n", " plt.plot(xpoints, y_pred, 'sb', markersize=7, label=r'$\\^y$')\n", "\n", " # ReLUs\n", " for i in range(ahi.shape[-1]):\n", " plt.plot(xpoints, ahi[:, i], '-', alpha=1, color='lightblue')\n", " plt.plot([], [], '-', label='ReLUs', color='lightblue')\n", "\n", " # Thresholds\n", " plt.plot(l1_thresholds, np.zeros(len(l1_thresholds)), '|r', \\\n", " markersize=15, label=r'$\\theta_{i}$')\n", "\n", " plt.xlabel('Input (x)')\n", " plt.ylabel('Output (y)')\n", " plt.ylim((0, 5))\n", " plt.xlim((0, 3))\n", " plt.legend()\n", " ax = plt.gca()\n", " for line in ax.get_lines():\n", " line.set_path_effects([path_effects.Normal()])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We can see that the thresholds (red lines) are bunched up between 0 and 1, which means that the model dedicates the most resources to learning the function on this interval. Let's quantify the learning by plotting the threshold distributions and dynamics with epochs.\n", "\n", "Here we plot the cumulative distributions of $\\theta_{i}$ & $w_{i}$. We also plot the values of the parameters as they change across epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot Weight Dynamics\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plot Weight Dynamics\n", "\n", "with plt.xkcd():\n", " fig, ax = plt.subplots(2, 2, figsize=(12.5,10))\n", " ax[0, 0].set_xlabel(r'Thresholds ($\\theta_{i}$)')\n", " thereshold_weights = model.input_threshold_weights.data.cpu()\n", " sns.ecdfplot(input_thresholds_tensor[0, :], color='b', ax=ax[0, 0], label='initial')\n", " sns.ecdfplot(thereshold_weights, color='r', ax=ax[0, 0], label='final')\n", " ax[0, 0].legend()\n", " ax[0, 1].set_xlabel('Slopes ($𝑤_{i}$)')\n", " slopes = model.output_layer.weight.data[0].cpu()\n", " sns.ecdfplot(output_weights_tensor[0, :], color='b', ax=ax[0, 1], label='initial')\n", " sns.ecdfplot(slopes, color='r', ax=ax[0, 1], label='final')\n", " fig.suptitle(r'$\\hat{y} = \\sum_{i}𝑤_{i} ReLU(\\theta_{i} - x)$')\n", " ax[0, 1].legend()\n", "\n", " # Input thresholds\n", " n_cols = input_thresholds_tensor.shape[-1]\n", " n_rows = input_thresholds_tensor.shape[0]\n", " for n_col in range(n_cols):\n", " ax[1, 0].plot(range(n_rows), input_thresholds_tensor[:, n_col], '-k', alpha=0.5)\n", " ax[1, 0].set_xlabel('Epochs')\n", " ax[1, 0].set_ylabel('Input Thresholds')\n", " # Output weights\n", " n_cols = output_weights_tensor.shape[-1]\n", " n_rows = output_weights_tensor.shape[0]\n", " for n_col in range(n_cols):\n", " ax[1, 1].plot(range(n_rows), output_weights_tensor[:, n_col], '-k', alpha=0.5)\n", " ax[1, 1].set_xlabel('Epochs')\n", " ax[1, 1].set_ylabel('Output weights')\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "From the cumulative distribution plot of the thresholds ($\\theta_{i}$), we can see that around $80%$ of them are below $x=1$. Hence, the model majorly struggles to learn the function between 0 and 1, where the slope changes a lot.\n", "\n", "Since the slope changes infinite times between $x=0$ and 1, and ReLUs implement a linear function with a single slope, we would ideally need an infinite number of ReLUs units to fit the $y=\\frac{1}{x+\\epsilon}$ function. Hence, even though theoretically we can estimate the function, it is not empirically feasible to do so." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Coding Exercise 1 Discussion\n", "\n", "1. Do you think that having more slope changes in the activation function would help?\n", "\n", "Take a minute to think on your own, then discuss in a group." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_relunet\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 2: Test other non-linear activation functions\n", "\n", "Let's see if other non-linear activation functions perform better. Specifically, we test:\n", "1. $\\text{ReLU}(x) = (x)^{+} = \\max(0,x)$\n", "2. $\\text{ReLU6}(x) = \\min(\\max(0,x),6)$\n", "3. $\\text{SoftPlus}(x, \\beta=1) = \\frac{1}{\\beta} \\log(1+e^{βx})$\n", "4. $\\text{Sigmoid}(x) = \\sigma(x)= \\frac{1}{1+e^{-x}}$\n", "5. $\\text{Tanh}(x) = \\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}}$\n", "\n", "Our model is the same as before, having one layer, except we change the activation function:\n", "\n", "$$\\hat{y} = \\sum_{i}w_{i} \\text{Activation}(\\theta_{i} - x)$$\n", "\n", "Here $\\theta_{i}$ is the threshold, and $w_{i}$ is the slope of neuron $i$. We train and evaluate each model three times and plot the mean performance across runs. Your task is to complete the dictionary of the proposed non-linear functions (by defining them using `torch.nn` library)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "###################################################################\n", "## Fill out the following then remove\n", "raise NotImplementedError(\"Student exercise: complete non-linearities.\")\n", "###################################################################\n", "non_linearities = {\n", " 'ReLU': nn.ReLU(),\n", " 'ReLU6': ...,\n", " 'SoftPlus': nn.Softplus(),\n", " 'Sigmoid': ...,\n", " 'Tanh': ...\n", "}" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D5_Microcircuits/solutions/W1D5_Tutorial2_Solution_a36a7d90.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, let's train different networks and evaluate them. Notice that the cell below will run for 1 minute approximately." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train & Evaluate\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Train & Evaluate\n", "\n", "class NonLinearNet(nn.Module):\n", " \"\"\"\n", " NonLinearNet architecture\n", " The structure is the following:\n", " y = Σi(ai * Non-Linearity(θi - x))\n", " \"\"\"\n", " # Define the structure of your network\n", " def __init__(self, n_units, non_linearity):\n", " \"\"\"\n", " Args:\n", " n_units (int): Number of hidden units\n", "\n", " Returns:\n", " Nothing\n", " \"\"\"\n", " super(NonLinearNet, self).__init__()\n", " self.n_units = n_units\n", " self.input_threshold_weights = nn.Parameter(torch.normal(0., 0.1, (self.n_units,)))\n", " self.non_linearity = non_linearity\n", " self.output_layer = nn.Linear(n_units, 1)\n", " nn.init.normal_(self.output_layer.weight, mean=0, std=0.1)\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " Args:\n", " x: torch.Tensor\n", " Input tensor of size ([1])\n", " \"\"\"\n", " # Threshold\n", " op = self.input_threshold_weights - x\n", " op = self.non_linearity(op)\n", " op = self.output_layer(op)\n", " return op\n", "\n", " # Choose the most likely label predicted by the network\n", " def predict(self, x):\n", " \"\"\"\n", " Args:\n", " x: torch.Tensor\n", " Input tensor of size ([1])\n", " \"\"\"\n", " output = self.forward(x)\n", " return output\n", "\n", "# Model\n", "n_units = 1\n", "\n", "# Training\n", "n_epochs = 100\n", "learning_rate = 5e-3\n", "\n", "# Experiment\n", "n_runs = 3\n", "\n", "nls_train_loss_epochs = {}\n", "nls_test_losses = {}\n", "\n", "for n_run in range(n_runs):\n", " for nl_name, nl in non_linearities.items():\n", " model = NonLinearNet(n_units, nl).to(DEVICE)\n", " losses_iter, losses_epoch, training_dynamics_mat, \\\n", " input_thresholds_tensor, output_weights_tensor = train_sec1(model, \\\n", " train_dataloader_sec1, learning_rate, n_epochs, VAL_X_LOW, VAL_X_HIGH)\n", " if nl_name in nls_train_loss_epochs.keys():\n", " nls_train_loss_epochs[nl_name] = np.vstack((nls_train_loss_epochs[nl_name], np.array(losses_epoch)))\n", " else:\n", " nls_train_loss_epochs[nl_name] = np.array(losses_epoch)\n", "\n", " test_loss, x_all, y_all, y_pred_all = evaluate_sec1(model, test_dataloader_sec1)\n", " if nl_name in nls_test_losses.keys():\n", " nls_test_losses[nl_name].append(test_loss.item())\n", " else:\n", " nls_test_losses[nl_name] = [test_loss.item()]\n", "\n", "with plt.xkcd():\n", " fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", " # Plot Training loss\n", " colors = iter(plt.cm.rainbow(np.linspace(0, 1, len(nls_train_loss_epochs))))\n", " for nl_ in non_linearities.keys():\n", " c = next(colors)\n", " mean_train_loss = np.mean(nls_train_loss_epochs[nl_], axis=0)\n", " mean_train_loss = np.log(mean_train_loss)\n", " ax[0].plot(range(1, len(mean_train_loss)+1), mean_train_loss, '-', color=c, label=nl_)\n", " ax[0].set_xlabel('Epochs')\n", " ax[0].set_ylabel('Log Mean MSE Loss')\n", " ax[0].legend()\n", " ax[0].set_title('Training Loss')\n", "\n", " # Plot loss per epoch\n", " colors = iter(plt.cm.rainbow(np.linspace(0, 1, len(nls_train_loss_epochs))))\n", " box = ax[1].boxplot(list(nls_test_losses.values()), showfliers=False, \\\n", " medianprops={'color':'gray'})\n", " for median in box['medians']:\n", " c = next(colors)\n", " median.set_color(c)\n", " # plt.ylim((-0.5, 5))\n", " ax[1].set_xticks(range(1, len(nls_test_losses)+1), labels=nls_test_losses.keys())\n", " ax[1].set_xlabel('Non Linearity')\n", " ax[1].set_ylabel('Test MSE Loss')\n", " ax[1].set_title('Test loss')\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We can see that all of the proposed non-linear activation functions do not perform very well. Hence, it is beneficial to have dedicated computation that implements normalization." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_nonlinear_activation_functions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 2: Summary\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: Summary\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'PrkBtUdXEEM'), ('Bilibili', 'BV1Ff421Q7oC')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_first_section_summary\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Benefits of normalization\n", "\n", "*Estimated timing to here from start of tutorial: 20 minutes.*\n", "\n", "In this section, we will propose a simple normalization function, which you are going to test in different environments and observe how it is connected to the generalization. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 3: Introduction to Normalization\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 3: Introduction to Normalization\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'i9cPx6jt590'), ('Bilibili', 'BV1zn4y197yW')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_introduction_to_normalization\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "## Subsection 2.1: Explore normalization" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "This subsection is devoted to the definition of simple normalization function and the exploration of hyperparameters' impact on its result." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Coding Exercise 2.1: Implement normalization\n", "\n", "Let's implement the example of normalization, specifically, we will use the following function:\n", "\n", "$$\\mathbf{x}_{norm} = \\frac{g \\mathbf{x}}{\\sigma + \\sqrt[p]{\\Sigma_{i = 1}^{N} |x_{i}|^{p}}}$$\n", "\n", "Where\n", "\n", "- $x$ is an $N$-dimensional vector ($x \\in \\mathbb{R}^N$),\n", "\n", "- $g$ is a scaling factor,\n", "\n", "- $\\sigma$ is a smoothing factor,\n", "\n", "- $p$ defines $p$-norm of the input vector.\n", "\n", "Note that we are normalizing across dimensions (similar to Layer Normalization). You can test your implementation by running the code cell below!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def normalize(x, sigma, p, g):\n", " \"\"\"\n", " Inputs:\n", " - x(np.ndarray): Input array (n_samples * n_dim)\n", " - sigma(float): Smoothing factor\n", " - p(int): p-norm\n", " - g(int): scaling factor\n", "\n", " Outputs:\n", " - xnorm (np.ndarray): normalized values.\n", " \"\"\"\n", " #################################################\n", " ## TODO: Implement the normalization example equation ##\n", " # Fill remove the following line of code once you have completed the exercise:\n", " raise NotImplementedError(\"Student exercise: complete normalization function.\")\n", " #################################################\n", " # Raise the absolute value of x to the power p\n", " xp = ...\n", " # Sum the x over the dimensions (n_dim) axis\n", " xp_sum = ...\n", " # Correct the dimensions of xp_sum, and taking the average reduces the dimensions\n", " # Making xp_sum a row vector of shape (1, n_dim)\n", " xp_sum = np.expand_dims(xp_sum, axis=1)\n", " # Raise the sum to the power 1/p and add the smoothing factor (sigma)\n", " denominator = ...\n", " # Scale the input data with a factor of g\n", " numerator = ...\n", " # Calculate normalized x\n", " xnorm = numerator/denominator\n", " return xnorm" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D5_Microcircuits/solutions/W1D5_Tutorial2_Solution_dd43ccaf.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Test `normalize()` function\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown Test `normalize()` function\n", "def check_normalize(func):\n", " def np_norm(x, sigma, p, g):\n", " xnorm = (x*g)/(np.expand_dims(np.linalg.norm(x, ord=p, axis=1), axis=1)+sigma)\n", " return xnorm\n", " # Function to check the normalization function\n", " incorrect_message = \"Normalize function incorrect\"\n", " test_x = np.random.rand(3, 3)\n", " # Test 1\n", " assert np.array_equal(np_norm(test_x, 1, 1, 1), normalize(test_x, 1, 1, 1)), incorrect_message\n", " # Test 2\n", " assert np.array_equal(np_norm(test_x, 2, 0.3, 1.2), normalize(test_x, 2, 0.3, 1.2)), incorrect_message\n", " # Test 3\n", " assert np.array_equal(np_norm(test_x, 0.1, 3, 2), normalize(test_x, 0.1, 3, 2)), incorrect_message\n", " # Test 4\n", " assert np.array_equal(np_norm(test_x, 2.4, 3.2, 1.5), normalize(test_x, 2.4, 3.2, 1.5)), incorrect_message\n", " print('Normalize function works correctly!')\n", "\n", "check_normalize(normalize)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_implement_normalization\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Interactive Demo 2.1\n", "\n", "Let's explore the effect of smoothing factor ($\\sigma$), p-norm ($p$) and scaling factor ($g$) in our normalization function:\n", "\n", "$$\\mathbf{x}_{norm} = \\frac{g \\mathbf{x}}{\\sigma + \\sqrt[p]{\\Sigma_{i = 1}^{N} |x_{i}|^{p}}}$$\n", "\n", "We will see the effect of normalization being induced on the points sampled from a 2-dimensional normal distribution.\n", "\n", "Take a minute to play around with the values and then discuss them in the group." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Effect of smoothing factor ($\\sigma$), p-norm ($p$) and scaling factor ($g$)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Effect of smoothing factor ($\\sigma$), p-norm ($p$) and scaling factor ($g$)\n", "\n", "n_points = 1000\n", "n_dim = 2\n", "x_sec21 = np.random.normal(loc=0.0, scale=0.5, size=(n_points, n_dim))\n", "\n", "@widgets.interact(sigma=widgets.FloatSlider(0.1, min=0, max=2, description='σ', layout=widgets.Layout(width='50%')),\\\n", " p=widgets.FloatSlider(1, min=0.1, max=5, description=r'p', layout=widgets.Layout(width='50%')), \\\n", " g=widgets.FloatSlider(1, min=0.1, max=2, description='g', layout=widgets.Layout(width='50%')))\n", "def visualize_normalization(sigma, p, g):\n", " x_ = normalize_implemented(x_sec21, sigma, p, g)\n", " # Create a figure and axis\n", " fig, ax = plt.subplots(figsize=(5, 5))\n", "\n", " # Set the spines (axes lines) to intersect at the center\n", " ax.spines['left'].set_position('zero')\n", " ax.spines['bottom'].set_position('zero')\n", " ax.spines['right'].set_color('none')\n", " ax.spines['top'].set_color('none')\n", "\n", " # Set the ticks\n", " ax.xaxis.set_ticks_position('bottom')\n", " ax.yaxis.set_ticks_position('left')\n", "\n", " # Bold ticks\n", " for tick in ax.get_xticklabels():\n", " tick.set_fontweight('bold')\n", " for tick in ax.get_yticklabels():\n", " tick.set_fontweight('bold')\n", "\n", " ax.plot(x_sec21[:, 0], x_sec21[:, 1], '.b', markersize=5, alpha=0.5, label='Original')\n", " ax.plot(x_[:, 0], x_[:, 1], '.r', markersize=5, alpha=0.75, label='Normalized')\n", " ax.set_xlabel('$x_{1}$', loc='right', fontsize=20, fontweight='bold')\n", " ax.set_ylabel('$x_{2}$', loc='top', rotation=0, fontsize=20, fontweight='bold')\n", " ax.set_xlim((-2, 2))\n", " ax.set_ylim((-2, 2))\n", " ax.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Video 4: Effect of smoothing factor, p-norm and scaling factor\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 4: Effect of smoothing factor, p-norm and scaling factor\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'v8lubAW9GKk'), ('Bilibili', 'BV1Ns421u7Lt')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_hyperparameters_impact\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "## Subsection 2.2: Estimating latent properties" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In this subsection, we will use the normalization function to retrieve the target variable being corrupted with scaling." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 5: Normalization example\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 5: Normalization example\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'RiNXyNqH9YI'), ('Bilibili', 'BV1RJ4m1g7t4')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_normalization_example\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Interactive Demo 2.2.1" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "For this demo, we have a target image ($\\mathbf{x}$), which we would like to infer, and a visible image ($\\mathbf{v}$), which is a scaled version of the target one: $\\mathbf{x} = s \\mathbf{v}$. We will generate 300 different examples (we will visualize only 25 of them) of 5-dimensional vectors $\\mathbf{x}$ (each of the components of the vectors is generated from an exponential distribution with $\\lambda = 1$). Then, the scaling factor $s$ is generated from an exponential distribution with $\\lambda = 1$ as well.\n", "\n", "Play around with different hyperparameter values to get the best R-squared value." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "number_samples = 300 # Number of samples\n", "number_pixels = 5 # Number of pixels per sample\n", "\n", "# True reflectance\n", "reflectance = np.random.exponential(1, size=(number_samples, number_pixels))\n", "# Illuminant intensity\n", "illuminant_intensity = np.random.exponential(1, size=(number_samples, 1))\n", "# Visible image\n", "visible_image = np.repeat(illuminant_intensity, number_pixels, axis=1) * reflectance\n", "\n", "#################################################\n", "## TODO: Implement the normalization example equation ##\n", "# Fill remove the following line of code one you have completed the exercise:\n", "raise NotImplementedError(\"Student exercise: choose your parameters values.\")\n", "#################################################\n", "\n", "# Normalized visible image\n", "norm_visible_image = normalize(\n", " visible_image,\n", " sigma = ...,\n", " p = ...,\n", " g = ...\n", ")\n", "\n", "# Visualize the images\n", "visualize_images_sec22(\n", " [reflectance, illuminant_intensity, visible_image, norm_visible_image],\n", " ['Reflectance', 'Illuminant intensity', 'Visible image', 'Normalized visible image'],\n", " 25\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D5_Microcircuits/solutions/W1D5_Tutorial2_Solution_9b3b7306.py)\n", "\n", "*Example output:*\n", "\n", "Solution hint\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot correlation of estimated reflectance with true reflectance\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plot correlation of estimated reflectance with true reflectance\n", "\n", "with plt.xkcd():\n", " fig, ax = plt.subplots(1, 2, figsize=(15, 5))\n", " # Plot v vs. x\n", " x_ = reflectance.reshape((-1, 1)).squeeze(-1)\n", " y_ = visible_image.reshape((-1, 1)).squeeze(-1)\n", " sns.regplot(x=x_, y=y_, ax=ax[0])\n", " ax[0].set_xlabel('Reflectance')\n", " ax[0].set_ylabel('Visible image')\n", " # Calculate R-squared and p-value\n", " result = scipy.stats.linregress(x_, y_)\n", " ax[0].set_title(f'$R^{2}$ = {(result.rvalue**2):.2f}')\n", "\n", " # Plot xh vs. x\n", " x_ = reflectance.reshape((-1, 1)).squeeze(-1)\n", " y_ = norm_visible_image.reshape((-1, 1)).squeeze(-1)\n", " sns.regplot(x=x_, y=y_, ax=ax[1])\n", " ax[1].set_xlabel('Reflectance')\n", " ax[1].set_ylabel('Normalized visible image')\n", " ax[1].set_ylim((-0.05, 1.05))\n", " # Calculate R-squared and p-value\n", " result = scipy.stats.linregress(x_, y_)\n", " ax[1].set_title(f'$R^{2}$ = {(result.rvalue**2):.2f}')\n", "\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_infer_visible_image\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Interactive Demo 2.2.2" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, explore the impact of change in the hyperparemeter values on the R-squared value by using interactive sliders below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Effect of smoothing factor ($\\sigma$), p-norm ($p$) and scaling factor ($g$)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Effect of smoothing factor ($\\sigma$), p-norm ($p$) and scaling factor ($g$)\n", "@widgets.interact(sigma=widgets.FloatSlider(0.1, min=0, max=2, description='σ', layout=widgets.Layout(width='50%')),\\\n", " p=widgets.FloatSlider(1, min=0.1, max=5, description=r'p', layout=widgets.Layout(width='50%')), \\\n", " g=widgets.FloatSlider(1, min=0.1, max=2, description='g', layout=widgets.Layout(width='50%')))\n", "def plot_correlation_vhat_x(sigma, p, g):\n", " # Plot xh vs. x\n", " x_ = reflectance.reshape((-1, 1)).squeeze(-1)\n", " y_ = normalize_implemented(visible_image, sigma, p, g).reshape((-1, 1)).squeeze(-1)\n", " sns.regplot(x=x_, y=y_)\n", " plt.xlabel('Reflectance')\n", " plt.ylabel('Normalized visible image')\n", " plt.ylim((-0.05, 1.05))\n", " # Calculate R-squared and p-value\n", " result = scipy.stats.linregress(x_, y_)\n", " plt.title(f'$R^{2}$ = {(result.rvalue**2):.2f}')\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Video 6: Summary\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 6: Summary\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'uEY8s-sxRMA'), ('Bilibili', 'BV1cn4y1X7sF')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_latent_properties_section_summary\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Subsection 2.3: Layer Normalization" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In this subsection, we are going to explore normalization operation on the real application of a scaled subset of the CIFAR10 dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 7: Image recognition example\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 7: Image recognition example\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'vMGOlm94iP0'), ('Bilibili', 'BV1Gz42187yV')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_image_recognition_example\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "First, we fetch the dataset and derive only the part corresponding to three classes: plane, car, and cat. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get CIFAR3 scaled and unscaled dataloader\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Get CIFAR3 scaled and unscaled dataloader\n", "\n", "set_seed(seed=42, seed_torch=True)\n", "\n", "batch_size = 512\n", "\n", "# CIFAR-10 dataset\n", "cifar_transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Resize((8, 8)),\n", " transforms.Normalize(\n", " (0.49139968, 0.48215827, 0.44653124),\n", " (0.24703233, 0.24348505, 0.26158768))\n", " ])\n", "\n", "cifar10_classes = ('plane', 'car', 'bird', 'cat',\n", " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", "\n", "class CIFAR3(torch.utils.data.Dataset):\n", " def __init__(self, dataset, classes, scale=False, scaling_limits=[0.05, 0.25]):\n", " self.dataset = dataset\n", " self.classes = classes\n", "\n", " # Filter indices based on chosen classes\n", " cls_indices = {}\n", " subset_ratio = 0.8\n", " self.indices = []\n", " for cls, cls_enum in self.classes.items():\n", " imgs_ = [i for i, label in enumerate(self.dataset.targets) if label == cls]\n", " self.indices += imgs_[:int(len(imgs_)*subset_ratio)]\n", "\n", " self.scale = scale\n", " self.scaling_limits = scaling_limits\n", " if self.scale:\n", " self.scale_distribution = torch.distributions.Uniform(*self.scaling_limits)\n", " self.scale_values = self.scale_distribution.sample((len(self.indices),))\n", "\n", " def __getitem__(self, index_cifar3):\n", " index_cifar10 = self.indices[index_cifar3]\n", " image, label = self.dataset[index_cifar10]\n", " label = self.classes[label]\n", " if self.scale:\n", " scalar = self.scale_values[index_cifar3]\n", " image = image * scalar\n", " return image, label\n", "\n", " def __len__(self):\n", " return len(self.indices)\n", "\n", "cifar3_classes = {\n", " 0: 0, # plane\n", " 1: 1, # car\n", " 3: 2, # cat\n", "}\n", "\n", "# CIFAR-10 datasets\n", "cifar10_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", " download=True, transform=cifar_transform)\n", "cifar10_testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", " download=True, transform=cifar_transform)\n", "\n", "# CIFAR3\n", "cifar3_trainset = CIFAR3(cifar10_trainset, cifar3_classes)\n", "cifar3_trainloader = torch.utils.data.DataLoader(cifar3_trainset, batch_size=batch_size,\n", " shuffle=True)\n", "cifar3_testset = CIFAR3(cifar10_testset, cifar3_classes)\n", "cifar3_testloader = torch.utils.data.DataLoader(cifar3_testset, batch_size=batch_size,\n", " shuffle=False)\n", "\n", "# CIFAR3 Scaled\n", "cifar3scaled_trainset = CIFAR3(cifar10_trainset, cifar3_classes, scale=True)\n", "cifar3scaled_trainloader = torch.utils.data.DataLoader(cifar3scaled_trainset, batch_size=batch_size,\n", " shuffle=True)\n", "cifar3scaled_testset = CIFAR3(cifar10_testset, cifar3_classes, scale=True)\n", "cifar3scaled_testloader = torch.utils.data.DataLoader(cifar3scaled_testset, batch_size=batch_size,\n", " shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let us take a look at the original (latent) images and their scaled versions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize Images\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize Images\n", "\n", "set_seed(seed=42, seed_torch=True)\n", "\n", "num_images_show = 8\n", "cifar_demo_transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(\n", " (0.49139968, 0.48215827, 0.44653124),\n", " (0.24703233, 0.24348505, 0.26158768))\n", " ])\n", "\n", "# Demo datasets\n", "cifar10_demo_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", " download=True, transform=cifar_demo_transform)\n", "cifar3_demo_trainset = CIFAR3(cifar10_demo_trainset, cifar3_classes)\n", "cifar3scaled_demo_trainset = CIFAR3(cifar10_demo_trainset, cifar3_classes, scale=True)\n", "\n", "# functions to show an image\n", "def imshow(img, title):\n", " with plt.xkcd():\n", " img = (img * 0.25) + 0.5 # unnormalize\n", " npimg = img.numpy()\n", " plt.title(title)\n", " plt.imshow(np.transpose(npimg, (1, 2, 0)))\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.show()\n", "\n", "# get some random training images\n", "random_idxs = np.random.choice(range(len(cifar3_demo_trainset)), num_images_show)\n", "cifar3_imgs = None\n", "cifar3scaled_imgs = None\n", "cifar3_labels_ = []\n", "\n", "for ridx in random_idxs:\n", " c4img, lbl_ = cifar3_demo_trainset[ridx]\n", " c4simg, slbl_ = cifar3scaled_demo_trainset[ridx]\n", " assert lbl_ == slbl_, 'Labels do not match'\n", " cifar3_labels_.append(lbl_)\n", " cifar3_imgs = c4img.unsqueeze(0) if (cifar3_imgs is None) else torch.cat((cifar3_imgs, c4img.unsqueeze(0)))\n", " cifar3scaled_imgs = c4simg.unsqueeze(0) if (cifar3scaled_imgs is None) else torch.cat((cifar3scaled_imgs, c4simg.unsqueeze(0)))\n", "\n", "# print labels\n", "print('Labels = ' + ', '.join(f'{cifar10_classes[list(cifar3_classes.keys())[cifar3_labels_[j]]]:5s}' for j in range(num_images_show)))\n", "# show images\n", "imshow(torchvision.utils.make_grid(cifar3_imgs), 'CIFAR3')\n", "# print scaling factor\n", "print('Scaling factors = ' + ', '.join([str(round(i.item(), 3) )for i in cifar3scaled_demo_trainset.scale_values[random_idxs]]))\n", "# show scaled images\n", "imshow(torchvision.utils.make_grid(cifar3scaled_imgs), 'CIFAR3 Scaled')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Here, we define the CNN model with an optional parameter for adding the `LayerNorm` layer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define CIFARNet model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Define CIFARNet model\n", "\n", "set_seed(seed=42, seed_torch=True)\n", "\n", "class CIFARNet(nn.Module):\n", " def __init__(self, layer_norm=False):\n", " super().__init__()\n", " self.features = nn.Sequential(OrderedDict([\n", " ('conv1', nn.Conv2d(3, 16, 5, padding='same')),\n", " ('norm1', nn.LayerNorm((16, 8, 8))),\n", " ('relu1', nn.ReLU()),\n", " ('maxpool1', nn.MaxPool2d(3, padding=1, stride=2)),\n", " ('conv2', nn.Conv2d(16, 32, 3, padding='same')),\n", " ('norm2', nn.LayerNorm((32, 4, 4))),\n", " ('relu2', nn.ReLU()),\n", " ('avgpool2', nn.AvgPool2d(3, padding=1, stride=2))\n", " ]))\n", " if not layer_norm:\n", " del self.features.norm1\n", " del self.features.norm2\n", " self.classifier = nn.Sequential(OrderedDict([\n", " ('fc1', nn.Linear(128, 64)),\n", " ('fc2', nn.Linear(64, 3)),\n", " ]))\n", " # Initialize weights\n", " nn.init.normal_(self.features.conv1.weight, mean=0.0, std=1e-4)\n", " nn.init.normal_(self.features.conv2.weight, mean=0.0, std=1e-4)\n", " nn.init.normal_(self.classifier.fc1.weight, mean=0.0, std=1e-1)\n", " nn.init.normal_(self.classifier.fc2.weight, mean=0.0, std=1e-1)\n", "\n", " def forward(self, x):\n", " x = self.features(x)\n", " x = torch.flatten(x, 1) # flatten all dimensions except batch\n", " x = self.classifier(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "It will take around 3 minutes to complete training on different types of models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training & Evaluating the models\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Training & Evaluating the models\n", "\n", "# Training\n", "n_epochs = 10\n", "learning_rate = 5e-2\n", "momentum = 0.9\n", "\n", "# Unscaled CIFAR3\n", "cifar3_net = CIFARNet(layer_norm=False).to(DEVICE)\n", "losses_iter, losses_epoch = train_cnns(cifar3_net, cifar3_trainloader, \\\n", " learning_rate, momentum, n_epochs)\n", "# With LayerNorm\n", "cifar3_net_LN = CIFARNet(layer_norm=True).to(DEVICE)\n", "losses_iter_LN, losses_epoch_LN = train_cnns(cifar3_net_LN, cifar3_trainloader, \\\n", " learning_rate, momentum, n_epochs)\n", "\n", "# Scaled CIFAR3\n", "cifar3scaled_net = CIFARNet(layer_norm=False).to(DEVICE)\n", "losses_iter_scaled, losses_epoch_scaled = train_cnns(cifar3scaled_net, cifar3scaled_trainloader, \\\n", " learning_rate, momentum, n_epochs)\n", "# With LayerNorm\n", "cifar3scaled_net_LN = CIFARNet(layer_norm=True).to(DEVICE)\n", "losses_iter_scaled_LN, losses_epoch_scaled_LN = train_cnns(cifar3scaled_net_LN, cifar3scaled_trainloader, \\\n", " learning_rate, momentum, n_epochs)\n", "\n", "with plt.xkcd():\n", " # Plot training losses\n", " fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", "\n", " # Plot 1 - Unscaled CIFAR3\n", " # Plot loss per epoch\n", " ax[0].plot(range(1, len(losses_epoch)+1), np.log(losses_epoch), '-b', label='Without LN')\n", " ax[0].plot(range(1, len(losses_epoch_LN)+1), np.log(losses_epoch_LN), '-r', label='With LN')\n", " ax[0].legend()\n", " ax[0].set_xlabel('Epochs')\n", " ax[0].set_ylabel('Log Cross Entropy Loss')\n", " ax[0].set_title('Unscaled CIFAR3: Training Loss')\n", "\n", " # Plot 2 - Scaled CIFAR3\n", " # Plot loss per epoch\n", " ax[1].plot(range(1, len(losses_epoch_scaled)+1), np.log(losses_epoch_scaled), '-b', label='Without LN')\n", " ax[1].plot(range(1, len(losses_epoch_scaled_LN)+1), np.log(losses_epoch_scaled_LN), '-r', label='With LN')\n", " ax[1].legend()\n", " ax[1].set_xlabel('Epochs')\n", " ax[1].set_ylabel('Log Cross Entropy Loss')\n", " ax[1].set_title('Scaled CIFAR3: Training Loss')\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Training evaluation\n", "training_loss, training_accuracy = evaluate_cnns(cifar3_net, cifar3_trainloader)\n", "training_loss_LN, training_accuracy_LN = evaluate_cnns(cifar3_net_LN, cifar3_trainloader)\n", "training_loss_scaled, training_accuracy_scaled = evaluate_cnns(cifar3scaled_net, cifar3scaled_trainloader)\n", "training_loss_scaled_LN, training_accuracy_scaled_LN = evaluate_cnns(cifar3scaled_net_LN, cifar3scaled_trainloader)\n", "\n", "# Evaluation\n", "test_loss, test_accuracy = evaluate_cnns(cifar3_net, cifar3_testloader)\n", "test_loss_LN, test_accuracy_LN = evaluate_cnns(cifar3_net_LN, cifar3_testloader)\n", "test_loss_scaled, test_accuracy_scaled = evaluate_cnns(cifar3scaled_net, cifar3scaled_testloader)\n", "test_loss_scaled_LN, test_accuracy_scaled_LN = evaluate_cnns(cifar3scaled_net_LN, cifar3scaled_testloader)\n", "\n", "with plt.xkcd():\n", " # Plot\n", " fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "\n", " training_cm = [[training_accuracy, training_accuracy_LN], [training_accuracy_scaled, training_accuracy_scaled_LN]]\n", " testing_cm = [[test_accuracy, test_accuracy_LN], [test_accuracy_scaled, test_accuracy_scaled_LN]]\n", " all_cm = np.array(training_cm+training_cm).flatten()\n", " vmin = np.min(all_cm)\n", " vmax = np.max(all_cm)\n", "\n", " training_disp = ConfusionMatrixDisplay(confusion_matrix=np.array(training_cm),\n", " display_labels=['Unscaled CIFAR3', 'Scaled CIFAR3'])\n", " training_disp.plot(cmap=plt.cm.Greys, ax=axs[0])\n", " axs[0].images[-1].set_clim(vmin, vmax)\n", " axs[0].set_yticks([0, 1], ['Unscaled CIFAR3', 'Scaled CIFAR3'], rotation=90)\n", " axs[0].set_xticks([0, 1], ['without LayerNorm', 'with LayerNorm'])\n", " axs[0].set_xlabel('Model')\n", " axs[0].set_ylabel('Dataset')\n", " axs[0].set_title('Training Accuracy')\n", "\n", " testing_disp = ConfusionMatrixDisplay(confusion_matrix=np.array(testing_cm),\n", " display_labels=['Unscaled CIFAR3', 'Scaled CIFAR3'])\n", " testing_disp.plot(cmap=plt.cm.Greys, ax=axs[1])\n", " axs[1].images[-1].set_clim(vmin, vmax)\n", " axs[1].set_yticks([0, 1], ['Unscaled CIFAR3', 'Scaled CIFAR3'], rotation=90)\n", " axs[1].set_xticks([0, 1], ['without LayerNorm', 'with LayerNorm'])\n", " axs[1].set_xlabel('Model')\n", " axs[1].set_ylabel('Dataset')\n", " axs[1].set_title('Testing Accuracy')\n", " plt.tight_layout()\n", " # Removing shadows from text inside confusion matrix\n", " for txt in training_disp.text_.flatten():\n", " txt.set_path_effects([path_effects.Normal()])\n", " for txt in testing_disp.text_.flatten():\n", " txt.set_path_effects([path_effects.Normal()])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let us also check whether normalization improves generalization with different scaling parameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test Generalization\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Test Generalization\n", "\n", "n_tests = 5\n", "\n", "cifar3_scaling_performances = {}\n", "cifar3_scaling_performances_LN = {}\n", "\n", "# scaling_limits_tests = [[0.3, 0.5], [0.5, 0.75], [0.75, 1]]\n", "scaling_limits_tests = [[0.4, 0.6], [0.8, 1]]\n", "\n", "for test_idx in range(n_tests):\n", " cifar3_scaled_testloaders = {\n", " '0.05-0.25 (training range)': cifar3scaled_testloader\n", " }\n", "\n", " for sc_ in scaling_limits_tests:\n", " sc_str = '-'.join([str(i) for i in sc_])\n", " testset_ = CIFAR3(cifar10_testset, cifar3_classes, scale=True, scaling_limits=sc_)\n", " testloader_ = torch.utils.data.DataLoader(testset_, batch_size=batch_size,\n", " shuffle=False)\n", " cifar3_scaled_testloaders[sc_str] = testloader_\n", "\n", " for scaling_lims, scaling_testloader in cifar3_scaled_testloaders.items():\n", " _, noLN_accuracy = evaluate_cnns(cifar3scaled_net, scaling_testloader)\n", " _, LN_accuracy = evaluate_cnns(cifar3scaled_net_LN, scaling_testloader)\n", " if scaling_lims in cifar3_scaling_performances.keys():\n", " cifar3_scaling_performances[scaling_lims] += [noLN_accuracy]\n", " cifar3_scaling_performances_LN[scaling_lims] += [LN_accuracy]\n", " else:\n", " cifar3_scaling_performances[scaling_lims] = [noLN_accuracy]\n", " cifar3_scaling_performances_LN[scaling_lims] = [LN_accuracy]\n", "\n", "with plt.xkcd():\n", " plt.figure(figsize=(5, 5))\n", " mean_cifar3_scaling_performances = [np.mean(i) for i in cifar3_scaling_performances.values()]\n", " std_cifar3_scaling_performances = [np.std(i) for i in cifar3_scaling_performances.values()]\n", " plt.plot(range(1, len(cifar3_scaled_testloaders)+1), mean_cifar3_scaling_performances, \\\n", " '-.b', label='Without LN')\n", " plt.errorbar(range(1, len(cifar3_scaled_testloaders)+1), mean_cifar3_scaling_performances, \\\n", " yerr=std_cifar3_scaling_performances, color='b', capsize=5, capthick=2)\n", " mean_cifar3_scaling_performances_LN = [np.mean(i) for i in cifar3_scaling_performances_LN.values()]\n", " std_cifar3_scaling_performances_LN = [np.std(i) for i in cifar3_scaling_performances_LN.values()]\n", " plt.plot(range(1, len(cifar3_scaled_testloaders)+1), mean_cifar3_scaling_performances_LN, \\\n", " '-.r', label='With LN')\n", " plt.errorbar(range(1, len(cifar3_scaled_testloaders)+1), mean_cifar3_scaling_performances_LN, \\\n", " yerr=std_cifar3_scaling_performances_LN, color='r', capsize=5, capthick=2)\n", " plt.xticks(range(1, len(cifar3_scaled_testloaders)+1), labels=cifar3_scaled_testloaders.keys())\n", " plt.xlabel('CIFAR3 scaling range')\n", " plt.ylabel('Test accuracy')\n", " plt.title('Testing Generalization for scaled CIFARNet')\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Indeed, by adding a normalization layer, the training process converges quicker, and we achieve better test accuracy as well as better out-of-distribution generalization accuracy in the image recognition tasks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_layer_normalization_example\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 8: Section summary\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 8: Section summary\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', '2ME2uKnRisw'), ('Bilibili', 'BV1zw4m1Y7W4')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_last_section_summary\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "*Estimated timing of tutorial: 50 minutes*\n", "\n", "In this tutorial, we observed that normalization as an inductive bias is useful. We have implemented the normalization function and explored the examples. Finally, we discovered the benefits of using normalization." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 9: Tutorial summary\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 9: Tutorial summary\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'gsMsRPWLJj8'), ('Bilibili', 'BV15z421b7uf')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_tutorial_summary\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Bonus Section: Benefits of using normalization - Efficient Coding" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Non-linearities are critical for computation, but they can also lose information. We propose you look at a very simple example of how normalization can help preserve information through a network. In the exercise below, complete `HardTanh` and functions `LeakyHardTanh` (observe that the inverse of the latter is already here for you).\n", "\n", "`HardTanh` is the function $f(x)$ which is defined as following:\n", "\n", "$$f(x) = \\begin{cases}\n", "1, & \\text{if } x > 1\\\\\n", "x, & \\text{if } -1 \\leq x \\leq 1\\\\\n", "-1, & \\text{if } x < -1\n", "\\end{cases}$$\n", "\n", "while `LeakyHardTanh` is $f(x) = \\text{HardTanh}(x) + \\text{leak-slope}* x$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "#################################################\n", "## TODO: Implement the normalization example equation ##\n", "# Fill remove the following line of code once you have completed the exercise:\n", "raise NotImplementedError(\"Student exercise: complete missing calculations in `HardTanh` and `LeakyHardTanh` functions.\")\n", "#################################################\n", "\n", "def HardTanh(x):\n", " \"\"\"\n", " Calculate `tanh` output for the given input data.\n", "\n", " Inputs:\n", " - x (np.ndarray): input data.\n", "\n", " Outputs:\n", " - output (np.ndarray): `tanh(x)`.\n", " \"\"\"\n", " min_val = -1\n", " max_val = 1\n", " output = np.copy(x)\n", " output[output>...] = ...\n", " output[output<...] = ...\n", " return output\n", "\n", "def LeakyHardTanh(x, leak_slope=0.03):\n", " \"\"\"\n", " Calculate `tanh` output for the given input data with the leaky term.\n", "\n", " Inputs:\n", " - x (np.ndarray): input data.\n", " - leak_slope (float, default = 0.03): leaky term.\n", "\n", " Outputs:\n", " - output (np.ndarray): `tanh(x)`.\n", " \"\"\"\n", " output = np.copy(x)\n", " output = HardTanh(output) + ...*...\n", " return output\n", "\n", "def InverseLeakyHardTanh(y, leak_slope=0.03):\n", " \"\"\"\n", " Calculate input into the `tanh` function with the leaky term for the given output.\n", "\n", " Inputs:\n", " - y (np.array): output of leaky tanh function.\n", " - leak_slope (float, default = 0.03): leaky term.\n", "\n", " Outputs:\n", " - output (np.array): input into leaky tanh function.\n", " \"\"\"\n", " ycopy = np.copy(y)\n", " output = np.where(\n", " np.abs(ycopy) >= 1+leak_slope, \\\n", " (ycopy - np.sign(ycopy))/leak_slope, \\\n", " ycopy/(1+leak_slope)\n", " )\n", " return output" ] }, { "cell_type": "markdown", "metadata": { "cellView": "form", "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D5_Microcircuits/solutions/W1D5_Tutorial2_Solution_60908715.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize the functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize the functions\n", "\n", "# with plt.xkcd():\n", "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", "plot_vals = np.arange(-4, 4, 0.01)\n", "leak_slope = 0.03\n", "for i in range(3):\n", " # Set the spines (axes lines) to intersect at the center\n", " ax[i].spines['left'].set_position('zero')\n", " ax[i].spines['bottom'].set_position('zero')\n", " ax[i].spines['right'].set_color('none')\n", " ax[i].spines['top'].set_color('none')\n", " ax[i].set_xlabel('x', loc='right', fontsize=20)\n", "ax[0].plot(plot_vals, LeakyHardTanh(plot_vals, leak_slope), '-k')\n", "ax[0].set_title('LeakyHardTanh(x)', fontsize=14)\n", "ax[1].plot(plot_vals, InverseLeakyHardTanh(plot_vals, leak_slope), '-k')\n", "ax[1].set_title('InverseLeakyHardTanh(x)', fontsize=14)\n", "ax[2].plot(plot_vals, InverseLeakyHardTanh(LeakyHardTanh(plot_vals, leak_slope), leak_slope), '-k')\n", "ax[2].set_title('InverseLeakyHardTanh( LeakyHardTanh(x) )', fontsize=14)\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, let's define an $n$-dimensional vector $\\mathbf{x}$. This is our target latent variable, and we would like to preserve the information about it. However, $\\mathbf{x}$ is corrupted by a multiplicative scalar nuisance variable, g: $\\mathbf{y}$=g$\\mathbf{x}$.\n", "\n", "Downstream computation will use $\\mathbf{y}$ by passing it through an element-wise non-linearity $f$ (that saturates beyond a certain input range) and adding noise. By doing so, we lose information -- potentially a lot of information if $g$ is large and pushes the inputs into the saturating part of the non-linearity.\n", "\n", "If we knew $g$, then we could remove it by division and reduce the problem. Although we don't know $g$, we can still use Normalization as an estimate of $g$, divide by that estimate, and invert the non-linearity to recover an approximation of the original $\\mathbf{x}$. Here we use a `LeakyHardTanh`, which almost saturates but is technically invertible.\n", "\n", "Let's see if Normalization helps. We will compute the correlation between x and the estimate $\\hat{x}$ and compare this correlation with and without the usage of the Normalization function.\n", "\n", "Our information ($\\mathbf{X}$) is a collection of 10-dimensional vectors, having 400 samples in total. $\\mathbf{X} \\in \\mathbb{R}^{400 \\times 10}$, each of the components are drawn from $ \\mathcal{N}(0, 1)$. For each component for each vector in $\\mathbf{X}$, we have a nuisance scaling factor $s \\in \\mathbb{R}^{400}$, $s \\sim Exp(0.2)$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Data\n", "\n", "def normalize(x, sigma, p, g):\n", " \"\"\"\n", " Inputs:\n", " - x(np.ndarray): Input array (n_samples * n_dim)\n", " - sigma(float): Smoothing factor\n", " - p(int): p-norm\n", " - g(int): scaling factor\n", "\n", " Outputs:\n", " - xnorm (np.ndarray): normalized values.\n", " \"\"\"\n", " # Raise the absolute value of x to the power p\n", " xp = np.power(np.abs(x), p)\n", " # Sum the x over the dimensions (n_dim) axis\n", " xp_sum = np.sum(np.power(np.abs(x), p), axis=1)\n", " # Correct the dimensions of xp_sum, and taking the average reduces the dimensions\n", " # Making xp_sum a row vector of shape (1, n_dim)\n", " xp_sum = np.expand_dims(xp_sum, axis=1)\n", " # Raise the sum to the power 1/p and add the smoothing factor (sigma)\n", " denominator = sigma + np.power(xp_sum, 1/p)\n", " # Scale the input data with a factor of g\n", " numerator = x*g\n", " # Calculate normalized x\n", " xnorm = numerator/denominator\n", " return xnorm\n", "\n", "# data\n", "n_samples = 400 # number of samples\n", "n_dim = 10 # dimensions of each sample\n", "latent_std = 1 # width of latent distribution\n", "\n", "# nuisance\n", "nuisance_scale = 5 # distribution width for nuisance scaling factor\n", "\n", "# normalization\n", "smoothing_factor = 0.1 # normalization smoothness - sigma\n", "norm_p = 2 # Lp norm\n", "norm_scale = 1 # normalization scale\n", "\n", "# noise\n", "noise_std = 0.05 # added noise standard deviation\n", "\n", "# Non-Linearity\n", "leak_slope = 0.001 # slope after leaky saturation\n", "\n", "# random nuisance scaling for each example vector\n", "nuisance = np.random.exponential(nuisance_scale, size=(n_samples, 1))\n", "x_sec31 = np.random.normal(loc=0.0, scale=latent_std, size=(n_samples, n_dim))\n", "y_sec31 = x_sec31 * nuisance # input vectors scaled by random nuisance\n", "ynorm_sec31 = normalize(y_sec31, smoothing_factor, norm_p, norm_scale) * norm_scale # normalized vectors\n", "noise = np.random.normal(loc=0.0, scale=noise_std, size=(n_samples, n_dim))\n", "\n", "# without normalization\n", "transmit_noisy_x = LeakyHardTanh(y_sec31, leak_slope) + noise\n", "estimate_x = InverseLeakyHardTanh(transmit_noisy_x, leak_slope)\n", "# with normalization\n", "transmitNormalized_noisy_x = LeakyHardTanh(ynorm_sec31, leak_slope) + noise\n", "estimateNormalized_x = InverseLeakyHardTanh(transmitNormalized_noisy_x, leak_slope)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's take a look at one of the dimensions of $\\mathbf{x}$ and visualize it after nuisance scaling as well as after normalization.\n", "\n", "$$\\mathbf{x}_{norm} = \\frac{g \\mathbf{x}}{\\sigma + \\sqrt[p]{\\Sigma_{i = 1}^{N} |x_{i}|^{p}}}$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize input\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize input\n", "with plt.xkcd():\n", " sns.kdeplot(ynorm_sec31[:, 0], color='r', label='$(s \\mathbf{x})_{norm}$')\n", " sns.kdeplot(x_sec31[:, 0], color='k', label='$\\mathbf{x}$')\n", " sns.kdeplot(y_sec31[:, 0], color='b', label='$s \\mathbf{x}$')\n", " plt.xlabel('Information (x)')\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, let's transmit this observable information through a network. In this example, the network is an element-wise `LeakyHardTanh`. Additionally, the transmission is noisy with transmission noise $n \\sim \\mathcal{N}(0, 0.05)$.\n", "\n", "Hence, the transmitted signal is `LeakyHardTanh`($s\\mathbf{x}$) + $n$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize noisy transmitted signal\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize noisy transmitted signal\n", "with plt.xkcd():\n", " plt.figure(figsize=(7.5, 7.5))\n", " sns.kdeplot(LeakyHardTanh(y_sec31, leak_slope)[:, 0], linestyle='--', color='b', label=r'LeakyHardTanh$(s \\mathbf{x})$')\n", " sns.kdeplot(transmit_noisy_x[:, 0],color='b', label=r'LeakyHardTanh$(s \\mathbf{x})$+noise')\n", " sns.kdeplot(LeakyHardTanh(ynorm_sec31, leak_slope)[:, 0], linestyle='--', color='r', label='LeakyHardTanh$(s \\mathbf{x})_{norm}$')\n", " sns.kdeplot(transmitNormalized_noisy_x[:, 0], color='r', label='LeakyHardTanh$(s \\mathbf{x})_{norm}$+noise')\n", " plt.xlabel('Transmitted Signal')\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's estimate the true information by calculating the inverse of the network (`InverseLeakyHardTanh`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize estimated information\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize estimated information\n", "with plt.xkcd():\n", " sns.kdeplot(estimateNormalized_x[:, 0], color='r', label='$\\mathbf{\\hat{x}}_{norm}$')\n", " sns.kdeplot(x_sec31[:, 0], color='k', label='$\\mathbf{x}$')\n", " sns.kdeplot(estimate_x[:, 0], color='b', label='$\\mathbf{\\hat{x}}$')\n", " plt.xlabel('Estimated information (x)')\n", " plt.xlim(-50, 50)\n", " plt.legend()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's quantify how well we can estimate the true information by calculating R-squared values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot correlation between estimated information and true information\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plot correlation between estimated information and true information\n", "\n", "with plt.xkcd():\n", " fig, ax = plt.subplots(1, 2, figsize=(15, 5))\n", "\n", " # Plot x vs. estimated x\n", " x_ = x_sec31.reshape((-1, 1)).squeeze(-1)\n", " y_ = estimate_x.reshape((-1, 1)).squeeze(-1)\n", " sns.regplot(x=x_, y=y_, ax=ax[0], fit_reg=False)\n", " ax[0].set_xlabel('x')\n", " ax[0].set_ylabel(r'$\\hat{x}$')\n", " # Calculate R-squared and p-value\n", " result = scipy.stats.linregress(x_, y_)\n", " ax[0].set_title(r'$\\hat{x} \\enspace vs. \\enspace x, \\enspace R^{2} = $' + \\\n", " f'{(result.rvalue**2):.2f}')\n", " ax[0].set_ylabel('$\\hat{x}$', loc='bottom', fontsize=20)\n", " ax[0].set_ylim((-5, 5))\n", "\n", " # Plot x vs. estimated normalized x\n", " x_ = x_sec31.reshape((-1, 1)).squeeze(-1)\n", " y_ = estimateNormalized_x.reshape((-1, 1)).squeeze(-1)\n", " sns.regplot(x=x_, y=y_, ax=ax[1], fit_reg=False)\n", " ax[1].set_xlabel('x')\n", " ax[1].set_ylabel(r'$\\hat{x_{norm}}$')\n", " # ax[1].set_ylim((-1.05, 1.05))\n", " ax[1].set_ylim((-5, 5))\n", " # Calculate R-squared and p-value\n", " result = scipy.stats.linregress(x_, y_)\n", " ax[1].set_title(r'$\\hat{x_{norm}} \\enspace vs. \\enspace x, \\enspace R^{2} = $' + \\\n", " f'{(result.rvalue**2):.2f}')\n", " ax[1].set_ylabel('$\\hat{x_{norm}}$', loc='bottom', fontsize=20)\n", "\n", " for i in range(2):\n", " # Set the spines (axes lines) to intersect at the center\n", " ax[i].spines['left'].set_position('zero')\n", " ax[i].spines['bottom'].set_position('zero')\n", " ax[i].spines['right'].set_color('none')\n", " ax[i].spines['top'].set_color('none')\n", " ax[i].set_xlabel('x', loc='right', fontsize=20)\n", " ax[i].set_xlim((-4, 4))\n", "\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We see that normalization helped us preserve the information through transmission by preventing saturation (constraining the information within a limited dynamic range)." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Bonus Think 1.1\n", "\n", "1. We control the dynamic range for normalization. Does there exist an optimum range?" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's control the range by manipulating the scaling factor ($g$).\n", "\n", "$$\\mathbf{x}_{norm} = \\frac{g \\mathbf{x}}{\\sigma + \\sqrt[p]{\\Sigma_{1}^{N} |x_{i}|^{p}}}$$\n", "\n", "We will plot the improvement in the correlation versus the range that the normalization produces (via scaling factor $g$)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Effect of scaling normalization ($g$)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Effect of scaling normalization ($g$)\n", "norm_scales = np.arange(0.01, 5, 0.01)\n", "improvements = []\n", "\n", "x_ = x_sec31.reshape((-1, 1)).squeeze(-1)\n", "y_ = estimate_x.reshape((-1, 1)).squeeze(-1)\n", "result = scipy.stats.linregress(x_, y_)\n", "nonnorm_r2 = result.rvalue\n", "\n", "for norm_scale in norm_scales:\n", " ynorm_ = normalize(y_sec31, smoothing_factor, norm_p, norm_scale) # normalized vectors\n", " transmitNormalized_noisy_x = LeakyHardTanh(ynorm_, leak_slope) + noise\n", " estimateNormalized_x = InverseLeakyHardTanh(transmitNormalized_noisy_x, leak_slope)\n", " x_ = x_sec31.reshape((-1, 1)).squeeze(-1)\n", " y_ = estimateNormalized_x.reshape((-1, 1)).squeeze(-1)\n", " result = scipy.stats.linregress(x_, y_)\n", " norm_r2 = result.rvalue\n", " improvement = norm_r2/nonnorm_r2\n", " improvements.append(improvement)\n", "\n", "with plt.xkcd():\n", " plt.figure(figsize=(5,5))\n", " plt.plot(norm_scales, improvements, '.')\n", " plt.ylim((-0.05, 2.05))\n", " plt.xlabel('Normalization scaling factor ($g$)')\n", " plt.ylabel(r'Improvement')\n", " plt.title(r'Improvement = $\\frac{R^{2}(x, \\hat{x}_\\mathrm{scaled norm})}{R^{2}(x, \\hat{x})}$')\n", " ax = plt.gca()\n", " for line in ax.get_lines():\n", " line.set_path_effects([path_effects.Normal()])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "There's an optimal normalization range: while being too narrow - the noise dominates, and with too wide - the saturation destroys information." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Bonus Think 1.2\n", "\n", "1. Thinking deeper: here, we have used normalization only to preserve the information, essentially by avoiding most of the non-linearity. Do you think the computation can gain an advantage by using saturation? How?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_efficient_coding\")" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "e-l_z3KLBmb8", "Q-y5BZIKBmb9" ], "gpuType": "T4", "include_colab_link": true, "name": "W1D5_Tutorial2", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 4 }