Feature/silu block quant fusion v1#32996
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a fused CUDA kernel for SiLU, multiplication, and block-wise FP8 quantization, along with corresponding benchmarks, tests, and integration into the torch.compile fusion passes. The new kernel shows significant performance improvements in the provided benchmarks.
My review has identified a couple of important issues:
- A critical issue in the
torch.compilefusion pass where the pattern for the new fused kernel is hardcoded for a singlegroup_size, which will prevent fusion for other supported sizes. - A high-severity issue in the CUDA kernel implementation regarding a hardcoded shared memory size, which makes the code brittle and prone to future bugs.
Addressing these points will improve the correctness and maintainability of the new feature. The rest of the changes, including the tests and benchmark code, look solid.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
|
Hello @ProExpertProg. I've created the kernel for SiluMul+BlockQuant fusion, and it's working fine(yet, not performant enough). I'm still having some issues with the fusion pass and pattern matching, which I'll be working on. I want to get some feedback on the kernel and how you think it can be made more optimized and efficient. @ElizaWszola, I used your #27883 PR as a good reference to get some understanding of the internal workings. Thanks for it. And, if you can also share some review on the kernel, what I missed, etc., it'll be really helpful. I see and fix the ones raised by the bots shortly. |
|
Hello @ProExpertProg, quick follow-up. When you have some time, let me know your thoughts on the kernel part. |
|
Hello @ProExpertProg, I've updated my kernel, and now it's performing better than the unfused implementation. Can you please comment on this and share your review? Thanks. |
|
Hello @ProExpertProg, |
|
Hello @ProExpertProg, I'm having some trouble with the fusion pattern match. Could you please provide some guidance on this? I have tried using the existing matchers for silu_mul and block_quant, as well as expressing silu_mul inline, but I still can’t get it to match the pattern for replacement with the fused kernel. Additionally, I noticed that the silu_and_mul_per_block_quant_kernel_large kernel performs well across all cases, as it parallelizes computations on a per-token and per-group basis. If you have any suggestions for further optimizations, I would greatly appreciate it. Thank you. |
|
Hi @Monishver11, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
|
Have you tried using the VLLM_PATTERN_MATCH_DEBUG env variable? You can set them at to the name of the node in the graph you expect to match in (node of the first return from the pattern). |
e18d654 to
a528b86
Compare
|
Hi @Monishver11, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
1 similar comment
|
Hi @Monishver11, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hello @ProExpertProg. Thanks for your last comment. I've now fixed the kernel fusion pattern match. Can you please review this PR? |
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
|
One more request actually: could you update docs/design/fusions.md to mention this kernel is now supported? And can you check that this fusion is enabled by default for applicable models? |
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
|
Documentation preview: https://vllm--32996.org.readthedocs.build/en/32996/ |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Hello @ProExpertProg. I've updated docs/design/fusions.md with this new kernel. And for the default enablement: yes, this fusion is automatically enabled at O1+ via |
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
|
Breaks on ROCm |
Thanks for the fix @gshtras. |
Purpose
CUDA kernel and pattern matching for Fused SiluMul+Groupwise FP8-Quantization. For #27847
Test Result
The experiments are done on NVIDIA GeForce RTX 4070 and CUDA Version: 13.0.
Test fused op:
Microbenchmark isolated op:
python benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py [------------------------------------------------------ silu-mul-block-quant ------------------------------------------------------] | unfused_fp8_impl | unfused_groupwise_fp8_impl | fused_groupwise_fp8_impl 1 threads: ------------------------------------------------------------------------------------------------------------------------- N 16 x D 1024 x DT torch.float16 x GS 64 | 278.0 | 321.6 | 133.4 N 16 x D 1024 x DT torch.float16 x GS 128 | 278.7 | 320.1 | 133.6 N 16 x D 1024 x DT torch.bfloat16 x GS 64 | 279.5 | 321.2 | 133.5 N 16 x D 1024 x DT torch.bfloat16 x GS 128 | 278.9 | 321.3 | 133.1 N 16 x D 2048 x DT torch.float16 x GS 64 | 277.5 | 325.1 | 133.0 N 16 x D 2048 x DT torch.float16 x GS 128 | 278.9 | 321.6 | 133.1 N 16 x D 2048 x DT torch.bfloat16 x GS 64 | 277.6 | 320.2 | 133.2 N 16 x D 2048 x DT torch.bfloat16 x GS 128 | 278.7 | 320.4 | 133.8 N 16 x D 4096 x DT torch.float16 x GS 64 | 278.5 | 320.7 | 133.0 N 16 x D 4096 x DT torch.float16 x GS 128 | 279.6 | 321.3 | 133.4 N 16 x D 4096 x DT torch.bfloat16 x GS 64 | 278.3 | 321.4 | 132.5 N 16 x D 4096 x DT torch.bfloat16 x GS 128 | 277.9 | 322.5 | 132.3 N 16 x D 5120 x DT torch.float16 x GS 64 | 277.4 | 319.6 | 132.5 N 16 x D 5120 x DT torch.float16 x GS 128 | 278.1 | 320.2 | 132.5 N 16 x D 5120 x DT torch.bfloat16 x GS 64 | 276.9 | 319.2 | 132.7 N 16 x D 5120 x DT torch.bfloat16 x GS 128 | 277.1 | 319.5 | 132.7 N 16 x D 14336 x DT torch.float16 x GS 64 | 277.6 | 319.1 | 132.2 N 16 x D 14336 x DT torch.float16 x GS 128 | 277.5 | 319.5 | 132.5 N 16 x D 14336 x DT torch.bfloat16 x GS 64 | 278.0 | 321.8 | 131.9 N 16 x D 14336 x DT torch.bfloat16 x GS 128 | 277.0 | 321.3 | 132.1 N 128 x D 1024 x DT torch.bfloat16 x GS 64 | 276.8 | 318.4 | 132.1 N 128 x D 1024 x DT torch.bfloat16 x GS 128 | 283.5 | 317.0 | 131.9 N 128 x D 2048 x DT torch.float16 x GS 64 | 275.4 | 316.7 | 131.5 N 128 x D 2048 x DT torch.float16 x GS 128 | 274.8 | 322.1 | 131.1 N 128 x D 2048 x DT torch.bfloat16 x GS 64 | 274.7 | 316.2 | 131.2 N 128 x D 2048 x DT torch.bfloat16 x GS 128 | 273.0 | 317.3 | 130.8 N 128 x D 4096 x DT torch.float16 x GS 64 | 273.6 | 316.1 | 130.9 N 128 x D 4096 x DT torch.float16 x GS 128 | 274.6 | 315.7 | 131.0 N 128 x D 4096 x DT torch.bfloat16 x GS 64 | 273.7 | 315.0 | 130.4 N 128 x D 4096 x DT torch.bfloat16 x GS 128 | 272.4 | 314.9 | 130.2 N 128 x D 5120 x DT torch.float16 x GS 64 | 273.2 | 315.4 | 130.7 N 128 x D 5120 x DT torch.float16 x GS 128 | 273.4 | 315.5 | 130.4 N 128 x D 5120 x DT torch.bfloat16 x GS 64 | 272.9 | 314.2 | 130.9 N 128 x D 5120 x DT torch.bfloat16 x GS 128 | 275.3 | 315.3 | 130.1 N 128 x D 14336 x DT torch.float16 x GS 64 | 274.1 | 316.3 | 130.7 N 128 x D 14336 x DT torch.float16 x GS 128 | 274.2 | 318.3 | 130.6 N 128 x D 14336 x DT torch.bfloat16 x GS 64 | 273.9 | 316.3 | 130.7 N 128 x D 14336 x DT torch.bfloat16 x GS 128 | 274.6 | 315.8 | 130.5 N 512 x D 1024 x DT torch.float16 x GS 64 | 271.9 | 313.7 | 130.1 N 512 x D 1024 x DT torch.float16 x GS 128 | 271.0 | 313.2 | 130.5 N 512 x D 1024 x DT torch.bfloat16 x GS 64 | 270.6 | 312.2 | 129.2 N 512 x D 1024 x DT torch.bfloat16 x GS 128 | 271.0 | 313.5 | 129.9 N 512 x D 2048 x DT torch.float16 x GS 64 | 270.4 | 314.2 | 130.2 N 512 x D 2048 x DT torch.float16 x GS 128 | 271.2 | 313.4 | 129.9 N 512 x D 2048 x DT torch.bfloat16 x GS 64 | 271.7 | 312.3 | 130.2 N 512 x D 2048 x DT torch.bfloat16 x GS 128 | 271.0 | 313.7 | 129.7 N 512 x D 4096 x DT torch.float16 x GS 64 | 272.0 | 313.9 | 129.9 N 512 x D 4096 x DT torch.float16 x GS 128 | 272.1 | 315.3 | 130.8 N 512 x D 4096 x DT torch.bfloat16 x GS 64 | 270 N 512 x D 5120 x DT torch.bfloat16 x GS 128 | 271.3 | 313.5 | 129.4 N 512 x D 14336 x DT torch.float16 x GS 64 | 274.0 | 316.3 | 130.4 N 512 x D 14336 x DT torch.float16 x GS 128 | 272.6 | 313.3 | 130.4 N 512 x D 14336 x DT torch.bfloat16 x GS 64 | 273.0 | 314.7 | 130.1 N 512 x D 14336 x DT torch.bfloat16 x GS 128 | 273.6 | 315.2 | 129.5 N 2048 x D 1024 x DT torch.float16 x GS 64 | 270.2 | 313.2 | 130.0 N 2048 x D 1024 x DT torch.float16 x GS 128 | 271.1 | 313.3 | 129.7 N 2048 x D 1024 x DT torch.bfloat16 x GS 64 | 269.5 | 312.0 | 129.5 N 2048 x D 1024 x DT torch.bfloat16 x GS 128 | 269.9 | 340.8 | 128.9 N 2048 x D 2048 x DT torch.float16 x GS 64 | 271.3 | 313.1 | 129.0 N 2048 x D 2048 x DT torch.float16 x GS 128 | 270.7 | 312.4 | 128.9 N 2048 x D 2048 x DT torch.bfloat16 x GS 64 | 271.2 | 312.1 | 129.2 N 2048 x D 2048 x DT torch.bfloat16 x GS 128 | 270.7 | 312.6 | 128.3 N 2048 x D 4096 x DT torch.float16 x GS 64 | 270.8 | 313.7 | 140.6 N 2048 x D 4096 x DT torch.float16 x GS 128 | 272.2 | 313.0 | 140.2 N 2048 x D 4096 x DT torch.bfloat16 x GS 64 | 271.3 | 313.5 | 140.9 N 2048 x D 4096 x DT torch.bfloat16 x GS 128 | 272.9 | 313.9 | 140.1 N 2048 x D 5120 x DT torch.float16 x GS 64 | 293.2 | 333.5 | 180.2 N 2048 x D 5120 x DT torch.float16 x GS 128 | 294.0 | 313.0 | 178.8 N 2048 x D 5120 x DT torch.bfloat16 x GS 64 | 294.7 | 312.8 | 181.0 N 2048 x D 5120 x DT torch.bfloat16 x GS 128 | 294.8 | 315.1 | 178.4 N 2048 x D 14336 x DT torch.float16 x GS 64 | 997.6 | 971.5 | 503.7 N 2048 x D 14336 x DT torch.float16 x GS 128 | 997.3 | 847.3 | 499.4 N 2048 x D 14336 x DT torch.bfloat16 x GS 64 | 997.4 | 854.7 | 503.7 N 2048 x D 14336 x DT torch.bfloat16 x GS 128 | 997.2 | 846.5 | 499.2 Times are in microseconds (us).Compilation pattern matching of the fused op:
lm_eval & Benchmarks
Model: Qwen2.5-0.5B-Instruct (FP8_BLOCK quantized via llm-compressor)
GPU: NVIDIA RTX 4070 (12GB)
lm_eval (gsm8k, 5-shot, 250 samples)
fusion disabled:
fusion enabled:
+silu_and_muldisabled, fusion enabled:Results within error bars — no accuracy degradation.
Serving Benchmark (sonnet, 640 prompts, 128 RPS)
default (no
+silu_and_mul, equivalent to main branch for FP8_BLOCK models):+silu_and_mulenabled, fusion disabled:+silu_and_mulenabled, fusion enabled:+silu_and_muldisabled, fusion enabled: