Skip to content

Commit 2dd73af

Browse files
authored
Using jax scan for PPO + atari + envpool XLA (#328)
jax.scan for ppo + atari + envpool and corresponding docs and tests
1 parent b558b2b commit 2dd73af

File tree

9 files changed

+626
-3
lines changed

9 files changed

+626
-3
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ You may also use a prebuilt development environment hosted in Gitpod:
122122
| | [`ppo_atari_lstm.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_lstmpy)
123123
| | [`ppo_atari_envpool.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy)
124124
| | [`ppo_atari_envpool_xla_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax.py), [docs](/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy)
125+
| | [`ppo_atari_envpool_xla_jax_scan.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_envpool_xla_jax_scan.py), [docs](/rl-algorithms/ppo/#ppo_atari_envpool_xla_jax_scanpy)
125126
| | [`ppo_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_procgen.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_procgenpy)
126127
| | [`ppo_atari_multigpu.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_multigpupy)
127128
| | [`ppo_pettingzoo_ma_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy)

benchmark/ppo.sh

+7-1
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,10 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
104104
--command "poetry run python cleanrl/gymnasium_support/ppo_continuous_action.py --cuda False --track" \
105105
--num-seeds 3 \
106106
--workers 9
107-
107+
108+
poetry install --with envpool,jax
109+
python -m cleanrl_utils.benchmark \
110+
--env-ids Pong-v5 BeamRider-v5 Breakout-v5 \
111+
--command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --capture-video" \
112+
--num-seeds 3 \
113+
--workers 1

0 commit comments

Comments
 (0)