We provide a codebase for Towards a "universal translator" for neural dynamics at single-cell, single-spike resolution. Neuroscience research has made immense progress over the last decade, but our understanding of the brain remains fragmented and piecemeal: the dream of probing an arbitrary brain region and automatically reading out the information encoded in its neural activity remains out of reach. In this work, we build towards a first foundation model for neural spiking data that can solve a diverse set of tasks across multiple brain areas.
Create conda environment
conda env create -f env.yaml
Activate the environment
conda activate ibl-fm
We have uploaded both the processed sessions from the IBL dataset and the pre-trained multi-session models to Hugging Face. You can access these resources here. Downloading them is straightforward and allows you to integrate these models and datasets seamlessly into your projects.
-
Navigate to the script directory:
cd script
-
Start the training process:
source train_sessions.sh
-
Modify Model Configurations: To change the model for training, update the YAML files in
src/configs
and adjust settings insrc/train_sessions.py
. -
Example of Trainer and Mode Configurations:
# Default setting # Load configuration kwargs = { "model": "include:src/configs/ndt1_stitching.yaml" } config = config_from_kwargs(kwargs) config = update_config("src/configs/ndt1_stitching.yaml", config) config = update_config("src/configs/ssl_sessions_trainer.yaml", config)
-
Setting the Number of Sessions: To determine the number of sessions for training, edit
ssl_sessions_trainer.yaml
. The paper used configurations of 1, 10, or 34 sessions.num_sessions: 10 # Number of sessions to use in SSL training.
-
Training Logs: Training logs will be uploaded to Weights & Biases (wandb) and saved in the
results
folder.
The scripts provided are designed for use on a High-Performance Computing (HPC) environment with Slurm. They allow for fine-tuning and evaluation of the model using multiple test sessions.
- Script for Multiple Sessions:
To submit jobs for all test sessions listed in
data/test_re_eids.txt
for fine-tuning and evaluation, use the following command:source run_finetune_multi_session.sh NDT1 all 10 train-eval
- Script for a Single Session:
To execute fine-tuning and evaluation for a specific test session, use the command below. Replace the placeholder for EID with the actual unique ID of the test session.
source finetune_eval_multi_session.sh NDT1 all 10 5dcee0eb-b34d-4652-acc3-d10afc6eae68 train-eval
MODEL_NAME
: The name of the model (e.g., NDT1, NDT2).MASK_MODE
: The masking mode to apply (e.g., all, temporal).NUM_TRAIN_SESSIONS
: Number of training sessions to be used (e.g., 1, 10, 34).EID
: Unique identifier for a specific test session.MODE
: The operation mode (e.g., train, eval, train-eval).
Both scripts load the pre-trained model from the results
folder and save the evaluation results in .npy
files.
-
Navigate to the script directory:
cd script
-
Run the visualization script:
source draw.sh NUM_TRAIN_SESSIONS
This script outputs images visualizing results metrics, which are stored in the
results/table
folder.
Neural Data Transformer (NDT1) - re-implementation
The configuration for NDT1 is src/configs/ndt1.yaml
. Set the number of neurons by:
n_channels: NUM_NEURON # number of neurons recorded