Skip to content

Proof-of-concept: Faster PyTorch #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

DavidSlayback
Copy link

Description

While I share a lot of the excitement about the speed boosts that JAX provides, I've found that additional performance can be extracted from PyTorch. This commit is mostly a proof-of-concept, using PPO as a test cast, to show how various levels of optimization can be applied to improve PyTorch speed. There are a few more things I can add, and I still need to do runs with larger networks (where the speed improvement is greater), but this is a start to discuss how much the codebase can be changed to improve speed without losing readability.

Types of changes

  • Bug fix
  • New feature
  • New algorithm
  • Documentation

Checklist:

  • I've read the CONTRIBUTION guide (required).
  • I have ensured pre-commit run --all-files passes (required).
  • I have updated the documentation and previewed the changes via mkdocs serve.
  • I have updated the tests accordingly (if applicable).

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.

  • I have contacted vwxyzjn to obtain access to the openrlbenchmark W&B team (required).
  • I have tracked applicable experiments in openrlbenchmark/cleanrl with --capture-video flag toggled on (required).
  • I have added additional documentation and previewed the changes via mkdocs serve.
    • I have explained note-worthy implementation details.
    • I have explained the logged metrics.
    • I have added links to the original paper and related papers (if applicable).
    • I have added links to the PR related to the algorithm.
    • I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
    • I have added the learning curves (in PNG format with width=500 and height=300).
    • I have added links to the tracked experiments.
    • I have updated the overview sections at the docs and the repo
  • I have updated the tests accordingly (if applicable).

Below is the results of various levels of optimization applied to CartPole. A larger set of environments (using the benchmark utility on the same hardware for all runs) can be found in this wandb report

image

L0 is the baseline
L1 uses TorchScript on the Sequential modules and optimizer.zero_grad(True)
L2 additionally uses TorchScript for the advantage and normalization functions, as well as in-place activations
L3 additionally uses TorchScript for the probability computations and action sampling
L4 (not shown, need new runs) uses TorchScript for the full PPO loss function

I'm seeing large benefits from L3 in particular, and it's something I could potentially apply to any of the PyTorch algorithms, but it's also the first level of optimization where the readability really changes. Interested in starting a discussion over what is/is not worth it!

@vercel
Copy link

vercel bot commented Oct 30, 2022

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Oct 31, 2022 at 2:08PM (UTC)

@vwxyzjn
Copy link
Owner

vwxyzjn commented Oct 31, 2022

Thank you @DavidSlayback for these great prototyping efforts. It's really nice that you JITed the action sampling process by implementing our own sampling function! Some thoughts:

  1. If the speed-up is considerable (e.g., 30% overall training time reduction), then it's worth including the implementation.
    • For the sake of maintenance, maybe we can add the result of this PR as a variant like ppo_atari_jit.py. I am a little hesitant to use the JITed implementations in place of our regular implementations because the regular ones are still easier to debug.
    • Once you feel comfortable with the prototypes, would you mind giving the best of them a try in Atari games? I have plenty of data points to compare (see report here).
  2. Would it be possible to jit the linear rate annealing? Previously with JAX JITing the linear rate annealing improved the overall speed in MuJoCo by 2x (PPO + JAX + EnvPool + MuJoCo #217 (comment)).
  3. New techs are coming, such as https://github.com/pytorch/torchdynamo and https://github.com/metaopt/torchopt. How are they going to affect the optimization techniques?

@DavidSlayback
Copy link
Author

No problem!

  1. Yeah, I think it makes sense to do them as separate files. Once I determine the sweet spot of optimizations (performance without becoming unreadable), I was planning on doing a comparable JIT version for each algorithm (particularly interested in recurrent layers). I can definitely try them in some Atari games, I just needed something that would run quickly and demonstrate that the episodic performance matches.
  2. Unfortunately, I don't think that's doable with the base torch.optim optimizers. Like the torch.distributions classes, they don't play nicely with JIT...I can use a built-in LRScheduler, but they seem to do what you already do under the hood. I could JIT the annealing function, but I still have to set the underlying learning rate.
  3. I hadn't seen torchopt before, I'm somewhat familiar with functorch and torchdynamo but was waiting for them to become more mature. It looks like functorch and torchdynamo have already moved into PyTorch, so I'll see if I can use those to wrap the optimization

So next steps:

  1. Test torchdynamo/torchopt/functorch techniques, settle on best "return-on-optimization"
  2. Apply the chosen techniques to ppo_atari and ppo_atari_lstm

@vwxyzjn
Copy link
Owner

vwxyzjn commented Jan 5, 2023

Hey @DavidSlayback, thanks for doing the investigation. I did a quick prototype on JAX to see the speed difference .

Not using a GPU

image

https://wandb.ai/costa-huang/cleanRL/reports/Pytorch-JIT-vs-JAX-JIT--VmlldzozMjY4MjQ2

Jax source code https://wandb.ai/costa-huang/cleanRL/runs/j5k5vdl7/code?workspace=user-costa-huang

The SPS of ppo_jax.py can be further improved by removing the compilation time via #328. The real SPS is about 15k

Using a GPU

image

The real SPS is about 8k

@vwxyzjn
Copy link
Owner

vwxyzjn commented Mar 26, 2023

Hey @DavidSlayback, thanks again for this PR. I was thinking a more suited place for these experiments is probably a separate repository, and we are happy to refer it in our docs to more advanced users :)

Closing this PR now.

@vwxyzjn vwxyzjn closed this Mar 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants