-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix 20x slowdown of FP6 kernel due to device properties query #1092
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1092
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c1732e6 with merge base 0b71b8d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
awesome! possible to cherry-pick this into the next release? cc @drisspg |
The numbers look good already. But I wonder if we can cache this result? (e.g. use static variable or global variable - I know, frowned upon by C++ devs 🤣). Will using PyTorch API be better? (i.e. (Again, the current state looks good already. Just something I wonder) |
@HDCharles do you have a good sense of the state of our benchmarking story? Right now perf regressions are only ever caught manually for llama in your spot checks. Are we in a place we can start failing if kernels like the fp6 one become 20x slower? |
@gau-nernst just FYI, FlashAttention uses |
Replace `cudaGetDeviceProperties` with `cudaDeviceGetAttribute`
…h#1092) Replace `cudaGetDeviceProperties` with `cudaDeviceGetAttribute`
Fix 20x slowdown of FP6 kernel due to device properties query (#1092) Replace `cudaGetDeviceProperties` with `cudaDeviceGetAttribute` Co-authored-by: Tobias van der Werff <[email protected]>
…at/ folder (pytorch#1076) * [Hackability Refactor] Move known_model_params under torchchat (pytorch#1073) * [Hackability Refactor] Migrate CLI call sites to explicitly go through torchchat.py (pytorch#1075) * [Hackability Refactor] Move model.py underneath torchchat/ (pytorch#1077) * Move model.py * Clear out init to avoid package circular import * [Hackability Refactor] Move select top level docs into folders within torchchat (pytorch#1080) * [Hackability Refactor] Move the top level util folder into torchchat/utils (pytorch#1079) * [Hackability Refactor] Move the top level util file into torchchat/utils/ * Cleared out init to avoid packing * [Hackability Refactor] Collapse gguf_util into gguf_loader (pytorch#1078) * [Hackability Refactor] Collapse gguf_util into gguf_loader * Update bad import * [Hackability Refactor] Move model_config into torchchat/model_config (pytorch#1082) * [Hackability Refactor] Move cli related files under torchchat/cli (pytorch#1083) * [Hackability Refactor] Move build/util into torchchat/utils (pytorch#1084) * [Hackability Refactor] Easy Moves: eval, gguf_loader, quantize, model_dist (pytorch#1085) * [Hackability Refactor] Easy Cheap Moves: eval, gguf_loader, quantize, model_dist * Update eval.py call sites that slipped through the initial pass * [Hackability Refactor] Update missed direct file calls to use torchchat.py (pytorch#1088) * [Hackability Refactor] Move export and generate under torchchat/ (pytorch#1089) * [Hackability Refactor] Move scripts under torchchat/utils (pytorch#1090) * [Hackability Refactor] Move scripts under torchchat/utils * Fix install script for AOTI * Update referenced path in build_android * Adding missing utils path * Add another layer for torchchat * Move the source command depending on if TC root is defined * [Hackability Refactor] Move installation related files into install/ (pytorch#1081) * [Hackability Refactor] Move installation related files into install/ * Fix install req path * Test fix with install path for bash * Debug messages * Remove changes to install in et_python_libs * Remove debug echo * Fix pin path for et * [Hackability Refactor] Restricted Lint (pytorch#1091) * [Hackability Refactor] Removing __main__ from export/generate/eval (pytorch#1092)
While working on the FP6 kernel, I noticed a large slowdown of FP6 compared to FP16 (e.g. 1125 ms FP6 vs. 90 ms FP16 for m=1, k=8192, n=8192, as shown in the table below). This is not supposed to happen, since FP6 is expected to be faster than FP16 for at least most of the smaller matrix multiply shapes.
The slowdown occurs ever since the FP6 for SM75 PR (#942). This change was tested for SM75 (T4 GPU), but not for an A100 (SM80). Now that I've tried it on an A100, I noticed the slowdown. After digging for a bit, the problem seems to be the code that checks for the GPU compute capability to determine the launch parameters:
ao/torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Lines 29 to 43 in 0b71b8d
When I initially wrote this, I did a quick web search to check whether this code would incur any significant overhead from querying the device properties. The answer came back negative, but it turns out this was wrong: calling
cudaGetDeviceProperties()
can incur significant runtime overhead, as discussed in this NVIDIA forum post. The problem is that this function call returns all device properties, even those that require expensive PCIe reads. Instead, it is better to directly query a specific device property, usingcudaDeviceGetAttribute()
, which is what the current PR does.Benchmark results
Results of
benchmarks/benchmark_fp6.py
Tested on:
Before:
After (this PR):