A library intended for running model-based RL experiments, written with JAX. Currently only includes a reimplementation of PETS. Other algorithm implementations are planned soon!
Warning: This is a work-in-progress, and has not been evaluated on harder environments! Please let me know if you find any bugs.
- Evaluate on harder environments (e.g. HalfCheetah, Ant).
- Implement other more recent model-based algorithms.
A Dockerfile
with all required dependencies is provided in the /docker/
folder, together with an accompanying docker-compose.yml
file.
Remember to include the appropriate mounts in the docker-compose file as necessary for your needs!
A starter script for running an example experiment on cartpole is provided in model_based_experiment.py
.
This script can be run via
python3 model_based_experiment.py
--logdir DIR (optional) Directory for saving checkpoints and
rollout recordings.
--save-every FREQ (optional) Saving frequency. Defaults to 1 (i.e.
save after every iteration)
--keep-all-checkpoints (optional) Flag which enables saving of all
checkpoints (instead of only the most
recent one).
--iters ITERS (optional) Number of training iterations to run.
Defaults to 100.
-s SEED (optional) Experiment random seed. If not
provided, uniformly chosen in
[0, 10000).
env ENV (required) Experiment environment. Currently
supports [`MujocoCartpole-v0`,
`HalfCheetah-v3`]
agent_type AGENT (required) Agent type. Choices: [PETS, Policy].
For example, to run PETS and save recordings of rollouts to /external/
:
python3 model_based_experiment.py --logdir /external/ MujocoCartpole-v0 PETS