The official implementation of Policy-Guided Diffusion - built by Matthew Jackson and Michael Matthews.
- Offline RL agents (TD3+BC, IQL),
- Trajectory-level U-Net diffusion model,
- EDM diffusion training and sampling,
- Runs on the D4RL benchmark.
Diffusion and agent training is implemented entirely in Jax, with extensive JIT-compilation and parallelization!
Update (28/06/24): Added WandB report with diffusion and agent model training logs.
Diffusion and agent training is executed with python3 train_diffusion.py
and python3 train_agent.py
, with all arguments found in util/args.py
.
--log --wandb_entity [entity] --wandb_project [project]
enables logging to WandB.--debug
disables JIT compilation.
- Build docker image
cd docker && ./build.sh && cd ..
- (To enable WandB logging) Add your account key to
docker/wandb_key
:
echo [KEY] > docker/wandb_key
./run_docker.sh [GPU index] python3.9 [train_script] [args]
Diffusion training example:
./run_docker.sh 0 python3.9 train_diffusion.py --log --wandb_project diff --wandb_team flair --dataset_name walker2d-medium-v2
Agent training example:
./run_docker.sh 6 python3.9 train_agent.py --log --wandb_project agents --wandb_team flair --dataset_name walker2d-medium-v2 --agent iql
If you use this implementation in your work, please cite us with the following:
@misc{jackson2024policyguided,
title={Policy-Guided Diffusion},
author={Matthew Thomas Jackson and Michael Tryfan Matthews and Cong Lu and Benjamin Ellis and Shimon Whiteson and Jakob Foerster},
year={2024},
eprint={2404.06356},
archivePrefix={arXiv},
primaryClass={cs.LG}
}