Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiling a jitted function #774

Closed
adamhaber opened this issue May 26, 2019 · 5 comments · Fixed by #1148
Closed

Profiling a jitted function #774

adamhaber opened this issue May 26, 2019 · 5 comments · Fixed by #1148
Labels
enhancement New feature or request

Comments

@adamhaber
Copy link

I'm trying to optimize the performance a jitted function (which, in turn, calls other, smaller jitted functions), and wanted to know if there are any profiling tools suitable for this. I could line_profile the unjitted function, but from my very limited understanding of how XLA works, the jitted function might have entirely different computational bottlenecks/overhead/behavior. Is this correct?

Any references/suggestions on how to tackle this would be much appreciated. :-)

@jekbradbury
Copy link
Contributor

On CPU or GPU? On GPU, using nvprof or another timeline tool can be pretty useful as a sanity check for what XLA is doing with your code (a simpler sanity check is just making sure your nvidia-smi utilization is fairly high).

@adamhaber
Copy link
Author

Both are interesting. At the moment, I'm doing most of the development on the CPU, and things actually run faster there compared to the GPU. One of the reasons I'm interested in profiling is understanding these execution time differences (and of course, finding operations that can be further optimised).

I've used nvidia-smi to check for utilisation, and it looked OK (fluctuating around 60-70%). Regarding nvprof - can you explain how to use it to profile a jitted jax function on the GPU?

@hawkinsp
Copy link
Collaborator

You should just be able to run nvprof python myscript.py to get some profiling/tracing data from nvprof.

XLA itself actually has some reasonable profiling tools that I suspect are what you want here in combination with what nvprof shows you, but unfortunately they are currently not working from JAX. (The profiling logic is missing from "async" variants of the APIs, which JAX uses.) It is on my TODO list to fix this.

@adamhaber
Copy link
Author

Thanks! I'll try that and will post here if there are any useful insights.

As a general (very basic) question, just to be sure - it's not unreasonable that some JAX programs are faster on the CPU compared to GPU, right (even after ignoring memory-to-GPU overhead)? Specifically, my program has a for loop that I don't think can be paralleled; I suspect this is what makes the GPU version slower, and this is what I want to profile.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Aug 9, 2019

It's now possible (with the soon-to-be-released jaxlib 0.1.23) to use XLA's profiling features, which break down performance by individual HLO ops. This may be helpful in addition to, say, nvprof.

https://jax.readthedocs.io/en/latest/profiling.html

Hopefully these are enough for the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants