(Published in Transactions on Machine Learning Research)
Paper URL: https://openreview.net/forum?id=lh6vOAHuvo
Abstract: In this paper we investigate transformer architectures designed for partially observable online reinforcement learning. The self-attention mechanism in the transformer architecture is capable of capturing long-range dependencies and it is the main reason behind its effectiveness in processing sequential data. Nevertheless, despite their success, transformers have two significant drawbacks that still limit their applicability in online reinforcement learning: (1) in order to remember all past information, the self-attention mechanism requires access to the whole history to be provided as context. (2) The inference cost in transformers is expensive. In this paper, we introduce recurrent alternatives to the transformer self-attention mechanism that offer context-independent inference cost, leverage long-range dependencies effectively, and performs well in online reinforcement learning task. We quantify the impact of the different components of our architecture in a diagnostic environment and assess performance gains in 2D and 3D pixel-based partially-observable environments (e.g. T-Maze, Mystery Path, Craftax, and Memory Maze). Compared with a state-of-the-art architecture, GTrXL, inference in our approach is at least 40% cheaper while reducing memory use more than 50%. Our approach either performs similarly or better than GTrXL, improving more than 37% upon GTrXL performance in harder tasks.
This repository contains the code for reproducing the results from the AGaLiTe: Approximate Gated Linear Transformers for Online Reinforcement Learning paper, published in Transactions on Machine Learning Research (TMLR). The experiments cover several challenging environments, and the repository is organized based on the specific implementations used for each.
-
T-Maze and Mystery Path experiments were conducted using the CleanRL implementation of PPO. The relevant code for these experiments can be found in the
src/
directory. -
Craftax experiments utilized the purejaxrl implementation of PPO, optimized for JAX-based environments. This implementation leverages
jax.lax.scan
for efficient environment stepping, and the code for these experiments is located in thesrc_pure/
directory.
Follow https://jax.readthedocs.io/en/latest/installation.html for installing Jax and Jaxlib. Then run the following command to install the dependencies:
# Python version 3 is required
$ pip install -r requirements.txt
Install Weights and Biases for logging from https://docs.wandb.ai/quickstart.
# AGaLiTe in T-Maze corridor length 160
$ python trainer.py +tmaze=agalite task.corridor_len=160
# GaLiTe
$ python trainer.py +tmaze=galite
# GTrXL256
$ python trainer.py +tmaze=gtrxl256
# GTrXL128
$ python trainer.py +tmaze=gtrxl128
# LSTM
$ python trainer.py +tmaze=lstm
# GRU
$ python trainer.py +tmaze=gru
# AGaLiTe (\eta=4) in MPGrid
$ python trainer.py +mysterypath=agalite4 task.env_name=MysteryPath-Grid-Easy-v0
# AGaLiTe (\eta=4) in MP
$ python trainer.py +mysterypath=agalite4 task.env_name=MysteryPath-Easy-v0
# GTrXL128 in MPGrid
$ python trainer.py +mysterypath=gtrxl128 task.env_name=MysteryPath-Grid-Easy-v0
# GTrXL64 in MP
$ python trainer.py +mysterypath=gtrxl64 task.env_name=MysteryPath-Easy-v0
# GTrXL32 in MP
$ python trainer.py +mysterypath=gtrxl32 task.env_name=MysteryPath-Easy-v0
# AGaLiTe in Craftax Symbolic
$ python trainer_pure.py +craftax=arelit
# GTrXL128 in Craftax Symbolic
$ python trainer_pure.py +craftax=gtrxl128
The training script uses Hydra configuration management, the list of available configurations could be invoked using:
$ python trainer.py +<TASK_NAME>=<BASE_CONFIG_NAME> --help
Alternatively, for the purejaxrl implementation use:
$ python trainer_pure.py +<TASK_NAME>=<BASE_CONFIG_NAME> --help
- AGaLiTe implementation in Jax+Flax:
./src/models/agalite/agalite.py
or./src_pure/models/agalite.py
- GaLiTe implementation in Jax+Flax:
./src/models/agalite/galite.py
- GTrXL implementation in Jax+Flax:
./src/models/gtrxl.py
or./src_pure/models/gtrxl.py
- purejaxrl implementation:
./src_pure/purejaxrl/ppo_rnn.py
- Subhojeet Pramanik
- Esraa Elilemy
- Marlos C. Machado
- Adam White
@article{
pramanik2024agalite,
title={{AG}aLiTe: Approximate Gated Linear Transformers for Online Reinforcement Learning},
author={Subhojeet Pramanik and Esraa Elelimy and Marlos C. Machado and Adam White},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=lh6vOAHuvo},
note={}
}