Skip to content

[DSv4][Nvidia] SM12x DeepSeek V4 support#40991

Draft
jasl wants to merge 7 commits intovllm-project:mainfrom
jasl:ds4-sm120
Draft

[DSv4][Nvidia] SM12x DeepSeek V4 support#40991
jasl wants to merge 7 commits intovllm-project:mainfrom
jasl:ds4-sm120

Conversation

@jasl
Copy link
Copy Markdown
Contributor

@jasl jasl commented Apr 27, 2026

The PR combines #40929, now it's DeepGEMM free, thanks to @bbbearxyz !

UPDATE: To better aligh with Deepseek official API and the B200 code path, I made a harness to help to measure correctness, performance, and quality https://github.com/jasl/vllm-ds4-sm120-harness
And I will put the latest report for people to review

Summary

This PR enables DeepSeek V4 Flash to serve on NVIDIA SM12x GPUs, tested on a
2x RTX PRO 6000 Blackwell Workstation Edition host.

The important change from the earlier prototype is that this PR no longer pins
or rewrites the DeepGEMM dependency. The branch keeps vLLM's upstream DeepGEMM
installer and CMake metadata intact, and implements the required SM12x runtime
fallbacks in vLLM:

  • DeepSeek V4 tokenizer / parser / model integration.
  • Portable Triton sparse MLA path for SM12x.
  • fp8_ds_mla sparse MLA cache handling.
  • Sink-aware SWA + compressed sparse attention.
  • vLLM-side SM12x fallbacks for DeepSeek V4-specific DeepGEMM calls.
  • SM12x sparse indexer and paged MQA fallback kernels.
  • Guardrails so existing SM90 / SM100 optimized paths remain unchanged.

Motivation

DeepSeek V4 currently relies on kernels that are available on Hopper and
datacenter Blackwell paths, but not on SM120 / SM121 workstation and consumer
Blackwell GPUs. In particular, SM12x cannot directly reuse SM90 WGMMA kernels
or SM100 tcgen05 kernels.

This PR adds correctness-first portable kernels for the missing SM12x pieces,
then optimizes the hot sparse MLA paths enough for real serving. The result is
a reviewable vLLM-side compatibility layer that does not require maintainers to
accept a temporary DeepGEMM fork pin.

Scope

Included:

  • SM12x Triton sparse MLA decode and prefill paths.
  • fp8_ds_mla packed cache decode for SWA and compressed sparse candidates.
  • Sink-aware sparse attention denominator semantics.
  • SM12x local fallbacks for DeepSeek V4-specific DeepGEMM call sites.
  • Sparse indexer memory bound fixes for long prefill.
  • DeepSeek V4 tokenizer handling and tool-call parser fixes needed by the new
    model path.
  • Targeted correctness tests and an HTTP logprobs oracle comparator.

Not included:

  • Replacing FlashMLA on SM90 / SM100.
  • A final Tensor Core implementation for every SM12x kernel.
  • MTP speculative decoding fixes. Those are kept in a separate branch / PR.
  • Community performance experiments that are useful for evaluation but too
    broad for this PR.
  • Any DeepGEMM fork pin or DeepGEMM CMake / install-script rewrite.

Runtime controls

The SM12x sparse MLA path registers its environment variables in vllm.envs,
so users should not see unknown-variable warnings for these knobs.

Variable Default Meaning
VLLM_TRITON_MLA_SPARSE auto 1 forces the Triton sparse MLA path, 0 disables it. When unset, vLLM enables it on SM12x where FlashMLA sparse is unavailable.
VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE 512 Top-k candidate chunk size for sparse MLA accumulation. Lower values reduce transient workspace at the cost of more kernel work.
VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE 256 Query chunk size used by prefill sparse MLA fallback.
VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE auto Optional decode head block override. Supported values are 1, 2, and 4; benchmarks used 4.
VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE auto Optional matmul-based sparse MLA decode toggle. When unset it auto-enables on SM12x.
VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH context dependent Allows compile / CUDA graphs for the sparse MLA path. In the formal PR branch, unset keeps graphs for normal decode and disables them for speculative decoding; 1 forces allow, 0 disables.

Operational warning: do not set
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True with the TP=2 CUDA graph
configuration used below. In local testing it made custom all-reduce fail
during CUDA graph address registration. Leaving it unset avoids that failure.

Branches

Formal PR branch:

jasl/vllm@ds4-sm120
HEAD: 7a34ed538

Preview / evaluation branch with extra community performance work and MTP fixes:

jasl/vllm@ds4-sm120-full
HEAD: ab7336f21

The preview branch is not intended as the review target. It exists so users can
try the broader optimization stack while this PR stays focused.

Test environment

Hardware:

Host: jasl-workstation
GPU: 2x NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Compute capability: SM120
GPU memory: 95 GiB class per GPU

Software:

OS: Ubuntu, Linux 7.0.0-14-generic
CUDA toolkit: /usr/local/cuda
Python: 3.13.13
PyTorch: 2.11.0+cu130
vLLM package metadata: 0.20.1rc1.dev12+g363ffa145

Benchmark environment:

