Open In Colab   Open in Kaggle

Tutorial 3: Generalization in Cognitive Science#

Week 1, Day 1: Generalization

By Neuromatch Academy

Content creators: Samuele Bolotta, Patrick Mineault

Content reviewers: Samuele Bolotta, Lily Chamakura, RyeongKyung Yoon, Yizhou Chen, Ruiyi Zhang, Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk

Production editors: Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk


Tutorial Objectives#

Estimated timing of tutorial: 30 minutes

This tutorial will introduce you to generalization in the context of cognitive science. We’ll close the loop of our exploration of different views of handwriting with a model that combines aspects of the models we covered in the neuroscience and AI tutorials, including both generative and discriminative components.

In particular, we’ll be looking at the Omniglot dataset, and how it can be used to infer how humans and machines generalize in a handwritten symbol recognition task. We’ll try our hand at one-shot learning, and we’ll measure our sample complexity. We’ll then discuss how one cognitive model, Feinman and Lake (2020), attempts to solve the problem of handwritten symbol recognition using a neuro-symbolic method.

By the end of this tutorial, participants will be able to:

  1. Explore the goals of cognitive science. Understand the aims of cognitive science such as unraveling the complexities of human cognition.

  2. Define one-shot learning and sample complexity. Perform a task that involves one-shot learning.

  3. Explore how a neurosymbolic model with strong inductive biases could explain one-shot learning on Omniglot.


Setup#

Install and import feedback gadget#

Hide code cell source
# @title Install and import feedback gadget

!pip install matplotlib numpy Pillow scipy ipywidgets vibecheck tqdm --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_neuroai",
            "user_key": "wb2cxze8",
        },
    ).render()


feedback_prefix = "W1D1_T3"

Import dependencies#

Hide code cell source
# @title Import dependencies

# Standard libraries
import hashlib
import logging
import os
import random
import requests
import shutil
import time
from importlib import reload
import zipfile
from zipfile import ZipFile

# Data handling and visualization
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import skimage
from skimage import io
from sklearn.model_selection import train_test_split

# Deep Learning libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.spatial.distance import cdist

# Interactive controls in Jupyter notebooks
from IPython.display import clear_output, display, update_display
import ipywidgets as widgets

# Utility for progress bars
from tqdm import tqdm

Figure settings#

Hide code cell source
# @title Figure settings
# @markdown

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

Plotting functions#

Hide code cell source
# @title Plotting functions
# @markdown

def display_images(probe, options):
    # Open the probe image and the option images
    probe_image = Image.open(probe)
    option_images = [Image.open(img_path) for img_path in options]

    # Create a figure with the probe and the 3x3 grid for the options directly below
    fig = plt.figure(figsize=(15, 10))  # Adjust figure size as needed

    # Add the probe image to the top of the figure with a red border
    ax_probe = fig.add_subplot(4, 3, (1, 3))  # Span the probe across the top 3 columns
    ax_probe.imshow(probe_image)
    ax_probe.axis('off')
    rect = patches.Rectangle((0, 0), probe_image.width-1, probe_image.height-1, linewidth=2, edgecolor='r', facecolor='none')
    ax_probe.add_patch(rect)

    # Position the 3x3 grid of option images directly below the probe image
    for index, img in enumerate(option_images):
        row = (index // 3) + 1  # Calculate row in the 3x3 grid, starting directly below the probe
        col = (index % 3) + 1   # Calculate column in the 3x3 grid
        ax_option = fig.add_subplot(4, 3, row * 3 + col)  # Adjust grid position to directly follow the probe
        ax_option.imshow(img)
        ax_option.axis('off')

    plt.tight_layout()
    plt.show()

Data retrieval for zip files#

Hide code cell source
# @title Data retrieval for zip files

def handle_file_operations(fname, url, expected_md5, extract_to='data'):
    """Handles downloading, verifying, and extracting a file."""

    # Define helper functions for download, verify, and extract operations
    def download_file(url, filename):
        """Downloads file from the given URL and saves it locally."""
        try:
            r = requests.get(url, stream=True)
            r.raise_for_status()
            with open(filename, "wb") as fid:
                for chunk in r.iter_content(chunk_size=8192):
                    fid.write(chunk)
            print("Download successful.")
            return True
        except requests.RequestException as e:
            print(f"!!! Failed to download data: {e} !!!")
            return False

    def verify_file_md5(filename, expected_md5):
        """Verifies the file's MD5 checksum."""
        hash_md5 = hashlib.md5()
        with open(filename, "rb") as f:
            for chunk in iter(lambda: f.read(4096), b""):
                hash_md5.update(chunk)
        if hash_md5.hexdigest() == expected_md5:
            print("MD5 checksum verified.")
            return True
        else:
            print("!!! Data download appears corrupted !!!")
            return False

    def extract_zip_file(filename, extract_to):
        """Extracts the ZIP file to the specified directory."""
        try:
            with zipfile.ZipFile(filename, 'r') as zip_ref:
                zip_ref.extractall(extract_to)
            print(f"File extracted successfully to {extract_to}")
        except zipfile.BadZipFile:
            print("!!! The ZIP file is corrupted or not a zip file !!!")

    # Main operation
    if not os.path.isfile(fname) or not verify_file_md5(fname, expected_md5):
        if download_file(url, fname) and verify_file_md5(fname, expected_md5):
            extract_zip_file(fname, extract_to)
    else:
        print(f"File '{fname}' already exists and is verified. Proceeding to extraction.")
        extract_zip_file(fname, extract_to)

# Example usage
file_info = [
    {"fname": "omniglot-py.zip", "url": "https://osf.io/bazxp/download", "expected_md5": "f7a4011f5c25460c6d95ee1428e377ed"},
]

import contextlib
import io

with contextlib.redirect_stdout(io.StringIO()):
    for file in file_info:
        handle_file_operations(**file)

#Current directory
base_dir = os.getcwd()

Data retrieval for torch models#

Hide code cell source
# @title Data retrieval for torch models

def download_file(url, filename):
    """
    Download a file from a given URL and save it in the specified directory.
    """
    filepath = os.path.join(base_dir, filename)  # Ensure the file is saved in base_dir

    response = requests.get(url)
    response.raise_for_status()  # Check for HTTP request errors

    with open(filepath, 'wb') as f:
        f.write(response.content)


def verify_checksum(filename, expected_checksum):
    """
    Verify the MD5 checksum of a file

    Parameters:
    filename (str): Path to the file
    expected_checksum (str): Expected MD5 checksum

    Returns:
    bool: True if the checksum matches, False otherwise
    """
    md5 = hashlib.md5()

    with open(filename, 'rb') as f:
        for chunk in iter(lambda: f.read(4096), b""):
            md5.update(chunk)

    return md5.hexdigest() == expected_checksum

def load_models(model_files, directory, map_location='cpu'):
    """
    Load multiple models from a specified directory.
    """
    models = {}
    for model_file in model_files:
        full_path = os.path.join(directory, model_file)  # Correctly join paths
        models[model_file] = torch.load(full_path, map_location=map_location)
    return models

def verify_models_in_destination(model_files, destination_directory):
    """
    Verify the presence of model files in the specified directory.

    Parameters:
    model_files (list of str): Filenames of the models to verify.
    destination_directory (str): The directory where the models are supposed to be.

    Returns:
    bool: True if all models are found in the directory, False otherwise.
    """
    missing_files = []
    for model_file in model_files:
        # Construct the full path to where the model should be
        full_path = os.path.join(destination_directory, model_file)
        # Check if the model exists at the location
        if not os.path.exists(full_path):
            missing_files.append(model_file)

    if missing_files:
        print(f"Missing model files in destination: {missing_files}")
        return False
    else:
        print("All models are correctly located in the destination directory.")
        return True

# URLs and checksums for the models
models_info = {
    'location_model.pt': ('https://osf.io/zmd7y/download', 'dfd51cf7c3a277777ad941c4fcc23813'),
    'stroke_model.pt': ('https://osf.io/m6yc7/download', '511ea7bd12566245d5d11a85d5a0abb0'),
    'terminate_model.pt': ('https://osf.io/dsmhc/download', '2f3e26cfcf36ce9f9172c15d8b1079d1')
}

destination_directory = base_dir

# Define model_files based on the keys of models_info to ensure we have the filenames
model_files = list(models_info.keys())

with contextlib.redirect_stdout(io.StringIO()):
    # Iterate over the models to download and verify
    for model_name, (url, checksum) in models_info.items():
        download_file(url, model_name)  # Downloads directly into base_dir
        if verify_checksum(os.path.join(base_dir, model_name), checksum):
            print(f"Successfully verified {model_name}")
        else:
            print(f"Checksum does not match for {model_name}. Download might be corrupted.")

with contextlib.redirect_stdout(io.StringIO()):
    # Verify the presence of the models in the destination directory
    if verify_models_in_destination(model_files, destination_directory):
        print("Verification successful: All models are in the correct directory.")
    else:
        print("Verification failed: Some models are missing from the destination directory.")

# Load the models from the destination directory
models = load_models(model_files, destination_directory, map_location='cpu')

Helper functions#

Hide code cell source
# @title Helper functions

def select_random_images_within_alphabet(base_path, alphabet_path, exclude_character_path, num_images=8):
    # Initialize an empty list to store the paths of the chosen images
    chosen_images = []

    # Get a list of all character directories within the alphabet_path, excluding the directory specified by exclude_character_path
    all_characters = [
        char for char in os.listdir(alphabet_path)
        if os.path.isdir(os.path.join(alphabet_path, char)) and os.path.join(alphabet_path, char) != exclude_character_path
    ]

    # Keep selecting images until we have the desired number of images (num_images)
    while len(chosen_images) < num_images:
        # If there are no more characters to choose from, exit the loop
        if not all_characters:
            break

        # Randomly select a character directory from the list of all characters
        character = random.choice(all_characters)
        # Construct the full path to the selected character directory
        character_path = os.path.join(alphabet_path, character)

        # Get a list of all image files (with .png extension) in the selected character directory
        all_images = [
            img for img in os.listdir(character_path)
            if img.endswith('.png')
        ]

        # If there are no images in the selected character directory, continue to the next iteration
        if not all_images:
            continue

        # Randomly select an image file from the list of image files
        image_file = random.choice(all_images)
        # Construct the full path to the selected image file
        image_path = os.path.join(character_path, image_file)

        # Add the selected image path to the list of chosen images
        chosen_images.append(image_path)

    # Return the list of paths to the chosen images
    return chosen_images

def run_trial_interactive(base_path, output):
    # Context manager to direct output to the provided widget
    with output:
        # Initialize and display the score widget
        score_widget = widgets.Label(value=f'Score: {total_score}/{total_trials}', disabled=True)
        display(score_widget)

        # List all directories (languages) within the base path
        languages = [lang for lang in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, lang))]
        # Randomly select a language directory
        selected_language = random.choice(languages)
        # Construct the path to the selected language directory
        language_path = os.path.join(base_path, selected_language)

        # List all directories (characters) within the selected language path
        characters = [char for char in os.listdir(language_path) if os.path.isdir(os.path.join(language_path, char))]
        # Randomly select a character directory
        selected_character = random.choice(characters)
        # Construct the path to the selected character directory
        character_path = os.path.join(language_path, selected_character)

        # List all .png files (images) in the selected character directory
        images = [img for img in os.listdir(character_path) if img.endswith('.png')]
        # Randomly select two images: one as the probe image and one as the correct answer
        probe_image_path, correct_answer_image_path = random.sample(images, 2)
        # Construct full paths to the probe image and the correct answer image
        probe_image_path = os.path.join(character_path, probe_image_path)
        correct_answer_image_path = os.path.join(character_path, correct_answer_image_path)

        # Select a number of wrong answer images from other characters within the same language
        wrong_answers = select_random_images_within_alphabet(base_path, language_path, character_path, num_images=8)
        # Create the options list, which includes the wrong answers and the correct answer
        options = wrong_answers
        # Insert the correct answer at a random position within the options list
        options.insert(random.randint(0, len(options)), correct_answer_image_path)

        # Display a label indicating the reference image
        display(widgets.Label(value='Reference image'))

        # Display the probe image
        display(widgets.Image(value=open(probe_image_path, 'rb').read(), format='png'))

        # Create a grid of image widgets for the options
        image_grid = widgets.GridBox([widgets.Image(value=open(opt, 'rb').read(), format='png', layout=widgets.Layout(width='100px', height='100px'))
                                      for opt in options], layout=widgets.Layout(grid_template_columns='repeat(3, 100px)'))

        # Create a grid of numbered buttons corresponding to the images
        button_grid = widgets.GridBox([widgets.Button(description=str(i+1), layout=widgets.Layout(width='auto', height='auto'))
                                       for i in range(len(options))], layout=widgets.Layout(grid_template_columns='repeat(3, 100px)'))

        # Combine the image grid and the button grid into a single grid layout
        global_grid = widgets.GridBox([image_grid, button_grid], layout=widgets.Layout(grid_template_columns='repeat(2, 300px)'))

        # Display a label prompting the user to match the reference image
        display(widgets.Label(value='Which of these images match the reference? '))

        time.sleep(.2)

        # Display the combined grid of images and buttons
        display(global_grid)

        # Attach click event handlers to the buttons
        for b in button_grid.children:
            b.on_click(lambda b: on_button_clicked(b, options, correct_answer_image_path, score_widget))

