This is the first version of our paper State-space models can learn in-context by gradient descent. The model code is highly influenced by the S5 repository, and the regression tasks influenced from Transformers Learn In-Context by Gradient Descent.
pip install -r requirements.txt
-
For linear regression task:
python3 run_train.py --epochs=10000 --dataset=normal_token_vector --ssm_lr=0.0001 --analyse=True --n_layers=1 --lr_factor=2 --regression=linear --dataset_size=10
-
For non-linear regression task:
python3 run_train.py --epochs=10000 --dataset=normal_token_vector --ssm_lr=0.0001 --analyse=True --n_layers=1 --lr_factor=2 --regression=non-linear --dataset_size=10 --activation_fn=half_glu2
For the argument details, please have a look at run_train.py
.
For normal token setup, please use main
branch and for the constructed token setup, please use constructed-token
branch.
We are working on a more efficient implementation of our GD-SSM model that can be used for more advanced sequence modelling task.