export PATH="/usr/local/cuda/bin:$PATH"
export CUDA_HOME="/usr/local/cuda"
export TRITON_PTXAS_PATH="/usr/local/cuda/bin/ptxas"
export CUDA_ARCH_LIST="120a"
export TORCH_CUDA_ARCH_LIST="12.0a"
export VLLM_TRITON_MLA_SPARSE=1
export VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=4
export VLLM_RPC_TIMEOUT=100000
unset PYTORCH_CUDA_ALLOC_CONF

Note: DGX Spark use 121a and 12.1a

Validation

Formal PR branch checks:

python -m ruff check \
  vllm/envs.py \
  vllm/utils/deep_gemm.py \
  vllm/tokenizers/deepseek_v4_encoding.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  vllm/v1/attention/backends/mla/sparse_mla_env.py \
  vllm/v1/attention/backends/mla/sparse_swa.py \
  tests/tokenizers_/test_deepseek_v4.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py \
  tests/v1/attention/test_sm120_deepgemm_fallbacks.py

Result:

All checks passed!

Compile check:

python -m py_compile \
  vllm/envs.py \
  vllm/utils/deep_gemm.py \
  vllm/tokenizers/deepseek_v4_encoding.py \
  vllm/v1/attention/backends/mla/sparse_mla_kernels.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  vllm/v1/attention/backends/mla/sparse_swa.py

Targeted tests:

python -m pytest -q \
  tests/tokenizers_/test_deepseek_v4.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_sparse_mla_backends.py \
  tests/v1/attention/test_sm120_deepgemm_fallbacks.py \
  tests/v1/attention/test_sparse_attn_indexer.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py

Result:

151 passed, 504 skipped, 16 warnings in 356.93s

Diff hygiene:

git diff --check origin/main...HEAD

Result: clean.

Preview branch focused checks:

python -m ruff check \
  vllm/v1/attention/backends/mla/sparse_mla_env.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  tests/v1/spec_decode/test_mtp.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py

python -m pytest -q \
  tests/v1/spec_decode/test_mtp.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py

Result:

95 passed, 16 warnings in 48.35s

Serving command

Formal PR branch, no MTP:

PYTHONPATH=~/tmp/vllm-bench-ds4-sm120 \
~/tmp/vllm/.venv/bin/vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port 8017 \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --max-model-len 16384 \
  --gpu-memory-utilization 0.94 \
  --tensor-parallel-size 2 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4

Preview branch, MTP:

PYTHONPATH=~/tmp/vllm-bench-ds4-sm120-full \
~/tmp/vllm/.venv/bin/vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port 8018 \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --max-model-len 16384 \
  --gpu-memory-utilization 0.985 \
  --tensor-parallel-size 2 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4 \
  --speculative-config '{"method":"mtp","num_speculative_tokens":2}'

Benchmark command

The short-context benchmark uses 128 -> 512; the long-context benchmark uses
8192 -> 512. Each row uses 48 prompts and temperature=0.

~/tmp/vllm/.venv/bin/vllm bench serve \
  --model deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port <port> \
  --dataset-name random \
  --random-input-len <128-or-8192> \
  --random-output-len 512 \
  --num-prompts 48 \
  --max-concurrency <C> \
  --ignore-eos \
  --temperature 0 \
  --save-result \
  --result-dir <result-dir> \
  --result-filename <name>.json

Formal PR branch benchmark

Branch:

jasl/vllm@ds4-sm120
HEAD: 7a34ed538

Server memory setting:

--gpu-memory-utilization 0.94

MTP is not included in this branch. Starting the formal branch with
--speculative-config '{"method":"mtp","num_speculative_tokens":2}' fails
because the MTP fix stack is intentionally kept separate.

Context Concurrency Output tok/s Requests/s Mean TPOT Mean TTFT
128 -> 512 1 100.38 0.196 9.76 ms 113.4 ms
128 -> 512 4 296.84 0.580 13.16 ms 171.9 ms
128 -> 512 8 478.34 0.934 16.18 ms 291.6 ms
8192 -> 512 1 58.61 0.114 10.94 ms 3143.0 ms
8192 -> 512 2 81.35 0.159 15.37 ms 4732.0 ms

Result directory:

/home/jasl/tmp/ds4_sm120_bench_20260429_032651

Preview branch benchmark

Branch:

jasl/vllm@ds4-sm120-full
HEAD: ab7336f21

Server memory setting:

--gpu-memory-utilization 0.985

This branch includes the separate MTP fixes and community performance patches.
It is for evaluation only, not the formal PR review target.

Startup notes:

  • no-MTP CUDA graph reserve: 3.67 GiB
  • no-MTP available KV cache: 10.6 GiB
  • MTP CUDA graph reserve: 4.38 GiB
  • MTP available KV cache: 6.2 GiB
Context Concurrency no-MTP tok/s MTP tok/s MTP delta no-MTP TPOT MTP TPOT no-MTP TTFT MTP TTFT MTP acceptance
128 -> 512 1 103.03 161.14 +56.4% 9.60 ms 5.95 ms 62.3 ms 138.7 ms 78.61%
128 -> 512 4 303.20 326.51 +7.7% 12.93 ms 11.47 ms 145.6 ms 346.0 ms 80.14%
128 -> 512 8 473.53 525.08 +10.9% 16.46 ms 14.07 ms 236.3 ms 402.2 ms 77.17%
8192 -> 512 1 58.54 79.17 +35.2% 10.81 ms 6.23 ms 3223.4 ms 3283.6 ms 81.48%
8192 -> 512 2 80.77 98.33 +21.7% 15.33 ms 13.46 ms 4843.8 ms 3486.3 ms 79.02%

Result directory:

/home/jasl/tmp/ds4_sm120_full_bench_20260429_041151

Review notes

Changes made before this update:

  • Removed the temporary DeepGEMM fork pin and related env bridge.
  • Removed sparse MLA diagnostic dump hooks and tests.
  • Kept runtime-facing names production-oriented; test oracle helpers remain
    clearly separated from serving kernels.
  • Verified there are no stale prototype DeepGEMM refs.
  • Re-signed the branch with DCO trailers.
  • Re-ran targeted tests and benchmarks after the cleanup.

Known follow-ups

  • MTP speculative decoding should be reviewed as an independent PR.
  • ds4-sm120-full can continue to carry community performance patches for
    public evaluation.
  • Further SM12x optimization should focus on full decode profiling across
    indexer, MoE, collectives, sampling, and sparse MLA rather than broadening
    this PR.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added ci/build deepseek Related to DeepSeek models nvidia v1 labels Apr 27, 2026
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

@WoosukKwon
I rebased my original PR #40899
Here it is, please help to review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DeepSeek V4 models, including updates to DeepGEMM integration, new FP8 einsum kernels for SM12x, and infrastructure for sparse MLA attention. However, there are two critical issues: the removal of the optional dependency check for tilelang in vllm/model_executor/layers/mhc.py will break installations on non-CUDA platforms, and the replacement of DeepseekV4MLP with DeepseekV2MLP for shared experts removes necessary swiglu_limit clamping, which is vital for numerical stability in FP8 inference.

Comment thread vllm/model_executor/layers/mhc.py Outdated
Comment thread vllm/model_executor/models/deepseek_v4.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4e2adf8a9f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/model_executor/models/deepseek_v4.py Outdated
Comment thread vllm/model_executor/layers/mhc.py Outdated
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

The PR is ready to review.
I'm benchmarking the latest result.

@wuwenthink
Copy link
Copy Markdown

Thanks to your development, the length of the context currently supported locally has increased significantly, and the speed of decode has increased a lot. It's amazing!

@bbbearxyz
Copy link
Copy Markdown

@jasl My understanding is that your current approach supports SM120 through a combination of DeepGEMM and Triton. I wonder whether a pure Triton implementation, without depending on DeepGEMM at all, would be cleaner and perhaps worth considering as an alternative. I’d be interested to hear your thoughts.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

@jasl My understanding is that your current approach supports SM120 through a combination of DeepGEMM and Triton. I wonder whether a pure Triton implementation, without depending on DeepGEMM at all, would be cleaner and perhaps worth considering as an alternative. I’d be interested to hear your thoughts.

I don't have a preference.
IMO, contributing to DeepGEMM would help align behavior with DeepSeek's official behavior and enable them to pay attention to the community's needs.
I can imagine pure Triton would help to improve performance. I can try it later.

@v1b3coder
Copy link
Copy Markdown

v1b3coder commented May 3, 2026

Reproduced on 2x RTX PRO 6000 Server Edition @ 300W TDP + PCIe 4.0: 88/91/70/87/84% of your reference numbers across the five workloads, roughly fits the PCIe 4.0 speeds (I assume you run PCIe 5.0). Boots fine at 250K context with max-num-seqs=16, gpu-memory-utilization=0.97.

Thank you for your work!

@v1b3coder
Copy link
Copy Markdown

KV-cache reuse breaks under concurrent multi-session load (sparse-MLA + SWA page sharing)

Disclosure: I'm not a vLLM developer — just a downstream user trying to run this PR on 2× RTX PRO 6000 Blackwell Server Edition. The investigation below was done with AI assistance reading the patch source, so please correct me if my reading is off.

TL;DR

Two parallel chat sessions at ~50 K context each behave as if the KV cache is wiped between turns — every request re-prefills from scratch. Single-session multi-turn is fine.

Solo (1 session, 3 turns) Parallel (2 sessions, 3 turns each)
Turn 1 TTFT 29 725 ms A=297 / B=30 659 ms
Turn 2 TTFT 635 ms A=30 311 / B=60 326 ms
cold→warm speedup 46.8× 0.3×
/metrics prefix-cache hit rate 66.6 % 16.7 %

vllm:num_preemptions_total does not increment, no Preempted lines in the log, KV pool is at ~12 % utilisation when the failure occurs (691 K tokens / 2 701 blocks at block_size=256, two sessions need ~312 blocks).

Smoking gun

A's parallel turn 1 = 297 ms (cache hit — A's blocks from the solo phase were still resident). A's parallel turn 2 = 30 311 ms (full re-prefill). Between A.1 and A.2, B's cold prefill ran. B's prefill destroyed A's blocks despite ~88 % free pool.

