{ "cells": [ { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 3: Statistical inference on representational geometries\n", "\n", "**Week 1, Day 3: Comparing Artificial And Biological Networks**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Veronica Bossio, Eivinas Butkus, Jasper van den Bosch\n", "\n", "__Content reviewers:__ Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Patrick Mineault, Hlib Solodzhuk\n", "\n", "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "___\n", "\n", "\n", "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 50 minutes*\n", "\n", "To evaluate alternative models of measured data, we need statistical inference that takes our uncertainty about relative model performance into account.\n", "In computational neuroscience, we want to statistically compare different models in terms of their ability to account for representations in brains.\n", "In AI, we can employ similar techniques to compare models to each other and understand their internal representations.\n", "\n", "By the end of this tutorial, you will be able to:\n", "\n", "1. Understand Representational Similarity Analysis (RSA), including its theoretical foundations, practical applications, and its significance in the context of machine learning and computational neuroscience.\n", "\n", "2. Extract neural network activations; understand the structure of neural networks, the role of activations in interpreting neural network decisions, and practical techniques for accessing these activations.\n", "\n", "3. Discuss frequentist model comparison: This part of the tutorial will cover the basics of frequentist model comparison methods. It will provide an overview of the principles underlying these methods and their applications.\n", "\n", "4. Identify sources of estimation error and the motivation for model-comparative frequentist inference. You will learn about the three main sources of estimation error in statistical inference—measurement noise, stimulus sampling, and subject sampling. Additionally, the tutorial will explore how these sources of error justify the use of model-comparative frequentist inference, particularly through the application of the 2-factor bootstrap method. This section will detail the impact of each source of error on statistical inference and demonstrate how the 2-factor bootstrap accounts for our uncertainty about model performance during model comparison." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/uwn2g/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/uwn2g/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip install numpy pandas torch torchvision matplotlib ipython Pillow plotly networkx requests vibecheck --quiet\n", "!pip install rsatoolbox==0.1.5 --quiet\n", "\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_neuroai\",\n", " \"user_key\": \"wb2cxze8\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W1D3_T3\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import dependencies\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Import dependencies\n", "\n", "# Standard library imports\n", "import pathlib\n", "import logging\n", "\n", "\n", "# Third-party library imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torchvision\n", "from IPython.display import IFrame, display\n", "from IPython.display import Image as IMG\n", "from PIL import Image\n", "import rsatoolbox as rsa\n", "\n", "from copy import deepcopy\n", "\n", "import ipywidgets as widgets\n", "import numpy as np\n", "import pandas as pd\n", "import plotly.graph_objects as go\n", "from IPython.display import clear_output, display\n", "from plotly import colors\n", "from plotly.subplots import make_subplots\n", "from rsatoolbox.inference import evaluate as eval\n", "from rsatoolbox.inference.bootstrap import (bootstrap_sample,\n", " bootstrap_sample_pattern,\n", " bootstrap_sample_rdm)\n", "from rsatoolbox.util.inference_util import all_tests, get_errorbars\n", "from rsatoolbox.util.rdm_utils import batch_to_vectors\n", "\n", "import os\n", "import requests\n", "import hashlib\n", "import zipfile\n", "from torchvision.models.feature_extraction import get_graph_node_names\n", "import plotly.graph_objects as go\n", "from plotly.subplots import make_subplots\n", "\n", "from torchvision.models.feature_extraction import create_feature_extractor\n", "\n", "import warnings\n", "from copy import deepcopy\n", "\n", "import matplotlib.pyplot as plt\n", "import networkx as nx\n", "import numpy as np\n", "import plotly.colors\n", "import plotly.graph_objects as go\n", "from matplotlib import cm, patches, transforms\n", "from matplotlib.path import Path\n", "from networkx.algorithms.clique import find_cliques as maximal_cliques\n", "from plotly.express import colors\n", "from plotly.subplots import make_subplots\n", "from rsatoolbox.util.inference_util import all_tests, get_errorbars\n", "from rsatoolbox.util.rdm_utils import batch_to_vectors\n", "from scipy.spatial.distance import squareform\n", "\n", "# Enabling automatic reloading of modules before executing user code\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plotting functions\n", "\n", "def traces_bar_and_scatter(eval_result, models, bar_color='blue'):\n", "\n", " evaluations = eval_result.evaluations.squeeze()\n", " subject_names = [f'Subject {i+1}' for i in range(evaluations.shape[1])]\n", " model_names = [model.name for model in models]\n", " df_evaluations = pd.DataFrame(data=evaluations, index=model_names, columns=subject_names)\n", " means = df_evaluations.mean(axis=1)\n", " sem = df_evaluations.sem(axis=1)\n", "\n", " bars_trace = go.Bar(\n", " x=model_names,\n", " y=means,\n", " showlegend=False,\n", " marker_color=bar_color\n", " )\n", "\n", " scatter_traces = []\n", " for subject in subject_names:\n", " if subject == \"Subject 1\":\n", " showlegend = True\n", " scatter_traces.append(go.Scatter(\n", " x=df_evaluations.index,\n", " y=df_evaluations[subject],\n", " mode='markers',\n", " marker=dict(size=5,\n", " color='white',\n", " line=dict(width=1)),\n", " showlegend=False\n", " ))\n", " blank_trace = go.Scatter(\n", " x=[None], # This ensures the trace doesn't actually plot data\n", " y=[None],\n", " mode='markers',\n", " marker=dict(size=5, color='white', line=dict(width=1)),\n", " name='Each dot represents
a subject'\n", " )\n", " return bars_trace, scatter_traces, blank_trace\n", "\n", "def plot_bars_and_scatter_from_trace(bars_trace, scatter_traces, blank_trace):\n", "\n", " fig = go.Figure()\n", " fig.add_trace(bars_trace)\n", " for trace in scatter_traces:\n", " fig.add_trace(trace)\n", " fig.add_trace(blank_trace)\n", " fig.update_layout(\n", " title=\"\",\n", " xaxis_title=\"Model\",\n", " yaxis_title=\"Cosine Similarity to Data RDMs\",\n", " legend_title=\"\",\n", " width=700,\n", " height=500,\n", " template=\"simple_white\"\n", " )\n", " return fig\n", "\n", "def convert_result_to_list_of_dicts(result):\n", " means = result.get_means()\n", " sems = result.get_sem()\n", " p_zero = result.test_zero()\n", " p_noise = result.test_noise()\n", " model_names = [model.name for model in result.models]\n", "\n", " results_list = []\n", " for i, model_name in enumerate(model_names):\n", " result_dict = {\n", " \"Model\": model_name,\n", " \"Eval±SEM\": f\"{means[i]:.3f} ± {sems[i]:.3f}\",\n", " \"p (against 0)\": \"< 0.001\" if p_zero[i] < 0.001 else f\"{p_zero[i]:.3f}\",\n", " \"p (against NC)\": \"< 0.001\" if p_noise[i] < 0.001 else f\"{p_noise[i]:.3f}\"\n", " }\n", " results_list.append(result_dict)\n", "\n", " return results_list\n", "\n", "def print_results_table(table_trace):\n", "\n", " fig = go.Figure()\n", " fig.add_trace(table_trace)\n", "\n", " return fig\n", "\n", "def get_trace_for_table(eval_result):\n", "\n", " results_list = convert_result_to_list_of_dicts(eval_result)\n", "\n", " table_trace = go.Table(\n", " header=dict(values=[\"Model\", \"Eval ± SEM\", \"p (against 0)\", \"p (against NC)\"]),\n", " cells=dict(\n", " values=[\n", " [result[\"Model\"] for result in results_list], # Correctly accesses each model name\n", " [result[\"Eval±SEM\"] for result in results_list], # Correctly accesses the combined Eval and SEM value\n", " [result[\"p (against 0)\"] for result in results_list], # Accesses p-value against 0\n", " [result[\"p (against NC)\"] for result in results_list] # Accesses p-value against noise ceiling\n", " ],\n", " font=dict(size=12), # Smaller font size for the cells\n", " height=27 # Smaller height for the cell rows\n", " )\n", " )\n", " return table_trace\n", "\n", "def get_trace_for_noise_ceiling(noise_ceiling):\n", "\n", " noise_lower = np.nanmean(noise_ceiling[0])\n", " noise_upper = np.nanmean(noise_ceiling[1])\n", " #model_names = [model.name for model in models]\n", "\n", " noise_rectangle = dict(\n", " # Rectangle reference to the axes\n", " type=\"rect\",\n", " xref=\"x domain\", # Use 'x domain' to span the whole x-axis\n", " yref=\"y\", # Use specific y-values for the height\n", " x0=0, # Starting at the first x-axis value\n", " y0=noise_lower, # Bottom of the rectangle\n", " x1=1, # Ending at the last x-axis value (in normalized domain coordinates)\n", " y1=noise_upper, # Top of the rectangle\n", " fillcolor=\"rgba(128, 128, 128, 0.4)\", # Light grey fill with some transparency\n", " line=dict(\n", " width=0,\n", " #color=\"rgba(128, 128, 128, 0.5)\",\n", " )\n", "\n", " )\n", " return noise_rectangle\n", "\n", "def plot_bars_and_scatter_with_table(eval_result, models, method, color='blue', table = True):\n", "\n", " if method == 'cosine':\n", " method_name = 'Cosine Similarity'\n", " elif method == 'corr':\n", " method_name = 'Correlation distance'\n", " else:\n", " method_name = 'Comparison method?'\n", "\n", " if table:\n", " cols = 2\n", " subplot_titles=[\"Model Evaluations\", \"Model Statistics\"]\n", " else:\n", " cols = 1\n", " subplot_titles=[\"Model Evaluations\"]\n", "\n", " fig = make_subplots(rows=1, cols=cols,\n", " #column_widths=[0.4, 0.6],\n", " subplot_titles=subplot_titles,\n", " #specs=[[{\"type\": \"bar\"}, {\"type\": \"table\"}]]\n", "\n", " )\n", "\n", " bars_trace, scatter_traces, blank_trace = traces_bar_and_scatter(eval_result, models, bar_color=color)\n", "\n", " fig.add_trace(bars_trace, row=1, col=1)\n", "\n", " for trace in scatter_traces:\n", " fig.add_trace(trace, row=1, col=1)\n", "\n", " if table:\n", " table_trace = get_trace_for_table(eval_result)\n", " fig.add_trace(table_trace, row=1, col=2)\n", "\n", " width = 600*cols\n", "\n", " fig.update_layout(\n", " yaxis_title=f\"RDM prediction accuracy
(across subject mean of {method_name})\",\n", " #legend_title=\"\",\n", " width=width,\n", " height=600,\n", " template=\"plotly_white\"\n", " )\n", "\n", " return fig\n", "\n", "def add_noise_ceiling_to_plot(fig, noise_ceiling):\n", "\n", " rectangle = get_trace_for_noise_ceiling(noise_ceiling)\n", " fig.add_shape(rectangle, row=1, col=1)\n", " return fig\n", "\n", "\n", "def bar_bootstrap_interactive(human_rdms, models_to_compare, method):\n", "\n", " color = 'orange'\n", "\n", " button = widgets.Button(\n", " description=\"New Bootstrap Sample\",\n", " layout=widgets.Layout(width='auto', height='auto') # Adjust width and height as needed\n", " )\n", "\n", " #button.style.button_color = 'lightblue' # Change the button color as you like\n", " button.style.font_weight = 'bold'\n", " button.layout.width = '300px' # Make the button wider\n", " button.layout.height = '48px' # Increase the height for a squarer appearance\n", " button.layout.margin = '0 0 0 0' # Adjust margins as needed\n", " button.layout.border_radius = '12px' # Rounded corners for the button\n", "\n", " output = widgets.Output(layout={'border': '1px solid black'})\n", "\n", " def generate_plot(bootstrap=False):\n", " if bootstrap:\n", " boot_rdms, idx = bootstrap_sample_rdm(human_rdms, rdm_descriptor='subject')\n", " result = eval.eval_fixed(models_to_compare, boot_rdms, method=method)\n", " else:\n", " result = eval.eval_fixed(models_to_compare, human_rdms, method=method)\n", "\n", " with output:\n", " clear_output(wait=True) # Make sure to clear previous output first\n", "\n", " fig = plot_bars_and_scatter_with_table(result, models_to_compare, method, color)\n", " fig.update_layout(height=600, width=1150,\n", " title=dict(text = f\"Performance of Model layers for a random bootstrap sample of subjects\",\n", " x=0.5, y=0.95,\n", " font=dict(size=20)))\n", " fig.show() # Display the figure within the `with` context\n", "\n", "\n", " def on_button_clicked(b):\n", " generate_plot(bootstrap=True)\n", "\n", " # Now, let's create a VBox to arrange the button above the output\n", " vbox_layout = widgets.Layout(\n", " display='flex',\n", " flex_flow='column',\n", " align_items='stretch',\n", " width='100%',\n", " )\n", "\n", "\n", " output = widgets.Output(layout={'border': '1px solid black'})\n", " button.on_click(lambda b: generate_plot(bootstrap=True)) # Generate plot on button click\n", " vbox = widgets.VBox([button, output], layout=vbox_layout)\n", "\n", " # Display everything\n", " display(button, output)\n", "\n", " generate_plot(bootstrap=False)\n", "\n", "def show_rdm_plotly(rdms, pattern_descriptor=None, cmap='Greys',\n", " rdm_descriptor=None, n_column=None, n_row=None,\n", " show_colorbar=False, gridlines=None, figsize=(None, None),\n", " vmin=None, vmax=None):\n", " # Determine the number of matrices\n", " mats = rdms.get_matrices()\n", " n_matrices = mats.shape[0]\n", "\n", "\n", " # Determine the number of subplots\n", " if n_row is None or n_column is None:\n", " # Calculate rows and columns to fit all matrices in a roughly square layout\n", " n_row = 1\n", " n_column = n_matrices\n", "\n", " # n_side = int(n_matrices ** 0.5)\n", " # n_row = n_side if n_side ** 2 >= n_matrices else n_side + 1\n", " # n_column = n_row if n_row * (n_row - 1) < n_matrices else n_row - 1\n", "\n", " subplot_size = 150\n", " fig_width = n_column * subplot_size\n", " fig_height = n_row * subplot_size\n", " subplot_titles = [f'{rdm_descriptor } {rdms.rdm_descriptors[rdm_descriptor][i]}' for i in range(n_matrices)] if rdm_descriptor else None\n", " # Create subplots\n", " fig = make_subplots(rows=n_row, cols=n_column,\n", " subplot_titles=subplot_titles,\n", " shared_xaxes=True, shared_yaxes=True,\n", " horizontal_spacing=0.02, vertical_spacing=0.1)\n", "\n", " # Iterate over RDMs and add them as heatmaps\n", " for index in range(n_matrices):\n", " row, col = divmod(index, n_column)\n", " fig.add_trace(\n", " go.Heatmap(z=mats[index],\n", " colorscale=cmap,\n", " showscale=show_colorbar,\n", " zmin=vmin, zmax=vmax),\n", " row=row+1, col=col+1\n", " )\n", "\n", " fig.update_layout(height=290, width=fig_width)\n", " fig.update_xaxes(showticklabels=False)\n", " fig.update_yaxes(showticklabels=False)\n", "\n", "\n", " #fig.show()\n", " return fig\n", "\n", "def show_rdm_plotly_interactive_bootstrap_patterns(rdms, pattern_descriptor=None, cmap='Greys',\n", " rdm_descriptor=None, n_column=None, n_row=None,\n", " show_colorbar=False, gridlines=None, figsize=(None, None),\n", " vmin=None, vmax=None):\n", "\n", "\n", " button = widgets.Button(\n", " description=\"New Bootstrap Sample\",\n", " layout=widgets.Layout(width='auto', height='auto') # Adjust width and height as needed\n", " )\n", "\n", " #button.style.button_color = 'lightblue' # Change the button color as you like\n", " button.style.font_weight = 'bold'\n", " button.layout.width = '300px' # Make the button wider\n", " button.layout.height = '48px' # Increase the height for a squarer appearance\n", " button.layout.margin = '0 0 0 0' # Adjust margins as needed\n", " button.layout.border_radius = '12px' # Rounded corners for the button\n", "\n", " #output = widgets.Output(layout={'border': '1px solid black'})\n", " output = widgets.Output()\n", "\n", " def generate_plot(bootstrap=False):\n", " if bootstrap:\n", " im_boot_rdms, pattern_idx = bootstrap_sample_pattern(rdms, pattern_descriptor='index')\n", " else:\n", " im_boot_rdms = rdms\n", "\n", " with output:\n", " clear_output(wait=True) # Make sure to clear previous output first\n", "\n", " fig = show_rdm_plotly(im_boot_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')\n", " fig.update_layout(title=dict(text = f\"Bootstrapped sample of patterns\",\n", " x=0.5, y=0.95,\n", " font=dict(size=20)))\n", " fig.show()\n", "\n", " def on_button_clicked(b):\n", " generate_plot(bootstrap=True)\n", "\n", " # Now, let's create a VBox to arrange the button above the output\n", " vbox_layout = widgets.Layout(\n", " display='flex',\n", " flex_flow='column',\n", " align_items='stretch',\n", " width='100%',\n", " )\n", "\n", " button.on_click(lambda b: generate_plot(bootstrap=True)) # Generate plot on button click\n", " vbox = widgets.VBox([button, output], layout=vbox_layout)\n", "\n", " display(vbox)\n", "\n", " generate_plot(bootstrap=False)\n", "\n", "def plot_model_comparison_trans(result, sort=False, colors=None,\n", " alpha=0.01, test_pair_comparisons=True,\n", " multiple_pair_testing='fdr',\n", " test_above_0=True,\n", " test_below_noise_ceil=True,\n", " error_bars='sem',\n", " test_type='t-test'):\n", "\n", "\n", " # Prepare and sort data\n", " evaluations = result.evaluations\n", " models = result.models\n", " noise_ceiling = result.noise_ceiling\n", " method = result.method\n", " model_var = result.model_var\n", " diff_var = result.diff_var\n", " noise_ceil_var = result.noise_ceil_var\n", " dof = result.dof\n", "\n", " while len(evaluations.shape) > 2:\n", " evaluations = np.nanmean(evaluations, axis=-1)\n", "\n", " evaluations = evaluations[~np.isnan(evaluations[:, 0])]\n", " n_bootstraps, n_models = evaluations.shape\n", " perf = np.mean(evaluations, axis=0)\n", "\n", " noise_ceiling = np.array(noise_ceiling)\n", " sort = 'unsorted'\n", " # run tests\n", " if any([test_pair_comparisons,\n", " test_above_0, test_below_noise_ceil]):\n", " p_pairwise, p_zero, p_noise = all_tests(\n", " evaluations, noise_ceiling, test_type,\n", " model_var=model_var, diff_var=diff_var,\n", " noise_ceil_var=noise_ceil_var, dof=dof)\n", "\n", " if error_bars:\n", " limits = get_errorbars(model_var, evaluations, dof, error_bars,\n", " test_type)\n", " if error_bars.lower() == 'sem':\n", " limits = limits[0,:]\n", "\n", " #return limits, perf\n", "\n", " fig = make_subplots(rows=2, cols=1,\n", " row_heights=[0.3, 0.7],\n", " vertical_spacing=0.05,\n", " subplot_titles=(\"Model Evaluations\", ''),\n", " shared_xaxes=True,\n", " )\n", "\n", " n_colors_needed = len(models)\n", " # Sample n_colors_needed colors from the Plasma color scale\n", " plasma_scale = plotly.colors.get_colorscale('Bluered') # Retrieve the color scale\n", " color_indices = np.linspace(0, 1, n_colors_needed) # Evenly spaced indices between 0 and 1\n", " sampled_colors = plotly.colors.sample_colorscale(plasma_scale, color_indices) # Sample colors\n", "\n", " for i, (perf_val, model) in enumerate(zip(perf, models)):\n", " name = model.name\n", " #bar_color = antique_colors[i % n_colors]\n", "\n", " fig.add_trace(\n", " go.Bar(\n", " x=[name], # x-axis position\n", " y=[perf_val], # Performance value\n", " error_y=dict(type='data',\n", " array=limits, visible=True, color='black'), # Adding error bars\n", " marker_color=sampled_colors[i], # Cycle through colors\n", " name=name\n", " ),\n", " row=2, col=1 # Assuming a single subplot for simplicity\n", " )\n", "\n", "\n", " fig.update_layout(width=600, height=700, showlegend=False, template='plotly_white')\n", " # return fig\n", "\n", "\n", " model_significant = p_zero < alpha / n_models\n", " significant_indices = [i for i, significant in enumerate(model_significant) if significant]\n", " symbols = {'dewdrops': 'circle', 'icicles': 'diamond-tall'}\n", "\n", " fig.add_trace(\n", " go.Scatter(\n", " x=[models[i].name for i in significant_indices], # X positions of significant models\n", " y=[0.0005] * len(significant_indices), # Y positions (at 0 for visualization)\n", " mode='markers',\n", " marker=dict(symbol=symbols['dewdrops'], # Example using 'triangle-up'\n", " size=9,\n", " color='white'), # Example using 'triangle-up'\n", " showlegend=False\n", " ),\n", " row=2, col=1\n", " )\n", "\n", " # Plot noise ceiling\n", " if noise_ceiling is not None:\n", "\n", " noise_lower = np.nanmean(noise_ceiling[0])\n", " noise_upper = np.nanmean(noise_ceiling[1])\n", " model_names = [model.name for model in models]\n", "\n", " fig.add_shape(\n", " # Rectangle reference to the axes\n", " type=\"rect\",\n", " xref=\"x domain\", # Use 'x domain' to span the whole x-axis\n", " yref=\"y\", # Use specific y-values for the height\n", " x0=0, # Starting at the first x-axis value\n", " y0=noise_lower, # Bottom of the rectangle\n", " x1=1, # Ending at the last x-axis value (in normalized domain coordinates)\n", " y1=noise_upper, # Top of the rectangle\n", " fillcolor=\"rgba(128, 128, 128, 0.5)\", # Light grey fill with some transparency\n", " line=dict(\n", " color='gray',\n", " ),\n", " opacity=0.5,\n", " layer=\"below\", # Ensure the shape is below the data points\n", " row=2, col=1 # Specify the subplot where the shape should be added\n", "\n", " )\n", "\n", " test_below_noise_ceil = 'dewdrops' # Example, can be True/'dewdrops'/'icicles'\n", " model_below_lower_bound = p_noise < (alpha / n_models)\n", "\n", " significant_indices_below = [i for i, below in enumerate(model_below_lower_bound) if below]\n", "\n", " # Choose the symbol based on the test_below_noise_ceil\n", " if test_below_noise_ceil is True or test_below_noise_ceil.lower() == 'dewdrops':\n", " symbol = 'circle-open' # Use open circle as a proxy for dewdrops\n", " elif test_below_noise_ceil.lower() == 'icicles':\n", " symbol = 'diamond-open' # Use open diamond as a proxy for icicles\n", " else:\n", " raise ValueError('Argument test_below_noise_ceil is incorrectly defined as ' + test_below_noise_ceil)\n", "\n", " symbol = 'triangle-down'\n", "# y_position_below = noise_lower + 0.0005 # Adjust based on your visualization needs\n", "\n", " #y_positions_below = [perf[i] for i in significant_indices_below] # Extracting perf values for significant models\n", " y_positions_below = [noise_lower-0.005] * len(significant_indices_below) # Adjust based on your visualization needs\n", " fig.add_trace(\n", " go.Scatter(\n", " x=[models[i].name for i in significant_indices_below], # X positions of significant models\n", " y= y_positions_below, #* len(significant_indices_below), # Y positions slightly above noise_lower\n", " mode='markers',\n", " marker=dict(symbol=symbol, size=7, color='gray'), # Customizing marker appearance\n", " showlegend=False\n", " ),\n", " row=2, col=1\n", " )\n", "\n", " # Pairwise model comparisons\n", " if test_pair_comparisons:\n", " if test_type == 'bootstrap':\n", " model_comp_descr = 'Model comparisons: two-tailed bootstrap, '\n", " elif test_type == 't-test':\n", " model_comp_descr = 'Model comparisons: two-tailed t-test, '\n", " elif test_type == 'ranksum':\n", " model_comp_descr = 'Model comparisons: two-tailed Wilcoxon-test, '\n", " n_tests = int((n_models ** 2 - n_models) / 2)\n", " if multiple_pair_testing is None:\n", " multiple_pair_testing = 'uncorrected'\n", " if multiple_pair_testing.lower() == 'bonferroni' or \\\n", " multiple_pair_testing.lower() == 'fwer':\n", " significant = p_pairwise < (alpha / n_tests)\n", " elif multiple_pair_testing.lower() == 'fdr':\n", " ps = batch_to_vectors(np.array([p_pairwise]))[0][0]\n", " ps = np.sort(ps)\n", " criterion = alpha * (np.arange(ps.shape[0]) + 1) / ps.shape[0]\n", " k_ok = ps < criterion\n", " if np.any(k_ok):\n", " k_max = np.max(np.where(ps < criterion)[0])\n", " crit = criterion[k_max]\n", " else:\n", " crit = 0\n", " significant = p_pairwise < crit\n", " else:\n", " if 'uncorrected' not in multiple_pair_testing.lower():\n", " raise ValueError(\n", " 'plot_model_comparison: Argument ' +\n", " 'multiple_pair_testing is incorrectly defined as ' +\n", " multiple_pair_testing + '.')\n", " significant = p_pairwise < alpha\n", " model_comp_descr = _get_model_comp_descr(\n", " test_type, n_models, multiple_pair_testing, alpha,\n", " n_bootstraps, result.cv_method, error_bars,\n", " test_above_0, test_below_noise_ceil)\n", "\n", "\n", " # new_fig_nili = plot_nili_bars_plotly(fig, significant, models, version=1)\n", " # new_fig_gol = plot_golan_wings_plotly(fig, significant, perf, models)\n", "\n", " new_fig_metro = plot_metroplot_plotly(fig, significant, perf, models, sampled_colors)\n", "\n", " return new_fig_metro\n", "\n", "def plot_golan_wings_plotly(original_fig, significant, perf, models):\n", " with plt.xkcd():\n", " # First, create a deep copy of the original figure to preserve its state\n", " fig = deepcopy(original_fig)\n", "\n", " n_models = len(models)\n", " model_names = [m.name for m in models]\n", " # Use the Plotly qualitative color palette\n", " colors = plotly.colors.qualitative.Plotly\n", "\n", " k = 1 # Vertical position tracker\n", " marker_size = 8 # Size of the markers\n", " for i in range(n_models):\n", "\n", " js = np.where(significant[i, :])[0] # Indices of models significantly different from model i\n", " if len(js) > 0:\n", " for j in js:\n", " # Ensure cycling through the color palette\n", " color = colors[i % len(colors)]\n", " fig.add_trace(go.Scatter(x=[model_names[i], model_names[j]],\n", " y=[k, k],\n", " mode='lines',\n", " line=dict(color=color, width=2)\n", " ),\n", " row=1, col=1)\n", " fig.add_trace(go.Scatter(x=[model_names[i]], y=[k],\n", " mode='markers',\n", " marker=dict(symbol='circle', color=color, size=10,\n", " line=dict(color=color, width=2))\n", " ),\n", " row=1, col=1)\n", "\n", " if perf[i] > perf[j]:\n", " # Draw downward feather\n", " fig.add_trace(go.Scatter(x=[model_names[j]],\n", " y=[k],\n", " mode='markers',\n", " marker=dict(symbol='triangle-right', color=color, size=marker_size,\n", " line=dict(color=color, width=2))\n", " ),\n", " row=1, col=1)\n", " elif perf[i] < perf[j]:\n", " # Draw upward feather\n", " fig.add_trace(go.Scatter(x=[model_names[i], model_names[j]],\n", " y=[k, k],\n", " mode='lines',\n", " line=dict(color=color, width=2)\n", " ),\n", " row=1, col=1)\n", " fig.add_trace(go.Scatter(x=[model_names[j]], y=[k],\n", " mode='markers',\n", " marker=dict(symbol='triangle-left', color=color, size=marker_size,\n", " line=dict(color=color, width=2))\n", " ),\n", " row=1, col=1)\n", " k += 1 # Increment vertical position after each model's wings are drawn\n", "\n", " # Update y-axis to fit the wings\n", " fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=1)\n", " fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=1)\n", "\n", " return fig\n", "\n", "\n", "def plot_metroplot_plotly(original_fig, significant, perf, models, sampled_colors):\n", " with plt.xkcd():\n", " # First, create a deep copy of the original figure to preserve its state\n", " fig = deepcopy(original_fig)\n", "\n", " n_models = len(models)\n", " model_names = [m.name for m in models]\n", " # Use the Plotly qualitative color palette\n", " colors = plotly.colors.qualitative.Antique\n", "\n", " k = 1 # Vertical position tracker\n", " marker_size = 8 # Size of the markers\n", " for i, (model, color) in enumerate(zip(model_names,sampled_colors)):\n", "\n", " js = np.where(significant[i, :])[0] # Indices of models significantly different from model i\n", " j_worse = np.where(perf[i] > perf)[0]\n", "\n", " worse_models = [model_names[j] for j in j_worse] # Model names that performed worse\n", " metropoints = worse_models + [model] # Model names to plot on the y-axis\n", " marker_colors = ['white' if point != model else color for point in metropoints] # Fill color for markers\n", "\n", " fig.add_trace(go.Scatter(\n", " y = np.repeat(model, len(metropoints)),\n", " x = metropoints,\n", " mode = 'lines+markers',\n", " marker = dict(\n", " color = marker_colors,\n", " symbol = 'circle',\n", " size = 10,\n", " line = dict(width=2, color=color)\n", " ),\n", " line=dict(width=2, color=color),\n", " showlegend = False),\n", " row = 1, col = 1,\n", "\n", " )\n", "\n", " # Update y-axis to fit the wings\n", " fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=1)\n", " fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=1)\n", "\n", " return fig\n", "\n", "def plot_nili_bars_plotly(original_fig, significant, models, version=1):\n", "\n", " with plt.xkcd():\n", "\n", " fig = deepcopy(original_fig)\n", "\n", " k = 1 # Vertical position tracker\n", " ns_col = 'rgba(128, 128, 128, 0.5)' # Non-significant comparison color\n", " w = 0.2 # Width for nonsignificant comparison tweaks\n", " model_names = [m.name for m in models]\n", "\n", " for i in range(significant.shape[0]):\n", " drawn1 = False\n", " for j in range(i + 1, significant.shape[0]):\n", " if version == 1 and significant[i, j]:\n", " # Draw a line for significant differences\n", " fig.add_shape(type=\"line\",\n", " x0=i, y0=k, x1=j, y1=k,\n", " line=dict(color=\"black\", width=2),\n", " xref=\"x1\", yref=\"y1\",\n", " row=1, col=1)\n", " k += 1\n", " drawn1 = True\n", " elif version == 2 and not significant[i, j]:\n", " # Draw a line for non-significant differences\n", " fig.add_shape(type=\"line\",\n", " x0=i, y0=k, x1=j, y1=k,\n", " line=dict(color=ns_col, width=2),\n", " xref=\"x1\", yref=\"y1\",\n", " row=1, col=1)\n", " # Additional visual tweaks for non-significant comparisons\n", " fig.add_annotation(x=(i+j)/2, y=k, text=\"n.s.\",\n", " showarrow=False,\n", " font=dict(size=8, color=ns_col),\n", " xref=\"x1\", yref=\"y1\",\n", " row=1, col=1)\n", " k += 1\n", " drawn1 = True\n", "\n", " if drawn1:\n", " k += 1 # Increase vertical position after each row of comparisons\n", "\n", " fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=1)\n", " fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=1)\n", "\n", " fig.update_layout(height=700) # Adjust as necessary\n", " return fig\n", "\n", "\n", "def _get_model_comp_descr(test_type, n_models, multiple_pair_testing, alpha,\n", " n_bootstraps, cv_method, error_bars,\n", " test_above_0, test_below_noise_ceil):\n", " \"\"\"constructs the statistics description from the parts\n", "\n", " Args:\n", " test_type : String\n", " n_models : integer\n", " multiple_pair_testing : String\n", " alpha : float\n", " n_bootstraps : integer\n", " cv_method : String\n", " error_bars : String\n", " test_above_0 : Bool\n", " test_below_noise_ceil : Bool\n", "\n", " Returns:\n", " model\n", "\n", " \"\"\"\n", " if test_type == 'bootstrap':\n", " model_comp_descr = 'Model comparisons: two-tailed bootstrap, '\n", " elif test_type == 't-test':\n", " model_comp_descr = 'Model comparisons: two-tailed t-test, '\n", " elif test_type == 'ranksum':\n", " model_comp_descr = 'Model comparisons: two-tailed Wilcoxon-test, '\n", " n_tests = int((n_models ** 2 - n_models) / 2)\n", " if multiple_pair_testing is None:\n", " multiple_pair_testing = 'uncorrected'\n", " if multiple_pair_testing.lower() == 'bonferroni' or \\\n", " multiple_pair_testing.lower() == 'fwer':\n", " model_comp_descr = (model_comp_descr\n", " + 'p < {:<.5g}'.format(alpha)\n", " + ', Bonferroni-corrected for '\n", " + str(n_tests)\n", " + ' model-pair comparisons')\n", " elif multiple_pair_testing.lower() == 'fdr':\n", " model_comp_descr = (model_comp_descr +\n", " 'FDR q < {:<.5g}'.format(alpha) +\n", " ' (' + str(n_tests) +\n", " ' model-pair comparisons)')\n", " else:\n", " if 'uncorrected' not in multiple_pair_testing.lower():\n", " raise ValueError(\n", " 'plot_model_comparison: Argument ' +\n", " 'multiple_pair_testing is incorrectly defined as ' +\n", " multiple_pair_testing + '.')\n", " model_comp_descr = (model_comp_descr +\n", " 'p < {:<.5g}'.format(alpha) +\n", " ', uncorrected (' + str(n_tests) +\n", " ' model-pair comparisons)')\n", " if cv_method in ['bootstrap_rdm', 'bootstrap_pattern',\n", " 'bootstrap_crossval']:\n", " model_comp_descr = model_comp_descr + \\\n", " '\\nInference by bootstrap resampling ' + \\\n", " '({:<,.0f}'.format(n_bootstraps) + ' bootstrap samples) of '\n", " if cv_method == 'bootstrap_rdm':\n", " model_comp_descr = model_comp_descr + 'subjects. '\n", " elif cv_method == 'bootstrap_pattern':\n", " model_comp_descr = model_comp_descr + 'experimental conditions. '\n", " elif cv_method in ['bootstrap', 'bootstrap_crossval']:\n", " model_comp_descr = model_comp_descr + \\\n", " 'subjects and experimental conditions. '\n", " if error_bars[0:2].lower() == 'ci':\n", " model_comp_descr = model_comp_descr + 'Error bars indicate the'\n", " if len(error_bars) == 2:\n", " CI_percent = 95.0\n", " else:\n", " CI_percent = float(error_bars[2:])\n", " model_comp_descr = (model_comp_descr + ' ' +\n", " str(CI_percent) + '% confidence interval.')\n", " elif error_bars.lower() == 'sem':\n", " model_comp_descr = (\n", " model_comp_descr +\n", " 'Error bars indicate the standard error of the mean.')\n", " elif error_bars.lower() == 'sem':\n", " model_comp_descr = (model_comp_descr +\n", " 'Dots represent the individual model evaluations.')\n", " if test_above_0 or test_below_noise_ceil:\n", " model_comp_descr = (\n", " model_comp_descr +\n", " '\\nOne-sided comparisons of each model performance ')\n", " if test_above_0:\n", " model_comp_descr = model_comp_descr + 'against 0 '\n", " if test_above_0 and test_below_noise_ceil:\n", " model_comp_descr = model_comp_descr + 'and '\n", " if test_below_noise_ceil:\n", " model_comp_descr = (\n", " model_comp_descr +\n", " 'against the lower-bound estimate of the noise ceiling ')\n", " if test_above_0 or test_below_noise_ceil:\n", " model_comp_descr = (model_comp_descr +\n", " 'are Bonferroni-corrected for ' +\n", " str(n_models) + ' models.')\n", " return model_comp_descr" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data retrieval\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Data retrieval\n", "\n", "def download_file(fname, url, expected_md5):\n", " \"\"\"\n", " Downloads a file from the given URL and saves it locally.\n", " \"\"\"\n", " if not os.path.isfile(fname):\n", " try:\n", " r = requests.get(url)\n", " except requests.ConnectionError:\n", " print(\"!!! Failed to download data !!!\")\n", " return\n", " if r.status_code != requests.codes.ok:\n", " print(\"!!! Failed to download data !!!\")\n", " return\n", " if hashlib.md5(r.content).hexdigest() != expected_md5:\n", " print(\"!!! Data download appears corrupted !!!\")\n", " return\n", " with open(fname, \"wb\") as fid:\n", " fid.write(r.content)\n", "\n", "def extract_zip(zip_fname):\n", " \"\"\"\n", " Extracts a ZIP file to the current directory.\n", " \"\"\"\n", " with zipfile.ZipFile(zip_fname, 'r') as zip_ref:\n", " zip_ref.extractall(\".\")\n", "\n", "# Details for the zip files to be downloaded and extracted\n", "zip_files = [\n", " {\n", " \"fname\": \"fmri_patterns.zip\",\n", " \"url\": \"https://osf.io/7jc3n/download\",\n", " \"expected_md5\": \"c21395575573c62129dc7e9d806f0b5e\"\n", " },\n", " {\n", " \"fname\": \"images.zip\",\n", " \"url\": \"https://osf.io/zse8u/download\",\n", " \"expected_md5\": \"ecb0d1a487e90be908ac24c2b0b10fc3\"\n", " }\n", "]\n", "\n", "# New addition for other files to be downloaded, specifically non-zip files\n", "image_files = [\n", " {\n", " \"fname\": \"NSD.png\",\n", " \"url\": \"https://osf.io/69tj8/download\",\n", " \"expected_md5\": \"a5ff07eb016d837da2624d8e511193ca\"\n", " }\n", "]\n", "\n", "# Process zip files: download and extract\n", "for zip_file in zip_files:\n", " download_file(zip_file[\"fname\"], zip_file[\"url\"], zip_file[\"expected_md5\"])\n", " extract_zip(zip_file[\"fname\"])\n", "\n", "# Process image files: download only\n", "for image_file in image_files:\n", " download_file(image_file[\"fname\"], image_file[\"url\"], image_file[\"expected_md5\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Tutorial Introduction\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Tutorial Introduction\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', '54yjuJ0kd9U'), ('Bilibili', 'BV1L6421Z7hR')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_tutorial_introduction\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 1: Data Acquisition" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In this section, we are going to download and explore the data used in the tutorial. \n", "\n", "We will load from the [Natural Scene Dataset](https://naturalscenesdataset.org/). NSD is a large 7T fMRI dataset of 8 adults viewing more than 73,000 photos of natural scenes. We have taken a small subset of 90 images from NSD and have pre-extracted the fMRI data for V1 and Fusiform Face Area (FFA) from 8 subjects. Both of these areas are part of the visual cortex; V1 is known to respond to low-level visual features, while the FFA is famously responsive to high-level features, in particular faces." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define constants\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Define constants\n", "SUBJECTS = list(range(1, 9)) # There are 8 subjects\n", "ROIS = [\"V1\", \"FFA\"] # Regions of interest in fMRI data\n", "IMAGES_DIR = pathlib.Path('images')\n", "FMRI_PATTERNS_DIR = pathlib.Path('fmri_patterns')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Show image\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Show image\n", "\n", "display(IMG(filename=\"NSD.png\"))" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Loading the images" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "First, let's load the 90 image files with the Pillow Image class." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the images and get image size\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Load the images and get image size\n", "\n", "image_paths = sorted(IMAGES_DIR.glob(\"*.png\")) # Find all pngs file paths in the image directory\n", "images = [Image.open(p).convert('RGB') for p in image_paths] # Load them as Image objects\n", "np.array(images[0]).shape # Dimensions of the image array: width x height x channels (RGB)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, let's take a look at these images.\n", "Notice that the first 45 images we selected have no faces, while the other 45 do have faces in them!\n", "So, we should expect to see a 2x2 block pattern in the Fusiform Face Area (FFA) representational dissimilarity matrices (RDMs)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize images\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize images\n", "\n", "with plt.xkcd():\n", " fig, ax = plt.subplots(9, 10, figsize=(10, 10))\n", "\n", " for i, img in enumerate(images):\n", " ax[i//10, i%10].imshow(img)\n", " ax[i//10, i%10].axis('off')\n", " ax[i//10, i%10].text(0, 0, str(i+1), color='black', fontsize=12)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Loading fMRI patterns from the NSD datset" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's now load the fMRI patterns from the NSD dataset for these 90 images.\n", "We have pre-extracted the patterns, so we just need to load Numpy arrays from the `.npy` files." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading fMRI data\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Loading fMRI data\n", "fmri_patterns = {}\n", "\n", "for subject in SUBJECTS:\n", " fmri_patterns[subject] = {}\n", "\n", " for roi in ROIS:\n", " fmri_patterns[subject][roi] = {}\n", "\n", " full_data = np.load(FMRI_PATTERNS_DIR / f\"subj{subject}_{roi}.npy\")\n", " fmri_patterns[subject][roi] = full_data\n", "\n", "# This is how we can index into subject 5 FFA patterns for all the images\n", "fmri_patterns[5][\"V1\"].shape # Number of images x number of voxels" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's now take a look at the pattern of responses for two non-face images and two face images." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def plot_fmri_pattern(subject, roi, image_idx, ax):\n", " pattern = fmri_patterns[subject][roi][image_idx]\n", " ax.plot(pattern)\n", " ax.set_title(f\"Subject {subject}, ROI {roi}, Image {image_idx}\")\n", " ax.set_xlabel(\"Voxel #\")\n", " ax.set_ylabel(\"Activation\")\n", " ax.set_xlim([200, 400])\n", " ax.set_ylim([-3, 3])\n", "\n", "\n", "plt.figure(figsize=(8, 4))\n", "ax = plt.gca()\n", "subject = 1\n", "roi = \"FFA\"\n", "\n", "# non-face images\n", "plot_fmri_pattern(subject, roi, 1, ax)\n", "plot_fmri_pattern(subject, roi, 3, ax)\n", "\n", "# face images\n", "plot_fmri_pattern(subject, roi, 57, ax)\n", "plot_fmri_pattern(subject, roi, 75, ax)\n", "\n", "plt.legend(['non-face 1', 'non-face 2', 'face 1', 'face 2'])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "The activity is complex, but we clearly see several voxels (e.g., voxel 275) that have higher activation for faces than for non-faces. This is as expected for the face-selective FFA." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_data_acquisition\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Get artificial neural network activations\n", "\n", "Estimated timing to here from start of tutorial: 15 minutes\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now that we have fMRI patterns, we want to explain this data using computational models.\n", "\n", "In this tutorial, we will take our models to be **layers of AlexNet**." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "![](https://d2l.ai/_images/alexnet.svg)\n", "\n", "*Comparing LeNet architecture to AlexNet. Image from [Dive Into Deep Learning book](https://d2l.ai/).*" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We load a version of AlexNet that is already pre-trained on ImageNet. This step may take a minute; feel free to read ahead." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load AlexNet model pretrained on ImageNet\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Load AlexNet model pretrained on ImageNet\n", "alexnet = torchvision.models.alexnet(weights=\"IMAGENET1K_V1\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "To pass images through the model, we need to _preprocess_ them to be in the same format as the images shown to the model during training.\n", "\n", "For AlexNet, this includes resizing the images to 224x224 and normalizing their color channels to particular values. We also need to turn them into PyTorch tensors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preprocess NSD images as input to AlexNet\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Preprocess NSD images as input to AlexNet\n", "\n", "# We need to use the exact same preprocessing as was used to train AlexNet\n", "transform = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize((224,224)), # Resize the images to 224x24 pixels\n", " torchvision.transforms.ToTensor(), # Convert the images to a PyTorch tensor\n", " torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalize the image color channels\n", "])\n", "\n", "images_tensor = torch.stack([transform(img) for img in images])\n", "print(images_tensor.shape) # (number of images, channels, height, width)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's inspect AlexNet architecture to select some of the layers as our models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspect architecture\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Inspect architecture\n", "print(\"Architecture of AlexNet:\")\n", "print(alexnet)\n", "\n", "node_names = get_graph_node_names(alexnet) # this returns a tuple with layer names for the forward pass and the backward pass\n", "print(\"\\nGraph node names (layers) in the forward pass:\")\n", "print(node_names[0]) # forward pass layer names" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We extract activations from different layers of AlexNet processing the same images that were presented to people during the NSD task." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make hooks in AlexNet to extract activations from different layers\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Make hooks in AlexNet to extract activations from different layers\n", "return_nodes = {\n", " \"features.2\": \"conv1\",\n", " \"features.5\": \"conv2\",\n", " \"features.7\": \"conv3\",\n", " \"features.9\": \"conv4\",\n", " \"features.12\": \"conv5\",\n", " \"classifier.1\": \"fc6\",\n", " \"classifier.4\": \"fc7\",\n", " \"classifier.6\": \"fc8\"\n", "}\n", "feature_extractor = create_feature_extractor(alexnet, return_nodes=return_nodes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extract activations from AlexNet\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Extract activations from AlexNet\n", "alexnet_activations = feature_extractor(images_tensor)\n", "\n", "# Convert to numpy arrays\n", "for layer, activations in alexnet_activations.items():\n", "\n", " act = activations.detach().numpy().reshape(len(images), -1)\n", " alexnet_activations[layer] = act # Keep original data under 'all'\n", "\n", "alexnet_activations['conv1'].shape # number of images x number of neurons in conv1 layer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_ann_activations\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 3: Create representational dissimilarity matrices (RDMs)\n", "\n", "Estimated timing to here from start of tutorial: 20 minutes\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now that we have fMRI patterns and AlexNet activations, the first step in representation similarity analysis (RSA) is to compute the representational dissimilarity matrices (RDMs). RSA characterizes the representational geometry of the brain region of interest (ROI) by estimating the representational distance for each pair of experimental conditions (e.g., different images).\n", "\n", "RDMs represent how dissimilar neural activity patterns or model activations are for each stimulus. In our case, these will be 90x90 image-by-image matrices representing how dissimilar fMRI patterns or AlexNet layer activations are for each image.\n", "\n", "For instance, we expect that in FFA, there will be a large distance between the 45 face and 45 non-face images: we expect to see a 2x2 block pattern inside the RDM.\n", "\n", "## Creating RSA toolbox datasets\n", "\n", "First, let's wrap our neural and model data in `Dataset` objects to use the RSA toolbox." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create RSA datasets for each subject and ROI\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Create RSA datasets for each subject and ROI\n", "fmri_datasets = {}\n", "\n", "for subject in SUBJECTS:\n", " fmri_datasets[subject] = {}\n", "\n", " for roi in ROIS:\n", " fmri_datasets[subject][roi] = {}\n", "\n", " # for stimset in ['D1', 'D2', 'all']\n", " measurements = fmri_patterns[subject][roi]\n", " fmri_datasets[subject][roi] = rsa.data.Dataset(measurements=measurements,\n", " descriptors = {'subject': subject, 'roi': roi},\n", " obs_descriptors = {'image': np.arange(measurements.shape[0])},\n", " ## this assumes that the patterns are all in the same order? - jasper\n", " channel_descriptors = {'voxel': np.arange(measurements.shape[1])})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create RSA datasets for AlexNet activations\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Create RSA datasets for AlexNet activations\n", "alexnet_datasets = {}\n", "\n", "for layer, activations in alexnet_activations.items():\n", " alexnet_datasets[layer] = {}\n", "\n", " # For stimset in ['D1', 'D2', 'all', 'random']:\n", " measurements = activations\n", " alexnet_datasets[layer] = rsa.data.Dataset(measurements=measurements,\n", " descriptors={'layer': layer},\n", " obs_descriptors={'image': np.arange(measurements.shape[0])},\n", " channel_descriptors={'channel': np.arange(measurements.shape[1])})" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Computing the RDMs\n", "\n", "Let's compute RDMs for fMRI patterns and AlexNet activations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute rdms for each subject and ROI\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Compute rdms for each subject and ROI\n", "fmri_rdms = {}\n", "fmri_rdms_list = []\n", "\n", "for subject in SUBJECTS:\n", " fmri_rdms[subject] = {}\n", "\n", " for roi in ROIS:\n", " fmri_rdms[subject][roi] = {}\n", "\n", " # For stimset in ['D1', 'D2']:\n", " fmri_rdms[subject][roi] = rsa.rdm.calc_rdm(fmri_datasets[subject][roi])\n", " fmri_rdms_list.append(fmri_rdms[subject][roi])" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 1: RDMs of AlexNet\n", "\n", "Use the RSA toolbox to compute the RDMs for the layers of AlexNet. It should be done in the very same way as RDMs for fMRI patterns above." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "#################################################\n", "## TODO for students: fill in the missing variables ##\n", "# Fill out function and remove\n", "raise NotImplementedError(\"Student exercise: fill in the missing variables\")\n", "#################################################\n", "\n", "# Compute rdms for each layer of AlexNet\n", "alexnet_rdms_dict = {}\n", "for layer, dataset in alexnet_datasets.items():\n", " alexnet_rdms_dict[layer] = ..." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/solutions/W1D3_Tutorial3_Solution_5ab03b03.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_rdms_of_alexnet\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Visualizing human RDMs" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Here we use methods on the `rsatoolbox` RDM object to select a subset of the RDMs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "fmri_rdms = rsa.rdm.concat(fmri_rdms_list)\n", "ffa_rdms = fmri_rdms.subset('roi', 'FFA')\n", "show_rdm_plotly(ffa_rdms, rdm_descriptor='subject')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "As predicted above, we can see a 2x2 block-like pattern in the FFA fMRI pattern RDMs.\n", "\n", "This is because we have 45 non-face images followed by 45 face images.\n", "\n", "The lighter regions indicate larger representational distances." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# the same RDMs, using a different visualization method\n", "fmri_rdms = rsa.rdm.concat(fmri_rdms_list)\n", "fig = rsa.vis.rdm_plot.show_rdm(ffa_rdms, rdm_descriptor='subject')[0]" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 2: Human RDMs\n", "\n", "Visualize the RDMs for the fMRI patterns from the V1 region." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "fmri_rdms = rsa.rdm.concat(fmri_rdms_list)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "#################################################\n", "## TODO for students: fill in the missing variables ##\n", "# Fill out function and remove\n", "raise NotImplementedError(\"Student exercise: fill in the missing variables\")\n", "#################################################\n", "v1_rdms = fmri_rdms.subset('roi', ...)\n", "show_rdm_plotly(v1_rdms, rdm_descriptor='subject')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/neuromatch/NeuroAI_Course/tree/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/solutions/W1D3_Tutorial3_Solution_147d3932.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_human_rdms\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Visualizing AlexNet RDMs" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's look at RDMs for different layers of AlexNet." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "alexnet_rdms = rsa.rdm.concat(alexnet_rdms_dict.values())\n", "fig = rsa.vis.rdm_plot.show_rdm(alexnet_rdms, rdm_descriptor='layer')[0]" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Each RDM contains the dissimilarities for each pair of activation patterns extracted by the different layers.\n", "\n", "We see a similar pattern emerge, clustering face and non-face images in fully connected `fc6`, `fc7`, and `fc8` layers.\n", "\n", "AlexNet seems to represent faces differently than non-faces, at least to some extent." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# same RDMs, different visualization method\n", "show_rdm_plotly(alexnet_rdms, rdm_descriptor='layer')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_create_rdms\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 4: RSA. Model comparison and statistical inference\n", "\n", "Estimated timing to here from start of tutorial: 35 minutes\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In the second step of RSA, each model is evaluated by the accuracy of its prediction of the data RDM. To this end, we will use the RDMs we computed for each model representation.\n", "\n", "Each model’s prediction of the data RDM is evaluated using an RDM comparator. In this case, we will use the correlation coefficient." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "First, let's look at the performance of different Alexnet layers across all subjects." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get the Model objects to use the rsa toolbox for model comparisons\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Get the Model objects to use the rsa toolbox for model comparisons\n", "for layer, rdm in alexnet_rdms_dict.items():\n", " if layer == \"conv1\":\n", " models = [rsa.model.ModelFixed(rdm=rdm, name=layer)]\n", " else:\n", " models.append(rsa.model.ModelFixed(rdm=rdm, name=layer))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize AlexNet performance\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize AlexNet performance\n", "\n", "roi = 'FFA'\n", "human_rdms = fmri_rdms.subset('roi', roi)\n", "models_to_compare = models\n", "\n", "method = 'corr'\n", "result = rsa.inference.evaluate.eval_fixed(models_to_compare, human_rdms, method=method) # get the performance of the models compared to the fMRI data of the first 3 subjects for the FFA ROI\n", "\n", "fig = plot_bars_and_scatter_with_table(result, models_to_compare, method, table = False)\n", "fig.update_layout(title=dict(text = f\"Performance of AlexNet layers on stimuli
in {roi} ROI for original set of subjects\",\n", " x=0.5, y=0.95,\n", " font=dict(size=15)))\n", "add_noise_ceiling_to_plot(fig, result.noise_ceiling)\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In the plot, each data point represents the representational dissimilarity matrix (RDM) for an individual subject. The Y-axis indicates the correlation distance of the data RDMs to the RDM obtained from representations from each model layer. Each bar shows the mean correlation distance between data RDMs (across subjects) and one model layer. The observed variability reflects the extent to which our models (layers) accurately predict neural activity patterns across different individuals.\n", "\n", "Our goal is to determine how these results might generalize to a new cohort of subjects and new sets of stimuli (assuming the new subjects and stimuli are sampled from the same respective populations). Since we cannot practically rerun the experiment countless times with fresh subjects and stimuli, we turn to computational simulations.\n", "\n", "To achieve this, we will employ bootstrap resampling—a statistical technique that involves resampling our existing dataset with replacement to generate multiple simulated samples. This approach allows us to mimic the process of conducting the experiment anew with different subjects and/or different stimuli.\n", "\n", "First, we'll focus on generalization to new subjects. By bootstrapping the subject dataset for each simulated sample, we can compute the predictive accuracy of our models on the subjects' RDMs. After running many simulations, we will accumulate a distribution of mean accuracy estimates that simulates the distribution of mean accuracies we might have obtained if we had actually repeated the experiment many times. This distribution will enable us to perform statistical inferences about our models’ generalizability to new subjects. Later, we will address the problem of generalizing to new stimuli as well." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's simulate a *new* sample of subjects by bootstrap resampling using the RSA toolbox." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "boot_rdms, idx = rsa.inference.bootstrap_sample_rdm(human_rdms, rdm_descriptor='subject')" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now, we plot the RDMs of the bootstrapped sample.\n", "\n", "Each RDM is a subject (note that some subjects might be repeated and some might be missing in the bootstrapped sample)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize RDMs\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize RDMs\n", "fig1 = show_rdm_plotly(fmri_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')\n", "fig1.update_layout(title=dict(text = f\"Original sample of subjects\",\n", " x=0.5, y=0.95,\n", " font=dict(size=20)))\n", "fig2 = show_rdm_plotly(boot_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')\n", "fig2.update_layout(title=dict(text = f\"Bootstrapped sample of subjects\",\n", " x=0.5, y=0.95,\n", " font=dict(size=20)))\n", "fig1.show()\n", "fig2.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "As before, each RDM contains the pairwise dissimilarities between activation patterns in the specified brain region (in this case, FFA) for each of the subjects. \n", "\n", "The first row shows the RDMs for each of the original 8 subjects. \n", "The second row shows 8 RDMs sampled from the set of 8 original RDMs with replacement. This is bootstrap resampling. It can be seen as a simulation of a \"new\" set of 8 subjects' RDMs. Notice that in this new set not all subjects are represented and some others might appear more than once.\n", "\n", "The idea is to see how the results for the model evaluations would change had our data come from this simulated sample of subjects." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's see the model performance on different bootstrap resampled subject sets." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize model evaluations\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Visualize model evaluations\n", "\n", "boot_rdms, idx = rsa.inference.bootstrap_sample_rdm(human_rdms, rdm_descriptor='subject')\n", "eval_result = rsa.inference.evaluate.eval_fixed(models_to_compare, boot_rdms, method=method)\n", "fig = plot_bars_and_scatter_with_table(eval_result, models, method, color='blue', table = False)\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "If you run the cell above again, you will see the model performance for a new bootstrap sample of subjects." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Discussion \n", "\n", "1. Explore the results for a few simulated new cohorts: run the cell above a few times and see how to points (individual RDMs) and bars (means) change. \n", "\n", "**Hint**: For different simulated sets of subjects, the mean correlation of their RDMs to the RDMs given by each layer may look quite different. For example, if a given subject's RDM correlates very closely with layer `conv4`, and that subject is overrepresented in the bootstrap sample, the layer `conv4` might yield a higher accuracy compared to the other layers. " ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Comparing representations statistically\n", "\n", "How can we know if these differences in model performance are statistically significant?\n", "\n", "That is what the third and final step of RSA does: we conduct inferential comparisons between models based on their accuracy in predicting the representational dissimilarity matrices (RDMs).\n", "\n", "We leverage the variability in the performance estimates observed in the bootstrapped samples to conduct statistical tests. These tests are designed to determine whether the differences in RDM prediction accuracy between models are statistically significant." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "plot_model_comparison_trans(result)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "The Y-axis, again, shows RDM prediction accuracy, as measured by the correlation between the data RDMs and the model RDMs. \n", "The error bars reflect the variability in this estimate across bootstrap samples over subjects). The gray bar indicates the noise ceiling, which is a measure of how well the best possible model (capturing the true data-generating process) could do, given the noise and intersubject variability in the data. The gray arrows indicate the models that don't perform significantly better than the noise-ceiling. The white dots at the bottom of the bars indicate the models whose correlation distance to the data RDMs is significantly better than zero. \n", "\n", "#### Details of the figure above\n", "\n", "Model comparisons are performed using a two-tailed t-test, FDR q < 0.01. Error bars indicate the standard error of the mean. One-sided comparisons of each model's performance against 0 and against the lower-bound estimate of the noise ceiling are Bonferroni-corrected for the number of models." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Generalization to new images" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "We have applied a method that enables us to infer how well the models might perform when predicting neural activity patterns for a new cohort of subjects. However, this approach has not yet considered the variability that would arise when replicating the experiment with a new sample of stimuli.\n", "\n", "To make statistical inferences expected to generalize to new stimuli, we will once again use a bootstrapping procedure, focusing this time on the stimuli rather than the subjects.\n", "\n", "To do this, we will first maintain the original cohort of subjects and apply bootstrapping to resample the stimulus set. That is, for each subject, we will sample a new RDM based on the RDM from that subject's original activation patterns. The RDM will contain the pairwise dissimilarities between the response patterns to a set of stimuli sampled with replacement from the original set of stimuli." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get the RDMs for a bootstrap sample of the images\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Get the RDMs for a bootstrap sample of the images\n", "im_boot_rdms, pattern_idx = rsa.inference.bootstrap_sample_pattern(human_rdms, pattern_descriptor='index')\n", "\n", "\n", "# plot RDMs\n", "fig = show_rdm_plotly(im_boot_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "As before, rerunning the cell above will show you the RDMs for a new set of bootstrap-resampled stimuli each time.\n", "\n", "Note that the block-like structure from before is generally preserved. This is because the stimuli are sorted so that the face stimuli are still adjacent to other face stimuli, and non-face stimuli are adjacent to other non-face stimuli. What changes are the specific stimuli used when computing the RDM. For each bootstrap resample, some stimuli might appear more than once, and some might be left out." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Let's see the inferential model comparisons based on 1000 bootstraps of the image set." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "result = rsa.inference.eval_bootstrap_pattern(models, human_rdms, theta=None, method='corr', N=1000,\n", " pattern_descriptor='index', rdm_descriptor='index',\n", " boot_noise_ceil=True)\n", "\n", "plot_model_comparison_trans(result)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "The Y-axis shows RDM prediction accuracy, as measured by the correlation distance between the data RDMs and the model RDMs. \n", "The error bars reflect the variability in this estimate across bootstrap samples (over stimuli). The gray bar indicates the noise ceiling, which is a measure of how well any model could do, given the noise in the data. The gray arrows indicate the models that don't perform significantly better than the noise-ceiling. The white dots at the bottom of the bars indicate the models whose correlation distance to the data RDMs is significantly better than zero. \n", "\n", "#### Details of the figure above\n", "\n", "Model comparisons: two-tailed t-test, FDR q < 0.01. Error bars indicate the standard error of the mean. One-sided comparisons of each model's performance against 0 and against the lower-bound estimate of the noise ceiling are Bonferroni-corrected for the number of models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_model_comparison_statistical_inference\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 5: Model Comparison Using Two-factor Bootstrap\n", "\n", "Estimated timing to here from start of tutorial: 45 minutes\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "For generalization across both the subject and stimulus populations, we can use a two-factor bootstrap method. For an in-depth discussion of this technique, refer to [Schütt et al., 2023](https://elifesciences.org/articles/82566).\n", "\n", "We can use the RSA toolbox to implement bootstrap resampling of subjects and stimuli simultaneously. It is important to note that a naive 2-factor bootstrap approach triple-counts the variance contributed by the measurement noise. For further understanding of this issue, see the explanation provided by Schütt et al. Fortunately, the RSA toolbox has an implementation that corrects this potential overestimation. \n", "\n", "Let's evaluate the performance of the models with simultaneous bootstrap resampling of the subjects and stimuli." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "eval_result = rsa.inference.eval_dual_bootstrap(models, fmri_rdms.subset('roi', 'FFA'), method='corr')\n", "print(eval_result)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "plot_model_comparison_trans(eval_result)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In the plot above, the statistical comparison results take into account the variability in the data that comes from considering different sets of subjects and sets of stimuli." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_model_comparison_two_factor_bootstrap\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "*Estimated timing of tutorial: 50 minutes*\n", "\n", "In this tutorial, we:\n", "\n", "1. Reviewed the principles of RSA in the context of machine learning and computational neuroscience and applied it to the problem of comparing representations between fMRI patterns and neural network models for a vision task.\n", "\n", "2. Explored the structure of AlexNet and extracted activations from the different layers of this neural network to a set of visual stimuli. \n", "\n", "3. Evaluated how well the different layers of AlexNet could explain neural responses from humans to the same stimuli, by comparing the representations extracted from these layers (which we treated as different models) to the activity patterns derived from fMRI data.\n", "\n", "4. Used frequentist statistical inference to compare the performance of the different model representations (i.e. layers).\n", "\n", "5. Addressed two sources of model-performance estimation error that statistical inference must account for in addition to the error due to measurement noise: stimulus sampling and subject sampling, using the 2-factor bootstrap method." ] } ], "metadata": { "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "W1D3_Tutorial3", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 4 }