-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Profiling a jitted function #774
Comments
On CPU or GPU? On GPU, using nvprof or another timeline tool can be pretty useful as a sanity check for what XLA is doing with your code (a simpler sanity check is just making sure your nvidia-smi utilization is fairly high). |
Both are interesting. At the moment, I'm doing most of the development on the CPU, and things actually run faster there compared to the GPU. One of the reasons I'm interested in profiling is understanding these execution time differences (and of course, finding operations that can be further optimised). I've used nvidia-smi to check for utilisation, and it looked OK (fluctuating around 60-70%). Regarding nvprof - can you explain how to use it to profile a jitted jax function on the GPU? |
You should just be able to run XLA itself actually has some reasonable profiling tools that I suspect are what you want here in combination with what nvprof shows you, but unfortunately they are currently not working from JAX. (The profiling logic is missing from "async" variants of the APIs, which JAX uses.) It is on my TODO list to fix this. |
Thanks! I'll try that and will post here if there are any useful insights. As a general (very basic) question, just to be sure - it's not unreasonable that some JAX programs are faster on the CPU compared to GPU, right (even after ignoring memory-to-GPU overhead)? Specifically, my program has a for loop that I don't think can be paralleled; I suspect this is what makes the GPU version slower, and this is what I want to profile. |
It's now possible (with the soon-to-be-released jaxlib 0.1.23) to use XLA's profiling features, which break down performance by individual HLO ops. This may be helpful in addition to, say, https://jax.readthedocs.io/en/latest/profiling.html Hopefully these are enough for the moment. |
I'm trying to optimize the performance a jitted function (which, in turn, calls other, smaller jitted functions), and wanted to know if there are any profiling tools suitable for this. I could line_profile the unjitted function, but from my very limited understanding of how XLA works, the jitted function might have entirely different computational bottlenecks/overhead/behavior. Is this correct?
Any references/suggestions on how to tackle this would be much appreciated. :-)
The text was updated successfully, but these errors were encountered: