[Bugfix] Fix TRITON_MLA FP8 KV cache decode on Blackwell GPUs#35833
[Bugfix] Fix TRITON_MLA FP8 KV cache decode on Blackwell GPUs#35833ricky-chaoju wants to merge 4 commits intovllm-project:mainfrom
Conversation
66a1ab8 to
e7edcf5
Compare
There was a problem hiding this comment.
Code Review
This pull request adds support for FP8 KV cache in Triton MLA, specifically fixing an illegal instruction error on Blackwell GPUs. The changes involve casting the query to bfloat16 to avoid problematic FP8 tl.dot instructions and reducing Triton pipeline stages to prevent shared memory overflow. While the changes are generally correct, I've identified a critical issue where hardcasting the query to bfloat16 can cause a data type mismatch and a runtime error in the downstream V up-projection when the model's data type is float16. I've provided a detailed explanation and a suggested fix for this issue.
Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com> Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com> Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
41bf30e to
71a42cb
Compare
|
Hi @ricky-chaoju, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com> Signed-off-by: "vllm-dev" <ricky.chen@infinirc.com>
|
Documentation preview: https://vllm--35833.org.readthedocs.build/en/35833/ |
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
|
This pull request has merge conflicts that must be resolved before it can be |
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
Summary
tl.dotinstructions that produce illegal instruction errors on Blackwell (SM 12.x).Fixes #35577
Test plan
--kv-cache-dtype fp8, generation throughput ~32–80 tok/s