Tested

  • Baseline (VLLM_TRITON_MLA_SPARSE=1): FAIL.
  • --attention-backend FLASHINFER_MLA_SPARSE: same FAIL pattern (hit rate 12.5 %, A.2 = 30 796 ms). Boot log still says Using DeepSeek's fp8_ds_mla KV cache format, so both backends route through the same KV path. Eliminates a kernel-specific cause.
  • VLLM_TRITON_MLA_SPARSE=0: not testable in our build (RuntimeError: vllm._flashmla_C is not available — image lacks the FlashMLA C ext for SM120).

My (AI-assisted) reading of the patch

DS-V4-Flash has sliding_window=128 and per-layer compress_ratios = [0, 0, 4, 128, …] — three layer types: SWAonly, C4A, C128A.

  • sparse_swa.py:75–79 documents that SWA and C4A pages share the same physical 64-token blocks.
  • deepseek_v4_attention.py:605: kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
  • The SlidingWindowMLASpec cache spec uses upstream SlidingWindowManager, which actively recycles pages mid-prefill via remove_skipped_blocks (kv_cache_interface.py:439–451).

The hypothesis I landed on: upstream SlidingWindowManager has no invariant that "before evicting a SWA page, check that the C4A view still references it" — because in upstream vLLM, sliding-window pages are never co-owned with another spec. The patch introduces co-ownership, but the upstream eviction policy isn't aware. Solo runs in lock-step so it works; concurrent B can recycle pages A's C4A view still needs.

This fits the index-level hit rate staying healthy on solo (66 %) but collapsing on concurrent (12–17 %) — the index thinks it has A's data, the memory says otherwise.

I'm reasonably confident about the shape of the bug, much less so about the exact location. Could very well be wrong.

Reproducer

~250-line standalone Python script (httpx only). Two distinct ~40 K-token prompts, 3-turn conversation per session, solo then concurrent in two threads. Streams /v1/completions, snapshots /metrics deltas, prints PASS/FAIL. ~3 min per run. Happy to share as a gist — just say where you'd like it.

Setup

  • 2× RTX PRO 6000 Blackwell Server Edition (SM120), driver 595.71.05, CUDA 13.2
  • Base: vllm/vllm-openai:deepseekv4-cu130
  • jasl/vllm ds4-sm120 @ 843fe9e (current HEAD)
  • jasl/DeepGEMM sm120 @ 7a7a41a
  • vLLM args: --tensor-parallel-size 2 --kv-cache-dtype fp8 --block-size 256 --max-model-len 250000 --max-num-seqs 16 --gpu-memory-utilization 0.97 --enable-prefix-caching --enable-chunked-prefill --max-num-batched-tokens 16384 --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}'
  • VLLM_TRITON_MLA_SPARSE=1, VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=4 (your canonical recipe)

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 4, 2026

I think I have found the root cause, a stupid fault I missed an argument...

@wuwenthink
Copy link
Copy Markdown

Update to ds4-sm120-full and I'll test

@pasta-paul
Copy link
Copy Markdown

Now testing the DeepSeek-V4-Flash-W4A16-FP8 build (W4A16 GPTQ + FP8_BLOCK attention, calibrated against jasl's branch + #41276) on dual DGX Spark TP=2 — will post full vllm-ds4-sm120-harness results when it lands.

Note: this build uses Marlin INT4 kernels for the routed experts, not the SM12x sparse-MLA / FP4 path where the @wuwenthink garbled-output issue currently lives — so if it produces clean output on chat-smoke coding (aquarium_html, clock_html) where native FP4/FP8 currently doesn't, that's a useful diagnostic point about where the bug isn't.

Anyone with 2× RTX PRO 6000 (SM 12.0) want to also run this on the same harness while we wait for jasl's fix? --tensor-parallel-size 2 is the validated topology (TP>2 hits #41511). Recipe + reproduction: pasta-paul/dsv4-flash-w4a16-fp8. H200 SM 9.0 numbers from yesterday: chat-smoke 10/10, toolcall15 26/30 (87% vs native baseline 77%).

@wuwenthink
Copy link
Copy Markdown

Now testing the DeepSeek-V4-Flash-W4A16-FP8 build (W4A16 GPTQ + FP8_BLOCK attention, calibrated against jasl's branch + #41276) on dual DGX Spark TP=2 — will post full vllm-ds4-sm120-harness results when it lands.

Note: this build uses Marlin INT4 kernels for the routed experts, not the SM12x sparse-MLA / FP4 path where the @wuwenthink garbled-output issue currently lives — so if it produces clean output on chat-smoke coding (aquarium_html, clock_html) where native FP4/FP8 currently doesn't, that's a useful diagnostic point about where the bug isn't.

Anyone with 2× RTX PRO 6000 (SM 12.0) want to also run this on the same harness while we wait for jasl's fix? --tensor-parallel-size 2 is the validated topology (TP>2 hits #41511). Recipe + reproduction: pasta-paul/dsv4-flash-w4a16-fp8. H200 SM 9.0 numbers from yesterday: chat-smoke 10/10, toolcall15 26/30 (87% vs native baseline 77%).

The author said he might have found a problem and would test it after he fixed the update

@varamik
Copy link
Copy Markdown

varamik commented May 4, 2026

I took the down the latest version of these and now DS4-Flash works with opencode, speed is about 40-50 tok/s with dual RTX Pro 6000. Great progress!

Using
jasl/vllm ds4-sm120 @ 843fe9e (current HEAD)
jasl/DeepGEMM sm120 @ 7a7a41a
CUDA 13.1

For Opencode I needed to add these start parameters:
--tool-call-parser deepseek_v4
--enable-auto-tool-choice
--served-model-name deepseek-v4-flash \

For some reason I need to make this patch to avoid "unsupported architecture" error:
``
diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py
index eb0ea8f52..5775831c1 100644
--- a/vllm/v1/attention/backends/mla/indexer.py
+++ b/vllm/v1/attention/backends/mla/indexer.py
@@ -623,8 +623,12 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if seq_lens.dim() == 1:
seq_lens = seq_lens.unsqueeze(-1)

// # DeepGEMM is required for the paged MQA logits on CUDA devices
// if current_platform.is_cuda() and has_deep_gemm():

        if (
            current_platform.is_cuda()
            and has_deep_gemm()
            and not current_platform.is_device_capability_family(120)
        ):

``

Any ideas why this is needed?

@wuwenthink
Copy link
Copy Markdown

wuwenthink commented May 4, 2026

@jasl Is the latest ds4-sm120-full branch going to be tested with the export VLLM_DEEPSEEK_V4_USE_DEEPGEMM_SM12X_KERNELS=1 command? Author: Can you provide a version of your export environment parameters and VLLL loading parameters that can effectively improve the speed of prefill and decode?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 4, 2026

@jasl Is the latest ds4-sm120-full branch going to be tested with the export VLLM_DEEPSEEK_V4_USE_DEEPGEMM_SM12X_KERNELS=1 command? Author: Can you provide a version of your export environment parameters and VLLL loading parameters that can effectively improve the speed of prefill and decode?

I'm not focusing on performance right now, since you reported the correctness and low-quality generation issues.
Now the issue should be resolved.

GPT summary:

ds4-sm120@5bb51ec03 这一轮我认为已经消除了用户反馈的低质量生成主问题。官方参数形态下 generation smoke 是 105/105 ok,全是 finish_reason=stop,没有 length 截断;抽查代码、HTML、中文写作、翻译输出没有乱码、错段、明显胡言乱语。ToolCall-15 是 232/270 = 86%,对比 B200 TP4 no-MTP 243/270 = 90%、B200 TP2/MTP 225/270 = 83%、官方 API 单轮 81/90 = 90%,属于同一档。剩余失败主要还是工具选择/策略问题,不像 parser 或 kernel correctness 问题。

I'm fixing the KV cache issue and building a new SM120 baseline to start the performance tuning.

@moshemalawach
Copy link
Copy Markdown

Pre-warm sparse-MLA TileLang kernels at boot to avoid first-request JIT spikes

Running ds4-sm120-full @ 7f63eb1f9 in production (2× RTX 6000 Pro Blackwell, TP=2 + EP, MTP=1, max-model-len 200000, VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH=1, fresh ~/.cache/vllm). On first hit of a sparse-MLA shape bucket the JIT hasn't seen, three TileLang kernels compile per worker and block the engine for ~13s; queued requests (small or large) sit at 0-1 tok/s during that window, then jump back to ~30-40 tok/s with healthy MTP (80-89% accept, mean accept length ~1.85).

Excerpt from engine log on a fresh boot under live mixed traffic:

20:52:51  torch.compile took 4.28 s
20:52:52 → 20:52:57  TileLang compile mhc_pre_big_fuse_tilelang  (5s, both workers)
20:53:00 → 20:53:04  TileLang compile mhc_post_tilelang          (4s)
20:53:20 → 20:53:24  TileLang compile hc_head_fuse_tilelang      (4s)
20:53:34 → 20:53:39  TileLang compile mhc_pre_big_fuse_tilelang  (apparently distinct invocation — 5s more)

The compiled artifacts persist to ~/.cache/vllm, so the spike doesn't recur once enough shapes have been seen — but every container rebuild starts from cold, and any prompt landing in a fresh shape pays the full ~13s blocking compile. Long-context prompts trip new shapes more often than short ones, which is what makes it look like a long-context perf bug rather than a JIT bug.

Suggested: pre-warm sparse-MLA TileLang kernels at engine init, ideally driven by a list analogous to cudagraph_capture_sizes but along the seq-len / sparse-top-k axes the mhc_*/hc_* kernels specialize on. Simplest knob would be an env var enumerating buckets to warm; cleaner would be auto-deriving them from the cudagraph capture configuration so they stay in sync.

Is there already a knob for this I missed? Happy to test a patch on the same hardware.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 4, 2026

Pre-warm sparse-MLA TileLang kernels at boot to avoid first-request JIT spikes

Running ds4-sm120-full @ 7f63eb1f9 in production (2× RTX 6000 Pro Blackwell, TP=2 + EP, MTP=1, max-model-len 200000, VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH=1, fresh ~/.cache/vllm). On first hit of a sparse-MLA shape bucket the JIT hasn't seen, three TileLang kernels compile per worker and block the engine for ~13s; queued requests (small or large) sit at 0-1 tok/s during that window, then jump back to ~30-40 tok/s with healthy MTP (80-89% accept, mean accept length ~1.85).

Excerpt from engine log on a fresh boot under live mixed traffic:

20:52:51  torch.compile took 4.28 s
20:52:52 → 20:52:57  TileLang compile mhc_pre_big_fuse_tilelang  (5s, both workers)
20:53:00 → 20:53:04  TileLang compile mhc_post_tilelang          (4s)
20:53:20 → 20:53:24  TileLang compile hc_head_fuse_tilelang      (4s)
20:53:34 → 20:53:39  TileLang compile mhc_pre_big_fuse_tilelang  (apparently distinct invocation — 5s more)

The compiled artifacts persist to ~/.cache/vllm, so the spike doesn't recur once enough shapes have been seen — but every container rebuild starts from cold, and any prompt landing in a fresh shape pays the full ~13s blocking compile. Long-context prompts trip new shapes more often than short ones, which is what makes it look like a long-context perf bug rather than a JIT bug.

Suggested: pre-warm sparse-MLA TileLang kernels at engine init, ideally driven by a list analogous to cudagraph_capture_sizes but along the seq-len / sparse-top-k axes the mhc_*/hc_* kernels specialize on. Simplest knob would be an env var enumerating buckets to warm; cleaner would be auto-deriving them from the cudagraph capture configuration so they stay in sync.

Is there already a knob for this I missed? Happy to test a patch on the same hardware.

Thank you! I'll check it.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 4, 2026

@wuwenthink another analysis from Opus 4.7 Max

总评:在同一档生成质量上

baselines/20260503_sm120_nomtp_ds4_sm120_6e1e7ecad_diagnostic/generation 与官方 API 和两组 B200 baseline 处于同一水平;最贴近的同构对照(B200 TP2 nomtp)在盲读尺度下基本无法区分。GSM8K(EM flexible 95.30%,stderr 0.0058)和 oracle 的 5/5 通过也佐证模型路径本身是健康的。

分项观察

维度 SM120 vs 官方 API SM120 vs B200 TP4 nomtp SM120 vs B200 TP2 nomtp
翻译(中⇄英技术) 同一档,术语和句长分布一致 同一档 同一档
代码(en_code_be / zh_code_fe) 更可靠:官方 API 6/9 zh_code_fe 因为 4096 token 上限 finish_reason: length 截断;SM120 9/9 都正常闭合 同一档 同一档
摘要(中/英技术摘要) 同一档,结构、覆盖、字数都贴近 同一档 同一档
写作(中/英技术写作) 同一档;中文行文自然,无英语夹带 同一档 同一档
Token 预算 / 长度分布 SM120 completion_tokens 区间几乎完全落在 B200 的区间内

真正出现的瑕疵(实地核对过)

  1. zh/zh2en_tech_001.3.think-max:在已经把竹纸正文译完之后,模型在 造皮紙 这个空标题之下凭空补写了约 120 词的 "Making Bark Paper" 段落(提到 mulberry、wingceltis (qingtan)、fengzhi、chengxin tang 等原文里没有的内容),违反了 prompt 中"不要扩写为百科文章"的要求。同 seed/mode 下 B200 TP2 直接停在裸标题、B200 TP4 停在 "(abbreviated)"、官方 API 因为 think-max 把 4096 token 全花在了 chain-of-thought 草稿上根本没出译文 —— 所以这不是共享缺陷,是 SM120 在这一例 temperature=1.0, think-max 上独立采样到了一个偏差路径。
  2. zh/zh_sum_tech_001.2.think-high:把原文的 "柬紙" 写成了 "束纸",并多出了来源里没有的 "绵纸"。
  3. zh/zh_sum_tech_001.3.think-high:编造了原文不存在的术语 "楸纱纸"(疑似 "棂纱纸" 字符级混淆)。

这三例都是在 zh 长上下文 + temperature=1.0 下的零星采样偏移,9 个 seed 里只命中 1 个;同任务族其他 seed(包括同一 think-max 模式)干净;en 侧在我抽查的范围内未观察到此类瑕疵。

是否在可接受范围内

是。理由:

  • 样式上:是 sub-character 级或单段级的随机性偏差,没有 token 重复、乱码、语种串音、过早截断、风格崩溃、refusal 这些"硬故障"信号。
  • 数量上:在我跨 8 任务族、约 144 个 SM120 样本里只抓到 3 例可指认问题,命中率 ~2%,与温度 1.0 下其它两个 baseline 的随机抖动同量级(B200 TP4 seed3 think-max 也写出了 "(abbreviated)" 这种轻度偏离指令的尾巴)。
  • 对照下限:官方 API 自身在 think-max 模式下就有把整个 token 预算花在草稿上的失败模式,以及在 zh_code_fe 上批量截断的硬上限问题。SM120 在这两点上反而表现得更好。
  • harness 自身的非绿门report.mdacceptance / generation / toolcall15 三个门 exit 1,但 README 已明示这是 diagnostic baseline 而非完整通过;toolcall15 224/270 的失败全部归类于 TC-06 / TC-11 / TC-12 / TC-14,是工具调用策略(拆分翻译、不必要的计算器、不支持操作的拒绝、工具错误诚实性)问题,不是文本生成质量问题。

结论:把 20260503_sm120_nomtp_..._diagnostic 当作一个生成质量参考是安全的;和官方 API、B200 TP4/TP2(nomtp 一路)在同一水平线,少量偏差落在采样噪声范围内,没有发现可归因于 sm_120 内核的系统性回归。

@pasta-paul
Copy link
Copy Markdown

re: 77bbc16 — Validated for production tool-use and chat at TP=2 on DGX Spark with --max-model-len 16384 at ~14–17 tok/s decode, pending full harness confirmation. Long-context behavior beyond 16K and 24-hour stability soak still pending.

@pasta-paul
Copy link
Copy Markdown

Reporting a workspace-allocator bug we hit deploying this PR's W4A16 quant on dual DGX Spark TP=2: _forward_prefill allocates ~21.8 MB workspace but lock_workspace() fires at 21.62 MB because the dummy-run path (attention_impl, if not isinstance(attn_metadata, dict)) returns early and never sizes the prefill workspace.

The comment at deepseek_v4_attention.py:170-172 says "matching profile-time reservation in attention_impl's dummy-run branch" — implying a hook was intended. We have a working ~30-line patch that implements it (no --enforce-eager needed; ~14–17 tok/s decode vs ~3.9 tok/s under eager).

Full report + patch + Spark TP=2 validation results (gsm8k 95.37%, HumanEval pass@1 80.49%) in #41700.

jasl and others added 6 commits May 5, 2026 15:33
Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Protect hybrid-aligned DeepSeek V4 MLA prompt cache blocks so they survive decode and unrelated long-session cache churn. Keep common-prefix accounting aware of the extra protection reference and cover compressor-state SlidingWindowMLA groups in a regression test.

Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
@v1b3coder
Copy link
Copy Markdown

Confirming e68cdb98 ("Fix DeepSeek V4 MLA prefix cache reuse") resolves the concurrent-load prefix-cache eviction I were hitting on 2× RTX PRO 6000 (SM120, TP=2, ds4-sm120 branch + jasl/DeepGEMM sm120).

Cherry-picked the commit on top of ds4-sm120 HEAD (77bbc1627, before today's rebase). Same image otherwise — no source rebuild, just dckr cp of the two changed Python files + container restart.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 5, 2026

@pasta-paul @moshemalawach @v1b3coder
Thank you all for helping me improve the PR.
Your suggestions have landed on ds4-sm120-full, I'm working on testing and then building the full baseline.
If everything is good, I'll start testing on GB10 and tune performance.

@aabbccddwasd
Copy link
Copy Markdown
Contributor

MTP indexer performance fix for SM12x

@jasl We investigated the 1M-context MTP decode slowdown on SM120 (RTX PRO 6000 ×4) and found two issues in the paged-MQA-logits Triton kernels.

Problem

With --max-model-len 1024000 --speculative-config '{"method":"mtp","num_speculative_tokens":3}', MTP N=3 decode was only 107 tok/s at 28K actual context, versus 217 tok/s with --max-model-len 32768. Performance scaled with max_model_len rather than actual context length.

Root cause #1: MTP excluded from row-wise kernel

fp8_paged_mqa_logits_triton (line 863) gates the optimized row-wise kernel with next_n == 1. The row-wise kernel (your 4976b9741) reuses each KV tile across a block of heads — 6–10× faster than the general kernel for the near-full-context case. The gate was safe when MTP had not yet been validated on SM120, but the kernel itself already handles next_n > 1 correctly (it reads per-token context_len and the batch/q_pos computation is identical to the general kernel).

-    if next_n == 1 and head_dim % 64 == 0 and num_heads % 4 == 0:
+    if head_dim % 64 == 0 and num_heads % 4 == 0:

Root cause #2: Grid sized to max_model_len

Even with the row-wise kernel, token_count = max_model_len (250K compressed for C4A at 1M vs 8K at 32K). For a 28K request the grid launches 1954 blocks but only 63 cover valid columns. The other 1891 blocks do full QK dot products that are ultimately masked to -inf (line 1050).

Fix: early-exit in _fp8_paged_mqa_logits_rowwise_kernel — after loading context_len, if the block starts beyond the valid range, write -inf and return:

+    if token_start + pid_n * BLOCK_N >= context_len:
+        tl.store(
+            logits_ptr + row * stride_lm + offs_local_n * stride_ln,
+            tl.full((BLOCK_N,), float("-inf"), dtype=tl.float32),
+            mask=valid_row & valid_n,
+        )
+        return
     context_mask = valid_n & (offs_n < context_len)

The grid stays at max_model_len (CUDA graph compatible); only per-block work is skipped.

Results (SM120, TP=4, DeepSeek-V4-Flash, MTP N=3)

Before After 32K baseline
Math decode 107 tok/s 203 tok/s 217 tok/s
Code decode 110 tok/s 207 tok/s 187 tok/s
Draft speed 118 tok/s 220 tok/s 208 tok/s
paged_mqa_logits kernel 433 µs 29 µs 22 µs

Full commit: aabbccddwasd@b1b1b532c

Both changes are correctness-preserving: the paper defines the indexer scan as s < Floor(t/m) (strictly causal), and positions beyond context_len are already masked to -inf in the current kernel — we just skip computing them.

Tested with -O3 (CUDA graphs + torch.compile), no correctness degradation on math benchmarks.

Investigation by DeepSeek-V4-Pro under Claude Code

@moshemalawach
Copy link
Copy Markdown

Different failure mode under sustained streaming load — silent wedge, KV blocks not freed at request teardown

Reporting a wedge we hit after e68cdb98 ("Fix DeepSeek V4 MLA prefix cache reuse") on the same hardware as @v1b3coder (2× RTX PRO 6000 Blackwell, SM120, TP=2). Their original eviction repro is solved by that commit, but a different mode shows up under sustained real-conversation streaming.

Setup

  • vLLM ds4-sm120-full 8920a2f4f (and re-confirmed on the post-rebase tip a026eea1f — same content, just rebased SHAs).
  • Python-overlay only (no C++ rebuild); jasl/DeepGEMM 7a7a41a.
  • Launch: TP=2 + EP, MTP N=1, fp8 KV, --max-num-seqs 16 --max-num-batched-tokens 4096 --gpu-memory-utilization 0.96, --compilation-config cudagraph_mode=FULL_AND_PIECEWISE, VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH=1.
  • Workload: real production chat traffic, mix of /v1/chat/completions and /v1/messages?beta=true (Anthropic Messages API). Conversations grow turn-by-turn; not a synthetic shared-prefix probe.

Signature

Healthy operation at high prefix-hit rate, then in a single 10 s sampling window the engine transitions to a non-recovering wedge. Two captures, both showing the same shape:

# capture 1 — last healthy → wedge
09:54:11 Engine: Running 1 Waiting 0  KV 83.1%  Prefix hit 85.6%  gen 58.9 tok/s
09:54:21 Engine: Running 1 Waiting 1  KV 83.1%  Prefix hit 46.3%  gen 58.6 tok/s   <- second req queues, hit drops
09:54:31 Engine: Running 1 Waiting 1  KV 83.1%  Prefix hit 31.7%  gen 59.2 tok/s
09:54:41 Engine: Running 0 Waiting 2  KV 82.8%  Prefix hit  0.0%  gen 18.1 tok/s   <- prior req finishes, KV does NOT free
09:54:51 Engine: Running 0 Waiting 2  KV 82.8%  Prefix hit  0.0%  gen  0.0 tok/s   <- wedged from here, indefinitely
[ ... no further engine activity for 4+ minutes ...]

After this point, new POST /v1/chat/completions and POST /v1/messages?beta=true requests still return 200 OK from the API server, but the engine never schedules them — Running stays at 0, queue depth grows. GPU util 0% on both cards, KV pool stuck at ~82%, no exceptions raised, no traceback, no log line at the transition.

The other capture (different image build, same hardware/config) hit the wedge at KV 81.3% with the same Running 1 → 0, Waiting 1, Prefix hit 94.9% → 0.0% shape over a single 10 s window.

Why I think this is distinct from the eviction race

  • Silent: no exception, no warning. Just metrics showing the state change.
  • Block accounting: Running drops to 0 but KV usage stays at ~80%. Blocks held by no live request are not being released.
  • Scheduler stuck: with KV reporting near-full and no completed reqs freeing pages, new requests can never be admitted — perpetual Waiting.
  • Trigger: appears under sustained streaming traffic with growing conversation histories (multi-turn /v1/messages and /v1/chat/completions mixed). Both wedges fired ~40-50 min into healthy operation, sustained 50-115 tok/s aggregate generation, prefix hit 85-96%.

This shape is consistent with a refcount leak on streaming-request teardown — possibly the new _protect_* reference accounting in the prefix-cache fix not balancing on streaming-disconnect / streaming-finalize paths, leaving blocks pinned. A few seconds before each wedge the prefix hit metric started degrading mid-run (94.9% → 46.3% → 31.7% → 0%) which would track with the global pool filling with un-evictable hybrid-aligned blocks.

What I can share

  • Two anonymized engine-log excerpts (timestamps + loggers.py/metrics.py lines) covering the last-healthy → first-wedged → 4-minutes-still-wedged windows for each capture.
  • Container args / launch flags exactly as above.

Not posting the full Docker logs publicly (they contain client IPs and internal endpoint paths), but happy to forward via DM or any channel that works for @jasl.

Speculation on next step

The new commits on ds4-sm120-experimental — "Stabilize DeepSeek V4 MTP draft sampling" and "Tune SM12x sparse MLA graph defaults" — touch the spec-decode + sparse-MLA paths that ran through the workload that wedged. Worth an isolated repro on -experimental to see whether the wedge is in the prefix-cache refcount logic or in the MTP/sparse-MLA interaction. I can build and try once jasl says it's worth a soak run.

cc @jasl @v1b3coder @pasta-paul

Co-authored-by: OpenAI Codex <codex@openai.com>
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 5, 2026

@moshemalawach @aabbccddwasd
Fixed.
Thank you for reporting issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.