Skip to content

[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations#10867

Merged
vllm-bot merged 48 commits intovllm-project:mainfrom
neuralmagic:sage/silu-mul-quant
May 1, 2025
Merged

[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations#10867
vllm-bot merged 48 commits intovllm-project:mainfrom
neuralmagic:sage/silu-mul-quant

Conversation

@SageMoore
Copy link
Copy Markdown
Contributor

@SageMoore SageMoore commented Dec 3, 2024

Credit to @LucasWilkinson for the kernel.

This pass currently only supports static per-tensor quantization. Other quantization schemes will be included in a subsequent PRs.

I've attached some QPS sweeps that were run using neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 on an H100. Generally speaking, this pass improves the TPOT of FP8 Llama by 2-3%. There are similar improvements with TTFT with the exception of 20QPS which is much (~2x) faster.

fused_results
torch_compile_results

@github-actions
Copy link
Copy Markdown

github-actions bot commented Dec 3, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Dec 3, 2024
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
@SageMoore SageMoore force-pushed the sage/silu-mul-quant branch from 27be0bd to e2fda7f Compare December 6, 2024 20:34
@SageMoore SageMoore marked this pull request as ready for review December 6, 2024 20:36
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Comment thread vllm/compilation/activation_quant_fusion.py Outdated
Comment thread tests/kernels/test_fused_quant_activation.py
Comment thread csrc/torch_bindings.cpp
Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Focused on csrc/quantization/activation_kernels.cu. spotted a couple of potential int32_t overflows

Comment thread csrc/core/math.hpp Outdated
Comment thread csrc/quantization/activation_kernels.cu Outdated
Comment thread csrc/quantization/activation_kernels.cu
Comment thread csrc/quantization/activation_kernels.cu Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 21, 2025
@mergify mergify bot removed the needs-rebase label Jan 27, 2025
Signed-off-by: Sage Moore <sage@neuralmagic.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 8, 2025
…silu-mul-quant

Signed-off-by: Sage Moore <sage@neuralmagic.com>
@mergify mergify bot removed the needs-rebase label Apr 24, 2025
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
…silu-mul-quant

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Comment on lines +58 to +60
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still an issue on 2.7.0? (@zou3519)

Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been fixed in pytorch/pytorch#139321 (@eellison), and yes that's in 2.7.0

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! In that case, @SageMoore could you clean this up before landing?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore I think the comment got left in

Signed-off-by: Sage Moore <sage@neuralmagic.com>
…silu-mul-quant

Signed-off-by: Sage Moore <sage@neuralmagic.com>
…silu-mul-quant

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
@SageMoore
Copy link
Copy Markdown
Contributor Author

Here are lm_eval results for neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 with fusion running.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.736|±  |0.0197|
|     |       |strict-match    |     5|exact_match|↑  |0.732|±  |0.0198|

Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still LGTM, and thanks for cleaning up that last piece!

@vllm-bot vllm-bot merged commit 460a2b1 into vllm-project:main May 1, 2025
72 of 75 checks passed
@ProExpertProg
Copy link
Copy Markdown
Collaborator

A follow-up question: are we planning on doing the dynamic pathway?

radeksm pushed a commit to radeksm/vllm that referenced this pull request May 2, 2025
…subsequent scaled_fp8_quant operations (vllm-project#10867)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…subsequent scaled_fp8_quant operations (vllm-project#10867)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…subsequent scaled_fp8_quant operations (vllm-project#10867)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
@SageMoore SageMoore deleted the sage/silu-mul-quant branch June 18, 2025 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build frontend ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants