Tutorial 3: Generalization in Cognitive Science#
Week 1, Day 1: Generalization
By Neuromatch Academy
Content creators: Samuele Bolotta, Patrick Mineault
Content reviewers: Samuele Bolotta, Lily Chamakura, RyeongKyung Yoon, Yizhou Chen, Ruiyi Zhang, Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk
Production editors: Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk
Tutorial Objectives#
Estimated timing of tutorial: 30 minutes
This tutorial will introduce you to generalization in the context of cognitive science. We’ll close the loop of our exploration of different views of handwriting with a model that combines aspects of the models we covered in the neuroscience and AI tutorials, including both generative and discriminative components.
In particular, we’ll be looking at the Omniglot dataset, and how it can be used to infer how humans and machines generalize in a handwritten symbol recognition task. We’ll try our hand at one-shot learning, and we’ll measure our sample complexity. We’ll then discuss how one cognitive model, Feinman and Lake (2020), attempts to solve the problem of handwritten symbol recognition using a neuro-symbolic method.
By the end of this tutorial, participants will be able to:
Explore the goals of cognitive science. Understand the aims of cognitive science such as unraveling the complexities of human cognition.
Define one-shot learning and sample complexity. Perform a task that involves one-shot learning.
Explore how a neurosymbolic model with strong inductive biases could explain one-shot learning on Omniglot.
Setup#
Install and import feedback gadget#
Show code cell source
# @title Install and import feedback gadget
!pip install matplotlib numpy Pillow scipy ipywidgets vibecheck tqdm --quiet
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_neuroai",
"user_key": "wb2cxze8",
},
).render()
feedback_prefix = "W1D1_T3"
Import dependencies#
Show code cell source
# @title Import dependencies
# Standard libraries
import hashlib
import logging
import os
import random
import requests
import shutil
import time
from importlib import reload
import zipfile
from zipfile import ZipFile
# Data handling and visualization
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import skimage
from skimage import io
from sklearn.model_selection import train_test_split
# Deep Learning libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.spatial.distance import cdist
# Interactive controls in Jupyter notebooks
from IPython.display import clear_output, display, update_display
import ipywidgets as widgets
# Utility for progress bars
from tqdm import tqdm
Figure settings#
Show code cell source
# @title Figure settings
# @markdown
logging.getLogger('matplotlib.font_manager').disabled = True
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")
Plotting functions#
Show code cell source
# @title Plotting functions
# @markdown
def display_images(probe, options):
# Open the probe image and the option images
probe_image = Image.open(probe)
option_images = [Image.open(img_path) for img_path in options]
# Create a figure with the probe and the 3x3 grid for the options directly below
fig = plt.figure(figsize=(15, 10)) # Adjust figure size as needed
# Add the probe image to the top of the figure with a red border
ax_probe = fig.add_subplot(4, 3, (1, 3)) # Span the probe across the top 3 columns
ax_probe.imshow(probe_image)
ax_probe.axis('off')
rect = patches.Rectangle((0, 0), probe_image.width-1, probe_image.height-1, linewidth=2, edgecolor='r', facecolor='none')
ax_probe.add_patch(rect)
# Position the 3x3 grid of option images directly below the probe image
for index, img in enumerate(option_images):
row = (index // 3) + 1 # Calculate row in the 3x3 grid, starting directly below the probe
col = (index % 3) + 1 # Calculate column in the 3x3 grid
ax_option = fig.add_subplot(4, 3, row * 3 + col) # Adjust grid position to directly follow the probe
ax_option.imshow(img)
ax_option.axis('off')
plt.tight_layout()
plt.show()
Data retrieval for zip files#
Show code cell source
# @title Data retrieval for zip files
def handle_file_operations(fname, url, expected_md5, extract_to='data'):
"""Handles downloading, verifying, and extracting a file."""
# Define helper functions for download, verify, and extract operations
def download_file(url, filename):
"""Downloads file from the given URL and saves it locally."""
try:
r = requests.get(url, stream=True)
r.raise_for_status()
with open(filename, "wb") as fid:
for chunk in r.iter_content(chunk_size=8192):
fid.write(chunk)
print("Download successful.")
return True
except requests.RequestException as e:
print(f"!!! Failed to download data: {e} !!!")
return False
def verify_file_md5(filename, expected_md5):
"""Verifies the file's MD5 checksum."""
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
if hash_md5.hexdigest() == expected_md5:
print("MD5 checksum verified.")
return True
else:
print("!!! Data download appears corrupted !!!")
return False
def extract_zip_file(filename, extract_to):
"""Extracts the ZIP file to the specified directory."""
try:
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(extract_to)
print(f"File extracted successfully to {extract_to}")
except zipfile.BadZipFile:
print("!!! The ZIP file is corrupted or not a zip file !!!")
# Main operation
if not os.path.isfile(fname) or not verify_file_md5(fname, expected_md5):
if download_file(url, fname) and verify_file_md5(fname, expected_md5):
extract_zip_file(fname, extract_to)
else:
print(f"File '{fname}' already exists and is verified. Proceeding to extraction.")
extract_zip_file(fname, extract_to)
# Example usage
file_info = [
{"fname": "omniglot-py.zip", "url": "https://osf.io/bazxp/download", "expected_md5": "f7a4011f5c25460c6d95ee1428e377ed"},
]
import contextlib
import io
with contextlib.redirect_stdout(io.StringIO()):
for file in file_info:
handle_file_operations(**file)
#Current directory
base_dir = os.getcwd()
Data retrieval for torch models#
Show code cell source
# @title Data retrieval for torch models
def download_file(url, filename):
"""
Download a file from a given URL and save it in the specified directory.
"""
filepath = os.path.join(base_dir, filename) # Ensure the file is saved in base_dir
response = requests.get(url)
response.raise_for_status() # Check for HTTP request errors
with open(filepath, 'wb') as f:
f.write(response.content)
def verify_checksum(filename, expected_checksum):
"""
Verify the MD5 checksum of a file
Parameters:
filename (str): Path to the file
expected_checksum (str): Expected MD5 checksum
Returns:
bool: True if the checksum matches, False otherwise
"""
md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
return md5.hexdigest() == expected_checksum
def load_models(model_files, directory, map_location='cpu'):
"""
Load multiple models from a specified directory.
"""
models = {}
for model_file in model_files:
full_path = os.path.join(directory, model_file) # Correctly join paths
models[model_file] = torch.load(full_path, map_location=map_location)
return models
def verify_models_in_destination(model_files, destination_directory):
"""
Verify the presence of model files in the specified directory.
Parameters:
model_files (list of str): Filenames of the models to verify.
destination_directory (str): The directory where the models are supposed to be.
Returns:
bool: True if all models are found in the directory, False otherwise.
"""
missing_files = []
for model_file in model_files:
# Construct the full path to where the model should be
full_path = os.path.join(destination_directory, model_file)
# Check if the model exists at the location
if not os.path.exists(full_path):
missing_files.append(model_file)
if missing_files:
print(f"Missing model files in destination: {missing_files}")
return False
else:
print("All models are correctly located in the destination directory.")
return True
# URLs and checksums for the models
models_info = {
'location_model.pt': ('https://osf.io/zmd7y/download', 'dfd51cf7c3a277777ad941c4fcc23813'),
'stroke_model.pt': ('https://osf.io/m6yc7/download', '511ea7bd12566245d5d11a85d5a0abb0'),
'terminate_model.pt': ('https://osf.io/dsmhc/download', '2f3e26cfcf36ce9f9172c15d8b1079d1')
}
destination_directory = base_dir
# Define model_files based on the keys of models_info to ensure we have the filenames
model_files = list(models_info.keys())
with contextlib.redirect_stdout(io.StringIO()):
# Iterate over the models to download and verify
for model_name, (url, checksum) in models_info.items():
download_file(url, model_name) # Downloads directly into base_dir
if verify_checksum(os.path.join(base_dir, model_name), checksum):
print(f"Successfully verified {model_name}")
else:
print(f"Checksum does not match for {model_name}. Download might be corrupted.")
with contextlib.redirect_stdout(io.StringIO()):
# Verify the presence of the models in the destination directory
if verify_models_in_destination(model_files, destination_directory):
print("Verification successful: All models are in the correct directory.")
else:
print("Verification failed: Some models are missing from the destination directory.")
# Load the models from the destination directory
models = load_models(model_files, destination_directory, map_location='cpu')
Helper functions#
Show code cell source
# @title Helper functions
def select_random_images_within_alphabet(base_path, alphabet_path, exclude_character_path, num_images=8):
# Initialize an empty list to store the paths of the chosen images
chosen_images = []
# Get a list of all character directories within the alphabet_path, excluding the directory specified by exclude_character_path
all_characters = [
char for char in os.listdir(alphabet_path)
if os.path.isdir(os.path.join(alphabet_path, char)) and os.path.join(alphabet_path, char) != exclude_character_path
]
# Keep selecting images until we have the desired number of images (num_images)
while len(chosen_images) < num_images:
# If there are no more characters to choose from, exit the loop
if not all_characters:
break
# Randomly select a character directory from the list of all characters
character = random.choice(all_characters)
# Construct the full path to the selected character directory
character_path = os.path.join(alphabet_path, character)
# Get a list of all image files (with .png extension) in the selected character directory
all_images = [
img for img in os.listdir(character_path)
if img.endswith('.png')
]
# If there are no images in the selected character directory, continue to the next iteration
if not all_images:
continue
# Randomly select an image file from the list of image files
image_file = random.choice(all_images)
# Construct the full path to the selected image file
image_path = os.path.join(character_path, image_file)
# Add the selected image path to the list of chosen images
chosen_images.append(image_path)
# Return the list of paths to the chosen images
return chosen_images
def run_trial_interactive(base_path, output):
# Context manager to direct output to the provided widget
with output:
# Initialize and display the score widget
score_widget = widgets.Label(value=f'Score: {total_score}/{total_trials}', disabled=True)
display(score_widget)
# List all directories (languages) within the base path
languages = [lang for lang in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, lang))]
# Randomly select a language directory
selected_language = random.choice(languages)
# Construct the path to the selected language directory
language_path = os.path.join(base_path, selected_language)
# List all directories (characters) within the selected language path
characters = [char for char in os.listdir(language_path) if os.path.isdir(os.path.join(language_path, char))]
# Randomly select a character directory
selected_character = random.choice(characters)
# Construct the path to the selected character directory
character_path = os.path.join(language_path, selected_character)
# List all .png files (images) in the selected character directory
images = [img for img in os.listdir(character_path) if img.endswith('.png')]
# Randomly select two images: one as the probe image and one as the correct answer
probe_image_path, correct_answer_image_path = random.sample(images, 2)
# Construct full paths to the probe image and the correct answer image
probe_image_path = os.path.join(character_path, probe_image_path)
correct_answer_image_path = os.path.join(character_path, correct_answer_image_path)
# Select a number of wrong answer images from other characters within the same language
wrong_answers = select_random_images_within_alphabet(base_path, language_path, character_path, num_images=8)
# Create the options list, which includes the wrong answers and the correct answer
options = wrong_answers
# Insert the correct answer at a random position within the options list
options.insert(random.randint(0, len(options)), correct_answer_image_path)
# Display a label indicating the reference image
display(widgets.Label(value='Reference image'))
# Display the probe image
display(widgets.Image(value=open(probe_image_path, 'rb').read(), format='png'))
# Create a grid of image widgets for the options
image_grid = widgets.GridBox([widgets.Image(value=open(opt, 'rb').read(), format='png', layout=widgets.Layout(width='100px', height='100px'))
for opt in options], layout=widgets.Layout(grid_template_columns='repeat(3, 100px)'))
# Create a grid of numbered buttons corresponding to the images
button_grid = widgets.GridBox([widgets.Button(description=str(i+1), layout=widgets.Layout(width='auto', height='auto'))
for i in range(len(options))], layout=widgets.Layout(grid_template_columns='repeat(3, 100px)'))
# Combine the image grid and the button grid into a single grid layout
global_grid = widgets.GridBox([image_grid, button_grid], layout=widgets.Layout(grid_template_columns='repeat(2, 300px)'))
# Display a label prompting the user to match the reference image
display(widgets.Label(value='Which of these images match the reference? '))
time.sleep(.2)
# Display the combined grid of images and buttons
display(global_grid)
# Attach click event handlers to the buttons
for b in button_grid.children:
b.on_click(lambda b: on_button_clicked(b, options, correct_answer_image_path, score_widget))
Video 1: Overview#
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Overview")
Section 1: How people recognize new characters#
Let’s put ourselves in the mindset of a cognitive scientist studying handwriting. We’re interested in how people learn to recognize new characters. Indeed, humans display low sample complexity when learning new visual concepts: they seem to grasp new concepts with very few presentations, generalizing effortlessly. In AI, learning from \(k\) labeled examples is known as \(k\)-shot learning; one-shot and few-shot learning refer to learning from one or a few labeled examples.
A good dataset to investigate one-shot learning is the Omniglot dataset. Omniglot has sometimes been described as MNIST, transposed. Instead of thousands of examples from 10 digit classes, Omniglot consists of 20 instances from 1623 character classes. These character classes are sourced from 50 alphabets, both natural (e.g. Cherokee or Greek) and constructed (e.g. the alien alphabet from the TV show Futurama).
Let’s see if you’re a good one-shot classifier by trying the Omniglot task yourself. Observing human behavior in the lab to infer their strategies is an important way that cognitive scientists make progress in understanding human cognition.
Your task is to conduct a series of trials to explore the Omniglot dataset. Here’s how the experiment goes:
Click Start
Look at the reference character at the top
Look at 9 different potential matches at the bottom. These include one more instance of the reference character class, and 8 distractors from other characters of the same alphabet.
Click the button corresponding to the best match. The selection buttons are on the right of the grid.
Repeat for multiple trials. Get to 10 or 20 to get an estimate of how well you perform.
# Paths
base_path = "data/omniglot-py/images_background"
total_score = 0
total_trials = 0
output = widgets.Output()
btn = None
def start(b):
global total_score, total_trials
total_score = 0
total_trials = 0
output.clear_output(wait=True)
run_trial_interactive(base_path, output)
btn.description = 'Reset the interactive'
def on_button_clicked(b, options, correct_answer_image_path, score_widget):
global total_score, total_trials
if options[int(b.description) - 1] == correct_answer_image_path:
total_score += 1
total_trials += 1
output.clear_output(wait=True)
run_trial_interactive(base_path, output)
def display_start_button():
global btn
btn = widgets.Button(description='Start the interactive')
display(btn, output)
btn.on_click(start)
display_start_button()
How well did you do?
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Omniglot_Dataset")
Reflection activity 1.1#
Sample complexity \(N(\varepsilon, \delta)\) is formally defined as:
the number of examples \(N\) that a learner must see in order to perform a task with an error rate smaller than \(\varepsilon\) with probability greater than \(1-\delta\).
Based on this definition, what is your sample complexity on the Omniglot task?
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Reflection_Activity_1")
Reflection activity 1.2#
How do you think you, as a human, are performing a task like Omniglot?
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Reflection_Activity_2")
Section 2: Model of one-shot learning#
Video 2: GNS Model#
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_GNS_Model")
Feinman and Lake (2020) propose a cognitive model to explain how humans perform one-shot recognition tasks like Omniglot. Their model is based on the insight handwriting characters are highly structured: if we could infer how a character was generated, we could figure out the writer’s intent, and perform one-shot recognition of characters.
When we write down a character on a piece of paper or a screen, we might implicitly perform a sequence of steps:
Prepare a global motor plan to write a character based on prior experience
Decide where to put down the pen for the first stroke
Draw a stroke in an appropriate direction.
a. Look at the sheet of paper during the writing to adjust the direction of the stroke
Find a location for the second strike, and so on…
When satisfied, stop drawing strokes
Feinman and Lake (2020) propose to embed these assumptions into a generative model for how a single character is generated from strokes.
The result is a highly structured Bayesian generative model containing both discrete components (e.g. strokes) and continuous components (e.g. the location of the next stroke is a continuous variable). It combines symbolic primitives (strokes) as well as standard ANN components. This combination of using symbols and neural networks is known as a neuro-symbolic approach.
This is an example of a model with strong inductive biases.
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_One_Shot_Learning")
Summary#
Cognitive science seeks to understand how human cognition works.
Humans display one-shot learning on Omniglot, a character recognition task. This requires extensive generalization.
Sample complexity measures the minimum number of examples needed to reach a specific performance with some probability; a sample complexity of 1 indicates one-shot learning at a specific performance level.
A generative neurosymbolic model with strong inductive biases exhibits human-level performance on Omniglot.