{ "cells": [ { "cell_type": "markdown", "id": "195ff9f1-1b9c-4772-853a-f8498926ba63", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "id": "d78e57f5-dcf1-4fea-a607-dc8b812be7eb", "metadata": { "execution": {} }, "source": [ "# Tutorial 3: Generalization in Cognitive Science\n", "\n", "**Week 1, Day 1: Generalization**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Samuele Bolotta, Patrick Mineault\n", "\n", "__Content reviewers:__ Samuele Bolotta, Lily Chamakura, RyeongKyung Yoon, Yizhou Chen, Ruiyi Zhang, Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk\n", "\n", "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk\n" ] }, { "cell_type": "markdown", "id": "aab73da2-b26e-457e-a1ae-0e7815c47592", "metadata": { "execution": {} }, "source": [ "___\n", "\n", "\n", "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 30 minutes*\n", "\n", "This tutorial will introduce you to generalization in the context of cognitive science. We'll close the loop of our exploration of different views of handwriting with a model that combines aspects of the models we covered in the neuroscience and AI tutorials, including both generative and discriminative components. \n", "\n", "In particular, we'll be looking at the Omniglot dataset, and how it can be used to infer how humans and machines generalize in a handwritten symbol recognition task. We'll try our hand at one-shot learning, and we'll measure our sample complexity. We'll then discuss how one cognitive model, [Feinman and Lake](https://arxiv.org/abs/2006.14448) (2020), attempts to solve the problem of handwritten symbol recognition using a neuro-symbolic method. \n", "\n", "By the end of this tutorial, participants will be able to:\n", "\n", "1. Explore the goals of cognitive science. Understand the aims of cognitive science such as unraveling the complexities of human cognition.\n", "\n", "2. Define one-shot learning and sample complexity. Perform a task that involves one-shot learning.\n", "\n", "3. Explore how a neurosymbolic model with strong inductive biases could explain one-shot learning on Omniglot." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "450a02ee-d0ae-4923-ad3c-d978b9b4aa46", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/79523/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/79523/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "id": "9b51b72e-386a-49ad-94e8-648d62124e1f", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4fac5d8b-5d72-423b-bbb2-975e61668304", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip install matplotlib numpy Pillow scipy ipywidgets vibecheck tqdm --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_neuroai\",\n", " \"user_key\": \"wb2cxze8\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W1D1_T3\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "id": "627f2870-3fd6-4a63-8360-36763673490d", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Import dependencies\n", "\n", "# Standard libraries\n", "import hashlib\n", "import logging\n", "import os\n", "import random\n", "import requests\n", "import shutil\n", "import time\n", "from importlib import reload\n", "import zipfile\n", "from zipfile import ZipFile\n", "\n", "# Data handling and visualization\n", "import matplotlib.patches as patches\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from PIL import Image\n", "import skimage\n", "from skimage import io\n", "from sklearn.model_selection import train_test_split\n", "\n", "# Deep Learning libraries\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from scipy.spatial.distance import cdist\n", "\n", "# Interactive controls in Jupyter notebooks\n", "from IPython.display import clear_output, display, update_display\n", "import ipywidgets as widgets\n", "\n", "# Utility for progress bars\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e870241d-73fb-455f-a8ef-c92f9b2bdaf3", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "# @markdown\n", "\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting functions\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4452c1b0-d79d-4c17-8738-b0b342180a09", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plotting functions\n", "# @markdown\n", "\n", "def display_images(probe, options):\n", " # Open the probe image and the option images\n", " probe_image = Image.open(probe)\n", " option_images = [Image.open(img_path) for img_path in options]\n", "\n", " # Create a figure with the probe and the 3x3 grid for the options directly below\n", " fig = plt.figure(figsize=(15, 10)) # Adjust figure size as needed\n", "\n", " # Add the probe image to the top of the figure with a red border\n", " ax_probe = fig.add_subplot(4, 3, (1, 3)) # Span the probe across the top 3 columns\n", " ax_probe.imshow(probe_image)\n", " ax_probe.axis('off')\n", " rect = patches.Rectangle((0, 0), probe_image.width-1, probe_image.height-1, linewidth=2, edgecolor='r', facecolor='none')\n", " ax_probe.add_patch(rect)\n", "\n", " # Position the 3x3 grid of option images directly below the probe image\n", " for index, img in enumerate(option_images):\n", " row = (index // 3) + 1 # Calculate row in the 3x3 grid, starting directly below the probe\n", " col = (index % 3) + 1 # Calculate column in the 3x3 grid\n", " ax_option = fig.add_subplot(4, 3, row * 3 + col) # Adjust grid position to directly follow the probe\n", " ax_option.imshow(img)\n", " ax_option.axis('off')\n", "\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data retrieval for zip files\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cab5deda-c9f0-4733-a01b-3d077a96bafb", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Data retrieval for zip files\n", "\n", "def handle_file_operations(fname, url, expected_md5, extract_to='data'):\n", " \"\"\"Handles downloading, verifying, and extracting a file.\"\"\"\n", "\n", " # Define helper functions for download, verify, and extract operations\n", " def download_file(url, filename):\n", " \"\"\"Downloads file from the given URL and saves it locally.\"\"\"\n", " try:\n", " r = requests.get(url, stream=True)\n", " r.raise_for_status()\n", " with open(filename, \"wb\") as fid:\n", " for chunk in r.iter_content(chunk_size=8192):\n", " fid.write(chunk)\n", " print(\"Download successful.\")\n", " return True\n", " except requests.RequestException as e:\n", " print(f\"!!! Failed to download data: {e} !!!\")\n", " return False\n", "\n", " def verify_file_md5(filename, expected_md5):\n", " \"\"\"Verifies the file's MD5 checksum.\"\"\"\n", " hash_md5 = hashlib.md5()\n", " with open(filename, \"rb\") as f:\n", " for chunk in iter(lambda: f.read(4096), b\"\"):\n", " hash_md5.update(chunk)\n", " if hash_md5.hexdigest() == expected_md5:\n", " print(\"MD5 checksum verified.\")\n", " return True\n", " else:\n", " print(\"!!! Data download appears corrupted !!!\")\n", " return False\n", "\n", " def extract_zip_file(filename, extract_to):\n", " \"\"\"Extracts the ZIP file to the specified directory.\"\"\"\n", " try:\n", " with zipfile.ZipFile(filename, 'r') as zip_ref:\n", " zip_ref.extractall(extract_to)\n", " print(f\"File extracted successfully to {extract_to}\")\n", " except zipfile.BadZipFile:\n", " print(\"!!! The ZIP file is corrupted or not a zip file !!!\")\n", "\n", " # Main operation\n", " if not os.path.isfile(fname) or not verify_file_md5(fname, expected_md5):\n", " if download_file(url, fname) and verify_file_md5(fname, expected_md5):\n", " extract_zip_file(fname, extract_to)\n", " else:\n", " print(f\"File '{fname}' already exists and is verified. Proceeding to extraction.\")\n", " extract_zip_file(fname, extract_to)\n", "\n", "# Example usage\n", "file_info = [\n", " {\"fname\": \"omniglot-py.zip\", \"url\": \"https://osf.io/bazxp/download\", \"expected_md5\": \"f7a4011f5c25460c6d95ee1428e377ed\"},\n", "]\n", "\n", "import contextlib\n", "import io\n", "\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " for file in file_info:\n", " handle_file_operations(**file)\n", "\n", "#Current directory\n", "base_dir = os.getcwd()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data retrieval for torch models\n" ] }, { "cell_type": "code", "execution_count": null, "id": "593cf0b6-954a-414b-80b9-e356349bc934", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Data retrieval for torch models\n", "\n", "def download_file(url, filename):\n", " \"\"\"\n", " Download a file from a given URL and save it in the specified directory.\n", " \"\"\"\n", " filepath = os.path.join(base_dir, filename) # Ensure the file is saved in base_dir\n", "\n", " response = requests.get(url)\n", " response.raise_for_status() # Check for HTTP request errors\n", "\n", " with open(filepath, 'wb') as f:\n", " f.write(response.content)\n", "\n", "\n", "def verify_checksum(filename, expected_checksum):\n", " \"\"\"\n", " Verify the MD5 checksum of a file\n", "\n", " Parameters:\n", " filename (str): Path to the file\n", " expected_checksum (str): Expected MD5 checksum\n", "\n", " Returns:\n", " bool: True if the checksum matches, False otherwise\n", " \"\"\"\n", " md5 = hashlib.md5()\n", "\n", " with open(filename, 'rb') as f:\n", " for chunk in iter(lambda: f.read(4096), b\"\"):\n", " md5.update(chunk)\n", "\n", " return md5.hexdigest() == expected_checksum\n", "\n", "def load_models(model_files, directory, map_location='cpu'):\n", " \"\"\"\n", " Load multiple models from a specified directory.\n", " \"\"\"\n", " models = {}\n", " for model_file in model_files:\n", " full_path = os.path.join(directory, model_file) # Correctly join paths\n", " models[model_file] = torch.load(full_path, map_location=map_location)\n", " return models\n", "\n", "def verify_models_in_destination(model_files, destination_directory):\n", " \"\"\"\n", " Verify the presence of model files in the specified directory.\n", "\n", " Parameters:\n", " model_files (list of str): Filenames of the models to verify.\n", " destination_directory (str): The directory where the models are supposed to be.\n", "\n", " Returns:\n", " bool: True if all models are found in the directory, False otherwise.\n", " \"\"\"\n", " missing_files = []\n", " for model_file in model_files:\n", " # Construct the full path to where the model should be\n", " full_path = os.path.join(destination_directory, model_file)\n", " # Check if the model exists at the location\n", " if not os.path.exists(full_path):\n", " missing_files.append(model_file)\n", "\n", " if missing_files:\n", " print(f\"Missing model files in destination: {missing_files}\")\n", " return False\n", " else:\n", " print(\"All models are correctly located in the destination directory.\")\n", " return True\n", "\n", "# URLs and checksums for the models\n", "models_info = {\n", " 'location_model.pt': ('https://osf.io/zmd7y/download', 'dfd51cf7c3a277777ad941c4fcc23813'),\n", " 'stroke_model.pt': ('https://osf.io/m6yc7/download', '511ea7bd12566245d5d11a85d5a0abb0'),\n", " 'terminate_model.pt': ('https://osf.io/dsmhc/download', '2f3e26cfcf36ce9f9172c15d8b1079d1')\n", "}\n", "\n", "destination_directory = base_dir\n", "\n", "# Define model_files based on the keys of models_info to ensure we have the filenames\n", "model_files = list(models_info.keys())\n", "\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " # Iterate over the models to download and verify\n", " for model_name, (url, checksum) in models_info.items():\n", " download_file(url, model_name) # Downloads directly into base_dir\n", " if verify_checksum(os.path.join(base_dir, model_name), checksum):\n", " print(f\"Successfully verified {model_name}\")\n", " else:\n", " print(f\"Checksum does not match for {model_name}. Download might be corrupted.\")\n", "\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " # Verify the presence of the models in the destination directory\n", " if verify_models_in_destination(model_files, destination_directory):\n", " print(\"Verification successful: All models are in the correct directory.\")\n", " else:\n", " print(\"Verification failed: Some models are missing from the destination directory.\")\n", "\n", "# Load the models from the destination directory\n", "models = load_models(model_files, destination_directory, map_location='cpu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper functions\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b3609707-5107-4b50-8b0a-a607890adadb", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper functions\n", "\n", "def select_random_images_within_alphabet(base_path, alphabet_path, exclude_character_path, num_images=8):\n", " # Initialize an empty list to store the paths of the chosen images\n", " chosen_images = []\n", "\n", " # Get a list of all character directories within the alphabet_path, excluding the directory specified by exclude_character_path\n", " all_characters = [\n", " char for char in os.listdir(alphabet_path)\n", " if os.path.isdir(os.path.join(alphabet_path, char)) and os.path.join(alphabet_path, char) != exclude_character_path\n", " ]\n", "\n", " # Keep selecting images until we have the desired number of images (num_images)\n", " while len(chosen_images) < num_images:\n", " # If there are no more characters to choose from, exit the loop\n", " if not all_characters:\n", " break\n", "\n", " # Randomly select a character directory from the list of all characters\n", " character = random.choice(all_characters)\n", " # Construct the full path to the selected character directory\n", " character_path = os.path.join(alphabet_path, character)\n", "\n", " # Get a list of all image files (with .png extension) in the selected character directory\n", " all_images = [\n", " img for img in os.listdir(character_path)\n", " if img.endswith('.png')\n", " ]\n", "\n", " # If there are no images in the selected character directory, continue to the next iteration\n", " if not all_images:\n", " continue\n", "\n", " # Randomly select an image file from the list of image files\n", " image_file = random.choice(all_images)\n", " # Construct the full path to the selected image file\n", " image_path = os.path.join(character_path, image_file)\n", "\n", " # Add the selected image path to the list of chosen images\n", " chosen_images.append(image_path)\n", "\n", " # Return the list of paths to the chosen images\n", " return chosen_images\n", "\n", "def run_trial_interactive(base_path, output):\n", " # Context manager to direct output to the provided widget\n", " with output:\n", " # Initialize and display the score widget\n", " score_widget = widgets.Label(value=f'Score: {total_score}/{total_trials}', disabled=True)\n", " display(score_widget)\n", "\n", " # List all directories (languages) within the base path\n", " languages = [lang for lang in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, lang))]\n", " # Randomly select a language directory\n", " selected_language = random.choice(languages)\n", " # Construct the path to the selected language directory\n", " language_path = os.path.join(base_path, selected_language)\n", "\n", " # List all directories (characters) within the selected language path\n", " characters = [char for char in os.listdir(language_path) if os.path.isdir(os.path.join(language_path, char))]\n", " # Randomly select a character directory\n", " selected_character = random.choice(characters)\n", " # Construct the path to the selected character directory\n", " character_path = os.path.join(language_path, selected_character)\n", "\n", " # List all .png files (images) in the selected character directory\n", " images = [img for img in os.listdir(character_path) if img.endswith('.png')]\n", " # Randomly select two images: one as the probe image and one as the correct answer\n", " probe_image_path, correct_answer_image_path = random.sample(images, 2)\n", " # Construct full paths to the probe image and the correct answer image\n", " probe_image_path = os.path.join(character_path, probe_image_path)\n", " correct_answer_image_path = os.path.join(character_path, correct_answer_image_path)\n", "\n", " # Select a number of wrong answer images from other characters within the same language\n", " wrong_answers = select_random_images_within_alphabet(base_path, language_path, character_path, num_images=8)\n", " # Create the options list, which includes the wrong answers and the correct answer\n", " options = wrong_answers\n", " # Insert the correct answer at a random position within the options list\n", " options.insert(random.randint(0, len(options)), correct_answer_image_path)\n", "\n", " # Display a label indicating the reference image\n", " display(widgets.Label(value='Reference image'))\n", "\n", " # Display the probe image\n", " display(widgets.Image(value=open(probe_image_path, 'rb').read(), format='png'))\n", "\n", " # Create a grid of image widgets for the options\n", " image_grid = widgets.GridBox([widgets.Image(value=open(opt, 'rb').read(), format='png', layout=widgets.Layout(width='100px', height='100px'))\n", " for opt in options], layout=widgets.Layout(grid_template_columns='repeat(3, 100px)'))\n", "\n", " # Create a grid of numbered buttons corresponding to the images\n", " button_grid = widgets.GridBox([widgets.Button(description=str(i+1), layout=widgets.Layout(width='auto', height='auto'))\n", " for i in range(len(options))], layout=widgets.Layout(grid_template_columns='repeat(3, 100px)'))\n", "\n", " # Combine the image grid and the button grid into a single grid layout\n", " global_grid = widgets.GridBox([image_grid, button_grid], layout=widgets.Layout(grid_template_columns='repeat(2, 300px)'))\n", "\n", " # Display a label prompting the user to match the reference image\n", " display(widgets.Label(value='Which of these images match the reference? '))\n", "\n", " time.sleep(.2)\n", "\n", " # Display the combined grid of images and buttons\n", " display(global_grid)\n", "\n", " # Attach click event handlers to the buttons\n", " for b in button_grid.children:\n", " b.on_click(lambda b: on_button_clicked(b, options, correct_answer_image_path, score_widget))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Overview\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ce12a7d9-b6f6-4159-8ebe-ecc160f469ab", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Overview\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'MOmT5NDjD6A'), ('Bilibili', 'BV1Nz42187jL')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7262af7b-b567-4d67-91d0-a865fd6b1793", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Overview\")" ] }, { "cell_type": "markdown", "id": "9711413f-c4e5-4415-a1e3-78e3f776066d", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 1: How people recognize new characters\n", "\n", "Let's put ourselves in the mindset of a cognitive scientist studying handwriting. We're interested in how people learn to recognize new characters. Indeed, humans display low **sample complexity** when learning new visual concepts: they seem to grasp new concepts with very few presentations, generalizing effortlessly. In AI, learning from $k$ labeled examples is known as $k$-shot learning; one-shot and few-shot learning refer to learning from one or a few labeled examples.\n", "\n", "A good dataset to investigate one-shot learning is the Omniglot dataset. Omniglot has sometimes been described as *MNIST, transposed*. Instead of **thousands** of examples from **10** digit classes, Omniglot consists of **20** instances from **1623** character classes. These character classes are sourced from 50 alphabets, both natural (e.g. Cherokee or Greek) and constructed (e.g. the alien alphabet from the TV show Futurama). \n", "\n", "![Sample characters from the Omniglot dataset](https://github.com/brendenlake/omniglot/raw/master/omniglot_grid.jpg)\n", "\n", "Let's see if you're a good one-shot classifier by trying the Omniglot task yourself. Observing human behavior in the lab to infer their strategies is an important way that cognitive scientists make progress in understanding human cognition." ] }, { "cell_type": "markdown", "id": "624859f8-1735-4611-918e-29dafc295ab5", "metadata": { "execution": {} }, "source": [ "Your task is to conduct a series of trials to explore the Omniglot dataset. Here's how the experiment goes:\n", "\n", "1. **Click Start**\n", "\n", "2. **Look at the reference character at the top**\n", "\n", "3. **Look at 9 different potential matches at the bottom**. These include one more instance of the reference character class, and 8 distractors from other characters of the same alphabet.\n", "\n", "4. **Click the button corresponding to the best match**. The selection buttons are on the right of the grid.\n", "\n", "5. **Repeat for multiple trials**. Get to 10 or 20 to get an estimate of how well you perform." ] }, { "cell_type": "code", "execution_count": null, "id": "57baab32-15c1-4977-8775-ee0e51e159f1", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Paths\n", "base_path = \"data/omniglot-py/images_background\"\n", "\n", "total_score = 0\n", "total_trials = 0\n", "\n", "output = widgets.Output()\n", "btn = None\n", "\n", "def start(b):\n", " global total_score, total_trials\n", " total_score = 0\n", " total_trials = 0\n", " output.clear_output(wait=True)\n", " run_trial_interactive(base_path, output)\n", " btn.description = 'Reset the interactive'\n", "\n", "def on_button_clicked(b, options, correct_answer_image_path, score_widget):\n", " global total_score, total_trials\n", " if options[int(b.description) - 1] == correct_answer_image_path:\n", " total_score += 1\n", " total_trials += 1\n", " output.clear_output(wait=True)\n", " run_trial_interactive(base_path, output)\n", "\n", "def display_start_button():\n", " global btn\n", " btn = widgets.Button(description='Start the interactive')\n", " display(btn, output)\n", " btn.on_click(start)\n", "\n", "display_start_button()" ] }, { "cell_type": "markdown", "id": "4b5cfa75-b63f-44b5-95d9-92aedad4c4c7", "metadata": { "execution": {} }, "source": [ "How well did you do?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1579c4cb-d83a-49c0-a1c1-758cd8f67b68", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Omniglot_Dataset\")" ] }, { "cell_type": "markdown", "id": "3db39c9a-e4f7-4a8b-9be5-d04f017b9faf", "metadata": { "execution": {} }, "source": [ "## Reflection activity 1.1" ] }, { "cell_type": "markdown", "id": "b13e078b-17bc-4da1-9470-d70c534edc33", "metadata": { "execution": {} }, "source": [ "Sample complexity $N(\\varepsilon, \\delta)$ is formally defined as:\n", "\n", "> the number of examples $N$ that a learner must see in order to perform a task with an error rate smaller than $\\varepsilon$ with probability greater than $1-\\delta$. \n", "\n", "Based on this definition, what is your sample complexity on the Omniglot task?" ] }, { "cell_type": "markdown", "id": "96cd18b3-9912-4307-b2e9-5cf0260983dc", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial3_Solution_9e44e6ca.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "adbc5db1-c13c-47b8-96c0-4bd264a4a58b", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Reflection_Activity_1\")" ] }, { "cell_type": "markdown", "id": "22c68c49-e7e6-4104-9d76-b7bed3cbf084", "metadata": { "execution": {} }, "source": [ "## Reflection activity 1.2\n", "\n", "How do you think you, as a human, are performing a task like Omniglot?" ] }, { "cell_type": "markdown", "id": "af834cb3-7089-46f1-ab9b-41394ac1410f", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial3_Solution_dbbeabd0.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e7f01bd9-8016-4ddf-9009-981f4e0dbf4d", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Reflection_Activity_2\")" ] }, { "cell_type": "markdown", "id": "47666813-c5dd-45fa-a331-308c14582956", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Model of one-shot learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 2: GNS Model\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c6610cd5-0542-4178-8faf-2c925a3b39fd", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: GNS Model\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', '2q1Q8l3Lg9c'), ('Bilibili', 'BV1DZ421u7vk')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4e132277-5f80-4d87-a33a-615dbb39bfa0", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_GNS_Model\")" ] }, { "cell_type": "markdown", "id": "aa31cc80-3fbf-4a65-9ea2-178c981b7fd9", "metadata": { "execution": {} }, "source": [ "Feinman and Lake (2020) propose a cognitive model to explain how humans perform one-shot recognition tasks like Omniglot. Their model is based on the insight handwriting characters are highly structured: if we could infer how a character was generated, we could figure out the writer's intent, and perform one-shot recognition of characters. \n", "\n", "When we write down a character on a piece of paper or a screen, we might implicitly perform a sequence of steps:\n", "\n", "1. Prepare a global motor plan to write a character based on prior experience\n", "2. Decide where to put down the pen for the first stroke\n", "3. Draw a stroke in an appropriate direction.\n", "\n", " a. Look at the sheet of paper during the writing to adjust the direction of the stroke\n", "4. Find a location for the second strike, and so on...\n", "5. When satisfied, stop drawing strokes\n", "\n", "Feinman and Lake (2020) propose to embed these assumptions into a generative model for how a single character is generated from strokes.\n", "\n", "\n", "\n", "The result is a highly structured Bayesian generative model containing both discrete components (e.g. strokes) and continuous components (e.g. the location of the next stroke is a continuous variable). It combines symbolic primitives (strokes) as well as standard ANN components. This combination of using symbols and neural networks is known as a **neuro-symbolic** approach.\n", "\n", "This is an example of a model with **strong inductive biases**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1424a819-df4a-4f9a-8f11-2a2145c717e9", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_One_Shot_Learning\")" ] }, { "cell_type": "markdown", "id": "206f46f9-5600-454f-8be7-76cbd47ba43d", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "* Cognitive science seeks to understand how human cognition works.\n", "* Humans display one-shot learning on Omniglot, a character recognition task. This requires extensive generalization.\n", "* Sample complexity measures the minimum number of examples needed to reach a specific performance with some probability; a sample complexity of 1 indicates one-shot learning at a specific performance level.\n", "* A generative neurosymbolic model with strong inductive biases exhibits human-level performance on Omniglot." ] } ], "metadata": { "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "W1D1_Tutorial3", "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }