-
Notifications
You must be signed in to change notification settings - Fork 743
Proof-of-concept: Faster PyTorch #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof-of-concept: Faster PyTorch #306
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
Thank you @DavidSlayback for these great prototyping efforts. It's really nice that you JITed the action sampling process by implementing our own sampling function! Some thoughts:
|
No problem!
So next steps:
|
Hey @DavidSlayback, thanks for doing the investigation. I did a quick prototype on JAX to see the speed difference . Not using a GPUhttps://wandb.ai/costa-huang/cleanRL/reports/Pytorch-JIT-vs-JAX-JIT--VmlldzozMjY4MjQ2 Jax source code https://wandb.ai/costa-huang/cleanRL/runs/j5k5vdl7/code?workspace=user-costa-huang The SPS of Using a GPUThe real SPS is about 8k |
Hey @DavidSlayback, thanks again for this PR. I was thinking a more suited place for these experiments is probably a separate repository, and we are happy to refer it in our docs to more advanced users :) Closing this PR now. |
Description
While I share a lot of the excitement about the speed boosts that JAX provides, I've found that additional performance can be extracted from PyTorch. This commit is mostly a proof-of-concept, using PPO as a test cast, to show how various levels of optimization can be applied to improve PyTorch speed. There are a few more things I can add, and I still need to do runs with larger networks (where the speed improvement is greater), but this is a start to discuss how much the codebase can be changed to improve speed without losing readability.
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).Below is the results of various levels of optimization applied to CartPole. A larger set of environments (using the benchmark utility on the same hardware for all runs) can be found in this wandb report
L0 is the baseline
L1 uses TorchScript on the Sequential modules and optimizer.zero_grad(True)
L2 additionally uses TorchScript for the advantage and normalization functions, as well as in-place activations
L3 additionally uses TorchScript for the probability computations and action sampling
L4 (not shown, need new runs) uses TorchScript for the full PPO loss function
I'm seeing large benefits from L3 in particular, and it's something I could potentially apply to any of the PyTorch algorithms, but it's also the first level of optimization where the readability really changes. Interested in starting a discussion over what is/is not worth it!