Skip to content

Efficient-Scalable-Machine-Learning/gd-ssm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

State-space models can learn in-context by gradient descent

This is the first version of our paper State-space models can learn in-context by gradient descent. The model code is highly influenced by the S5 repository, and the regression tasks influenced from Transformers Learn In-Context by Gradient Descent.

Installation

pip install -r requirements.txt

Reproducing the results

  1. For linear regression task:

    python3 run_train.py --epochs=10000 --dataset=normal_token_vector --ssm_lr=0.0001 --analyse=True --n_layers=1 --lr_factor=2 --regression=linear --dataset_size=10

  2. For non-linear regression task:

    python3 run_train.py --epochs=10000 --dataset=normal_token_vector --ssm_lr=0.0001 --analyse=True --n_layers=1 --lr_factor=2 --regression=non-linear --dataset_size=10 --activation_fn=half_glu2

For the argument details, please have a look at run_train.py.

For normal token setup, please use main branch and for the constructed token setup, please use constructed-token branch.

Future works

We are working on a more efficient implementation of our GD-SSM model that can be used for more advanced sequence modelling task.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages