{ "cells": [ { "cell_type": "markdown", "id": "26e53498-57e6-477e-8e43-7b8eb1d95882", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "id": "d207f0be-cbd8-4b93-a823-61af51421e2a", "metadata": { "execution": {} }, "source": [ "# Tutorial 1: Generalization in AI\n", "\n", "**Week 1, Day 1: Generalization**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Samuele Bolotta & Patrick Mineault\n", "\n", "__Content reviewers:__ Samuele Bolotta, Lily Chamakura, RyeongKyung Yoon, Yizhou Chen, Ruiyi Zhang, Aakash Agrawal, Alish Dipani, Hossein Rezaei, Yousef Ghanbari, Mostafa Abdollahi, Hlib Solodzhuk\n", "\n", "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk\n" ] }, { "cell_type": "markdown", "id": "35d71fc6-0728-41bd-b84e-ac78a622ece4", "metadata": { "execution": {} }, "source": [ "___\n", "\n", "\n", "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 75 minutes*\n", "\n", "This tutorial will introduce you to generalization in the context of modern AI systems. We'll look at a particular system trained for handwriting recognition–TrOCR. We'll review what makes that model tick–the transformer architecture–and explore what goes on into training and finetuning large-scale models. We'll look at how augmentations can bake in tolerance to certain transformations like scaling and cropping.\n", "\n", "Our learning objectives for today are:\n", "\n", "1. Identify and articulate common objectives pursued by developers of operational AI systems, such as:\n", "\n", "- OOD robustness; Latency; Size, Weight, Power, and Cost (SWaP-C)\n", "- Explainability and understanding\n", "\n", "2. Explain at least three strategies for enhancing the generalization capabilities of AI systems, including the contemporary trend of training generic large-scale models on extensive datasets, commonly referred to as the [\"bitter lesson.\"]((http://www.incompleteideas.net/IncIdeas/BitterLesson.html))\n", "\n", "3. Gain practical experience with the fundamentals of deep learning and PyTorch.\n", "\n", "**Important note**: this tutorial leverages GPU acceleration. Using a GPU runtime in colab will make the the tutorial run 10x faster.\n", "\n", "Let's get started!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2ae629d9-c104-4716-8eb9-8e53e5151395", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/79523/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/79523/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "id": "ee444563-8ef0-4b20-b8d7-92eeace85488", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "id": "57acdfc5-c864-40a0-b648-be385d5c3eb5", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip install vibecheck Pillow matplotlib torch torchvision transformers gradio protobuf sentencepiece gradio torchmetrics --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_neuroai\",\n", " \"user_key\": \"wb2cxze8\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W1D1_T1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "id": "40270953", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Import dependencies\n", "\n", "# Standard Libraries for file and operating system operations, security, and web requests\n", "import os\n", "import functools\n", "import hashlib\n", "import requests\n", "import logging\n", "import io\n", "import re\n", "import time\n", "\n", "# Core python data science and image processing libraries\n", "import numpy as np\n", "from PIL import Image as IMG\n", "from PIL import ImageDraw, ImageFont\n", "import matplotlib.pyplot as plt\n", "import tqdm\n", "\n", "# Deep Learning and model specific libraries\n", "import torch\n", "import torchmetrics.functional.text as fm\n", "import transformers\n", "from torchvision import transforms\n", "from transformers import TrOCRProcessor, VisionEncoderDecoderModel\n", "\n", "# Utility and interface libraries\n", "import gradio as gr\n", "from IPython.display import IFrame, display, Image\n", "import sentencepiece\n", "import zipfile\n", "import pandas as pd\n", "\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e3fa95f5", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "# @markdown\n", "\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting functions\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1bf34b9a-1dd5-458a-b390-0fa12609d532", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plotting functions\n", "\n", "def display_image(image_path):\n", " \"\"\"Display an image from a given file path.\n", "\n", " Inputs:\n", " - image_path (str): The path to the image file.\n", " \"\"\"\n", " # Open the image\n", " image = Image.open(image_path)\n", " if image.mode != 'RGB':\n", " image = image.convert('RGB')\n", "\n", " # Display the image\n", " plt.imshow(image)\n", " plt.axis('off') # Turn off the axis\n", " plt.show()\n", "\n", "def display_transformed_images(image, transformations):\n", " \"\"\"\n", " Apply a list of transformations to an image and display them.\n", "\n", " Inputs:\n", " - image (Tensor): The input image as a tensor.\n", " - transformations (list): A list of torchvision transformations to apply.\n", " \"\"\"\n", " # Convert tensor image to PIL Image for display\n", " pil_image = transforms.ToPILImage()(image)\n", "\n", " fig, axs = plt.subplots(len(transformations) + 1, 1, figsize=(5, 15))\n", " axs[0].imshow(pil_image, cmap='gray')\n", " axs[0].set_title('Original')\n", " axs[0].axis('off')\n", "\n", " for i, transform in enumerate(transformations):\n", " # Apply transformation if it's not the placeholder\n", " if transform != \"Custom ElasticTransform Placeholder\":\n", " transformed_image = transform(image)\n", " # Convert transformed tensor image to PIL Image for display\n", " display_image = transforms.ToPILImage()(transformed_image)\n", " axs[i+1].imshow(display_image, cmap='gray')\n", " axs[i+1].set_title(transform.__class__.__name__)\n", " axs[i+1].axis('off')\n", " else:\n", " axs[i+1].text(0.5, 0.5, 'ElasticTransform Placeholder', ha='center')\n", " axs[i+1].axis('off')\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "def display_original_and_transformed_images(original_tensor, transformed_tensor):\n", " \"\"\"\n", " Display the original and transformed images side by side.\n", "\n", " Inputs:\n", " - original_tensor (Tensor): The original image as a tensor.\n", " - transformed_tensor (Tensor): The transformed image as a tensor.\n", " \"\"\"\n", " fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "\n", " # Display original image\n", " original_image = original_tensor.permute(1, 2, 0) # Convert from (C, H, W) to (H, W, C)\n", " axs[0].imshow(original_image, cmap='gray')\n", " axs[0].set_title('Original')\n", " axs[0].axis('off')\n", "\n", " # Display transformed image\n", " transformed_image = transformed_tensor.permute(1, 2, 0) # Convert from (C, H, W) to (H, W, C)\n", " axs[1].imshow(transformed_image, cmap='gray')\n", " axs[1].set_title('Transformed')\n", " axs[1].axis('off')\n", "\n", " plt.show()\n", "\n", "def display_generated_images(generator):\n", " \"\"\"\n", " Display images generated from strings.\n", "\n", " Inputs:\n", " - generator (GeneratorFromStrings): A generator that produces images from strings.\n", " \"\"\"\n", " plt.figure(figsize=(15, 3))\n", " for i, (text_img, lbl) in enumerate(generator, 1):\n", " ax = plt.subplot(1, len(generator.strings) * generator.count // len(generator.strings), i)\n", " plt.imshow(text_img)\n", " plt.title(f\"Example {i}\")\n", " plt.axis('off')\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", "# Function to generate an image with text\n", "def generate_image(text, font_path, space_width=2, skewing_angle=8):\n", " \"\"\"Generate an image with text.\n", "\n", " Args:\n", " text (str): Text to be rendered in the image.\n", " font_path (str): Path to the font file.\n", " space_width (int): Space width between characters.\n", " skewing_angle (int): Angle to skew the text image.\n", " \"\"\"\n", " image_size = (350, 50)\n", " background_color = (255, 255, 255)\n", " speckle_threshold = 0.05\n", " speckle_color = (200, 200, 200)\n", " background = np.random.rand(image_size[1], image_size[0], 1) * 64 + 191\n", " background = np.tile(background, [1, 1, 4])\n", " background[:, :, -1] = 255\n", " image = IMG.fromarray(background.astype('uint8'), 'RGBA')\n", " image2 = IMG.new('RGBA', image_size, (255, 255, 255, 0))\n", " draw = ImageDraw.Draw(image2)\n", " font = ImageFont.truetype(font_path, size=36)\n", " text_size = draw.textlength(text, font=font)\n", " text_position = ((image_size[0] - text_size) // 2, (image_size[1] - font.size) // 2)\n", " draw.text(text_position, text, font=font, fill=(0, 0, 0), spacing=space_width)\n", " image2 = image2.rotate(skewing_angle)\n", " image.paste(image2, mask=image2)\n", " return image\n", "\n", "# Function to generate images for multiple strings\n", "def image_generator(strings, font_path, space_width=2, skewing_angle=8):\n", " \"\"\"Generate images for multiple strings.\n", "\n", " Args:\n", " strings (list): List of strings to generate images for.\n", " font_path (str): Path to the font file.\n", " space_width (int): Space width between characters.\n", " skewing_angle (int): Angle to skew the text image.\n", " \"\"\"\n", " for text in strings:\n", " yield generate_image(text, font_path, space_width, skewing_angle)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data retrieval\n" ] }, { "cell_type": "code", "execution_count": null, "id": "046a34ac-fa41-4e90-ab45-82c14384a83e", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Data retrieval\n", "\n", "def download_file(fname, url, expected_md5):\n", " \"\"\"\n", " Downloads a file from the given URL and saves it locally.\n", " Verifies the integrity of the file using an MD5 checksum.\n", "\n", " Args:\n", " - fname (str): The local filename/path to save the downloaded file.\n", " - url (str): The URL from which to download the file.\n", " - expected_md5 (str): The expected MD5 checksum to verify the integrity of the downloaded data.\n", " \"\"\"\n", " if not os.path.isfile(fname):\n", " try:\n", " r = requests.get(url)\n", " r.raise_for_status() # Raises an HTTPError for bad responses\n", " except (requests.ConnectionError, requests.HTTPError) as e:\n", " print(f\"!!! Failed to download {fname} due to: {str(e)} !!!\")\n", " return\n", " if hashlib.md5(r.content).hexdigest() == expected_md5:\n", " with open(fname, \"wb\") as fid:\n", " fid.write(r.content)\n", " print(f\"{fname} has been downloaded successfully.\")\n", " else:\n", " print(f\"!!! Data download appears corrupted, {hashlib.md5(r.content).hexdigest()} !!!\")\n", "\n", "def extract_zip(zip_fname, folder='.'):\n", " \"\"\"\n", " Extracts a ZIP file to the specified folder.\n", "\n", " Args:\n", " - zip_fname (str): The filename/path of the ZIP file to be extracted.\n", " - folder (str): Destination folder where the ZIP contents will be extracted.\n", " \"\"\"\n", " if zipfile.is_zipfile(zip_fname):\n", " with zipfile.ZipFile(zip_fname, 'r') as zip_ref:\n", " zip_ref.extractall(folder)\n", " print(f\"Extracted {zip_fname} to {folder}.\")\n", " else:\n", " print(f\"Skipped extraction for {zip_fname} as it is not a zip file.\")\n", "\n", "# Define the list of files to download, including both ZIP files and other file types\n", "file_info = [\n", " (\"Dancing_Script.zip\", \"https://osf.io/32yed/download\", \"d59bd3201b58a37d0d3b4cd0b0ec7400\", '.'),\n", " (\"lines.zip\", \"https://osf.io/8a753/download\", \"6815ed3987f8eb2fd3bc7678c11f2e9e\", 'lines'),\n", " (\"transcripts.csv\", \"https://osf.io/9hgr8/download\", \"d81d9ade10db55603cc893345debfaa2\", None),\n", " (\"neuroai_hello_world.png\", \"https://osf.io/zg4w5/download\", \"f08b81e47f2fe66b5f25b2ccc204c780\", None), # New image file\n", " (\"sample0.png\", \"https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/sample_0.png?raw=true\", '920ae567f707bfee0be29dc854f804ed', None),\n", " (\"sample1.png\", \"https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/sample_1.png?raw=true\", 'cd28623a829b40d0a1dd8c0f17e9ebd7', None),\n", " (\"sample2.png\", \"https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/sample_2.png?raw=true\", 'c189c09abf989eac4e1a8d493bd362d7', None),\n", " (\"sample3.png\", \"https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/sample_3.png?raw=true\", 'dcffc678266952f18af1fc1242127e98', None)\n", "]\n", "\n", "import contextlib\n", "import io\n", "\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " # Process the downloads and extractions\n", " for fname, url, expected_md5, folder in file_info:\n", " download_file(fname, url, expected_md5)\n", " if folder is not None:\n", " extract_zip(fname, folder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Overview\n" ] }, { "cell_type": "code", "execution_count": null, "id": "192ee23b-4ebe-40c1-a6f3-fb1a1ba9d294", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Overview\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'PgA7wfi2eDo'), ('Bilibili', 'BV1Bm421L7zB')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6dff30da-aa91-4ada-babd-1b1e652b8589", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_overview_video\")" ] }, { "cell_type": "markdown", "id": "ef020325-5f43-4e07-bb44-5a5a4d816219", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 1: Motivation: building a handwriting recognition app with AI\n", "\n", "Let’s put ourselves into the mindset of an AI developer who wants to build a note app featuring handwriting recognition." ] }, { "cell_type": "markdown", "id": "febbaa92-2a2d-4e42-a236-0bae90c7f753", "metadata": { "cellView": "form", "execution": {} }, "source": [ "![Picture which shows the goal of the day.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/W1D1_goal.png?raw=true)" ] }, { "cell_type": "markdown", "id": "d0367ba9-a7ff-49c1-8bd1-0860130e7f39", "metadata": { "execution": {} }, "source": [ "Our intrepid developer doesn't want to start from scratch, so searches for a pretrained model. They find a suitable model hosted on HuggingFace, the largest repository of pretrained natural language and vision models. [TrOCR](https://huggingface.co/docs/transformers/en/model_doc/trocr) is a Transformer-based model that performs Optical Character Recognition and handwriting transcription. Several checkpoints are available, finetuned for different downstream applications like handwriting transcription and printed character recognition. Our relieved developer draws a deep sigh: they don't have to start from scratch." ] }, { "cell_type": "markdown", "id": "e1d7fdfb-4333-42b0-9906-aa7a54b7bba8", "metadata": { "cellView": "form", "execution": {} }, "source": [ "![Picture which shows trocr architecture.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/trocr_architecture.png?raw=true)" ] }, { "cell_type": "markdown", "id": "4bcd7a5a-b3cb-4999-a373-3dda956ba77f", "metadata": { "execution": {} }, "source": [ "In this tutorial, we'll look at the design considerations that go into training and deploying a model like TrOCR, what goes on inside the model's transformers, and how it achieves good or bad out-of-distribution generalization. While the NeuroAI course as a whole will explore new ideas at the frontier of neuroscience and AI, we'll first want to understand one of the bread-and-butter building blocks used in industrial AI: the transformer.\n", "\n", "Let's try out this model ourselves!\n", "\n", "## Interactive demo 1: TrOCR\n", "\n", "We load a pretrained TrOCR checkpoint from HuggingFace. The `transformers` package from HuggingFace allows us to download a PyTorch model definition and a preprocessing class, and to load a pretrained checkpoint in just a few lines of code." ] }, { "cell_type": "code", "execution_count": null, "id": "b00326ca-807d-460f-adf4-767e94bc0ccc", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Load the pre-trained TrOCR model and processor\n", "with contextlib.redirect_stdout(io.StringIO()):\n", " model = VisionEncoderDecoderModel.from_pretrained(\"microsoft/trocr-base-handwritten\")\n", " model.to(device=device)\n", " processor = TrOCRProcessor.from_pretrained(\"microsoft/trocr-base-handwritten\", use_fast=False)" ] }, { "cell_type": "markdown", "id": "1e20e1f0-c052-42f5-a33a-87d89548c90c", "metadata": { "execution": {} }, "source": [ "We now write a callback function that calls the preloaded model to decode a particular image." ] }, { "cell_type": "code", "execution_count": null, "id": "1d1debc6-b3b1-4173-839c-f2a4240efb2d", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Define the function to recognize text from an image\n", "def recognize_text(processor, model, image):\n", " \"\"\"\n", " This function takes an image as input and uses a pre-trained language model to generate text from the image.\n", "\n", " Inputs:\n", " - processor: The processor to use\n", " - model: The model to use\n", " - image (PIL Image or Tensor): The input image containing text to be recognized.\n", "\n", " Outputs:\n", " - text (str): The recognized text extracted from the input image.\n", " \"\"\"\n", " print(image)\n", " pixel_values = processor(images=image, return_tensors=\"pt\").pixel_values\n", " generated_ids = model.generate(pixel_values.to(device))\n", " text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n", " return text" ] }, { "cell_type": "markdown", "id": "88658698-5e50-4740-8b46-087a8417d189", "metadata": { "execution": {} }, "source": [ "We build a simple interface in `gradio` to try out the model interactively. Go ahead and try some example text to see how it works. You can use images from the internet, or scan your own handwriting. Just make sure that the text fits on one line. Observe the result of the recognized text." ] }, { "cell_type": "code", "execution_count": null, "id": "a73d5f88-0296-42b2-86af-aa7f55e8ba11", "metadata": { "execution": {} }, "outputs": [], "source": [ "import gradio as gr\n", "import functools\n", "\n", "with gr.Blocks() as demo:\n", " gr.HTML(\"

Interactive demo: TrOCR

\")\n", " gr.Markdown(\"Upload a single image or click one of the examples to try this.\")\n", "\n", " # Define the examples\n", " examples = [\n", " 'neuroai_hello_world.png',\n", " 'sample1.png',\n", " 'sample2.png',\n", " 'sample3.png',\n", " ]\n", "\n", " # Create the image input component\n", " image_input = gr.Image(type=\"pil\", label=\"Upload Image\")\n", "\n", " # Create the example gallery\n", " example_gallery = gr.Examples(\n", " examples,\n", " image_input,\n", " )\n", "\n", " # Create the submit button\n", " with gr.Row():\n", " submit_button = gr.Button(\"Recognize Text\", scale=1)\n", "\n", " # Create the text output component\n", " text_output = gr.Textbox(label=\"Recognized Text\", scale=2)\n", "\n", " # Define the event listeners\n", " submit_button.click(\n", " fn=functools.partial(recognize_text, processor, model),\n", " inputs=image_input,\n", " outputs=text_output\n", " )\n", "\n", "# Launch the interface\n", "demo.launch(height=650)\n", "_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "5763d833-7740-42a2-9118-b691988b456c", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Interactive_Demo_1\")" ] }, { "cell_type": "markdown", "id": "5479e8f3-33ae-4ce5-87f6-a2c34b25152a", "metadata": { "execution": {} }, "source": [ "### Discussion point 1\n", "\n", "How effective is the model's performance? Does it exhibit generalization beyond its training vocabulary?" ] }, { "cell_type": "markdown", "id": "ee0bf6d7", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial1_Solution_21c15345.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8df5e247-a5ee-4e96-bd44-b5a02f832965", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Discussion_Point_1\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Video 2: OOD Generalization\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f40ee421-3014-4749-947b-221f44f2dd05", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: OOD Generalization\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'pPljFAsgzA8'), ('Bilibili', 'BV1jx4y1b7Xh')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c341b09d-1785-4aaa-8460-8f4c1723902e", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_ood_generalization\")" ] }, { "cell_type": "markdown", "id": "93083419-5d48-4858-9bd7-979d5238e9d7", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Measuring out-of-distribution generalization in TrOCR\n", "\n", "How well does TrOCR work in practice? Our developer needs to know!\n", "\n", "Something you will see a lot of in machine learning papers are tables filled with benchmarks. The tables in the [TrOCR official paper](https://arxiv.org/abs/2109.10282) include measures of performance on different benchmark datasets, including IAM, a handwriting database assembled in 1999. The base and large model variants (334M and 558M parameters) display **character error rates (CER) of 3.42 and 2.89, respectively**. That means it gets 97% of characters correct.\n", "\n", "\"Wow!\", our developer thinks, \"That's probably good enough for my notes app! Guess I can go ahead and deploy it\"." ] }, { "cell_type": "markdown", "id": "5a422ca8-d407-4207-9921-d18ad82ee3a4", "metadata": { "execution": {} }, "source": [ "## Think! 1\n", "\n", "What are some reasons why the character error rate measured on IAM might be too optimistic?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "10b2fbd6-a130-4998-b7ee-e99d99098537", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Think_1\")" ] }, { "cell_type": "markdown", "id": "98ff99d0-b420-40a9-a77f-913d41f7b1d8", "metadata": { "execution": {} }, "source": [ "## Coding activity 1: Measuring out-of-distribution generalization\n", "\n", "Our developer reads through the fine print in the paper and realizes that the TrOCR is both *trained* on IAM and *tested* on IAM, on a different set of subjects. To be clear, the train and test splits are *distinct*; but samples come from the same underlying distribution. Our developer realizes that the reported error rates might be too optimistic:\n", "\n", "* IAM was recorded on a tablet. Our developer wants to be able to recognize lines of text handwritten on paper.\n", "* IAM is 25 years old. Maybe people write differently now compared to in the past. Do they even write in cursive anymore?\n", "* The sentences in IAM are based on a widely published corpus. Maybe TrOCR has memorized that corpus.\n", "\n", "The more the developer thinks about it, the more they realize that the paper is really estimating *in-distribution generalization*. However, what they care about is how well the model will work when it's deployed *in the wild*, which is closer to **out-of-distribution generalization**.\n", "\n", "In this coding activity, you'll measure out-of-distribution generalization on a small subset of the CVL database:\n", "\n", "> Kleber, F., Fiel, S., Diem, M., & Sablatnig, R. (2018). [CVL Database - An Off-line Database for Writer Retrieval, Writer Identification and Word Spotting [Data set]. Zenodo.](https://doi.org/10.5281/zenodo.1492267)\n", "\n", "Let's first have a look at this new out-of-distribution dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run this cell to visualize dataset.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a88a369a-d65c-401f-8ef5-72a38d6d2eec", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Run this cell to visualize dataset.\n", "def get_images_and_transcripts(df, subject):\n", " df_ = df[df.subject == subject]\n", " transcripts = df_.transcript.values.tolist()\n", "\n", " # Load the corresponding images\n", " images = []\n", " for _, row in df_.iterrows():\n", " images.append(IMG.open(row.filename))\n", "\n", " return images, transcripts\n", "\n", "def visualize_images_and_transcripts(images, transcripts):\n", " for img in images:\n", " display(img)\n", "\n", " for transcript in transcripts:\n", " print(transcript)\n", "\n", "df = pd.read_csv('transcripts.csv')\n", "df['filename'] = df.apply(lambda x: f\"lines/{x.subject:04}-{x.line}.jpg\", axis=1)\n", "df" ] }, { "cell_type": "markdown", "id": "48c9bbd1-5ddf-4851-aa13-b7a9916b4027", "metadata": { "execution": {} }, "source": [ "This is a small test set with 94 lines sampled from 10 different subjects. Let's have a look at the data from subject 54." ] }, { "cell_type": "code", "execution_count": null, "id": "3d64ebb4-2e02-4c20-a889-614fddd4c72a", "metadata": { "execution": {} }, "outputs": [], "source": [ "images, true_transcripts = get_images_and_transcripts(df, 52)\n", "visualize_images_and_transcripts(images, true_transcripts)" ] }, { "cell_type": "markdown", "id": "be0c13ca-1409-47a9-a0f2-6026371b855b", "metadata": { "execution": {} }, "source": [ "The text is transcribed from a passage in the novel [Flatland by Edwin Abbott Abbott](https://en.wikipedia.org/wiki/Flatland). The data is conceptually similar to the IAM database, with single isolated lines of text, but it was recorded on paper less than 10 years ago, so it should be more representative of how people write on paper today.\n", "\n", "How well does the model recognize the text? Run this cell to find out." ] }, { "cell_type": "code", "execution_count": null, "id": "441bdfcd-4dab-4864-8f22-fb64d6bcb680", "metadata": { "execution": {} }, "outputs": [], "source": [ "def transcribe_images(all_images, model, processor):\n", " \"\"\"\n", " Transcribe a batch of images using an OCR model.\n", "\n", " Args:\n", " all_images: a list of PIL images.\n", " model: the model to do image-to-token ids\n", " processor: the processor which maps token ids to text\n", "\n", " Returns:\n", " a list of the transcribed text.\n", " \"\"\"\n", " pixel_values = processor(images=all_images, return_tensors=\"pt\").pixel_values\n", " generated_ids = model.generate(pixel_values.to(device))\n", " decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=True)\n", " return decoded_text\n", "\n", "transcribed_text = transcribe_images(images, model, processor)\n", "print(transcribed_text)" ] }, { "cell_type": "markdown", "id": "d9b114a8-4aa5-420e-a490-c3b40b9b66f8", "metadata": { "execution": {} }, "source": [ "### Code exercise 1.1: Calculate CER and WER\n", "\n", "The model is not perfect but it performs far better than chance. Let's measure the character and the word error rates on this subject's data. \n", "\n", "The character error rate between a reference string `ref` and a predicted string `pred` is defined as:\n", "\n", "$$CharErrorRate = \\frac{S+D+I}{N}$$\n", "\n", "* $N$ is the number of characters in the reference string\n", "* $S$ is the number of substitutions to transform the predicted string to the reference string\n", "* $D$ is the number of deletions to transform the predicted string to the reference string\n", "* $I$ is the number of insertions to transform the predicted string to the reference string\n", "\n", "For example, to transform `3nanas` to `banana`, we'd need to replace `3` with `b`, insert an `a`, and delete the `s`. The character error rate would $(1+1+1)/6=0.5$. The word error rate is defined similarly, but at the single word rather than the character level.\n", "\n", "Thankfully, we can use a library function to help us out! `torchmetrics.functional.text.char_error_rate(preds, refs)` calculates the average character error rate over a list of predictions and references. `torchmetrics.functional.text.word_error_rate(preds, refs)` does the same for the average word error rate.\n", "\n", "Fill in missing code to measure character and word error rates on this dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "978ce371-2074-4258-a6a1-b73a0776a301", "metadata": { "execution": {} }, "outputs": [], "source": [ "import torchmetrics.functional.text as fm\n", "\n", "def clean_string(input_string):\n", " \"\"\"\n", " Clean string prior to comparison\n", "\n", " Args:\n", " input_string (str): the input string\n", "\n", " Returns:\n", " (str) a cleaned string, lowercase, alphabetical characters only, no double spaces\n", " \"\"\"\n", "\n", " # Convert all characters to lowercase\n", " lowercase_string = input_string.lower()\n", "\n", " # Remove non-alphabetic characters\n", " alpha_string = re.sub(r'[^a-z\\s]', '', lowercase_string)\n", "\n", " # Remove double spaces and start and end spaces\n", " return re.sub(r'\\s+', ' ', alpha_string).strip()\n", "\n", "\n", "def calculate_mismatch(estimated_text, reference_text):\n", " \"\"\"\n", " Calculate mismatch (character and word error rates) between estimated and true text.\n", "\n", " Args:\n", " estimated_text: a list of strings\n", " reference_text: a list of strings\n", "\n", " Returns:\n", " A tuple, (CER and WER)\n", " \"\"\"\n", " # Lowercase the text and remove special characters for the comparison\n", " estimated_text = [clean_string(x) for x in estimated_text]\n", " reference_text = [clean_string(x) for x in reference_text]\n", "\n", " ############################################################\n", " # Fill in this code to calculate character error rate and word error rate.\n", " # Hint: have a look at the torchmetrics documentation for the proper\n", " # metrics (type the proper metric name in the search bar).\n", " #\n", " # https://lightning.ai/docs/torchmetrics/stable/\n", " raise NotImplementedError(\"Student has to fill in these lines\")\n", " ############################################################\n", "\n", " # Calculate the character error rate and word error rates. They should be\n", " # raw floats, not tensors.\n", " cer = ...\n", " wer = ...\n", " return (cer, wer)" ] }, { "cell_type": "markdown", "id": "4933f09f-eaa2-42ec-8443-6cf7addb052b", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial1_Solution_4d36b048.py)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "54d82e89-04d2-4fae-a983-b483a6907103", "metadata": { "execution": {} }, "outputs": [], "source": [ "cer, wer = calculate_mismatch(transcribed_text, true_transcripts)\n", "assert isinstance(cer, float)\n", "cer, wer" ] }, { "cell_type": "markdown", "id": "94dda4a0-4f9d-4dba-9045-815d8ab3829e", "metadata": { "execution": {} }, "source": [ "For this particular subject, the character error rate is 3.3%, while the word error rate is 10%. Not bad, and in line with the results in the paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1b529eb7-cc55-40a9-97f2-0415570449af", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Code_Exercise_1.1\")" ] }, { "cell_type": "markdown", "id": "f0116627-1f93-4f00-8cb5-8bac5cbcbb44", "metadata": { "execution": {} }, "source": [ "### Code exercise 1.2: Calculate CER and WER across all subjects\n", "\n", "Let's measure the same metric, this time across all subjects. Note: If you run this code on the CPU, it might take around 5 minutes to complete. " ] }, { "cell_type": "code", "execution_count": null, "id": "eb4be1ee-635e-487f-8708-d5e6287fde8b", "metadata": { "execution": {} }, "outputs": [], "source": [ "def calculate_all_mismatch(df, model, processor):\n", " \"\"\"\n", " Calculate CER and WER for all subjects in a dataset\n", "\n", " Args:\n", " df: a dataframe containing information about images and transcripts\n", " model: an image-to-text model\n", " processor: a processor object\n", "\n", " Returns:\n", " a list of dictionaries containing a per-subject breakdown of the\n", " results\n", " \"\"\"\n", " subjects = df.subject.unique().tolist()\n", "\n", " results = []\n", "\n", " # Calculate CER and WER for all subjects\n", " for subject in tqdm.tqdm(subjects):\n", " ############################################################\n", " # Fill in the section to calculate the cer and wer for a\n", " # single subject. Look up at other sections to see how it's\n", " # done.\n", " raise NotImplementedError(\"Student exercise\")\n", " ############################################################\n", "\n", " # Load images and labels for a given subject\n", " images, true_transcripts = ...\n", "\n", " # Transcribe the images to text\n", " transcribed_text = ...\n", "\n", " # Calculate the CER and WER\n", " cer, wer = ...\n", "\n", " results.append({\n", " 'subject': subject,\n", " 'cer': cer,\n", " 'wer': wer,\n", " })\n", " return results" ] }, { "cell_type": "markdown", "id": "327b1cb3-cc7b-415f-a790-06054cfd53a5", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial1_Solution_cbbb272d.py)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4fdede9c-1958-476f-94da-de6483cab32e", "metadata": { "execution": {} }, "outputs": [], "source": [ "results = calculate_all_mismatch(df, model, processor)\n", "df_results = pd.DataFrame(results)\n", "df_results" ] }, { "cell_type": "markdown", "id": "d7b1ed00-87ef-4dc4-993e-0ff0118b9baa", "metadata": { "execution": {} }, "source": [ "Not all subjects are as easy to transcribe as subject 52! Let's check out subject 57, who has high CER and WER." ] }, { "cell_type": "code", "execution_count": null, "id": "a678275e-d2b3-4046-9898-59fa92f85e5a", "metadata": { "execution": {} }, "outputs": [], "source": [ "print(\"A subject that's harder to read\")\n", "images, true_transcripts = get_images_and_transcripts(df, 57)\n", "visualize_images_and_transcripts(images, true_transcripts)" ] }, { "cell_type": "markdown", "id": "45aec500-3b20-4738-90a0-185f8e427c4e", "metadata": { "execution": {} }, "source": [ "Indeed, this text seems harder to read." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3caf97a2-dbf4-4cac-94de-350b76228317", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Code_Exercise_1.2\")" ] }, { "cell_type": "markdown", "id": "0369b40f-f00c-4c14-9ac5-fa14c73ff301", "metadata": { "execution": {} }, "source": [ "### Code exercise 1.3: Measure OOD generalization\n", "\n", "What we've done thus far is to measure the empirical loss–the character error rate–for each subject. The empirical loss is defined as:\n", "\n", "$$R^e(\\theta) = \\mathbb{E}^e[ L(y, f(x, \\theta)) ] $$\n", "\n", "Here:\n", "\n", "* The environment $e$ is the training distribution\n", "* $R^e(\\theta)$ is the empirical risk in an environment\n", "* $\\theta$ are the learned parameters of the TrOCR model\n", "* $x$ is the conditioning data, that is, the images\n", "* $f$ is the function approximated by the TrOCR model, which maps images to probabilities of certain tokens\n", "* $L$ is the loss (or metric–not necessarily differentiable) for a single line of text, the character error rate (CER)\n", "* $\\mathbb{E}^e$ is the expectation taken over all the samples\n", "\n", "A single environment $e$ corresponds to a single subject. The out-of-distribution generalization is instead given by:\n", "\n", "$$R^{OOD} = \\max_{e \\in \\mathcal{E}_{all}} R^e(\\theta) $$\n", "\n", "It's the worst-case empirical loss over the out-of-distribution environments ${e \\in \\mathcal{E}_{all}}$ we wish to deploy on. In other words, the character error rate for the subject with the most difficult-to-read handwriting.\n", "\n", "Intuitively, our AI developer's vision of robustness might be: my note transcription app is robust and generalizes if it works well even when someone has illegible handwriting. The app is only as good as how well it works in the worst-case scenario. Let's measure that." ] }, { "cell_type": "code", "execution_count": null, "id": "cca28eb2-40ba-4b3b-8b89-adddcbe63a48", "metadata": { "execution": {} }, "outputs": [], "source": [ "def calculate_mean_max_cer(df_results):\n", " \"\"\"\n", " Calculate the mean character-error-rate across subjects as\n", " well as the maximum (that is, the OOD risk).\n", "\n", " Args:\n", " df_results: a dataframe containing results\n", "\n", " Returns:\n", " A tuple, (mean_cer, max_cer)\n", " \"\"\"\n", " ############################################################\n", " # Fill in the section to calculate the mean and max cer\n", " # across subjects.\n", " raise NotImplementedError(\"Student exercise\")\n", " ############################################################\n", "\n", " # Calculate the mean CER across test subjects.\n", " mean_subjects = ...\n", "\n", " # Calculate the max CER across test subjects.\n", " max_subjects = ...\n", " return mean_subjects, max_subjects" ] }, { "cell_type": "markdown", "id": "0b2f1759-816f-4af7-ab6e-cd1cb14e3d53", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial1_Solution_7cf70ea7.py)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ee896b4a-181b-4b49-9439-ad44373af99e", "metadata": { "execution": {} }, "outputs": [], "source": [ "mean_subjects, max_subjects = calculate_mean_max_cer(df_results)\n", "mean_subjects, max_subjects" ] }, { "cell_type": "markdown", "id": "b9c20d9b-5ebb-4e86-8c4b-2c16bab3477c", "metadata": { "execution": {} }, "source": [ "We see that:\n", "\n", "* when measured on this (admittedly small) out-of-distribution dataset, the average character error rate is about 5.8%, larger than the 3.4% reported for IAM\n", "* the out-of-distribution character error rate is 12%\n", "\n", "Whether that's good enough for our AI developer depends on the use case." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f776acfc-443e-400d-8560-cf0dac777213", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Code_Exercise_1.3\")" ] }, { "cell_type": "markdown", "id": "37296654-60d7-4a58-b93d-2e27458bcc7b", "metadata": { "execution": {} }, "source": [ "## Discussion\n", "\n", "Numbers in tables filled with benchmarks don't tell the whole story: often, we care about OOD robustness. Our developer benchmarked the TrOCR model for their use case and found a worst-case character error rate above 10%. Whether or not that's acceptable is a judgment call, and it's not the only metric the developer might care about. They might also need to meet other constraints:\n", "\n", "- Memory, FLOPs, latency, cost of inference: the deployment environment might not be able to support very large-scale models because of memory or compute constraints, or those would run too slowly for the use case. Cloud inference might not be practical with limited internet access.\n", "- SWaP-C: if the model is embodied in a physical device, the Size, Weight, Power and Cost of that device will ultimately be important. More powerful models can require bigger, heavier, more power-hungry hardware.\n", "- Latency of development: a bespoke model developed from scratch might take a long time to develop; our busy developer might prefer to adopt a pretrained, sub-optimal architecture than using a custom architecture\n", "- Cost of upkeep: machine learning systems can be notoriously difficult to keep running. Our developer might prefer to use a suboptimal system managed by somebody else rather than taking on the burden of dealing with the upkeep themselves.\n", "\n", "Our intrepid developer wants to ship this app soon! They decide on a strategy: the model is good enough to get started. They'll deploy the model as is, but they'll have an option in the app to report errors. They'll then label *those* errors and fine-tune the model. Before that though, they want to understand what's inside the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 3: TrOCR\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f3bd6315-fb5a-47ad-8dbd-69deb22f07fc", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 3: TrOCR\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'CFhBX4CL-88'), ('Bilibili', 'BV1iz421b7Qb')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c97e866e-d45c-4aa2-8746-0f36ff23f08d", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_trocr\")" ] }, { "cell_type": "markdown", "id": "fab3c4e0-7f5b-4518-9a01-2f930f8e81e2", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 3: Dissecting TrOCR\n", "\n", "TrOCR (transformer-based optical character recognition) is a model that performs printed optical character recognition and handwriting transcription on the basis of two transformers. But what's inside of it?" ] }, { "cell_type": "markdown", "id": "58e68dd4-b3dc-40ac-92af-8ad949fc0764", "metadata": { "cellView": "form", "execution": {} }, "source": [ "![Picture which shows trocr architecture.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/trocr_architecture.png?raw=true)\n", "\n" ] }, { "cell_type": "markdown", "id": "8eef6729-0f0b-4f4c-87b2-415cd792d865", "metadata": { "execution": {} }, "source": [ "TrOCR uses two transformers in an encoder-decoder architecture:\n", "\n", "1. An encoder, a vision transformer (ViT), maps 16x16 patches of the image to individual tokens\n", "2. A decoder, a text transformer, maps previously decoded text and the encoder's hidden state to the next token in the sequence to be decoded. This is known as causal language modeling." ] }, { "cell_type": "markdown", "id": "10c3ee14-f0cc-4639-bcb4-0ca4e0167a2c", "metadata": { "execution": {} }, "source": [ "## Section 3.1: A recap of transformers\n", "\n", "[We covered transformers in W2D5 of the DL course](https://deeplearning.neuromatch.io/tutorials/W2D5_AttentionAndTransformers/student/W2D5_Tutorial1.html). Let's quickly recap them. Transformers are a class of deep learning architectures that have become dominant in natural language processing (NLP) since their introduction in the paper \"Attention is All You Need\" by Vaswani et al. in 2017. Their success in natural language processing has led to their application across other domains, including computer vision, which is the case with TrOCR." ] }, { "cell_type": "markdown", "id": "7454f061-01c1-4e06-b878-405b4f270099", "metadata": { "cellView": "form", "execution": {} }, "source": [ "![Picture which shows one layer transformer.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/transformer_one_layer.png?raw=true)" ] }, { "cell_type": "markdown", "id": "AL2Wr4LNQPc6", "metadata": { "execution": {} }, "source": [ "\n", "\n", "*Illustration from Alammar, J (2018). The Illustrated Transformer. Retrieved from https://jalammar.github.io/illustrated-transformer/*\n", "\n", "Transformers are built on self-attention, allowing them to weigh the importance of different parts of the input data differently. This has proven useful for tasks that require an understanding of context, such as language translation, text summarization, and, as we will see, optical character recognition. Some key components of transformers are:\n", "\n", "- Tokenization: the input sequence (e.g. sentence) is split into different components (e.g. word pieces). Each component, or token, is embedded into a fixed dimensional space. In natural language processing, tokenization is done via a lookup table: every word piece is mapped to a fixed-dimensional vector. [See W3D1 of the DL course for a refresher on tokenization](https://deeplearning.neuromatch.io/tutorials/W3D1_TimeSeriesAndNaturalLanguageProcessing/student/W3D1_Tutorial2.html?highlight=word2vec#tokenizers).\n", "\n", "- Self-attention: A self-attention mechanism allows the tokens in the sequence to interact to form new representations. Specifically, queries and keys are derived from tokens; an inner product between queries and keys, followed by a softmax, forms the attention matrix. The attention matrix is multiplied by the value matrix to obtain a new representation.\n", "\n", "- Positional encoding: Positional encoding is added to the input to give the model information about the position of each token within the sequence. Unlike RNNs or CNNs, transformers **do not process data in order–without position encoding, they are permutation invariant**. We'll dig deeper into what this implies in the section on the inductive biases of transformers.\n", "\n", "- Layer Normalization and Residual Connections are used within the transformer architecture to stabilize the learning process and improve the model's ability to learn deep representations.\n", "\n", "One of the key advantages of transformers over previous architectures is a high degree of parallelism, which allows one to train larger, more capable models. Let's inspect TrOCR's architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2be60c27-0974-4da3-a70a-4f33e37d1a91", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Recap_Transformers\")" ] }, { "cell_type": "markdown", "id": "0de8292e-1509-4f08-a2fc-fe444211669a", "metadata": { "execution": {} }, "source": [ "## Section 3.2: The encoder\n", "\n", "Let's dig in more specifically into the **encoder** inside of TrOCR. It's a visual transformer (ViT), an adaptation of transformers for problems in vision. It proceeds as follows:\n", "\n", "1. It takes a raw image and resizes it to 384x384\n", "2. It chops it up into 16x16 patches\n", "3. It embeds each patch inside a fixed, 768-dimensional space, adding positional embeddings\n", "4. It passes the patches through self-attention layers.\n", "5. It ends up with one token for each patch, plus one for the class embedding, $577=(384/16)^2+1$. \n", "\n", "Let's look at the structure of the encoder:" ] }, { "cell_type": "code", "execution_count": null, "id": "70729ebc-9d99-4d06-8f71-b22656138065", "metadata": { "execution": {} }, "outputs": [], "source": [ "model.encoder" ] }, { "cell_type": "markdown", "id": "f97fe34c-c68c-468d-8afe-e27f87a6ebc4", "metadata": { "execution": {} }, "source": [ "### Code exercise 3.1: Understanding the inputs and outputs of the encoder\n", "\n", "Let's make sure we understand how the encoder operates by giving it a sample input and checking that its output matches the expected shape." ] }, { "cell_type": "code", "execution_count": null, "id": "11c3b9aa-0b35-42a7-adb8-a817526ed5cb", "metadata": { "execution": {} }, "outputs": [], "source": [ "def inspect_encoder(model):\n", " \"\"\"\n", " Inspect encoder to verify that it processes inputs in the expected way.\n", "\n", " Args:\n", " model: the TrOCR model\n", " \"\"\"\n", " ##################################################################\n", " # Feed the encoder an input and measure the output to understand\n", " # the role of the vision encoder.\n", " raise NotImplementedError(\"Student exercise\")\n", " #\n", " ##################################################################\n", " # Create an empty tensor (batch size of 1) to feed it to the encoder.\n", " # Remember that images should have 3 channels and have size 384x384\n", " # Recall that images are fed in pytorch with tensors of shape\n", " # batch x channels x height x width\n", " single_input = ...\n", "\n", " # Run the input through the encoder.\n", " output = ...\n", "\n", " # Measure the number of hidden tokens which are the output of the encoder\n", " hidden_shape = output['last_hidden_state'].shape\n", "\n", " assert hidden_shape[0] == 1\n", " assert hidden_shape[1] == 577\n", " assert hidden_shape[2] == 768" ] }, { "cell_type": "markdown", "id": "ca6177ce-7005-48ae-8faa-d4008ae5942d", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D1_Generalization/solutions/W1D1_Tutorial1_Solution_22613224.py)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ee2eec7c-6173-4113-b303-8722e6b356d9", "metadata": { "execution": {} }, "outputs": [], "source": [ "inspect_encoder(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "28dc20be-dc74-4377-bde0-1231794fa37c", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Code_Exercise_3.1\")" ] }, { "cell_type": "markdown", "id": "1046051f-f33a-48bf-bf1b-a82de33ead17", "metadata": { "execution": {} }, "source": [ "The vision transformer acts much like a conventional encoder transformer in sequence-to-sequence tasks: it maps the input sequence to a hidden representation, the image tokens. This hidden representation is then attended during decoding using cross-attention." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f1026f2c-b9a2-42b2-b2b4-ca7d955436b7", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Encoder\")" ] }, { "cell_type": "markdown", "id": "4e1abb36-ed14-4873-a2de-b12e7633649d", "metadata": { "execution": {} }, "source": [ "## Section 3.3: The decoder" ] }, { "cell_type": "markdown", "id": "0c87c4b6-eec6-47f5-9e84-62d22dc4345b", "metadata": { "execution": {} }, "source": [ "The decoder is another transformer that attends to both the image tokens and string tokens. At a given point in the decoding, the decoder uses both the reference image and the string prefix to predict the next string to produce. In this fashion, the transcript is built one string at a time. \n", "\n", "We can inspect the decoder to find both self-attention layers `self_attn` that attend to the string prefix, as well as cross-attention layers `encoder_attn` that attend to the image. " ] }, { "cell_type": "code", "execution_count": null, "id": "415f93ae-1d5d-4b9d-9d76-fad1685ac923", "metadata": { "execution": {} }, "outputs": [], "source": [ "model.decoder" ] }, { "cell_type": "markdown", "id": "7a86b22e-af98-4d70-94e7-a832a9158824", "metadata": { "execution": {} }, "source": [ "Notice that `encoder_attn` layers have an input dimensionality of 768, which matches the shape of the visual tokens, while its output dimensionality is 1024, which matches the string tokens.\n", "\n", "To see how the decoder takes a visual input to generate a text caption, we can feed a sample image to the encoder to obtain its encoding, then pass it to the decoder, and inspect the outputs." ] }, { "cell_type": "code", "execution_count": null, "id": "9302252d-8ce9-4229-9a86-f5758d932bd6", "metadata": { "execution": {} }, "outputs": [], "source": [ "# The sample image\n", "images[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "0aaa5e65-ecc8-4ed4-a026-17f3be9bba38", "metadata": { "execution": {} }, "outputs": [], "source": [ "pixel_values = processor(images=[images[0]], return_tensors=\"pt\").pixel_values\n", "encoded_image = model.encoder(pixel_values.to(device))\n", "encoded_image.last_hidden_state.shape" ] }, { "cell_type": "markdown", "id": "ced6a1bd-262e-4612-9373-340921c32704", "metadata": { "execution": {} }, "source": [ "Consistent with what we found previously, the image is encoded into 577 tokens of dimensionality 768. Let's pass these to the decoder:" ] }, { "cell_type": "code", "execution_count": null, "id": "811c6ea9-36b7-4c0b-b8a6-8eb0981e75e1", "metadata": { "execution": {} }, "outputs": [], "source": [ "decoded = model.decoder.forward(\n", " input_ids=torch.Tensor([[0]]).to(device, dtype=int),\n", " encoder_hidden_states=encoded_image['last_hidden_state'],\n", ")\n", "print(decoded.logits.shape)\n", "decoded.logits.argmax()" ] }, { "cell_type": "markdown", "id": "6714187e-6b37-4a36-a11c-cc6a9adbdb3c", "metadata": { "execution": {} }, "source": [ "The decoder gives probabilities for all 50265 potential string tokens in the tokenizer's vocabulary. The most likely token has the number 31206. What does this correspond to? The `processor` can translate between token numbers and strings. Let's give it a whirl:" ] }, { "cell_type": "code", "execution_count": null, "id": "87cb3d82-51c1-4147-83dc-1164aa3ca263", "metadata": { "execution": {} }, "outputs": [], "source": [ "processor.tokenizer.decode(31206)" ] }, { "cell_type": "markdown", "id": "1f190edb-913d-42e7-996e-e289284c6980", "metadata": { "execution": {} }, "source": [ "It's the first word in the sentence! We can keep feeding the outputs of the decoder to itself to build a string decoding." ] }, { "cell_type": "code", "execution_count": null, "id": "a290963a-8f45-4824-a36b-7fd640cd577c", "metadata": { "execution": {} }, "outputs": [], "source": [ "decoded = model.decoder.forward(\n", " input_ids=torch.Tensor([[0, 31206]]).to(device, dtype=int),\n", " encoder_hidden_states=encoded_image['last_hidden_state'],\n", ")\n", "processor.tokenizer.decode(decoded.logits[:, -1, :].argmax().item())" ] }, { "cell_type": "markdown", "id": "d68e09d7-8b64-4002-843f-f99db33b257a", "metadata": { "execution": {} }, "source": [ "Continuing this process allows us to transcribe the entire image. Greedily choosing the most likely word can lead to suboptimal decoding, however. A common technique to improve this is to keep multiple likely decodings in memory, pruning as we process more of the sequence, only deciding the very best sequence at the end. This is known as a beam search. `model.generate` uses a beam search to get the best transcription." ] }, { "cell_type": "code", "execution_count": null, "id": "63414027-c1f7-4925-ad45-e0dbd45ac25d", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Check if CUDA is available and set the device accordingly\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Move the model to the appropriate device\n", "model.to(device)\n", "\n", "# move it to the same device\n", "pixel_values = pixel_values.to(device)\n", "\n", "# Generate the sequence using the model\n", "best_sequence = model.generate(pixel_values)\n", "\n", "# Decode the generated sequence\n", "decoded_sequence = processor.tokenizer.decode(best_sequence[0])\n", "print(decoded_sequence)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c6562a0a-5156-49b1-90b2-bdf5eba06ed7", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Decoder\")" ] }, { "cell_type": "markdown", "id": "999f7253-d3d1-4795-b735-53190ce807c8", "metadata": { "execution": {} }, "source": [ "## Interactive exploration 3.2: What the model pays attention to" ] }, { "cell_type": "markdown", "id": "7f666c65-047f-439c-9e5c-0c82b39e21b9", "metadata": { "execution": {} }, "source": [ "We've just seen that these are two relatively large-scale transformers that are wired in the conventional encoder-decoder architecture. The transformers themselves are generic and have relatively weak built-in inductive biases. Has the model learned to process the sequence reasonably?\n", "\n", "One tool at our disposal to address this question is to look at the **attention pattern** in the decoding heads as we process an image. By looking at the output of the cross-attention heads, we can gain an intuitive understanding of what the model \"looks at.\" \n", "\n", "Let's look at how the model's attention evolves as we process more and more of the sequence." ] }, { "cell_type": "code", "execution_count": null, "id": "f6707747-a439-4e69-9755-cd8d79680bdb", "metadata": { "execution": {} }, "outputs": [], "source": [ "decoded = model.decoder.forward(\n", " input_ids=best_sequence,\n", " encoder_hidden_states=encoded_image['last_hidden_state'],\n", " output_attentions=True\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "8a5d8225-b572-45ca-887e-c10e0eed4885", "metadata": { "execution": {} }, "outputs": [], "source": [ "import ipywidgets as widgets\n", "from ipywidgets import interact\n", "import matplotlib.pyplot as plt\n", "\n", "def visualize_attention(layer, head):\n", " plt.figure(figsize=(10, 10))\n", "\n", " image = images[0]\n", " for token in range(decoded.cross_attentions[layer].shape[2]):\n", " attention_pattern = decoded.cross_attentions[layer][0, head, token, 1:].reshape((24, 24))\n", " attention_pattern = attention_pattern.detach().cpu().numpy()\n", "\n", " print(processor.decode(best_sequence[0][:token+1]))\n", " plt.imshow((np.array(image).mean(axis=2)).astype(float), cmap='gray')\n", " plt.imshow(attention_pattern, extent=[0, image.width, 0, image.height], alpha=attention_pattern/attention_pattern.max(), cmap='YlOrRd')\n", " plt.axis('off')\n", " plt.gca().invert_yaxis()\n", " plt.show()\n", "\n", "\n", "# Create interactive widgets\n", "layer_slider = widgets.IntSlider(min=0, max=len(decoded.cross_attentions)-1, step=1, value=7, description='Layer')\n", "head_slider = widgets.IntSlider(min=0, max=decoded.cross_attentions[0].shape[1]-1, step=1, value=5, description='Head')\n", "\n", "# Create the interactive visualization\n", "interact(visualize_attention, layer=layer_slider, head=head_slider)" ] }, { "cell_type": "markdown", "id": "8b8c6b3d-4866-4db3-b5e5-d5530fe4b5a8", "metadata": { "execution": {} }, "source": [ "You'll notice that attention heads in intermediate layers seem to track the likely location of the next word in the input image, left to right. It is remarkable that the model has learned an important aspect of the spatial structure of sentences written in the Latin alphabet, which is that words are written left-to-right and that words have a certain characteristic width. \n", "\n", "**Positional encoding** allows the model to express spatial biases; without it, the model would be position invariant. However, nothing in the model **predetermines** that text must be left to right: it learns that structure through data. The model follows the standard recipe of modern AI: take a model with weak inductive biases that scales well, and train it on large-scale data to distill structure." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "404339d9-efad-4c1e-b38c-f4e6eeee10e3", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Interactive_Exploration_3.2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 4: Weak Inductive Biases\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a24b3023-2051-4d61-8233-b6d92ff6246b", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 4: Weak Inductive Biases\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', '4BTM5Mrb94Y'), ('Bilibili', 'BV1jw4m1v7g6')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "481ddc5b-1af3-4552-b07e-d691eb4ad138", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_weak_inductive_biases\")" ] }, { "cell_type": "markdown", "id": "fee81e82-ff71-4a2d-8451-77b59cd71966", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 4: The magic in the data\n", "\n", "It's straightforward to write down the encoder-decoder transformer used by TrOCR–it's conceptually quite similar to the original transformer as outlined in Vaswani et al. (2017). What makes the model tick (and potentially break) is a training pipeline that ensures good generalization. It's worth taking a look at the TrOCR paper to see the many different sources of data that are used to train the model:\n", "\n", "1. [The encoder is pretrained on masked image modeling on ImageNet-22k](https://huggingface.co/docs/transformers/en/model_doc/beit)\n", "2. [The decoder is pretrained on masked language modeling on 160GB of raw text](https://arxiv.org/abs/1907.11692)\n", "3. The entire model is trained end-to-end on 648M text lines found in 2M PDF pages on the internet, with the fonts randomly swapped\n", "4. The model is then fine-tuned end-to-end on the IAM handwriting dataset, with heavy augmentation\n", "\n", "Let's look at a few of these pieces in turn." ] }, { "cell_type": "markdown", "id": "b6da9d1a-f84a-4a6c-99ba-5a3d2dad2823", "metadata": { "execution": {} }, "source": [ "## Section 4.1: Transfer learning\n", "\n", "Modern neural networks are often pre-trained on large datasets. For example, the TrOCR model's decoder is pretrained on masked language modeling on 160GB of raw text. The frozen weights are used as initialization for the model, a form of transfer learning. The same principle applies to the encoder, which is pretrained on masked image modeling on ImageNet-22k. Although these tasks are quite different from the final task of handwriting recognition, the model learns useful features that can be transferred to the final task. \n", "\n", "The datasets involved in pre-training are often large: 160GB of raw text would take a human close to 1000 lifetimes to write! Yet, this is quite small by modern standards: [FineWeb](https://huggingface.co/spaces/HuggingFaceFW/blogpost-fineweb-v1?utm_source=ainews&utm_medium=email&utm_campaign=ainews-mamba-2-state-space-duality) is almost a thousand times larger. It's an incredible feat of engineering that we can build models that learn effective representations from such large scale data. The \"bitter lesson\" of AI is that models trained on more data tend to perform better, and generalize better. " ] }, { "cell_type": "markdown", "id": "1e6b0d3d", "metadata": { "execution": {} }, "source": [ "### Reflection\n", "\n", "What happens when we've trained on all the data we can find? What are other ways to improve generalization in conventional large-scale AI?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0d0c34f4-2dcf-43d7-aa2e-d318cf89dc8b", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_transfer_learning\")" ] }, { "cell_type": "markdown", "id": "84ce0081-636c-4f40-bc2f-fc8f2d22dee7", "metadata": { "execution": {} }, "source": [ "## Section 4.2: Generalization via augmentation\n", "\n", "Another important ingredient in this model is the use of multiple augmentations of the data. When data is sparse, this can improve generalization. Thus, we take an expressive model with few built-in inductive biases and, through demonstrations, let it learn the structure of the data, encouraging generalization.\n", "\n", "By applying various transformations to images and displaying the results, you can visually understand how augmentation works and its impact on model performance. Let's look at parts of the TrOCR recipe." ] }, { "cell_type": "markdown", "id": "87b4dcd8-f11a-447a-bf33-564f90239456", "metadata": { "execution": {} }, "source": [ "Let's start with loading and visualizing our chosen image." ] }, { "cell_type": "markdown", "id": "b327840a-93df-45c7-9435-ea9d4269b9eb", "metadata": { "execution": {} }, "source": [ "![Picture which shows neuroai_hello_world.](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/static/neuroai_hello_world.png?raw=true)" ] }, { "cell_type": "markdown", "id": "68adf7a5-9cf0-457d-b645-38a26cddec10", "metadata": { "execution": {} }, "source": [ "Now, we will apply a few transformations to this image. You can play around with the input values!" ] }, { "cell_type": "code", "execution_count": null, "id": "bed0a320-46f3-4f84-8c51-c8cd07cea776", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Convert PIL Image to Tensor\n", "image = IMG.open(\"neuroai_hello_world.png\")\n", "image = transforms.ToTensor()(image)\n", "\n", "# Define each transformation separately\n", "# RandomAffine: applies rotations, translations, scaling. Here, rotates by up to ±15 degrees,\n", "affine = transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1))\n", "\n", "# ElasticTransform: applies elastic distortions to the image. The 'alpha' parameter controls\n", "# the intensity of the distortion.\n", "elastic = transforms.ElasticTransform(alpha=25.0)\n", "\n", "# RandomPerspective: applies random perspective transformations with a specified distortion scale.\n", "perspective = transforms.RandomPerspective(distortion_scale=0.2, p=1.0)\n", "\n", "# RandomErasing: randomly erases a rectangle area in the image.\n", "erasing = transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random', inplace=False)\n", "\n", "# GaussianBlur: applies gaussian blur with specified kernel size and sigma range.\n", "gaussian_blur = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.8, 5))" ] }, { "cell_type": "markdown", "id": "6c5ee77d-988c-474a-bc76-582b7315e436", "metadata": { "execution": {} }, "source": [ "Let's now combine them in a single list and display the images." ] }, { "cell_type": "code", "execution_count": null, "id": "7e691d06-e1bc-45fc-b289-89827eaf0317", "metadata": { "execution": {} }, "outputs": [], "source": [ "# A list of all transformations for iteration\n", "transformations = [affine, elastic, perspective, erasing, gaussian_blur]\n", "\n", "# Display\n", "display_transformed_images(image, transformations)" ] }, { "cell_type": "markdown", "id": "ea3a5322-0b23-49be-9a2c-a2c1a023b873", "metadata": { "execution": {} }, "source": [ "The transformations applied to the model include:\n", "\n", "1. Original: the baseline image without any modifications.\n", "2. RandomAffine: applies random affine transformations to the image, which include translation, scaling, rotation, and shearing. This helps the model become invariant to such transformations in the input data.\n", "3. ElasticTransform: introduces random elastic deformations, simulating non-linear transformations that might occur naturally. It is useful for tasks where we expect such distortions, like medical image analysis.\n", "4. RandomPerspective: changes the perspective from which the image is viewed, simulating the effect of viewing the object from different angles.\n", "5. RandomErasing: randomly removes parts of the image and fills it with some arbitrary pixel values. It can make the model robust against occlusions in the input data.\n", "6. GaussianBlur: applies a Gaussian blur to the image, smoothing it. This can help the model be better with out-of-focus images.\n", "\n", "All of these augmentations, which are part of this model's training recipe, help prevent overfitting and improve the generalization of the model to new, unseen images. We can compose these to create new challenging training images:" ] }, { "cell_type": "code", "execution_count": null, "id": "89c000dc-2c1e-41e6-a988-ab75819b0b39", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Combine all the transformations\n", "all_transforms = transforms.Compose([\n", " affine,\n", " elastic,\n", " perspective,\n", " erasing,\n", " gaussian_blur\n", "])\n", "\n", "# Apply combined transformation\n", "augmented_image_tensor = all_transforms(image)\n", "\n", "display_original_and_transformed_images(image, augmented_image_tensor)" ] }, { "cell_type": "markdown", "id": "5f13dff2-5e74-46d3-be3d-deafe27a7383", "metadata": { "execution": {} }, "source": [ "All those transformation create a challenging curriculum that forces the model to generalize. Note that we're limited by our imagination in creating these augmentations. Some real world invariances, for example invariance to the style of handwritten characters, can be hard to simulate with this approach." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3b03601c-741d-4c59-82eb-3ee0690afa3c", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Augmentation\")" ] }, { "cell_type": "markdown", "id": "7bb05cb5-9d0c-427d-a246-f80c09bea0ec", "metadata": { "execution": {} }, "source": [ "## Section 4.3: Generalization via synthetic data\n", "\n", "When augmentation is not enough, we can further improve generalization by training on synthetic data. **Data augmentation** creates variations of existing data without changing its inherent properties, while **synthetic data** generation creates entirely new data that mimics the characteristics of real data.\n", "\n", "As it turns out, generating new text is tractable–text can be rendered in a wide range of cursive fonts to simulate real data. Here, we'll showcase this idea by defining strings and generating synthetic images." ] }, { "cell_type": "code", "execution_count": null, "id": "2Yee_i85LExa", "metadata": { "execution": {} }, "outputs": [], "source": [ "# Define strings\n", "strings = ['Hello world', 'This is the first tutorial', 'For Neuromatch NeuroAI']\n", "\n", "# Specify font path\n", "font_path = \"DancingScript-VariableFont_wght.ttf\" # Ensure this path is correct\n", "\n", "# Example usage\n", "strings = ['Hello world', 'This is the first tutorial', 'For Neuromatch NeuroAI']\n", "font_path = \"DancingScript-VariableFont_wght.ttf\" # Ensure this path is correct\n", "\n", "# Create a generator with the specified parameters\n", "generator = image_generator(strings, font_path, space_width=2, skewing_angle=3)\n", "\n", "i = 1\n", "for img in generator:\n", " plt.imshow(img, cmap='gray')\n", " plt.title(f\"Example {i}\")\n", " plt.axis('off')\n", " plt.show()\n", " i += 1" ] }, { "cell_type": "markdown", "id": "552e6f2f-672d-41dc-9020-16d9f12259ff", "metadata": { "execution": {} }, "source": [ "### Discussion point\n", "\n", "What does this type of synthetic data capture that wouldn’t be easy to capture through data augmentation?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7d5ccec7-3247-4055-9838-ec8f08b0325b", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Synthetic_Data\")" ] }, { "cell_type": "markdown", "id": "f1c7071b-4230-481f-9abd-b2967f53f279", "metadata": { "execution": {} }, "source": [ "### Interactive demo 4.1: Generating handwriting style data\n", "\n", "We can take this idea further and generate handwriting-style data. We will use an embedded `calligrapher.ai` model to generate new snippets of writing data. This generator is based off of a recurrent neural network trained on the same corpus of handwritten data as the TrOCR model, the IAM dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "29ca19de-44b0-492d-af5b-8fa78fc5e51e", "metadata": { "execution": {} }, "outputs": [], "source": [ "IFrame(\"https://www.calligrapher.ai/\", width=800, height=600)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a592d515-dad1-4db3-a435-e2183422606d", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Generate_Handwriting\")" ] }, { "cell_type": "markdown", "id": "c4c0c558-a0be-46f2-8e65-9babe22ff4f4", "metadata": { "execution": {} }, "source": [ "# Conclusion\n", "\n", "We train models to minimize a loss function. Oftentimes, however, what we care about is something different, like how well the model will generalize when it's deployed. Our intrepid developer got a rude awakening when comparing the OOD robustness of the model to its empirical loss on the train set: the character error rate was several times larger than expected. Motivated by other factors, like engineering complexity, our developer decided to move forward and deploy a handwriting transcription system, hoping it could be fine-tuned based on users' data later.\n", "\n", "There's a lot that goes into the training of robust AI models that generalize well. Generic high-capacity models with weak inductive biases, like transformers, are trained on large-scale data. Pretraining, augmentations, and synthetic data can all be part of the recipe for learning structure that might be hard to express mathematically, such as the fact that text is written left to right. Because large-scale models can often require a significant amount of computation to train, in practice, models that have been trained for other purposes are adapted and re-used, preventing the need to learn from scratch. These models embody what's known as [\"the bitter lesson\"](http://www.incompleteideas.net/IncIdeas/BitterLesson.html): general methods that leverage computation are ultimately the most effective, and by a large margin." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 5: Final Thoughts\n" ] }, { "cell_type": "code", "execution_count": null, "id": "691bfb73-f871-4651-9de6-3d5d0e85477f", "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 5: Final Thoughts\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'fxgIYvbU1Pg'), ('Bilibili', 'BV1ci421e7mq')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9e69e404-548a-4a58-8c0c-a9ded0eb2fc9", "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_final_thoughts\")" ] }, { "cell_type": "markdown", "id": "caf97875-e5a7-482d-8b08-92ca9fa5487e", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "* Artificial intelligence practitioners aim to maximize the performance of their systems under engineering constraints of size, weight, power, cost, latency, and maintenance.\n", "* In-distribution performance doesn't tell the whole story: out-of-distribution robustness can be measured to determine how well a model will perform when deployed in the real world.\n", "* Generic models with weak inductive biases, like transformers, can learn structure from large-scale data.\n", "* Several strategies can be used to build models that display better generalization, all of which hinge on distilling structure from ever larger amounts of data:\n", " * Transfer learning\n", " * Augmentations\n", " * Synthetic examples" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "gpuType": "T4", "include_colab_link": true, "name": "W1D1_Tutorial1", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }