[Kernel][Quantization][MoE] add marlin kernel support for turing (sm75)#29901
[Kernel][Quantization][MoE] add marlin kernel support for turing (sm75)#29901vllm-bot merged 15 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for the Turing architecture (sm75) to the Marlin kernels, including both dense and MoE variants. The changes involve adding architecture-specific compilation paths in CMake, providing synchronous implementations for cp_async on older architectures, and using m16n8k8 MMA instructions to emulate m16n8k16. The changes look mostly correct and well-structured. However, I've found a few critical issues: a likely debugging leftover in a preprocessor directive that would cause performance regressions on newer GPUs, and the removal of static_asserts that could hide potential shared memory corruption bugs. There is also a minor correctness issue in a CMake file. Please address these points.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
@mgoin - could you take a loot at this |
mgoin
left a comment
There was a problem hiding this comment.
This looks really solid. It seems the added complexity isn't much, just the emulation and fp16_accum. Am I correct that it supports all weight types?
It support all weight types except MXFP4, which requires BF16 activation but Turing doesn't support BF16. But for most practical weights, the value range of E8M0 scales should be within E5M0. So we can also make it support FP16 and do some check when loading weight. |
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
|
@mgoin I have added MXFP4 x FP16 support (and added necessray check). If you think this support is inappropriate, I can revert it. |
|
Hi @jinzhen-lin, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
|
@jinzhen-lin Personally I think supporting MXFP4 x FP16 is too confusing, especially since MXFP4 is still hardcoded for GPT-OSS at the moment with BF16 weights for the other layers. If you could remove it I would appreciate it. It is impressive you were able to support all the other formats though! |
|
Could you show a benchmark comparing the original GPTQ to this Marlin gemm on turing? I'm curious if there is a large speedup. Also does this potentially mean we can remove |
Added. The improvement in small to medium batch sizes is significant. However, since the original gptq kernel use dequant + cublas for m > 8, the marlin is slower than it when batchsize is large. |
This reverts commit dada848. Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Reverted. Thank you for your suggestions! |
There was a problem hiding this comment.
Don't we need to update other places as well? Such as vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py, vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py, vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py, maybe some ct moe methods I'm not sure
|
I'm going to merge this PR for now since the tests looks good (failures are known), so please cover the capability updates in a follow up PR. Thanks! |
…5) (vllm-project#29901) Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
|
Thank you for your contribution, but it still fails to run cpatonn/Qwen3-Next-80B-A3B-Thinking-AWQ-4bit. Before merging this PR, the error was that the Marlin kernel was missing, but now a new error has emerged. Device: Tesla T10 x4 logs: |
|
@mokieli Seems that some modules are still running with bf16. Try |
…5) (vllm-project#29901) Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
…5) (vllm-project#29901) Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Hi @jinzhen-lin, any updates? I tested Qwen3-Coder-Next-FP8 and it showed me the same error message 'torch._dynamo.exc.SkipFrame: BF16 is not supported' |
…5) (vllm-project#29901) Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This PR add marlin kernel support for turing (sm75) (e.g. 2080ti / T4).
cp.asyncinstruction, but we can still use synchronous instructions to read from global memory and write to shared memory.m16n8k16MMA instruction, but it does have them16n8k8instruction. We only need to stack the instruction twice to achieve the same effect asm16n8k16.Kernel Benchmark
2080ti + Dense Marlin + GPTQ Channelwise
2080ti + Dense Marlin + GPTQ Group 128 (Comparing with gptq exllama v2)