Paper Link : https://arxiv.org/abs/2304.10573
Check out https://github.com/philippe-eecs/JaxDDPM for an implementation of DDPMs in JAX for continuous spaces!
Run Line for each variant. Edit the script location above to change hyperparameters and environments to sweep over.
python3 launcher/examples/train_ddpm_iql_offline.py --variant 0...N
Run
python3 launcher/examples/train_ddpm_iql_finetune.py --variant 0...N
Main run script were variant dictionary is passed.
Run
pip install --upgrade pip
pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
See instructions for other versions of CUDA here.
Based from a re-implementation of https://github.com/ikostrikov/jaxrl
Cite this paper
@misc{hansenestruch2023idql,
title={IDQL: Implicit Q-Learning as an Actor-Critic Method with Diffusion Policies},
author={Philippe Hansen-Estruch and Ilya Kostrikov and Michael Janner and Jakub Grudzien Kuba and Sergey Levine},
year={2023},
eprint={2304.10573},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Please also cite the JAXRL repo as well if you use this repo
@misc{jaxrl,
author = {Kostrikov, Ilya},
doi = {10.5281/zenodo.5535154},
month = {10},
title = {{JAXRL: Implementations of Reinforcement Learning algorithms in JAX}},
url = {https://github.com/ikostrikov/jaxrl},
year = {2021}
}