Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gymnasium support for DQN #370

Merged
merged 19 commits into from
May 3, 2023
Merged

Conversation

vcharraut
Copy link
Contributor

@vcharraut vcharraut commented Apr 1, 2023

Description

This PR updates the DQN files to the lastest version of gymnasium, replacing gym.

  • dqn.py
  • dqn_jax.py
  • dqn_atari.py
  • dqn_atari_jax.py

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the tests accordingly (if applicable).
  • I have updated the documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers.

If you need to run benchmark experiments for a performance-impacting changes:

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team.
  • I have used the benchmark utility to submit the tracked experiments to the openrlbenchmark/cleanrl W&B project, optionally with --capture-video.
  • I have performed RLops with python -m openrlbenchmark.rlops.
    • For new feature or bug fix:
      • I have used the RLops utility to understand the performance impact of the changes and confirmed there is no regression.
    • For new algorithm:
      • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves generated by the python -m openrlbenchmark.rlops utility to the documentation.
    • I have added links to the tracked experiments in W&B, generated by python -m openrlbenchmark.rlops ....your_args... --report, to the documentation.

Regression report

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
        'dqn_atari_jax?tag=rlops-pilot' \
        'dqn_atari_jax?tag=pr-370-atari-jax' \
    --env-ids Breakout-v5 BeamRider-v5 Pong-v5 \
    --check-empty-runs False \
    --ncols 5 \
    --ncols-legend 2 \
    --output-filename figures/0compare \
    --scan-history \
    --report
────────────────────────────────────────────────────────────────────────────────────── Runtime (m) (mean ± std) ──────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Environment  ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['rlops-pilot']}) ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['pr-370-atari-jax']}) ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Breakout-v5  │ 270.1473263136972                                                │ 538.7802477303775                                                     │
│ BeamRider-v5 │ 271.7741639644951                                                │ 538.6782197420808                                                     │
│ Pong-v5      │ 261.6593977599932                                                │ 522.4641281567034                                                     │
└──────────────┴──────────────────────────────────────────────────────────────────┴───────────────────────────────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────── Episodic Return (mean ± std) ────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Environment  ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['rlops-pilot']}) ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['pr-370-atari-jax']}) ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Breakout-v5  │ 365.77 ± 15.64                                                   │ 356.66 ± 5.64                                                         │
│ BeamRider-v5 │ 5888.53 ± 185.09                                                 │ 6058.41 ± 116.74                                                      │
│ Pong-v5      │ 20.39 ± 0.17                                                     │ 20.39 ± 0.02                                                          │
└──────────────┴──────────────────────────────────────────────────────────────────┴───────────────────────────────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────────── Runtime (m) Average ─────────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Environment                                                           ┃ Average Runtime   ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['rlops-pilot']})      │ 267.8602960127285 │
│ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['pr-370-atari-jax']}) │ 533.3075318763872 │
└───────────────────────────────────────────────────────────────────────┴───────────────────┘

image

https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-dqn_atari_jax--Vmlldzo0MjQ5OTA2

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
        'dqn?tag=pr-370' \
        'dqn_jax?tag=pr-370-jax' \
        'dqn?tag=rlops-pilot' \
        'dqn_jax?tag=rlops-pilot' \
    --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
    --check-empty-runs False \
    --ncols 3 \
    --ncols-legend 2 \
    --output-filename figures/0compare \
    --scan-history \
    --report
────────────────────────────────────────────────────────────────────────────────────── Runtime (m) (mean ± std) ──────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃
┃ Environment    ┃ ['pr-370']})                               ┃ ['pr-370-jax']})                           ┃ ['rlops-pilot']})                          ┃ ['rlops-pilot']})                          ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CartPole-v1    │ 3.099431800075442                          │ 1.8901799905559769                         │ 2.229200170565302                          │ 2.0570977331846896                         │
│ Acrobot-v1     │ 4.185325574186605                          │ 3.3383588646594835                         │ 3.2403913728341207                         │ 3.005497937894226                          │
│ MountainCar-v0 │ 3.5431891388538053                         │ 2.2788801149391746                         │ 2.5699978012313105                         │ 2.3790336879432625                         │
└────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────── Episodic Return (mean ± std) ────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃
┃ Environment    ┃ ['pr-370']})                               ┃ ['pr-370-jax']})                           ┃ ['rlops-pilot']})                          ┃ ['rlops-pilot']})                          ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CartPole-v1    │ 486.82 ± 8.32                              │ 324.99 ± 212.99                            │ 486.82 ± 8.32                              │ 499.26 ± 1.05                              │
│ Acrobot-v1     │ -90.20 ± 1.84                              │ -90.81 ± 1.94                              │ -90.20 ± 1.84                              │ -90.44 ± 0.99                              │
│ MountainCar-v0 │ -194.73 ± 7.30                             │ -191.72 ± 9.33                             │ -194.73 ± 7.30                             │ -169.26 ± 23.75                            │
└────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────────── Runtime (m) Average ─────────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ Environment                                                ┃ Average Runtime    ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│ openrlbenchmark/cleanrl/dqn ({'tag': ['pr-370']})          │ 3.6093155043719505 │
│ openrlbenchmark/cleanrl/dqn_jax ({'tag': ['pr-370-jax']})  │ 2.502472990051545  │
│ openrlbenchmark/cleanrl/dqn ({'tag': ['rlops-pilot']})     │ 2.679863114876911  │
│ openrlbenchmark/cleanrl/dqn_jax ({'tag': ['rlops-pilot']}) │ 2.4805431196740595 │
└────────────────────────────────────────────────────────────┴────────────────────┘

image

https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-dqn_jax--Vmlldzo0MjUwMDM1

@vercel
Copy link

vercel bot commented Apr 1, 2023

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback May 2, 2023 8:03pm

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small comment but otherwise LGTM. Feel free to start the RLops process.

cleanrl/dqn.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn On line 208 of jax atari and 180 of jax classic control have np rather than jnp
https://github.com/vwxyzjn/cleanrl/blob/599f9adfec89d63721578b08b75ec38ab0209372/cleanrl/dqn_jax.py#L180

Im guessing this is a simple mistake (it shouldn't affect performance), can we change to jnp

@pseudo-rnd-thoughts
Copy link
Collaborator

The error is due to needing stable baselines 3 ==2

@vcharraut vcharraut marked this pull request as ready for review April 29, 2023 20:28
@pseudo-rnd-thoughts pseudo-rnd-thoughts mentioned this pull request May 1, 2023
21 tasks
@vwxyzjn
Copy link
Owner

vwxyzjn commented May 3, 2023

No sign of regression, as shown in the PR description. Merging now.

@vwxyzjn vwxyzjn merged commit 39670fc into vwxyzjn:master May 3, 2023
@sdpkjc sdpkjc mentioned this pull request May 6, 2023
3 tasks
@ronuchit
Copy link

ronuchit commented Jun 5, 2023

Hi @vwxyzjn @charraut, I'm wondering what part of this change forced us to add the following line:
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

Vectorization was a useful feature earlier. Thank you!

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 6, 2023

@ronuchit this is due to SB3's replay buffer don't support num_envs>1 I think.

@ronuchit
Copy link

ronuchit commented Jun 6, 2023

I believe it does, actually: https://github.com/DLR-RM/stable-baselines3/blame/master/stable_baselines3/common/buffers.py#L162

We would just need to pass in n_envs=args.num_envs when we instantiate the ReplayBuffer. Perhaps there are other issues at play here?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jun 6, 2023

I believe it does, actually: https://github.com/DLR-RM/stable-baselines3/blame/master/stable_baselines3/common/buffers.py#L162

We would just need to pass in n_envs=args.num_envs when we instantiate the ReplayBuffer. Perhaps there are other issues at play here?

I see. That’s interesting. Would you be interested in making a PR that optionally supports num_envs>1?

@ronuchit ronuchit mentioned this pull request Jun 6, 2023
9 tasks
@ronuchit
Copy link

ronuchit commented Jun 6, 2023

sure, done: #395

@vcharraut vcharraut deleted the dqn-gymnasium branch July 26, 2023 19:42
@sdpkjc sdpkjc mentioned this pull request Aug 26, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants