-
Notifications
You must be signed in to change notification settings - Fork 650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax c51 contrib #224
Jax c51 contrib #224
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
Results on classical gym environments can be checked here. We see a speed-up of about ~30% in the JAX version compared to Pytorch. |
Here is the benchmark report on atari environments Important observations:
Need to look into more detail about the differences between PyTorch and JAX implementations to fix the above mentioned issues. |
How does it compare to Dopamine's version? |
|
Atari FixedAfter months of procrastination and debugging various aspects, I finally stumbled upon the cause of performance degradation. Reading up more on this led to the conclusion that this is a common issue even in NLP and CV as well. Benchmarking classical envs on CPUI have updated the plots of classical gym environments (CartPole, Acrobot, MountainCar) by benchmarking on CPU. We see significant speed-up compared to pytorch version on CPU. Comparison with dopamineBased on the beamrider plot shared above, the below table summarizes the final score comparison
Reports link
ConclusionThe updated plots are available on the above links itself. |
The results look incredible. Great job @kinalmehta. Thanks for chasing down the cause for the issue. The code also look great to me. Feel free to start adding documentation. You should also move the experiments to the |
I've added the documentation, and now I believe this PR is ready for the final review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Learning curves: | ||
<div class="grid-container"> | ||
<img src="../c51/jax/BeamRiderNoFrameskip-v4.png"> | ||
<img src="../c51/jax/BeamRiderNoFrameskip-v4-time.png"> | ||
|
||
<img src="../c51/jax/BreakoutNoFrameskip-v4.png"> | ||
<img src="../c51/jax/BreakoutNoFrameskip-v4-time.png"> | ||
|
||
<img src="../c51/jax/PongNoFrameskip-v4.png"> | ||
<img src="../c51/jax/PongNoFrameskip-v4-time.png"> | ||
</div> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These look great! A minor note for the future: we shouldn't feel obligated to export the wandb curves manually in the future anymore -- the rlops utility should give us the compare.png
and compare-time.png
that we can use directly to save manual labor :)
Description
JAX implementation for C51
Implementation for #221
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).