Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax c51 contrib #224

Merged
merged 20 commits into from
Dec 30, 2022
Merged

Jax c51 contrib #224

merged 20 commits into from
Dec 30, 2022

Conversation

kinalmehta
Copy link
Collaborator

@kinalmehta kinalmehta commented Jun 29, 2022

Description

JAX implementation for C51
Implementation for #221

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted @vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
  • I have updated the tests accordingly (if applicable).

@vercel
Copy link

vercel bot commented Jun 29, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Dec 30, 2022 at 5:23PM (UTC)

@kinalmehta kinalmehta linked an issue Aug 24, 2022 that may be closed by this pull request
@kinalmehta
Copy link
Collaborator Author

Results on classical gym environments can be checked here.
https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Classical-Gym-Environments--VmlldzoyNDQ3OTk5

We see a speed-up of about ~30% in the JAX version compared to Pytorch.

@kinalmehta
Copy link
Collaborator Author

Here is the benchmark report on atari environments
https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Atari-Environments--VmlldzoyNjkyNDY0

Important observations:

  • BeamRider performance is bad compared PyTorch version
  • Breakout performance almost matches PyTorch variant but is still a bit low comparatively
  • For Pong, the performance matches perfectly for 2 seeds, but the reward remains zero for one of the seeds.

Need to look into more detail about the differences between PyTorch and JAX implementations to fix the above mentioned issues.

@joaogui1
Copy link
Collaborator

How does it compare to Dopamine's version?

@kinalmehta
Copy link
Collaborator Author

How does it compare to Dopamine's version?
I haven't checked Dopamine yet. I will have a look and update here, though it might take some time.

@vwxyzjn
Copy link
Owner

vwxyzjn commented Sep 27, 2022

FYI dopamine has a benchmark, but its x-axis is not the environment steps... Any clue on how we can compare those results? @joaogui1
image

@kinalmehta
Copy link
Collaborator Author

Atari Fixed

After months of procrastination and debugging various aspects, I finally stumbled upon the cause of performance degradation.
The incorrect epsilon value caused this performance degradation. I missed this detail and used the default value $10^{-8}$ from optax.
However, the C51-PyTorch version uses ${0.01}/{batch\_size}$. Hoowever I couldn't find any motivation for using this value.

Reading up more on this led to the conclusion that this is a common issue even in NLP and CV as well.
More about this hyperparameter can be read here.

Benchmarking classical envs on CPU

I have updated the plots of classical gym environments (CartPole, Acrobot, MountainCar) by benchmarking on CPU. We see significant speed-up compared to pytorch version on CPU.

Comparison with dopamine

Based on the beamrider plot shared above, the below table summarizes the final score comparison

implementation score
dopamine 5000-7000
cleanrl-pytorch ~9500
cleanrl-jax-old ~2500
cleanrl-jax-fixed ~9500

Reports link

Conclusion

The updated plots are available on the above links itself.
The PR looks good to be mearged once the documentation is updated. Anything else I am missing here @vwxyzjn?

@vwxyzjn
Copy link
Owner

vwxyzjn commented Dec 29, 2022

The results look incredible. Great job @kinalmehta. Thanks for chasing down the cause for the issue. The code also look great to me. Feel free to start adding documentation. You should also move the experiments to the openrlbenchmark/cleanrl namespace.

@kinalmehta
Copy link
Collaborator Author

I've added the documentation, and now I believe this PR is ready for the final review.

Copy link
Owner

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Great work! Feel free to merge once they are addressed.

One quick note is that #292 is getting merged. Would you mind submitting a PR to #292, adding model eval and uploading trained models to huggingface (for just one random seed)?

docs/rl-algorithms/c51.md Outdated Show resolved Hide resolved
Comment on lines +228 to +238
Learning curves:
<div class="grid-container">
<img src="../c51/jax/BeamRiderNoFrameskip-v4.png">
<img src="../c51/jax/BeamRiderNoFrameskip-v4-time.png">

<img src="../c51/jax/BreakoutNoFrameskip-v4.png">
<img src="../c51/jax/BreakoutNoFrameskip-v4-time.png">

<img src="../c51/jax/PongNoFrameskip-v4.png">
<img src="../c51/jax/PongNoFrameskip-v4-time.png">
</div>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look great! A minor note for the future: we shouldn't feel obligated to export the wandb curves manually in the future anymore -- the rlops utility should give us the compare.png and compare-time.png that we can use directly to save manual labor :)

@kinalmehta kinalmehta merged commit 67b7f0d into vwxyzjn:master Dec 30, 2022
@kinalmehta kinalmehta deleted the jax-c51-contrib branch January 14, 2023 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JAX + C51
3 participants