We provide a codebase for Towards a "universal translator" for neural dynamics at single-cell, single-spike resolution. Neuroscience research has made immense progress over the last decade, but our understanding of the brain remains fragmented and piecemeal: the dream of probing an arbitrary brain region and automatically reading out the information encoded in its neural activity remains out of reach. In this work, we build towards a first foundation model for neural spiking data that can solve a diverse set of tasks across multiple brain areas.
Create conda environment
conda env create -f env.yaml
Activate the environment
conda activate ibl-fm
We have uploaded both the processed sessions from the IBL dataset and the pre-trained multi-session models to Hugging Face. You can access these resources here. Datasets are only visible to members of the organization, so please click to join the org to access the datasets. Downloading them is straightforward and allows you to integrate these models and datasets seamlessly into your projects.
-
Navigate to the script directory:
cd script
-
Start the training process:
source train_sessions.sh
-
Modify Model Configurations: To change the model for training, update the YAML files in
src/configs
and adjust settings insrc/train_sessions.py
. -
Example of Trainer and Mode Configurations:
# Default setting # Load configuration kwargs = { "model": "include:src/configs/ndt1_stitching.yaml" } config = config_from_kwargs(kwargs) config = update_config("src/configs/ndt1_stitching.yaml", config) config = update_config("src/configs/ssl_sessions_trainer.yaml", config)
-
Setting the Number of Sessions: To determine the number of sessions for training, edit
ssl_sessions_trainer.yaml
. The paper used configurations of 1, 10, or 34 sessions.num_sessions: 10 # Number of sessions to use in SSL training.
-
Training Logs: Training logs will be uploaded to Weights & Biases (wandb) and saved in the
results
folder.
The scripts provided are designed for use on a High-Performance Computing (HPC) environment with Slurm. They allow for fine-tuning and evaluation of the model using multiple test sessions.
- Script for Multiple Sessions:
To submit jobs for all test sessions listed in
data/test_re_eids.txt
for fine-tuning and evaluation, use the following command:source run_finetune_multi_session.sh NDT1 all 10 train-eval
- Script for a Single Session:
To execute fine-tuning and evaluation for a specific test session, use the command below. Replace the placeholder for EID with the actual unique ID of the test session.
source finetune_eval_multi_session.sh NDT1 all 10 5dcee0eb-b34d-4652-acc3-d10afc6eae68 train-eval
MODEL_NAME
: The name of the model (e.g., NDT1, NDT2).MASK_MODE
: The masking mode to apply (e.g., all, temporal).NUM_TRAIN_SESSIONS
: Number of training sessions to be used (e.g., 1, 10, 34).EID
: Unique identifier for a specific test session.MODE
: The operation mode (e.g., train, eval, train-eval).
Both scripts load the pre-trained model from the results
folder and save the evaluation results in .npy
files.
-
Navigate to the script directory:
cd script
-
Run the visualization script:
source draw.sh NUM_TRAIN_SESSIONS
This script outputs images visualizing results metrics, which are stored in the
results/table
folder.
Neural Data Transformer (NDT1) - re-implementation
The configuration for NDT1 is src/configs/ndt1.yaml
. Set the number of neurons by:
n_channels: NUM_NEURON # number of neurons recorded