Video 1: Overview#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Overview")

Section 1: How people recognize new characters#

Let’s put ourselves in the mindset of a cognitive scientist studying handwriting. We’re interested in how people learn to recognize new characters. Indeed, humans display low sample complexity when learning new visual concepts: they seem to grasp new concepts with very few presentations, generalizing effortlessly. In AI, learning from \(k\) labeled examples is known as \(k\)-shot learning; one-shot and few-shot learning refer to learning from one or a few labeled examples.

A good dataset to investigate one-shot learning is the Omniglot dataset. Omniglot has sometimes been described as MNIST, transposed. Instead of thousands of examples from 10 digit classes, Omniglot consists of 20 instances from 1623 character classes. These character classes are sourced from 50 alphabets, both natural (e.g. Cherokee or Greek) and constructed (e.g. the alien alphabet from the TV show Futurama).

Sample characters from the Omniglot dataset

Let’s see if you’re a good one-shot classifier by trying the Omniglot task yourself. Observing human behavior in the lab to infer their strategies is an important way that cognitive scientists make progress in understanding human cognition.

Your task is to conduct a series of trials to explore the Omniglot dataset. Here’s how the experiment goes:

  1. Click Start

  2. Look at the reference character at the top

  3. Look at 9 different potential matches at the bottom. These include one more instance of the reference character class, and 8 distractors from other characters of the same alphabet.

  4. Click the button corresponding to the best match. The selection buttons are on the right of the grid.

  5. Repeat for multiple trials. Get to 10 or 20 to get an estimate of how well you perform.

# Paths
base_path = "data/omniglot-py/images_background"

total_score = 0
total_trials = 0

output = widgets.Output()
btn = None

def start(b):
    global total_score, total_trials
    total_score = 0
    total_trials = 0
    output.clear_output(wait=True)
    run_trial_interactive(base_path, output)
    btn.description = 'Reset the interactive'

def on_button_clicked(b, options, correct_answer_image_path, score_widget):
    global total_score, total_trials
    if options[int(b.description) - 1] == correct_answer_image_path:
        total_score += 1
    total_trials += 1
    output.clear_output(wait=True)
    run_trial_interactive(base_path, output)

def display_start_button():
    global btn
    btn = widgets.Button(description='Start the interactive')
    display(btn, output)
    btn.on_click(start)

display_start_button()

How well did you do?

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Omniglot_Dataset")

Reflection activity 1.1#

Sample complexity \(N(\varepsilon, \delta)\) is formally defined as:

the number of examples \(N\) that a learner must see in order to perform a task with an error rate smaller than \(\varepsilon\) with probability greater than \(1-\delta\).

Based on this definition, what is your sample complexity on the Omniglot task?

Click for solution

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Reflection_Activity_1")

Reflection activity 1.2#

How do you think you, as a human, are performing a task like Omniglot?

Click for solution

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Reflection_Activity_2")

Section 2: Model of one-shot learning#

Video 2: GNS Model#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_GNS_Model")

Feinman and Lake (2020) propose a cognitive model to explain how humans perform one-shot recognition tasks like Omniglot. Their model is based on the insight handwriting characters are highly structured: if we could infer how a character was generated, we could figure out the writer’s intent, and perform one-shot recognition of characters.

When we write down a character on a piece of paper or a screen, we might implicitly perform a sequence of steps:

  1. Prepare a global motor plan to write a character based on prior experience

  2. Decide where to put down the pen for the first stroke

  3. Draw a stroke in an appropriate direction.

    a. Look at the sheet of paper during the writing to adjust the direction of the stroke

  4. Find a location for the second strike, and so on…

  5. When satisfied, stop drawing strokes

Feinman and Lake (2020) propose to embed these assumptions into a generative model for how a single character is generated from strokes.

https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/model_diagram.png?raw=true

The result is a highly structured Bayesian generative model containing both discrete components (e.g. strokes) and continuous components (e.g. the location of the next stroke is a continuous variable). It combines symbolic primitives (strokes) as well as standard ANN components. This combination of using symbols and neural networks is known as a neuro-symbolic approach.

This is an example of a model with strong inductive biases.

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_One_Shot_Learning")

Summary#

  • Cognitive science seeks to understand how human cognition works.

  • Humans display one-shot learning on Omniglot, a character recognition task. This requires extensive generalization.

  • Sample complexity measures the minimum number of examples needed to reach a specific performance with some probability; a sample complexity of 1 indicates one-shot learning at a specific performance level.

  • A generative neurosymbolic model with strong inductive biases exhibits human-level performance on Omniglot.