[Kernel] Added flashinfer fp8 per-tensor gemms#22895
[Kernel] Added flashinfer fp8 per-tensor gemms#22895simon-mo merged 13 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for FlashInfer's FP8 GEMM kernels, which is expected to improve performance, particularly for large batch sizes. The changes primarily involve refactoring the GEMM dispatch logic to accommodate a new 'flashinfer' backend and adding an optimization to pre-calculate combined scales. While the implementation is largely sound, I've identified a critical issue in the new FlashInfer wrapper where the output tensor is not reshaped, potentially causing shape mismatches for inputs with more than two dimensions.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Julien Lin <jullin@nvidia.com>
4738c59 to
26ddfd9
Compare
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
|
@nvjullin The Blackwell Test failures look clearly related https://buildkite.com/vllm/ci/builds/27727/steps/canvas?jid=0198c767-0112-49f4-9f26-c9fef601374c#0198c767-0112-49f4-9f26-c9fef601374c/98-3479 |
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
|
@mgoin the remaining errors are all something about huggingface gateway timeout |
ProExpertProg
left a comment
There was a problem hiding this comment.
A few minor notes.
This might be a bit too urgent but in general we should really improve the fp8 scaled_mm dispatching. I started a draft pr #19434 but never got around to it.
|
@nvjullin Please address the comments and rebase this PR. Thanks! |
Signed-off-by: Julien Lin <jullin@nvidia.com>
09089d8 to
9a847f7
Compare
Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Julien Lin <jullin@nvidia.com>
No I think it's out of the scope of this PR. But if you look at dispatching for int8 or Marlin/Machete, that's closer to something we want to do in general when we want to dispatch between multiple possible implementations. |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Julien Lin <jullin@nvidia.com>
mgoin
left a comment
There was a problem hiding this comment.
LGTM to get in, thanks. We should follow up with using an Enum instead of raw strings
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: tc-mb <caitianchi@modelbest.cn>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
Purpose
Added fp8 gemms from flashinfer.
The added gemms have better or same perf as the original gemms so we use it as the default.
For gemm sizes with small M, the added gemms are marginally faster.
For gemm sizes with large M, the added gemms are much faster.
These are the results for llama3 ISL=OSL=1024 concurrency=128 max_num_batched_tokens=8192 TP1.
As expected, TPOT is roughly the same but TTFT improved by ~13%.
Requires flashinfer autotuning, so depends on #22346(merged).Funcionality depends on flashinfer PR flashinfer-ai/flashinfer#1479(merged).Perf numbers depends on flashinfer PR flashinfer-ai/flashinfer#1491(merged).Requires next flashinfer release including aforementeiond PRs and vllm updating flashinfer version.(updated flashinfer)old
new
lm_eval shows
Test Plan
Tests to be added.
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.