-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[BugFix][AMD][Quantization] Fix torch.compile issue where wvSplitKQ not being called when it should when using quantized FP8 model #22281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly refactors rocm_per_tensor_w8a8_scaled_mm to register a custom PyTorch operation. This ensures that the data-dependent control flow for dispatching to the wvSplitKQ kernel is correctly handled by torch.compile, which was the goal of this PR. The implementation is sound. However, I've identified a pre-existing critical issue where the bias term is dropped when the wvSplitKQ path is taken. I've included a suggested fix to ensure the bias is applied correctly in all cases.
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
But in the PR description, seems without the fix and with the fix, the output is the same?
Without the fix, and running with eager mode enforced: |
I was also confused by this 🙂 |
Yes, the output is the same, we just want the function called on ROCm when it is supposed to be called. |
Yes, the output is the same. We want the function to be called, which was happening before, but is not happening as it should be anymore, so this fixes the issue. |
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]>
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]>
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]>
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]>
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Ekagra Ranjan <[email protected]>
…ot being called when it should when using quantized FP8 model (vllm-project#22281) Signed-off-by: Randall Smith <[email protected]>
This PR fixes an issue where
wvSplitKQis not being called when it should be when using a quantized FP8 model. This is because during compilation, this code path is not being used so does not get called during model execution even though it should be (e.g. batch size = 1).I tested using Llama-3.1-8B-Instruct-FP8-KV.
Before this, this kernel was not being called at all when eager mode was not enforced.
Without the fix, and running with eager mode enforced:
After applying fix and using profiler without eager mode enforced: