Skip to content

[Bugfix] Fix TRITON_MLA FP8 KV cache decode on Blackwell GPUs#35833

Open
ricky-chaoju wants to merge 4 commits intovllm-project:mainfrom
ricky-chaoju:triton-mla-fp8-kv-cache
Open

[Bugfix] Fix TRITON_MLA FP8 KV cache decode on Blackwell GPUs#35833
ricky-chaoju wants to merge 4 commits intovllm-project:mainfrom
ricky-chaoju:triton-mla-fp8-kv-cache

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

@ricky-chaoju ricky-chaoju commented Mar 3, 2026

Summary

  • Cast query to bfloat16 before the Triton decode kernel when FP8 KV cache is enabled, avoiding FP8 tl.dot instructions that produce illegal instruction errors on Blackwell (SM 12.x).
  • Reduce Triton pipeline stages to 1 for FP8 to prevent shared memory overflow from float32 dequantization intermediates.
  • KV cache dequantization is still performed on-the-fly inside the kernel, so there is no full-cache copy overhead.

Fixes #35577

Test plan

  • Verified on NVIDIA GB10 (SM 12.1) with GLM-4.7-Flash-NVFP4, --kv-cache-dtype fp8, generation throughput ~32–80 tok/s
  • FP8 vs bfloat16 kernel output: cosine similarity 0.9997, max abs diff 0.0006
  • Memory savings confirmed: FP8 cache uses 50% of bfloat16 baseline

@mergify mergify bot added v1 bug Something isn't working labels Mar 3, 2026
@ricky-chaoju ricky-chaoju force-pushed the triton-mla-fp8-kv-cache branch from 66a1ab8 to e7edcf5 Compare March 3, 2026 05:41
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FP8 KV cache in Triton MLA, specifically fixing an illegal instruction error on Blackwell GPUs. The changes involve casting the query to bfloat16 to avoid problematic FP8 tl.dot instructions and reducing Triton pipeline stages to prevent shared memory overflow. While the changes are generally correct, I've identified a critical issue where hardcasting the query to bfloat16 can cause a data type mismatch and a runtime error in the downstream V up-projection when the model's data type is float16. I've provided a detailed explanation and a suggested fix for this issue.

Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
@ricky-chaoju ricky-chaoju force-pushed the triton-mla-fp8-kv-cache branch from 41bf30e to 71a42cb Compare March 3, 2026 05:48
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 3, 2026

Hi @ricky-chaoju, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>

Signed-off-by: "vllm-dev" <ricky.chen@infinirc.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 3, 2026

Documentation preview: https://vllm--35833.org.readthedocs.build/en/35833/

@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 3, 2026
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 5, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ricky-chaoju.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 9, 2026
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 18, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working documentation Improvements or additions to documentation needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] TRITON_MLA: support FP8 KV cache (needed for SM12.0 / Blackwell)

1 participant