{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"execution": {},
"id": "view-in-github"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"# Tutorial 5: Replay\n",
"\n",
"**Week 2, Day 4: Macro-Learning**\n",
"\n",
"**By Neuromatch Academy**\n",
"\n",
"__Content creators:__ Hlib Solodzhuk, Ximeng Mao, Grace Lindsay\n",
"\n",
"__Content reviewers:__ Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk, Ximeng Mao, Grace Lindsay\n",
"\n",
"__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"___\n",
"\n",
"\n",
"# Tutorial Objectives\n",
"\n",
"*Estimated timing of tutorial: 40 minutes*\n",
"\n",
"In this tutorial, you will discover what replay is and how it helps with continual learning."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @markdown\n",
"from IPython.display import IFrame\n",
"from ipywidgets import widgets\n",
"out = widgets.Output()\n",
"with out:\n",
" print(f\"If you want to download the slides: https://osf.io/download/t36w8/\")\n",
" display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/t36w8/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n",
"display(out)"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Setup\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install and import feedback gadget\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Install and import feedback gadget\n",
"\n",
"!pip install numpy matplotlib scikit-learn ipywidgets jupyter-ui-poll torch vibecheck --quiet\n",
"\n",
"from vibecheck import DatatopsContentReviewContainer\n",
"def content_review(notebook_section: str):\n",
" return DatatopsContentReviewContainer(\n",
" \"\", # No text prompt\n",
" notebook_section,\n",
" {\n",
" \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n",
" \"name\": \"neuromatch_neuroai\",\n",
" \"user_key\": \"wb2cxze8\",\n",
" },\n",
" ).render()\n",
"\n",
"\n",
"feedback_prefix = \"W2D4_T5\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Imports\n",
"\n",
"#working with data\n",
"import numpy as np\n",
"import random\n",
"\n",
"#plotting\n",
"import matplotlib.pyplot as plt\n",
"import logging\n",
"from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix\n",
"\n",
"#interactive display\n",
"import ipywidgets as widgets\n",
"from IPython.display import display, clear_output\n",
"from jupyter_ui_poll import ui_events\n",
"import time\n",
"from tqdm.notebook import tqdm\n",
"\n",
"#modeling\n",
"import copy\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Figure settings\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Figure settings\n",
"\n",
"logging.getLogger('matplotlib.font_manager').disabled = True\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n",
"plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Plotting functions\n",
"\n",
"def plot_rewards(rewards, max_rewards):\n",
" \"\"\"\n",
" Plot the rewards over time.\n",
"\n",
" Inputs:\n",
" - rewards (list): list containing the rewards at each time step.\n",
" - max_rewards(list): list containing the maximum rewards at each time step.\n",
" \"\"\"\n",
" with plt.xkcd():\n",
" plt.plot(range(len(rewards)), rewards, marker='o', label = \"Obtained Reward\")\n",
" plt.plot(range(len(max_rewards)), max_rewards, marker='*', label = \"Maximum Reward\")\n",
" plt.xlabel('Time Step')\n",
" plt.ylabel('Reward Value')\n",
" plt.title('Reward Over Time')\n",
" plt.yticks(np.arange(0, 5, 1))\n",
" plt.xticks(np.arange(0, len(rewards), 1))\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
"def plot_confusion_matrix(rewards, max_rewards, mode = 1):\n",
" \"\"\"\n",
" Plots the confusion matrix for the chosen rewards and the maximum ones.\n",
"\n",
" Inputs:\n",
" - rewards (list): list containing the rewards at each time step.\n",
" - max_rewards (list): list containing the maximum rewards at each time step.\n",
" - mode (int, default = 1): mode of the environment.\n",
" \"\"\"\n",
" with plt.xkcd():\n",
"\n",
" all_colors = [color for color in mode_colors[mode]]\n",
"\n",
" cm = confusion_matrix(max_rewards, rewards)\n",
"\n",
" missing_classes = np.setdiff1d(np.array([color_names_rewards[color_name] for color_name in all_colors]), np.unique(max_rewards + rewards))\n",
" for cls in missing_classes:\n",
" cm = np.insert(cm, cls - 1, 0, axis=0)\n",
" cm = np.insert(cm, cls - 1, 0, axis=1)\n",
"\n",
" cm = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = all_colors)\n",
" cm.plot()\n",
" plt.xlabel(\"Chosen color\")\n",
" plt.ylabel(\"Maximum-reward color\")\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Helper functions\n",
"\n",
"def run_dummy_agent(env):\n",
" \"\"\"\n",
" Implement dummy agent strategy: chooses random action.\n",
"\n",
" Inputs:\n",
" - env (ChangingEnv): An environment.\n",
" \"\"\"\n",
" action = 0\n",
" rewards = [0]\n",
" max_rewards = [0]\n",
"\n",
" for _ in (range(num_trials)):\n",
" _, reward, max_reward = env.step(action)\n",
" rewards.append(reward)\n",
" max_rewards.append(max_reward)\n",
"\n",
" #dummy agent\n",
" if np.random.random() < 0.5:\n",
" action = 1 - action #change action\n",
" return rewards, max_rewards\n",
"\n",
"color_names_rewards = {\n",
" \"red\": 1,\n",
" \"yellow\": 2,\n",
" \"green\": 3,\n",
" \"blue\": 4\n",
"}\n",
"\n",
"color_names_values = {\n",
" \"red\": [255, 0, 0],\n",
" \"yellow\": [255, 255, 0],\n",
" \"green\": [0, 128, 0],\n",
" \"blue\": [0, 0, 255]\n",
"}\n",
"\n",
"first_mode = [\"red\", \"yellow\", \"green\"]\n",
"second_mode = [\"red\", \"green\", \"blue\"]\n",
"\n",
"mode_colors = {\n",
" 1: first_mode,\n",
" 2: second_mode\n",
"}\n",
"\n",
"def game():\n",
" \"\"\"\n",
" Create interactive game for this tutorial.\n",
" \"\"\"\n",
"\n",
" total_reward = 0\n",
" message = \"Start of the game!\"\n",
"\n",
" left_button = widgets.Button(description=\"Left\")\n",
" right_button = widgets.Button(description=\"Right\")\n",
" button_box = widgets.HBox([left_button, right_button])\n",
"\n",
" def define_choice(button):\n",
" \"\"\"\n",
" Change `choice` variable with respect to the pressed button.\n",
" \"\"\"\n",
" nonlocal choice\n",
" display(widgets.HTML(f\"