Comparing networks#
Comparing networks: Characterizing computational similarity in task-trained recurrent neural networks
By Neuromatch Academy
Content creators: Chris Versteeg
Content reviewers: Chris Versteeg, Hannah Choi, Eva Dyer
Production editors: Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk
Background#
Researchers training artificial networks to perform tasks (e.g., image classification, motor control) have found that the activity in the artificial networks can resemble the activity of biological neurons from brain areas thought to perform similar tasks. Unfortunately, it is unclear whether a superficial similarity in neural activation necessarily translates to a conserved computational strategy. We need ways to assess how well different models are able to capture the computational principles, which will require datasets where the ground-truth computations are known, and we can analyze the similarity between artificial and natural systems. The aim of this project is to explore ways to measure alignment in dynamical systems, and to study different approaches to quantify the changes in representations across different tasks and across different architectures.
Install and import feedback gadget#
Show code cell source
# @title Install and import feedback gadget
!pip install vibecheck datatops --quiet
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_neuroai",
"user_key": "wb2cxze8",
},
).render()
feedback_prefix = "Project_ComparingNetworks"
Project Background#
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_project_background")
Project slides#
If you want to download the slides: https://osf.io/download/vb3tw/
Project Template#
Show code cell source
#@title Project Template
from IPython.display import Image, display
import os
from pathlib import Path
url = "https://github.com/neuromatch/NeuroAI_Course/blob/main/projects/project-notebooks/static/ComputationalSimilarityTemplate.png?raw=true"
display(Image(url=url))
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_project_template")
In this notebook, we are going to provide code to get you started on Q1-Q3 of this project!
The basic outline looks like this:
Section 1: Preparing the environment.
Section 2: Overview of the available tasks.
Section 3: Understanding the Three-Bit Flip-Flop task (3BFF).
Section 4: Training a model to perform 3BFF.
Section 5: Inspecting the performance of trained models.
Part 1: Visualizing latent activity
Part 2: Quantifying latent similarity with State R2.
Part 3: Visualizing Fixed-Point architectures.
Section 6: Introduction to Random Target task.
Importantly, we’ve put landmarks in the notebook to indicate:
Interactive exercises
❓❓❓
Cells that will a decent amount of time to run (>5 mins)
⏳⏳⏳
Tutorial links
This project is mostly associated with the materials presented in W1D3, on comparing activities of artificial and biological networks. One of the main techniques here, DSA, completes and empowers your toolbox by enabling dynamic analysis of activity patterns. You will find the Tutorial 2 from the W1D1 the most similar by the model’s architecture and the idea of goal-oriented networks.
Section 1: Preparing the environment#
Disclaimer: As an alternative to Google Colab, Kaggle, and local installation, we have prepared a Dockerfile with the instructions on the virtual environment setup. It will allow you to work locally with no interventions into already installed packages (you will just need to install Docker and run two commands). Grab a Dockerfile, put it in the same folder as this notebook, and follow the instructions in README file.
IF USING COLAB // KAGGLE:
Uncomment the lines of code below and run them in order. The top, the second, and the last cells only need to be run once, but the third cells (envStr) need to be re-run if the Colab // Kaggle notebook crashes. These blocks install the needed dependencies and set up your environment. Notice that the first and third cell contents depend on whether you use Colab or Kaggle.
⏳⏳⏳
Colab // Kaggle installation (Part 1)
Show code cell source
# @markdown Colab // Kaggle installation (Part 1)
# ! git clone https://github.com/neuromatch/ComputationThruDynamicsBenchmark
# %cd ComputationThruDynamicsBenchmark
# ! pip install -e .
## RUN THIS CELL, THEN RESTART SESSION AS PROMPTED (BUTTON AT BOTTOM OF THIS CELL'S FINISHED OUTPUT). DO NOT NEED TO RUN AGAIN
## PLEASE RESTART THE ENVIRONMENT FOR KAGGLE MANUALLY (Run > Restart & clear cell outputs)
Colab // Kaggle installation (Part 2)
Show code cell source
# @markdown Colab // Kaggle installation (Part 2)
# !pip uninstall -y torchaudio torchvision
Colab // Kaggle installation (Part 3)
Show code cell source
# @markdown Colab // Kaggle installation (Part 3)
## GET BACK TO THE DIRECTORY AND CONFIGURE .env
################ COLAB #####################
# %cd /content/ComputationThruDynamicsBenchmark/
# envStr = """HOME_DIR=/content/ComputationThruDynamicsBenchmark/
# #Don't change these
# TRAIN_INPUT_FILE=train_input.h5\nEVAL_INPUT_FILE=eval_input.h5
# EVAL_TARGET_FILE=eval_target.h5
# """
#############################################
################ KAGGLE #####################
# %cd /kaggle/working/ComputationThruDynamicsBenchmark/
# envStr = """HOME_DIR=/kaggle/working/ComputationThruDynamicsBenchmark/
# #Don't change these
# TRAIN_INPUT_FILE=train_input.h5\nEVAL_INPUT_FILE=eval_input.h5
# EVAL_TARGET_FILE=eval_target.h5
# """
##############################################
################ COLAB // KAGGLE #####################
# with open('.env','w') as f:
# f.write(envStr)
##############################################
Colab // Kaggle installation (Part 4)
Show code cell source
# @markdown Colab // Kaggle installation (Part 4)
# !git clone https://github.com/mitchellostrow/DSA
# %cd DSA/
# !pip install -e .
IF RUNNING LOCALLY:
Follow the instructions here to setup the separate environment for this project, or you can run the cell below for general installment.
Local installation
Show code cell source
# @markdown Local installation
import contextlib
import io
import os
dirname = "ComputationThruDynamicsBenchmark"
with contextlib.redirect_stdout(io.StringIO()): #to suppress output
if not os.path.isdir(dirname):
! git clone https://github.com/neuromatch/ComputationThruDynamicsBenchmark
%cd ComputationThruDynamicsBenchmark
! pip install -e .
envStr = """HOME_DIR=ComputationThruDynamicsBenchmark/
#Don't change these
TRAIN_INPUT_FILE=train_input.h5\nEVAL_INPUT_FILE=eval_input.h5
EVAL_TARGET_FILE=eval_target.h5
"""
with open('.env','w') as f:
f.write(envStr)
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_preparing_environment")
Section 2: Overview of the available tasks#
First, let’s take a high-level look at the tasks that we are going to use to understand computation in artificial networks!
We’ll start by loading in some packages.
# set the random seed for reproducibility
import random
import dotenv
import pathlib
import os
import logging
# comment the next three lines if you want to see all training logs
pl_loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict if 'pytorch_lightning' in name]
for pl_log in pl_loggers:
logging.getLogger(pl_log.name).setLevel(logging.WARNING)
random.seed(2024)
dotenv.load_dotenv(override=True)
HOME_DIR = os.getenv("HOME_DIR")
if HOME_DIR is None:
HOME_DIR = ""
print(HOME_DIR)
ComputationThruDynamicsBenchmark/
The Computation-Thru-Dynamics Benchmark has three distinct behavioral tasks.
These tasks are called:
Three-Bit Flip-Flop (3BFF) (see Sussillo & Barak 2013)
MultiTask (See Driscoll et al. 2023)
RandomTarget (See Codol et al. 2023)
We chose these tasks because they represent a variety of task complexities. We have a pretty good understanding of how the simpler tasks operate (3BFF), but really are only starting to scratch the surface of more complex tasks (RandomTarget).
Specificially, in the Random Target task, the actions that the model takes can affect the future inputs, making it an important test case for being able to understand the dynamics of interacting systems!
Each task (which we call a “task environment”) follows a standardized format that allows alternative task environments to be incorporated without any changes to the training pipeline.
Here, we’ll take a walk through the two tasks in the project template (TBFF and RandomTarget) and inspect the behavior of networks trained in these environments.
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_overview_of_the_available_tasks")
Section 3: Understanding the Three-Bit Flip-Flop task (3BFF)#
We’re going to start out with the task that launched a thousand Nature papers, the 3-Bit Flip-Flop. Sussillo & Barak 2013 used the three-bit flip-flop in their seminal attempts to understand how dynamics can give rise to computation!
The code snippet below instantiates a “TaskEnv” object, which contains the logic for the NBFF task.
❓❓❓
The default parameters are shown in ./interpretability/task_modeling/configs/env_task/NBFF.yaml
, but try changing the parameters below to see how that affects trials generated from the environment. Note that this task is modular in the number of bits as well, so it provides an easy way to scale the dimensionality of a very simple dynamical system.
❓❓❓
from ctd.task_modeling.task_env.task_env import NBitFlipFlop
n = 3 # The number of bits in the flip-flop (default: 3)
trial_length = 500 # The number of time steps in each trial (default: 500)
switch_prob = 0.015 # The probability of an input pulse (default: 0.015 pulses/channel / time step)
noise = 0.15 # The standard deviation of the Gaussian noise added to the input (default: 0.15)
# This line creates the NBitFlipFlop environment. See ctd.task_modeling.task_env.task_env.NBitFlipFlop for more information.
env_3bff = NBitFlipFlop(
n = n,
n_timesteps=trial_length,
switch_prob=switch_prob,
noise=noise
)
# Renders a random trial from the environment
env_3bff.render()
Above, we are plotting the inputs and outputs of the 3BFF task. One trial is 500 time steps, each with a 1% probability of getting an “up” or “down” pulse on each of its 3 input channels. When the task receives an “up” pulse, the state corresponding to that input channel moves from zero to one (if possible), and if a state at one receives a “down” pulse, it goes to zero. In this way, this system acts as 3 bits of memory, encoding 8 potential system states (2^3 states). We add noise to the inputs of the system so that it better reflects realistic computations that a neural circuit might perform.
Try changing the parameters of your 3BFF environment to see how the behavior changes!
Another way to visualize this is to view the three states in 3D. Below, you can see that the 8 potential states appear as the vertices of a cube. Each trial is plotted as a column.
env_3bff.render_3d(n_trials=6)
Now that we can see the basic logic of the task, let’s do a basic overview of what task training is!
For task-training, we are simply training a model (e.g., an RNN) to produce a set of outputs given a set of inputs. This input/output relationship defines the task that the model is performing. In the case of 3BFF, an input pulse should cause the model’s output to change in a way that reflects the switching of a bit.
3BFF Training Objective:
3BFF models are trained to minimize the MSE between the desired output and the output of the model, with some other components that pressure the solution to be smooth. If you’re interested in the specifics, the implementation of the loss function can be found as the NBFFLoss object in ctd/task_modeling/task_env/loss_func.py
.
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_understanding_3bff")
Section 4: Training a model to perform 3BFF#
For this tutorial, we are using PyTorch Lightning to abstract much of the engineering away, allowing you to focus your full attention on the scientific questions you want to tackle!
This segment takes around 8 minutes to train, so I’d recommend planning your runs accordingly!
⏳⏳⏳
The cell below will create a recurrent neural network (RNN) model and use the 3BFF environment to generate samples on which the model will be trained!
Unfortunately, it generates a lot of output, so if you don’t care to see the model progress, set enable_progress_bar
to False below.
from ctd.task_modeling.model.rnn import GRU_RNN
from ctd.task_modeling.datamodule.task_datamodule import TaskDataModule
from ctd.task_modeling.task_wrapper.task_wrapper import TaskTrainedWrapper
from pytorch_lightning import Trainer
enable_progress_bar = False
# Step 1: Instantiate the model
rnn = GRU_RNN(latent_size = 128) # Look in ctd/task_modeling/models for alternative choices!
# Step 2: Instantiate the task environment
task_env = env_3bff
# Step 3: Instantiate the task datamodule
task_datamodule = TaskDataModule(task_env, n_samples = 1000, batch_size = 1000)
# Step 4: Instantiate the task wrapper
task_wrapper = TaskTrainedWrapper(learning_rate=1e-3, weight_decay = 1e-8)
# Step 5: Initialize the model with the input and output sizes (3 inputs, 3 outputs, in this case)
rnn.init_model(
input_size = task_env.observation_space.shape[0],
output_size = task_env.action_space.shape[0]
)
# Step 6: Set the environment and model in the task wrapper
task_wrapper.set_environment(task_env)
task_wrapper.set_model(rnn)
# Step 7: Define the PyTorch Lightning Trainer object
trainer = Trainer(max_epochs=500, enable_progress_bar=enable_progress_bar)
# Step 8: Fit the model
trainer.fit(task_wrapper, task_datamodule)
Now, we use pickle
to save the trained model and datamodule for future analyses!
❓❓❓
Once you get this model trained, feel free to try changing the hyperparameters to see if you can get the model to train faster!
❓❓❓
import pickle
# save model as .pkl
save_dir = pathlib.Path(HOME_DIR) / "models_GRU_128"
save_dir.mkdir(exist_ok=True)
with open(save_dir / "model.pkl", "wb") as f:
pickle.dump(task_wrapper, f)
# save datamodule as .pkl
with open(save_dir / "datamodule_sim.pkl", "wb") as f:
pickle.dump(task_datamodule, f)
So that we can start comparing our models, we’re going to train a second GRU_RNN to perform the 3BFF task, except this time, we’ll use an alternative model called a Neural ODE!
Notice that we’re using the same datamodule as for the first model, meaning that we can directly compare the two models trial-by-trial.
Again, this will take around 10 minutes to train!
⏳⏳⏳
from ctd.task_modeling.model.node import NODE
enable_progress_bar = False
rnn = NODE(latent_size = 3, num_layers = 3, layer_hidden_size=64) # Look in ctd/task_modeling/models for alternative choices!
task_wrapper = TaskTrainedWrapper(learning_rate=1e-3, weight_decay = 1e-10)
rnn.init_model(
input_size = task_env.observation_space.shape[0],
output_size = task_env.action_space.shape[0]
)
task_wrapper.set_environment(task_env)
task_wrapper.set_model(rnn)
trainer = Trainer(max_epochs=500, enable_progress_bar=enable_progress_bar)
trainer.fit(task_wrapper, task_datamodule)
save_dir = pathlib.Path(HOME_DIR) / "models_NODE_3"
save_dir.mkdir(exist_ok=True)
with open(save_dir / "model.pkl", "wb") as f:
pickle.dump(task_wrapper, f)
# save datamodule as .pkl
with open(save_dir / "datamodule_sim.pkl", "wb") as f:
pickle.dump(task_datamodule, f)
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_training_a_model_to_perform_3bff")
Section 5: Inspecting the performance of trained models#
Now that the models have been trained, let’s see if we can determine how similar their computational strategies are to each other!
To make your life easier, we’ve provided an “analysis” object that abstracts away much of the data handling, allowing you to work more easily with the data from the models.
The analysis object also offers visualization tools that can help to see how well the trained model learned to perform the task!
For example, plot_trial_io is a function that plots (for a specified number of trials):
Latent activity
Controlled output
Target output
Noisy inputs to model
❓❓❓
Try changing trials that are plotted. Do the models capture all of the states equally well?
❓❓❓
Part 1: Visualizing latent activity#
from ctd.comparison.analysis.tt.tt import Analysis_TT
fpath_GRU_128 = HOME_DIR + "models_GRU_128/"
# Create the analysis object:
analysis_GRU_128 = Analysis_TT(
run_name = "GRU_128_3bff",
filepath = fpath_GRU_128)
analysis_GRU_128.plot_trial_io(num_trials = 2)
fpath_NODE = HOME_DIR + "models_NODE_3/"
# Create the analysis object:
analysis_NODE = Analysis_TT(
run_name = "NODE_3_3bff",
filepath = fpath_NODE)
analysis_NODE.plot_trial_io(num_trials = 2)
There are also useful data visualization functions, such as visualizing a scree plot of the latent activity.
A scree plot shows the % of variance in the highest principle component dimensions. From this plot, we can see that the GRU has the majority of its variance in the first 3 PCs, but significant variance remains in the lower PCs!
analysis_GRU_128.plot_scree()
array([0.48745748, 0.23808762, 0.12121433, 0.02579751, 0.02215279,
0.02057775, 0.01052088, 0.00839948, 0.00731314, 0.00564797])
Importantly, the analysis object also provides functions that give access to the raw latent activity, predicted outputs, etc. of the trained models! All of these functions accept a “phase” variable that designates whether to return the training and/or validation datasets. These functions are:
get_latents()
: Returns latent activity of the trained modelget_inputs()
: Returns the inputs to the model (for 3BFF, the input pulses)get_model_output()
: Returns a dict that contains all model outputs:controlled - the variable that the model is controlling
latents - the latent activity
actions - the output from the model (for RandomTarget only)
states - the state of the environment (for RandomTarget only)
joints - Joint angles (for RandomTarget only)
print(f"All data shape: {analysis_GRU_128.get_latents().shape}")
print(f"Train data shape: {analysis_GRU_128.get_latents(phase = 'train').shape}")
print(f"Validation data shape: {analysis_GRU_128.get_latents(phase = 'val').shape}")
All data shape: torch.Size([1000, 500, 128])
Train data shape: torch.Size([800, 500, 128])
Validation data shape: torch.Size([200, 500, 128])
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_visualizing_latent_activity")
Part 2: Using affine transformations to compare latent activity#
Now that we have the latent activity for the 64D and the 128D GRU models trained on 3BFf, we can investigate how similar their latent activity is.
One problem: The models may be arbitrarily rotated, scaled, and translated relative to each other!
This means that we need to find the best “fit” between the two models that doesn’t fail when they are equivalent under an “affine” transformation (meaning a linear transformation and/or translation).
Luckily, we have a tool that can solve this problem for us! Linear regression.
In this code, we are:
Getting the latent activity from each model
Performing PCA on the latent activity (to get the dimensions ordered by their variance)
Fit a linear regression from one set of latent activity to the other.
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.decomposition import PCA
source = analysis_GRU_128
target = analysis_NODE
# Get the latent activity from the validation phase for each model:
latents_source = source.get_latents(phase='train').detach().numpy()
latents_targ = target.get_latents(phase='train').detach().numpy()
latents_source_val = source.get_latents(phase='val').detach().numpy()
latents_targ_val = target.get_latents(phase='val').detach().numpy()
n_trials, n_timesteps, n_latent_source = latents_source.shape
n_trials, n_timesteps, n_latent_targ = latents_targ.shape
n_trials_val, n_timesteps_val, n_latent_source_val = latents_source_val.shape
n_trials_val, n_timesteps_val, n_latent_targ_val = latents_targ_val.shape
print(f"Latent shape for source model: {latents_source.shape}"
f"\nLatent shape for target model: {latents_targ.shape}")
Latent shape for source model: (800, 500, 128)
Latent shape for target model: (800, 500, 3)
# Perform PCA on both latent spaces to find axes of highest variance
pca_source = PCA()
pca_targ = PCA()
lats_source_pca = pca_source.fit_transform(latents_source.reshape(-1, n_latent_source)).reshape((n_trials, n_timesteps, -1))
lats_source_pca_val = pca_source.transform(latents_source_val.reshape(-1, n_latent_source)).reshape((n_trials, n_timesteps, -1))
lats_targ_pca = pca_targ.fit_transform(latents_targ.reshape(-1, n_latent_targ)).reshape((n_trials, n_timesteps, -1))
lats_targ_pca_val = pca_targ.transform(latents_targ_val.reshape(-1, n_latent_targ_val)).reshape((n_trials_val, n_timesteps_val, -1))
# Fit a linear regression model to predict the target latents from the source latents
reg = LinearRegression().fit(lats_source_pca.reshape(-1, n_latent_source), lats_targ_pca.reshape(-1, n_latent_targ))
# Get the R2 of the fit
preds = reg.predict(lats_source_pca_val.reshape(-1, n_latent_source_val))
r2s = r2_score(lats_targ_pca_val.reshape((-1, n_latent_targ_val)), preds, multioutput = "raw_values")
r2_var = r2_score(lats_targ_pca_val.reshape((-1, n_latent_targ_val)), preds, multioutput = "variance_weighted")
print(f"R2 of linear regression fit: {r2s}")
print(f"Variance-weighted R2 of linear regression fit: {r2_var}")
R2 of linear regression fit: [0.80023307 0.52220986 0.77564701]
Variance-weighted R2 of linear regression fit: 0.7076211324642876
So, the variance weighted R2 from the source to the target is ~0.93.
Importantly, we had to pick a “direction” to compute this R2 value. What happens if we switch the source and targets?
❓❓❓
Try reversing the direction (the source and targets) and see how well the model fits!
❓❓❓
One final tool that is provided to you is the comparison object, which makes many of these direct comparisons within the object itself. Here is one example visualization that shows how similar the latent activities of two example trials are for these two models!
This function has the affine transformation “built-in,” so you don’t need this to show what your R2 value above looks like in the first 3 PCs!
from ctd.comparison.comparison import Comparison
comp = Comparison()
comp.load_analysis(analysis_GRU_128, reference_analysis=True)
comp.load_analysis(analysis_NODE)
comp.plot_trials_3d_reference(num_trials=2)
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_using_affine_transformations_to_compare_latent_activity")
Part 3: Fixed-point finding#
Finally, we can use fixed-point finding to inspect the linearized dynamics of the trained model.
What are fixed-points?
Fixed points are points in the dynamics for which the flow field is zero, meaning that points at that location do not move.
The fixed point structure for the 3BFF task was first shown in the original Sussillo and Barack paper.
We can see that the fixed-points are at the vertices of the cube above, drawing the activity towards them and keeping it there until an input pushes it out!
We use a modified version of a fixed point finder released by Golub et al. 2018 to search the flow field for these zero points.
❓❓❓
Try changing some of these parameters:
How quickly are the fixed-points found in the model?
How many initializations are needed to find the fixed points?
Do the stability properties tell us anything about the underlying computation?
❓❓❓
Importantly from Driscol et al. 2022, we know that changes in the inputs can have large effects on the fixed point architecture, so we’re going to set the inputs to zero in this optimization.
import torch
import contextlib
import io
with contextlib.redirect_stdout(io.StringIO()): #to suppress output
fps = analysis_GRU_128.plot_fps(
inputs= torch.zeros(3),
n_inits=1024,
learning_rate=1e-3,
noise_scale=0.0,
max_iters=20000,
seed=0,
compute_jacobians=True,
q_thresh=1e-5,
)
import matplotlib.pyplot as plt
q_thesh = 1e-6
q_vals = fps.qstar
x_star = fps.xstar[q_vals < q_thesh]
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x_star[:, 0], x_star[:, 1], x_star[:, 2], c='r', marker='o')
fig.show()
❓❓❓
What can you find out about the FPs of the trained models? Can you modify the FP finding to get more interpretable results?
What can we learn about the computational solution built in this 3BFF network from these fixed-point architectures?
❓❓❓
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_fixed_point_finding")
Section 6: Introducing the Random Target task#
Now that we’ve developed intuition on a simple, well-understood task, let’s move up the ladder of complexity!
The second task is a random-target reaching task performed by an RNN controlling a 2-joint musculoskeletal model of an arm actuated by 6 Mujoco muscles. This environment was built using MotorNet, a package developed by Oli Codol et al. that provides environments for training RNNs to control biomechanical models!
Here is a short clip of what this task looks like when performed by a well-trained model:
Behaviorally, the task has the following structure:
A random initial hand position is sampled from a range of reachable locations; the model is instructed to maintain that hand position.
A random target position is chosen from the range of reachable locations and fed to the model.
After a random delay period, a go-cue is fed to the model, which prompts the model to generate muscle activations that drive the hand to the target location.
On 20% of trials, the go-cue is never supplied (“catch” trials)
On 50% of trials, a randomly directed bump perturbation (5-10 N, 150-300 ms duration) is applied to the hand.
50% of these bumps occur in a small window after the go-cue
50% of these bumps occur at a random time in the trial
The model is trained to:
Minimize the MSE between the hand position and the desired hand position
Minimize the squared muscle activation
with each loss term being weighted by a scalar.
from ctd.task_modeling.task_env.task_env import RandomTarget
from motornet.effector import RigidTendonArm26
from motornet.muscle import MujocoHillMuscle
# Create the analysis object:
rt_task_env = RandomTarget(effector = RigidTendonArm26(muscle = MujocoHillMuscle()))
⏳⏳⏳
Now, to train the model! We use the same procedure as the 3BFF above; however, this model will take a bit longer to train, as of the serial nature of this task, the parallelization allowed by GPUs doesn’t help speed up our training!
⏳⏳⏳
from ctd.task_modeling.model.rnn import GRU_RNN
from ctd.task_modeling.datamodule.task_datamodule import TaskDataModule
from ctd.task_modeling.task_wrapper.task_wrapper import TaskTrainedWrapper
from pytorch_lightning import Trainer
# Step 1: Instantiate the model
rnn = GRU_RNN(latent_size = 128) # Look in ctd/task_modeling/models for alternative choices!
# Step 2: Instantiate the task environment
task_env = rt_task_env
# Step 3: Instantiate the task datamodule
task_datamodule = TaskDataModule(task_env, n_samples = 1000, batch_size = 256)
# Step 4: Instantiate the task wrapper
task_wrapper = TaskTrainedWrapper(learning_rate=1e-3, weight_decay = 1e-8)
# Step 5: Initialize the model with the input and output sizes
rnn.init_model(
input_size = task_env.observation_space.shape[0] + task_env.context_inputs.shape[0],
output_size = task_env.action_space.shape[0]
)
# Step 6: Set the environment and model in the task wrapper
task_wrapper.set_environment(task_env)
task_wrapper.set_model(rnn)
# Step 7: Define the PyTorch Lightning Trainer object (put `enable_progress_bar=True` to observe training progress)
trainer = Trainer(accelerator= "cpu",max_epochs=500,enable_progress_bar=False)
# Step 8: Fit the model
trainer.fit(task_wrapper, task_datamodule)
Importantly, this task is distinct from the previous two tasks because the outputs of the model affect the subsequent inputs!
Visualizing the latent dynamics of models trained on MotorNet tasks, we can see that there are complex features in the state space, but we’ll leave that to you to figure out what they mean!
In the later questions, we will ask you to modify the environments in MotorNet to test how well your models can generalize to new tasks!
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_introducing_the_random_target_task")
Conclusion#
That’s it!
To recap, in this tutorial, we learned:
The basics of two tasks, the Three-Bit Flip-Flop and the Random Target task.
How to train recurrent neural network models on these tasks
Methods of visualizing and quantifying differences between these task-trained models.
As you begin to extend beyond this tutorial, you will likely need to make your own environments, or modify existing environments to test the ability of models to generalize. We’ve tried to document the code-base to make this as easy as possible, but feel free to reach out if you have any questions!
Submit your feedback#
Show code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_conclusion")