[Kernel][Performance] Add FlashInfer cutedsl NVFP4 GEMM backend#42235
[Kernel][Performance] Add FlashInfer cutedsl NVFP4 GEMM backend#42235mmangkad wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new NVFP4 GEMM backend utilizing FlashInfer's CuteDSL, specifically targeting SM10x architectures. The changes include the implementation of the FlashInferCuteDslNvFp4LinearKernel, its registration within the kernel executor, and the addition of flashinfer-cutedsl as a valid environment variable option. Feedback highlights inconsistencies in the backend naming convention, recommending the use of "cutedsl" instead of "cute-dsl" across the codebase and tests for better alignment with existing backend identifiers.
There was a problem hiding this comment.
I already started this integration some time ago in #39933, but I don’t think there’s a clear heuristic for deciding when this backend should be selected.
Across the different shapes I tested, the best results were typically in the range 16 <= bs <= 32. Outside of that range, this backend is not consistently the best option and can actually be significantly slower in some cases.
One example is:
where speedup=1 means a different kernel of the existing ones was selected, but cuteDSL causes a regression
Could you clarify when these SM100 benchmarks were collected and which FlashInfer version was used? Based on my testing, |
|
@mmangkad Yeah, you have a point, this was with 0.6.8 and might have changed since then. I see the latest FI release is 0.6.11. Can you please benchmark those shapes in my plot to see how different it looks now? I recommend using |
|
@LopezCastroRoberto see below. I reran those shapes with
|
Thanks for the results, @mmangkad! Yeah, seems like it might have improved since last time I checked. Just to make sure, would you mind adding That way, we would have the full picture and it would be easier to define an heuristic, instead of just adding one more backend to the list. |
|
@LopezCastroRoberto TRTLLM is still strongest at the very smallest M values, especially M=1-4, but CuTeDSL already matches or beats it in many small-M cases and takes over by M=8+. The clearer result is that CuTeDSL is almost always better than the current CUTLASS default across these shapes. FlashInfer NVFP4 GEMM ResultsEach backend cell is Overall Winners
SM100 winners
SM103 winners
SM100N=7168, K=2048
N=4096, K=7168
N=18432, K=7168
N=7168, K=18432
SM103N=7168, K=2048
N=4096, K=7168
N=18432, K=7168
N=7168, K=18432
|
|
Thanks for the results, @mmangkad. Yeah, I think this makes sense. We should also update the FI version to the latest, i.e. |
|
cc: @mgoin |
@LopezCastroRoberto we are already at |
|
Nevermind, my bad. I accidentally checked my own fork instead of upstream. Waiting for @mgoin approval. |
3b67dbc to
d2c176d
Compare
d2c176d to
3700b17
Compare
|
Worth being aware of, BTW: flashinfer-ai/flashinfer#3295 |
|
Rebased after resolving conflicts caused by #39538 and aligning with its changes |
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
3700b17 to
c34ce76
Compare
|
@mmangkad -- following up on the FlashInfer autotuning issue I flagged earlier (flashinfer-ai/flashinfer#3295). The discussion has progressed and there's now a concrete fix, so wanted to share the conclusions since they directly affect this PR. Right now vLLM defaults to O2, which has Interestingly, seems like autotuning To fix this, flashinfer-ai/flashinfer#3396 adds a with flashinfer.autotune(skip_ops="fp4_gemm"):
...This brought warmup from 587s → 8s on DSV3.2-NVFP4 TP=4. I think we should track a follow-up to integrate cc: @mgoin |


Summary
Adds
flashinfer-cutedslfor dense NVFP4 GEMM and makes it the highest-priority CUDA backend when supported on SM10x. In serving benchmarks, cutedsl is fastest across concurrency 1-512 and improves tok/s/user by up to 27.07% over the tested FlashInfer backends.Performance Comparison
Setup:
nvidia/Llama-3.1-8B-Instruct-NVFP4Test Plan
CI, which now includes:
cute-dsl.flashinfer-cutedsl.