Skip to content

DeepSeek V4 support on SM12x with Triton sparse MLA fallback#40899

Closed
jasl wants to merge 70 commits intovllm-project:mainfrom
jasl:ds4-sm120
Closed

DeepSeek V4 support on SM12x with Triton sparse MLA fallback#40899
jasl wants to merge 70 commits intovllm-project:mainfrom
jasl:ds4-sm120

Conversation

@jasl
Copy link
Copy Markdown
Contributor

@jasl jasl commented Apr 26, 2026

This PR is based on #40760
Companion with deepseek-ai/DeepGEMM#318
Tested on 2 x RTX Pro 6000 (SM120)

Updated: 4.27

Summary

This PR is a runnable prototype for DeepSeek V4 Flash on NVIDIA SM12x GPUs, tested on RTX PRO 6000 Blackwell / SM120.

The main goal is to unblock correctness and end-to-end serving on workstation/consumer Blackwell GPUs where the existing FlashMLA / DeepGEMM SM90-SM100 paths are not available.

This stack includes:

  • DeepSeek V4 model / tokenizer / parser integration.
  • SM12x-compatible DeepGEMM build and runtime wiring.
  • A Triton sparse MLA fallback path for DeepSeek V4.
  • fp8_ds_mla cache support for sparse MLA.
  • SM12x fallback paths for DeepSeek V4-specific FP8 einsum / indexer / paged MQA usage.
  • Runtime knobs for the sparse MLA fallback and DeepGEMM SM120 paged MQA tiled kernel.
  • A temporary DeepGEMM pin to jasl/DeepGEMM@7a7a41a1bac7dacabe74057e7600e59f98f85bce.

Why

DeepSeek V4 currently depends on kernels that are available on datacenter Hopper/Blackwell paths but not on SM120/SM121 GPUs.

In particular, SM12x cannot directly reuse the existing SM90 WGMMA or SM100 tcgen05-based implementations. This PR adds a portable fallback path so DeepSeek V4 can run on SM12x first, with performance optimization left as incremental follow-up work.

Scope

Implemented / included:

  • DeepSeek V4 model registration and tokenizer mode.
  • Sparse MLA fallback for SM12x.
  • Sink-aware sparse MLA semantics.
  • SWA + compressed sparse subset handling.
  • Triton sparse MLA kernel primitives.
  • fp8_ds_mla packed cache handling.
  • DeepGEMM SM120 compatibility pin.
  • SM12x-specific guardrails for unsupported kernel selections.

Not intended as final form:

  • This is not a replacement for FlashMLA on SM90/SM100.
  • This is not a final optimized SM12x Tensor Core implementation.
  • The DeepGEMM dependency is pinned to a fork while SM120 compatibility work is being prepared upstream.
  • Some DeepSeek V4 parser / output formatting behavior is still evolving separately.

Kernel capability matrix

This stack keeps the existing optimized SM90/SM100 paths intact and only adds
SM12x fallback coverage for the DeepSeek V4 blockers. The last column is listed
to make the dependency boundary explicit: this PR is tested with the pinned
DeepGEMM fork, not as a fully DeepGEMM-free stack.

Function / path SM90 / SM100 path SM12x path in this PR Without DeepGEMM / FlashMLA
Sparse MLA attention FlashMLA sparse native kernels Triton sparse MLA kernels for fp8_ds_mla, SWA + compressed sparse candidates, and sink-aware denominator merge. Enabled automatically on SM12x; VLLM_TRITON_MLA_SPARSE=1 can force it. PyTorch reference/oracle path only; useful for debugging, not acceptable for serving throughput.
fp8_gemm_nt DeepGEMM native FP8 GEMM where selected by vLLM Existing vLLM fallback selection for unsupported DeepGEMM shapes, primarily CUTLASS/Triton scaled-mm paths. Same vLLM CUTLASS/Triton fallback selection, subject to normal kernel support.
fp8_m_grouped_gemm_nt / FP8 MoE DeepGEMM grouped FP8 GEMM where selected Existing vLLM MoE fallback path, primarily Marlin/CUTLASS depending on quantization and shape. Same vLLM MoE fallback path; performance is below native DeepGEMM.
fp8_einsum DeepGEMM native fp8_einsum V4-specific Triton kernel for bhr,hdr->bhd with recipe (1,128,128) on SM12x; other recipes still use the DeepGEMM wrapper. The V4-specific Triton recipe still works; other DeepGEMM-only recipes are unavailable.
fp8_fp4_mqa_logits DeepGEMM native MQA logits vLLM FP8 torch reference fallback for the SM12x FP8-Q path. FP4 remains outside this PR scope. Same FP8 torch reference fallback; FP4 remains unavailable.
fp8_fp4_paged_mqa_logits DeepGEMM native paged MQA logits Pinned DeepGEMM SM12x compatibility path. VLLM_DEEP_GEMM_SM120_PAGED_MQA_TILED=1 forwards DG_SM120_PAGED_MQA_TILED=1 to use the tiled implementation when present. Unavailable for serving; the DeepGEMM pin is still required for this blocker.
tf32_hc_prenorm_gemm DeepGEMM native HyperConnection kernel Pinned DeepGEMM SM12x compatibility implementation. Unavailable for serving; the DeepGEMM pin is still required for this blocker.

Validation

Test machine:

2x NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Compute capability: SM120
CUDA: 13.x
vLLM branch: ds4-sm120
vLLM commit: 8d0ebb76c
DeepGEMM pin: 7a7a41a1bac7dacabe74057e7600e59f98f85bce

Static / unit checks

python -m pytest -q \
  tests/v1/worker/test_kv_cache_view_utils.py \
  tests/v1/attention/test_sparse_mla_env.py
# 10 passed

python -m pytest -q \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py \
  tests/v1/attention/test_sparse_mla_env.py
# 85 passed, 16 warnings

python -m ruff check \
  csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu \
  vllm/v1/attention/backends/mla/sparse_mla_kernels.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  vllm/v1/worker/kv_cache_view_utils.py \
  vllm/v1/worker/gpu/attn_utils.py \
  vllm/v1/worker/gpu_model_runner.py \
  tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py \
  tests/v1/attention/test_sparse_mla_env.py \
  tests/v1/worker/test_kv_cache_view_utils.py
# passed

python -m py_compile \
  vllm/v1/attention/backends/mla/sparse_mla_kernels.py \
  vllm/model_executor/layers/deepseek_v4_attention.py \
  vllm/v1/worker/kv_cache_view_utils.py
# passed

DeepGEMM pin checks:

bash -n tools/install_deepgemm.sh
git diff --check -- cmake/external_projects/deepgemm.cmake tools/install_deepgemm.sh

Runtime environment knobs

Build / CUDA environment used for local editable installs and CUDA extension builds:

Note: DGX Spark change 120a and 12.0a to 121a and 12.1a

export PATH="/usr/local/cuda/bin:$PATH"
export CUDA_HOME="/usr/local/cuda"
export TRITON_PTXAS_PATH="/usr/local/cuda/bin/ptxas"
export CUDA_ARCH_LIST="120a"
export TORCH_CUDA_ARCH_LIST="12.0a"

Sparse MLA / DeepGEMM runtime knobs:

variable default meaning
VLLM_TRITON_MLA_SPARSE unset / auto When unset, SM12x automatically uses the Triton sparse MLA fallback. Set 1 to force-enable it, or 0 to disable it.
VLLM_DEEP_GEMM_SM120_PAGED_MQA_TILED 1 Enables the DeepGEMM SM120 tiled paged-MQA path when the pinned DeepGEMM build provides it. Set 0 to force the simpler compatibility fallback.
VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH 1 Keeps vLLM compile and CUDA graphs enabled for the Triton sparse MLA fallback. Set 0 to disable compile/cudagraph capture for debugging.
VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE 512 Top-k chunk size for the sparse MLA fallback when processing compressed candidates. Lower values reduce temporary pressure; higher values can reduce loop overhead.
VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE 256 Query-token chunk size for sparse MLA prefill. This bounds prefill workspace and launch granularity.
VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE auto Optional override for decode head grouping. Valid values are 1, 2, and 4; unset uses 1 for single-token decode, 2 for small decode batches, and 4 for larger decode batches.
VLLM_TRITON_MLA_SPARSE_DUMP 0 Debug-only switch that writes sparse MLA tensor metadata to JSONL and aborts the request.
VLLM_TRITON_MLA_SPARSE_DUMP_PATH /tmp/deepseek_v4_triton_mla_sparse_dump.jsonl Output path used when VLLM_TRITON_MLA_SPARSE_DUMP=1.

Unsupported follow-up knobs checked

Two potentially useful runtime knobs were checked separately and are not currently usable on this SM120 path.

use_fp4_indexer_cache must be passed through the JSON attention config, not as a dotted argparse option:

--attention-config '{"use_fp4_indexer_cache":true}'

The dotted form is rejected by argparse:

--attention_config.use_fp4_indexer_cache=True
# error: unrecognized arguments

Even with the correct JSON form, startup fails on RTX PRO 6000 Blackwell / SM120 because the current vLLM indexer metadata builder gates this feature to datacenter Blackwell SM10x:

AssertionError: use_fp4_indexer_cache requires Blackwell datacenter GPUs
(sm_10x, e.g. B200/GB200); sm_120 (consumer Blackwell) and earlier
architectures are not supported.

MTP speculative decoding must also use the dashed JSON CLI form:

--speculative-config '{"method":"mtp","num_speculative_tokens":2}'

The model-side MTP patch is present in this branch: the target model exposes the
pre-hc_head residual buffer, and startup loads the MTP draft model and shares
the target embedding / lm_head / topk_indices_buffer weights. The SM12x
sparse MLA decode fallback now also accepts MTP-shaped q_len > 1 decode by
using explicit global-slot sparse indices for the SWA subset instead of the
single-token paged SWA window path.

This is intentionally still experimental. It is useful as a correctness bridge
and as an optimization target for other developers, but current throughput is
below the non-MTP path.

Serving smoke

Short-context serving works with:

VLLM_TRITON_MLA_SPARSE=1 \
vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --max-model-len 8192 \
  --gpu-memory-utilization 0.985 \
  --tensor-parallel-size 2 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4

1M context initialization also works with:

--max-model-len 1048576
--gpu-memory-utilization 0.985
--tensor-parallel-size 2

Relevant startup log:

Using max model len 1048576
Model loading took 74.8 GiB memory
Estimated CUDA graph memory: 3.84 GiB total
Available KV cache memory: 9.15 GiB
GPU KV cache size: 7,468 tokens
Maximum concurrency for 1,048,576 tokens per request: 1.32x
Graph capturing finished in 23 secs, took 1.83 GiB
Supported tasks: ['generate']

Serving benchmark

Common benchmark shape:

random input len: 1024
random output len: 1024
num prompts: 32
max_model_len=8192
gpu_memory_utilization=0.985

Peak VRAM is the maximum memory.used observed by a 1-second nvidia-smi sampler during each benchmark run.

TP=2, PP=1

parallelism: TP=2, PP=1
benchmark artifacts: /tmp/vllm-ds4-sm120-tp2-bench-20260427-023128
max concurrency success duration s req/s output tok/s total tok/s mean TTFT ms mean TPOT ms peak VRAM GPU0/GPU1
1 32/32 335.70 0.10 97.61 195.22 283.32 9.98 93.42 / 93.43 GiB
4 32/32 120.39 0.27 272.18 544.36 964.58 13.77 93.49 / 93.49 GiB
8 32/32 78.19 0.41 419.10 838.21 2016.99 17.13 94.31 / 94.32 GiB
16 32/32 54.72 0.58 598.84 1197.68 3370.12 23.44 94.31 / 94.32 GiB
32 32/32 41.93 0.76 781.47 1562.94 6024.58 35.06 94.43 / 94.44 GiB

Compared with the previous PR table, the same TP=2/PP=1 c=32 benchmark moved
from 626.69 to 781.47 output tok/s, and c=1 moved from 45.14 to 97.61 output
tok/s.

TP=1, PP=2

parallelism: TP=1, PP=2
benchmark artifacts: /tmp/vllm-ds4-sm120-pp2tp1-bench-20260427-024506
max concurrency success duration s req/s output tok/s total tok/s mean TTFT ms mean TPOT ms peak VRAM GPU0/GPU1 notes
1 32/32 428.22 0.07 76.52 153.04 561.98 12.53 90.95 / 90.08 GiB stable
4 32/32 168.00 0.19 195.04 390.09 1537.56 19.02 91.81 / 90.14 GiB stable
8 32/32 117.04 0.27 279.98 559.97 3215.98 25.46 94.56 / 91.35 GiB stable
16 32/32 85.92 0.37 381.37 762.74 4792.50 37.29 94.83 / 94.19 GiB stable
32 0/32 n/a n/a n/a n/a n/a n/a 94.94 / 94.80 GiB overcommitted; PP0 OOM allocating 720 MiB in the SM12x fp8_fp4_mqa_logits torch reference path, then EngineCore waited on a dead worker

PP=2/TP=1 is stable through max_concurrency=16 in this short-context benchmark,
but max_concurrency=32 still overcommits memory. The c=32 row is included as a
failure case only and should not be interpreted as usable throughput.

MTP speculative decoding

parallelism: TP=2, PP=1
speculative config: {"method":"mtp","num_speculative_tokens":2}
gpu_memory_utilization=0.92
startup artifact: /tmp/vllm-ds4-sm120-mtp-smoke-092.log
benchmark artifact: /tmp/vllm-ds4-sm120-mtp-c4-bench-20260427-033828

At gpu_memory_utilization=0.985, startup reaches readiness but the first
request OOMs inside the DeepSeek compressor Triton path. Lowering to 0.92
leaves enough transient headroom for a short serving smoke and the benchmark
below.

max concurrency success duration s req/s output tok/s total tok/s mean TTFT ms mean TPOT ms p99 TPOT ms acceptance rate acceptance length peak VRAM GPU0/GPU1 notes
4 32/32 400.35 0.08 81.85 163.70 949.67 43.90 64.84 26.80% 1.54 92.15 / 92.16 GiB Stable, but substantially slower than non-MTP TP=2/PP=1 c=4 at 272.18 output tok/s.

A comparable c=1 MTP run was stopped after 5/32 requests because it was taking
roughly 40 seconds per request. No official throughput row is reported for that
partial run.

Known limitations

  • This is a prototype branch and still contains a large stacked diff.
  • Performance is correctness-first and not yet equivalent to optimized SM90/SM100 FlashMLA / DeepGEMM paths.
  • 1M context is memory-sensitive and requires high --gpu-memory-utilization; tested working at 0.985, while 0.99 is rejected by vLLM startup memory guard on this machine.
  • 0.985 supports approximately one full 1M-token request, not multiple full-context concurrent requests.
  • The DeepGEMM dependency is pinned to a fork commit until SM120 compatibility lands upstream.
  • MTP model plumbing and sparse MLA q_len > 1 decode are included, but MTP speculative serving remains experimental. It currently needs extra memory headroom and regresses throughput versus the non-MTP path.
  • --attention-config '{"use_fp4_indexer_cache":true}' is not supported on SM120; current vLLM gates it to datacenter Blackwell SM10x.

zyongye and others added 30 commits April 26, 2026 10:44
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Co-authored-by: Yongye Zhu <yongye@inferact.ai>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Simon Mo <simon@inferact.ai>
Co-authored-by: Bugen Zhao <i@bugenzhao.com>
Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roy Wang <yasong.wang@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Yifan Qiao <yifanqiao@inferact.ai>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Zhewen Li <jerven.vllm@gmail.com>
Co-authored-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: khluu <khluu000@gmail.com>
Co-authored-by: qizixi <zixi@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Add an experimental SM120 DeepSeek V4 path that keeps the existing
FlashMLA sparse metadata, KV-cache, and top-k/global-slot plumbing, but
uses a correctness-first reference attention implementation instead of
calling FlashMLA kernels that are unavailable on SM120.

The prototype is gated behind VLLM_SM120_REFERENCE_DEEPSEEK_V4_ATTENTION
and leaves the SM90/SM100 FlashMLA path unchanged. It also adds diagnostic
dumping via VLLM_SM120_DUMP_DEEPSEEK_V4_ATTENTION so shape and metadata
issues can be captured without changing normal execution.

Implemented pieces:
- sink-aware reference sparse attention with online softmax state
- SWA-only decode reference path
- compressed decode reference paths for C4A and C128A
- chunked C128A top-k processing to avoid materializing full 8192-slot KV
  tensors per token batch
- prefill reference path over the existing gathered KV workspace and
  combined sparse/SWA indices
- fp8_ds_mla global-slot dequantization helper for arbitrary physical KV
  cache slots
- SM120 tile-scheduler bypass when the reference path is enabled
- torch.compile defunctionalization for DeepSeek V4 FP8/CUTLASS custom ops
- E8M0 scale upcast before CUTLASS scaled-mm calls

Validation on the SM120 host:
- python -m py_compile over the modified Python modules
- git diff --check
- GPU smoke tests for single-chunk, multi-chunk, and prefill reference
  attention all reported max_abs=0.0 against PyTorch golden references
- vllm serve DeepSeek-V4-Flash with --max-model-len 262144 started
  successfully and returned HTTP 200 for a one-token /v1/completions request

The default 1M context remains outside this prototype's current memory
budget: after the reference path and CUDA graph capture, vLLM's KV-cache
admission reports insufficient available KV memory for 1,048,576 tokens.
DeepGEMM now has an experimental SM120/SM121 compatibility path in the
local dependency branch, but vLLM's vendored DeepGEMM CMake integration was
still filtering CUDA architectures down to SM90/SM100. On an SM120 host this
made the _deep_gemm_C extension fall through the unsupported-architecture
branch even when CUDA 13 and a compatible DeepGEMM source checkout were
available.

Add 12.0f to the DeepGEMM supported architecture list for CUDA 13.0+. vLLM's
cuda_archs_loose_intersection helper maps this to the requested SM12x target
(for example CUDA_ARCH_LIST=120a resolves to DeepGEMM CUDA architectures:
12.0a), matching the SM120/SM121 compatibility model used by the prototype.

Validation on the SM120 host:
- git diff --check
- DEEPGEMM_SRC_DIR=~/tmp/DeepGEMM CCACHE_NOHASHDIR=true MAX_JOBS=64
  pip install --verbose --no-build-isolation -e . with CUDA_ARCH_LIST=120a
  and TORCH_CUDA_ARCH_LIST=12.0a
- build log reported: DeepGEMM CUDA architectures: 12.0a
Refactor the experimental SM120 DeepSeek V4 reference attention path so the
attention subsets compute no-sink normalized outputs plus log-sum-exp values,
then apply the learnable attention sink exactly once in a small merge step.

This makes the prototype line up with the intended optimization boundary:
SWA and compressed top-k attention can be replaced independently by Triton or
CUDA kernels that return (out_no_sink, lse_no_sink), while the shared
sink-aware merge remains small and easy to verify.

The decode path now computes compressed and SWA subsets separately and merges
both LSEs with the sink denominator. SWA-only decode and prefill still use the
same math, but go through the no-sink finalize plus merge helper so all
reference paths share one sink application point.

Validation on the SM120 host:
- python -m py_compile vllm/model_executor/layers/deepseek_v4_attention.py
- git diff --check
- GPU smoke tests reported single_chunk_max_abs=0.0,
  multi_chunk_max_abs=0.0, and prefill_max_abs=0.0 against PyTorch golden
  references
- vllm serve DeepSeek-V4-Flash with --max-model-len 262144 started
  successfully and returned HTTP 200 for a one-token /v1/completions request
DeepSeek V4 chat prompts can enter the sparse attention indexer prefill path before the SM120 sparse-attention reference path runs. That path calls DeepGEMM fp8_fp4_mqa_logits, which still rejects SM12x in attention.hpp and kills the engine with Unsupported architecture.

Add an SM120-only FP8 torch reference implementation for non-paged MQA logits. It dequantizes FP8 K rows with the existing per-row scale, accumulates relu(q @ k.T) weighted across heads in small head chunks, and preserves the existing clean_logits mask behavior. SM90/SM100 and FP4 continue to use the DeepGEMM implementation.

Add an SM120 CUDA regression test comparing the fallback against an explicit PyTorch reference so the chat/prefill indexer path remains covered.
The SM120 DeepSeek V4 prototype reads several experimental controls directly from the environment. vLLM warns about VLLM_* variables that are not present in envs.environment_variables, so startup printed unknown-variable warnings even though the controls worked.

Register the SM120 reference attention flags, dump path, and chunk-size controls in envs.py so startup validation recognizes them without changing the existing os.getenv-based behavior.
The SM120 vLLM prototype needs the DeepGEMM branch that contains the experimental SM120 reference fallbacks. The CUDA-13 CMake architecture patch only allows DeepGEMM to be built for SM12x; if DEEPGEMM_SRC_DIR is not set, the default upstream DeepGEMM tag still lacks those fallback kernels.

Point the prototype vendored DeepGEMM fetch to jasl/DeepGEMM at the SM120 fallback commit so rebuilds without DEEPGEMM_SRC_DIR use the same dependency that was validated on the DGX Spark host.
Add a minimal pipeline-parallel path for DeepSeek V4 so the model can run with TP=1 and PP=2 on the SM120 prototype branch. The causal LM wrapper now advertises SupportsPP, creates the LM head only on the last pipeline rank, and exposes the model intermediate tensor factory.\n\nSplit the core model forward by pipeline rank: the first rank embeds tokens and expands them into the HyperConnection stream, intermediate ranks receive flattened HC hidden states, non-last ranks return IntermediateTensors, and the last rank applies the HC head and final RMSNorm. The model loader now skips parameters that belong to PP-missing layers, including stacked attention/MLP weights, per-expert MoE weights, attn sinks, embeddings, and final norm parameters.\n\nKeep this intentionally scoped to the tested prototype configuration: TP=1, PP=2, no speculative/MTP pipeline support. The MTP hidden-state buffer is only allocated and populated on the final rank, so get_mtp_target_hidden_states returns None away from the last stage.\n\nVerification:\n- supports_pp(DeepseekV4ForCausalLM) changed from False before the patch to True after the patch.\n- python -m py_compile vllm/model_executor/models/deepseek_v4.py tests/models/test_deepseek_v4_pp.py\n- manual invocation of tests/models/test_deepseek_v4_pp.py assertion passed in the venv.\n- Started DeepSeek-V4-Flash with --tensor-parallel-size 1 --pipeline-parallel-size 2 on port 8001; PP ranks initialized as PP0/PP1, layers split as [22,21], checkpoint loaded, CUDA graphs captured, and /v1/chat/completions returned HTTP 200 with normal assistant text.\n\nNote: python -m pytest tests/models/test_deepseek_v4_pp.py -q was not runnable because the remote venv has no pytest module.
Introduce VLLM_TRITON_MLA_SPARSE as the generic control for the correctness-first sparse MLA fallback, with dump and chunk-size knobs under the same namespace. The old VLLM_SM120_* names remain registered and are still honored as aliases so existing prototype scripts keep working without unknown-env warnings.

When the new control is unset, SM12x devices now select the reference sparse MLA path automatically. This keeps SM90/SM100 on the FlashMLA sparse path while avoiding the unavailable FlashMLA tile-scheduler path on SM120/SM121. The fallback can still be force-disabled with VLLM_TRITON_MLA_SPARSE=0 or force-enabled for debugging with VLLM_TRITON_MLA_SPARSE=1.

The attention helper names and diagnostics now describe the implementation as a sparse MLA reference fallback instead of a DeepSeek-V4-specific SM120 switch, which is a step toward a reviewable TRITON_MLA_SPARSE backend shape.

Signed-off-by: jasl <jasl9187@hotmail.com>
Honor VLLM_TRITON_MLA_SPARSE_DUMP explicitly when it is set, even if the legacy VLLM_SM120_DUMP_DEEPSEEK_V4_ATTENTION alias is also present. This keeps the new generic sparse MLA namespace authoritative while preserving the old alias as a fallback when the new variable is unset.

Signed-off-by: jasl <jasl9187@hotmail.com>
Keep the vLLM CMake change scoped to the CUDA 13 SM12x architecture entry. The prototype can still build against the SM120 DeepGEMM branch by passing DEEPGEMM_SRC_DIR, but the default FetchContent source now stays on the upstream DeepGEMM repository and tag so this patch is reviewable as a simple build-compatibility change.

Signed-off-by: jasl <jasl9187@hotmail.com>
Restore the default FetchContent source to the SM120 DeepGEMM prototype branch so users can build this vLLM branch directly without setting DEEPGEMM_SRC_DIR. The separate CUDA 13 SM12x architecture entry remains in place for local-source and vendored builds.

Signed-off-by: jasl <jasl9187@hotmail.com>
Cover the correctness-first reference path used by the SM12x sparse MLA fallback. The tests assert that the no-sink subset attention returns log-sum-exp state, that the sink contributes only to the denominator, and that merging SWA and compressed subsets by LSE is equivalent to attention over the concatenated sparse candidates.

Also add a CUDA regression for the V4 fp8_ds_mla global-slot dequant path. It builds the 584-byte block-packed cache layout directly, checks UE8M0 scale decoding, verifies invalid slots are zeroed, and exercises both 2D and 3D slot-id inputs.

Signed-off-by: jasl <jasl9187@hotmail.com>
Move the PyTorch reference sparse MLA math out of DeepSeek V4 attention into vllm.v1.attention.backends.mla.sparse_mla_reference. The model path now calls shared helpers for no-sink accumulation, LSE finalization, and sink-aware subset merging instead of carrying those methods on DeepseekV4MLAAttention.

Update the sparse MLA correctness tests to target the shared helper module directly. This keeps the reference contract reusable for a future TRITON_MLA_SPARSE backend while preserving the current SM12x DeepSeek V4 fallback behavior.

Signed-off-by: jasl <jasl9187@hotmail.com>
Extract the DeepSeek V4 prefill reference fallback into sparse_mla_reference.reference_sparse_mla_prefill. The model layer now passes the prepared combined sparse indices, lens, sink, scale, and chunk-size knobs into a shared helper instead of carrying the prefill accumulation loop inline.

Add a direct prefill correctness test that compares the shared helper against a dense golden formulation across multiple query and top-k chunk sizes, including duplicate indices, invalid -1 entries, and an all-invalid row.

Signed-off-by: jasl <jasl9187@hotmail.com>
Move VLLM_TRITON_MLA_SPARSE and legacy SM120 alias parsing into a shared sparse_mla_env module. DeepSeek V4 attention and the SWA metadata builder now use the same helpers for reference fallback enablement, diagnostic dumps, and chunk-size knobs.

This preserves the current behavior where unset VLLM_TRITON_MLA_SPARSE auto-enables the fallback on SM12x, while an explicit new value overrides the legacy aliases.

Signed-off-by: jasl <jasl9187@hotmail.com>
Allocate the intermediate wo_a einsum output through the V1 workspace manager instead of torch.empty. This removes one hot-path dynamic allocation after sparse MLA attention has completed while leaving the attention output buffer as a regular tensor, since that buffer must survive nested workspace use inside the attention implementation.

Signed-off-by: jasl <jasl9187@hotmail.com>
Cover the shared sparse MLA environment helper semantics, including new-name precedence over the legacy SM120 aliases, legacy chunk-size fallbacks, invalid chunk-size defaults, and diagnostic dump path precedence.

Signed-off-by: jasl <jasl9187@hotmail.com>
Add a small Triton kernel for merging the compressed and SWA sparse MLA subset outputs with the DeepSeek V4 attention sink. The kernel implements the stable LSE merge formula and is used by the SM12x compressed decode reference path after the two subset attentions have produced out/lse pairs.

Keep the PyTorch merge helper as the correctness oracle and add a CUDA regression that compares the Triton merge against that reference, including -inf subset LSE entries.

Signed-off-by: jasl <jasl9187@hotmail.com>
Add a portable Triton online-softmax accumulator for gathered sparse MLA subsets. The kernel updates per-token/head max, denominator, and accumulator state across chunks, and a finish kernel emits the subset-normalized output and LSE.

Route the DeepSeek V4 compressed decode reference path through the Triton accumulator for both compressed top-k and SWA subsets while keeping the existing fp8_ds_mla dequantize-by-slot kernels and sink-aware LSE merge boundary intact.

Extend sparse MLA reference tests with CUDA parity coverage for chunked slot-id accumulation, the slot-id-free SWA path, and 512-dim heads.

Signed-off-by: jasl <jasl9187@hotmail.com>
Add a Triton accumulator that reads DeepSeek V4 fp8_ds_mla packed KV-cache entries directly from global slot ids, dequantizes the 448 FP8 NoPE dimensions plus 64 BF16 RoPE dimensions, and updates the same online softmax state used by the portable sparse MLA fallback.

Use the fused packed-cache accumulator for the compressed top-k decode subset so the SM12x fallback no longer materializes a BF16 compressed_kv scratch buffer before attention. The SWA subset still uses the gathered BF16 accumulator and the final sink-aware LSE merge remains unchanged.

Cover the new path with CUDA parity tests against the PyTorch reference, including chunked slot-id accumulation, invalid slots, lens truncation, and the 584-byte fp8_ds_mla layout.

Signed-off-by: jasl <jasl9187@hotmail.com>
Add a Triton sparse MLA accumulator that reads DeepSeek V4 fp8_ds_mla packed KV-cache entries through seq_lens and block_table instead of pre-gathering the sliding-window subset into a BF16 scratch buffer.

Use the paged accumulator for the SWA subset inside the compressed decode fallback, keeping the compressed top-k accumulator and sink-aware LSE merge unchanged.

Cover the new path with a CUDA parity test using non-trivial block-table mappings plus a 128-token, 512-dim SWA smoke run against the PyTorch reference.

Signed-off-by: jasl <jasl9187@hotmail.com>
Add a Triton single-subset sink merge for sparse MLA outputs and use it with the paged fp8_ds_mla accumulator to remove the BF16 gather/reference path from SWA-only decode.

The SWA-only fallback now reads the packed paged cache directly, finishes the no-sink attention state, then applies the sink denominator in Triton. Tests cover the single-subset merge and paged SWA attention with sink against the PyTorch reference.

Signed-off-by: jasl <jasl9187@hotmail.com>
Add a BF16 indexed sparse MLA accumulator that reads kv_flat by combined_indices and updates online softmax state without materializing gathered_kv.

Route the DeepSeek V4 sparse MLA prefill fallback through query/top-k chunked Triton accumulation, finish, and single-subset sink merge. The PyTorch reference helper remains as a test oracle.

Tests cover invalid indices, duplicates, all-invalid rows, query chunking, and top-k chunking against the reference prefill path.

Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 26, 2026

@tonyliu312 Please check my latest commit, and feel free to cherry-pick.

image

I nearly doubled the performance.

jasl added 2 commits April 27, 2026 01:39
Use active prefill sequence and gather lengths to size the DeepSeek V4 sparse MLA staging workspace instead of reserving against max_model_len and max_num_batched_tokens. This keeps the gathered KV row stride bounded by the current request batch, which matters for long-prompt agent workloads and especially large max-model-len configurations.

Also route sparse indexer prefill logits sizing through a shared helper. SM12x now defaults to a 256 MiB logits cap when VLLM_SPARSE_INDEXER_MAX_LOGITS_MB is unset, while explicit env overrides keep the previous behavior. The profiling dummy allocation uses the same helper as runtime chunking so the memory profile reflects the configured cap.

Validation: python -m pytest -q tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py; python -m pytest -q tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py tests/v1/attention/test_sparse_mla_backends.py -k 'prefill_workspace_bounds or sparse_indexer_max_logits_bytes or split_indexer_prefill_chunks'; python -m ruff check vllm/model_executor/layers/deepseek_v4_attention.py vllm/v1/attention/backends/mla/sparse_swa.py vllm/v1/attention/backends/mla/indexer.py vllm/model_executor/layers/sparse_attn_indexer.py tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py tests/v1/attention/test_sparse_mla_backends.py; long-prompt smoke 15k input x10 on TP=2/EP/eager succeeded.
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 26, 2026

@BehindTheCartan
I can't reproduce your cases, so I'm guessing the ds4-sm120 branch is good.
But for 2 x RTX Pro 6000 configuration, OOM is a potential issue, because the memory usage is too extreme.
So I improved the memory usage slightly.

UPDATE: I'm working on improving stability for long prefill, will update tomorrow

Route multi-token speculative decode through global-slot sparse MLA accumulation for both SWA-only and compressed sparse paths while keeping single-token decode on the existing paged fast paths.

Disable sparse MLA CUDA graph capture by default when speculative decoding is configured, preserving the explicit env override.

Add regression coverage for MTP-shaped sparse MLA decode metadata and cudagraph policy.
list(APPEND DEEPGEMM_SUPPORT_ARCHS "10.0a")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
list(APPEND DEEPGEMM_SUPPORT_ARCHS "12.0f")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need also 12.1a here for DGX Spark?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12.0f means for all 12.x family

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see in log

[gpu_model_runner.py:4884] Model loading took 39.77 GiB memory and 522.056257 seconds
Running NVCC command: cd /root/.cache/vllm/deep_gemm/tmp && /usr/local/cuda/bin/nvcc /root/.cache/vllm/deep_gemm/tmp/203-af1d4b4f-d7902775-636c4fd8/kernel.cu -cubin -o /root/.cache/vllm/deep_gemm/tmp/203-af1d4b4f-d7902775-636c4fd8/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 -I/usr/local/lib/python3.12/dist-packages/vllm/third_party/deep_gemm/include -gencode=arch=compute_120f,code=sm_120f --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda

Was thinking, is it problem on my side with env, not sure, but would like to see there 121f :) okay, will continue with testing...

@idonati
Copy link
Copy Markdown

idonati commented Apr 27, 2026

DeepSeek-V4-Pro working on 8× DGX Spark (sm_121, TP=8) — recipe + suggested upstream nudge

Hi @jasl @tonyliu312 — first, thanks for the SM12x V4 work in #40899 and the
Marlin sm_12x cubin fix in #40923. Wanted to report success and one small
upstream finding that may help other GB10 users.

Stack

  • 8× NVIDIA DGX Spark (GB10, sm_121, 128 GB unified memory per node)
  • 200 Gbit/s RoCE multi-rail fabric, NCCL RDMA verified
  • Image: nvcr.io/nvidia/pytorch:25.11-py3 base + jasl/vllm@ds4-sm120-prototype + jasl/DeepGEMM@sm120
  • DeepSeek-V4-Pro (FP8 dense + MXFP4 384-expert MoE), 805 GB checkpoint, ~102 GiB / rank at TP=8

TL;DR

V4-Pro now fires up + serves coherently on the 8-Spark cluster with three changes:

  1. PR [Kernel] Marlin MoE: include SM 12.x in default arch list #40923 applied (Marlin MARLIN_ARCHS/MARLIN_BF16_ARCHS/MARLIN_MOE_ARCHS include 12.0;12.1)
  2. Rebuilt vLLM C extensions with TORCH_CUDA_ARCH_LIST="12.0;12.1" (replaces broken "12.0+PTX" which produces no native sm_12x cubins for the MoE Marlin path)
  3. --safetensors-load-strategy prefetch added to vllm serve invocation

Verified output (coherent, finish_reason: stop on all):

Prompt Response
"What is 7*8?" "7 × 8 = 56."
"Capital of France?" "The capital of France is Paris."
"Hello in Spanish?" "¡Hola!"

What we observed (and why the third flag was the breaker)

After applying #40923 + the rebuild, V4-Flash worked immediately but V4-Pro hung in MoE weight materialization for 5/8→3/8 workers (depending on run). The hung workers were ALL stuck in _load_w2 / _load_w13 at 100% CPU (single core). py-spy stack:

Process N: ray::RayWorkerWrapper.execute_method
Thread N (active): "MainThread"
    _load_w2 (vllm/model_executor/layers/fused_moe/layer.py:989)   ← just expert_data.copy_(loaded_weight)
    _load_model_weight_or_group_weight_scale (.../layer.py:822)
    weight_loader (.../layer.py:1336)
    load_weights (vllm/model_executor/models/deepseek_v4.py:745)
    ...
    load_model (vllm/v1/worker/gpu_worker.py:323)

The slow workers weren't doing anything special — _load_w2 is literally expert_data.copy_(loaded_weight). The slowness was per-tensor random-access against a lazily-mmap'd safetensors file. Engine log showed the auto-prefetch heuristic skipping us:

INFO weight_utils.py:904] Filesystem type for checkpoints: EXT4. Checkpoint size: 805.33 GiB. Available RAM: 9.30 GiB.
INFO weight_utils.py:934] Auto-prefetch is disabled because the filesystem (EXT4) is not a
                          recognized network FS (NFS/Lustre) and the checkpoint size (805.33 GiB)
                          exceeds 90% of available RAM (8.80 GiB).

Setting --safetensors-load-strategy prefetch (warms the OS page cache before workers do their per-tensor random reads) eliminates the per-worker straggler effect. All 8 workers complete MoE weight load uniformly, engine reaches Uvicorn startup, model serves.

Suggested upstream nudge

The weight_utils.py heuristic disables auto-prefetch when the FS isn't NFS/Lustre — but the problem isn't the FS type, it's the access pattern. Per-tensor random-access on lazily-mmap'd weights is slow on local NVMe too (especially under memory pressure where OS page cache evictions force redundant reads). The same condition the heuristic uses to warn (checkpoint size > 90% available RAM) is exactly the condition where prefetch matters most.

Suggested change: invert the condition — when checkpoint > 90% RAM, enable prefetch by default (regardless of FS), with the comment "user can --safetensors-load-strategy lazy to opt out". Or at minimum, log a hint like "consider --safetensors-load-strategy prefetch" when the warning fires.

This would have saved us about 5 hours of diagnostic cycles. Happy to file as a separate vllm issue if useful.

Numbers (V4-Pro on 8× DGX Spark, TP=8, max-model-len=1024, gpu-memory-utilization=0.88)

  • Cold-start to FIRED-UP: ~12 min (including 6 min shard load, ~5 min weight conversion + warmup, ~1 min KV alloc + Uvicorn)
  • Per-rank weight footprint: 102.39 GiB (logged)
  • Decode latency on the 3 sample prompts: <2s end-to-end each
  • Production V4-Flash on the same cluster: 11.4 tok/s sustained at TP=8 + RDMA RoCE multi-rail (separate ao-dsv4:dev image, no Marlin patch needed for that since V4-Flash has 256 experts vs Pro's 384 and was getting lucky on the broken PTX path)

vllm serve invocation that works

vllm serve deepseek-ai/DeepSeek-V4-Pro \
  --trust-remote-code --kv-cache-dtype fp8 --block-size 256 \
  --tensor-parallel-size 8 --pipeline-parallel-size 1 \
  --max-model-len 1024 --gpu-memory-utilization 0.88 \
  --safetensors-load-strategy prefetch \
  --tokenizer-mode deepseek_v4 --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice --reasoning-parser deepseek_v4 \
  --host 0.0.0.0 --port 5001 \
  --distributed-executor-backend ray --enforce-eager

Container env (key flags carried over from V4-Flash recipe):

NCCL_IB_HCA=rocep1s0f0,rocep1s0f1,roceP2p1s0f0,roceP2p1s0f1
NCCL_IB_GID_INDEX=3 NCCL_DEBUG=WARN NCCL_COLLNET_ENABLE=0
RAY_memory_monitor_refresh_ms=0 RAY_memory_usage_threshold=0.999
VLLM_USE_DEEP_GEMM_E8M0=0
VLLM_SM120_REFERENCE_DEEPSEEK_V4_ATTENTION=1

Attachments

  • v4pro-success-engine.log — full engine log from the successful run
  • v4pro-coherent-sample.txt — the three prompt + response samples
  • v4pro-mem-snapshot.txt — per-node free -h during V4-Pro inference

Thanks

The combination of #40899 (SM12x V4 support), #40923 (Marlin sm_12x cubins),
and the existing --safetensors-load-strategy prefetch knob landed
DeepSeek-V4-Pro on a workstation-class cluster that wasn't supposed to be the
target. Filing this as both a confirmation and a small documentation request,
not a complaint.

@idonati

@idonati
Copy link
Copy Markdown

idonati commented Apr 27, 2026

Supporting evidence (inlined for searchability):

Coherent V4-Pro samples (full, captured by diagnose-pro-marlin.sh after FIRED-UP)

=== V4-Pro diagnostic — Sun Apr 26 08:22:12 PM EDT 2026 ===
max_model_len: 1024

--- /v1/models ---
{"object":"list","data":[{"id":"deepseek-ai/DeepSeek-V4-Pro","object":"model","created":1777249332,"owned_by":"vllm","root":"deepseek-ai/DeepSeek-V4-Pro","parent":null,"max_model_len":1024,"permission":[{"id":"modelperm-a9e7dc4893af1a12","object":"model_permission","created":1777249332,"allow_create_engine":false,"allow_sampling":true,"allow_logprobs":true,"allow_search_indices":false,"allow_view":true,"allow_fine_tuning":false,"organization":"*","group":null,"is_blocking":false}]}]}
--- using MODEL_ID=deepseek-ai/DeepSeek-V4-Pro ---

PROMPT: What is 7*8?
RESPONSE: {"id":"chatcmpl-96835de946614aa3","object":"chat.completion","created":1777249332,"model":"deepseek-ai/DeepSeek-V4-Pro","choices":[{"index":0,"message":{"role":"assistant","content":"7 × 8 = **56**.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":11,"total_tokens":20,"completion_tokens":9,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
HTTP_STATUS=200

PROMPT: Capital of France?
RESPONSE: {"id":"chatcmpl-8717541853485b0f","object":"chat.completion","created":1777249353,"model":"deepseek-ai/DeepSeek-V4-Pro","choices":[{"index":0,"message":{"role":"assistant","content":"The capital of France is **Paris**.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":8,"total_tokens":17,"completion_tokens":9,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
HTTP_STATUS=200

PROMPT: Hello in Spanish?
RESPONSE: {"id":"chatcmpl-af8c2f9d5e809cc4","object":"chat.completion","created":1777249354,"model":"deepseek-ai/DeepSeek-V4-Pro","choices":[{"index":0,"message":{"role":"assistant","content":"¡Hola!","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":8,"total_tokens":13,"completion_tokens":5,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
HTTP_STATUS=200

Per-node memory snapshot during V4-Pro inference (TP=8, max-model-len=1024, gpu-mem=0.88)

ao1: 121Gi 116Gi 5.3Gi
ao2: 121Gi 115Gi 6.1Gi
ao3: 121Gi 115Gi 6.1Gi
ao4: 121Gi 115Gi 5.9Gi
ao5: 121Gi 115Gi 5.8Gi
ao6: 121Gi 115Gi 6.0Gi
ao7: 121Gi 116Gi 5.7Gi
ao8: 121Gi 115Gi 6.6Gi

Engine log around fired-up (last 80 lines)

(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=370, ip=10.111.51.13)�[0m INFO 04-27 00:19:54 [gpu_model_runner.py:4848] Model loading took 102.39 GiB memory and 328.588696 seconds�[32m [repeated 2x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.11)�[0m INFO 04-27 00:20:11 [default_loader.py:384] Loading weights took 342.45 seconds
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.11)�[0m INFO 04-27 00:20:11 [mxfp4.py:1238] Using MoEPrepareAndFinalizeNoDPEPModular
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.11)�[0m INFO 04-27 00:20:19 [gpu_model_runner.py:4848] Model loading took 102.39 GiB memory and 353.257973 seconds
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m INFO 04-27 00:20:50 [default_loader.py:384] Loading weights took 381.90 seconds
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m INFO 04-27 00:20:51 [mxfp4.py:1238] Using MoEPrepareAndFinalizeNoDPEPModular
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m INFO 04-27 00:21:05 [gpu_model_runner.py:4848] Model loading took 102.39 GiB memory and 393.339219 seconds
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=368, ip=10.111.51.17)�[0m 2026-04-27 00:21:12  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:133): TileLang begins to compile kernel `mhc_pre_big_fuse_tilelang` with `out_idx=None`
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=370, ip=10.111.51.13)�[0m 2026-04-27 00:21:14  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:141): TileLang completes to compile kernel `mhc_pre_big_fuse_tilelang`
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=370, ip=10.111.51.13)�[0m 2026-04-27 00:21:16  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:133): TileLang begins to compile kernel `mhc_post_tilelang` with `out_idx=None`
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m 2026-04-27 00:21:13  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:133): TileLang begins to compile kernel `mhc_pre_big_fuse_tilelang` with `out_idx=None`�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.14)�[0m 2026-04-27 00:21:18  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:141): TileLang completes to compile kernel `mhc_post_tilelang`
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m 2026-04-27 00:21:16  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:141): TileLang completes to compile kernel `mhc_pre_big_fuse_tilelang`�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.14)�[0m INFO 04-27 00:21:31 [gpu_worker.py:436] Available KV cache memory: 0.98 GiB
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m 2026-04-27 00:21:17  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:133): TileLang begins to compile kernel `mhc_post_tilelang` with `out_idx=None`�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m 2026-04-27 00:21:20  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:141): TileLang completes to compile kernel `mhc_post_tilelang`�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) INFO 04-27 00:21:31 [kv_cache_utils.py:1693] GPU KV cache size: 544 tokens
(EngineCore pid=1315) INFO 04-27 00:21:31 [kv_cache_utils.py:1698] Maximum concurrency for 1,024 tokens per request: 1.78x
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=370, ip=10.111.51.16)�[0m 2026-04-27 00:21:32,102 - INFO - autotuner.py:457 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m /usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py:322: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m   warnings.warn(�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m [rank0]:W0427 00:21:24.859000 2450 torch/_inductor/utils.py:1731] [2/0] Not enough SMs to use max_autotune_gemm mode�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m 2026-04-27 00:21:34,544 - INFO - autotuner.py:466 - flashinfer.jit: [Autotuner]: Autotuning process ends
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.15)�[0m 2026-04-27 00:21:37  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:133): TileLang begins to compile kernel `mhc_pre_big_fuse_tilelang` with `out_idx=None`
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=2450)�[0m INFO 04-27 00:21:31 [gpu_worker.py:436] Available KV cache memory: 4.64 GiB�[32m [repeated 7x across cluster]�[0m
(EngineCore pid=1315) �[36m(RayWorkerWrapper pid=369, ip=10.111.51.15)�[0m 2026-04-27 00:21:40  [TileLang:tilelang.jit.kernel:INFO] (kernel.py:141): TileLang completes to compile kernel `mhc_pre_big_fuse_tilelang`
(EngineCore pid=1315) INFO 04-27 00:21:42 [core.py:285] init engine (profile, create kv cache, warmup model) took 36.47 seconds
(EngineCore pid=1315) INFO 04-27 00:21:52 [vllm.py:819] Asynchronous scheduling is disabled.
(EngineCore pid=1315) WARNING 04-27 00:21:52 [vllm.py:877] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
(EngineCore pid=1315) WARNING 04-27 00:21:52 [vllm.py:888] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
(EngineCore pid=1315) INFO 04-27 00:21:52 [kernel.py:201] Final IR op priority after setting platform defaults: IrOpPriorityConfig(rms_norm=['vllm_c', 'native'])
(EngineCore pid=1315) INFO 04-27 00:21:52 [vllm.py:1066] Cudagraph is disabled under eager mode
(EngineCore pid=1315) WARNING 04-27 00:21:52 [vllm.py:1234] Auto-initialization of reasoning token IDs failed. Please check whether your reasoning parser has implemented the `reasoning_start_str` and `reasoning_end_str`.
(EngineCore pid=1315) INFO 04-27 00:21:52 [compilation.py:294] Enabled custom fusions: norm_quant, act_quant
(APIServer pid=1137) INFO 04-27 00:21:52 [api_server.py:600] Supported tasks: ['generate']
(APIServer pid=1137) INFO 04-27 00:21:55 [parser_manager.py:202] "auto" tool choice has been enabled.
(APIServer pid=1137) INFO 04-27 00:21:56 [api_server.py:604] Starting vLLM server on http://0.0.0.0:5001
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:37] Available routes are:
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /openapi.json, Methods: HEAD, GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /docs, Methods: HEAD, GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /docs/oauth2-redirect, Methods: HEAD, GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /redoc, Methods: HEAD, GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /tokenize, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /detokenize, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /load, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /version, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /health, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /metrics, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/models, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /ping, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /ping, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /invocations, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/chat/completions, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/chat/completions/batch, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/responses, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/responses/{response_id}, Methods: GET
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/responses/{response_id}/cancel, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/completions, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/messages, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/messages/count_tokens, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /inference/v1/generate, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /scale_elastic_ep, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /is_scaling_elastic_ep, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/chat/completions/render, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /v1/completions/render, Methods: POST
(APIServer pid=1137) INFO 04-27 00:21:56 [launcher.py:46] Route: /generative_scoring, Methods: POST
(APIServer pid=1137) INFO:     Started server process [1137]
(APIServer pid=1137) INFO:     Waiting for application startup.
(APIServer pid=1137) INFO:     Application startup complete.
(APIServer pid=1137) INFO:     192.168.2.35:33644 - "GET /v1/models HTTP/1.1" 200 OK
(APIServer pid=1137) INFO:     192.168.2.35:58268 - "GET /v1/models HTTP/1.1" 200 OK
(APIServer pid=1137) INFO:     192.168.2.35:58284 - "GET /v1/models HTTP/1.1" 200 OK
(EngineCore pid=1315) INFO 04-27 00:22:12 [ray_executor.py:551] RAY_CGRAPH_get_timeout is set to 300
(EngineCore pid=1315) INFO 04-27 00:22:12 [ray_executor.py:555] VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = auto
(EngineCore pid=1315) INFO 04-27 00:22:12 [ray_executor.py:559] VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = False
(EngineCore pid=1315) INFO 04-27 00:22:12 [ray_executor.py:618] Using RayPPCommunicator (which wraps vLLM _PP GroupCoordinator) for Ray Compiled Graph communication.
(APIServer pid=1137) INFO:     192.168.2.35:58296 - "POST /v1/chat/completions HTTP/1.1" 200 OK
(APIServer pid=1137) INFO 04-27 00:22:34 [loggers.py:259] Engine 000: Avg prompt throughput: 1.9 tokens/s, Avg generation throughput: 1.7 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 1.2%, Prefix cache hit rate: 0.0%
(APIServer pid=1137) INFO:     192.168.2.35:42978 - "POST /v1/chat/completions HTTP/1.1" 200 OK
(APIServer pid=1137) INFO:     192.168.2.35:42990 - "POST /v1/chat/completions HTTP/1.1" 200 OK

@tonyliu312
Copy link
Copy Markdown

@idonati this is fantastic data, thanks for the detailed write-up. A few quick reactions:

On #40923 reproduction. Independent confirmation from an 8-node TP=8 setup — using the exact same patch — is exactly the kind of validation that makes a small CMakeLists change credible. Two completely separate hardware/network configurations (dual-Spark TP=2 RoCE bonded + 8-node TP=8 RoCE multi-rail) hitting the same broken-PTX-JIT failure mode and being unblocked by the same arch-list addition is, IMO, the strongest signal we can give reviewers that the precedent already set by MARLIN_FP8_ARCHS = "8.9;12.0;12.1" should extend to the BF16/FP16/MoE entries too. I'll cross-link your comment from #40923.

On the V4-Flash vs V4-Pro divergence (256 vs 384 experts). Your point that V4-Flash was getting "lucky on the broken PTX path" while V4-Pro consistently hit the dead _load_w2 straggler is worth flagging — it suggests the silently-wrong-cubin failure surface is not deterministic per-arch, but contingent on which expert path the model exercises. That's a separate (orthogonal) line of risk for any sm_12x consumer of pre-#40923 vLLM builds, and another reason to land #40923 sooner rather than later.

On the prefetch nudge. +1 — the current heuristic optimises for the wrong axis (FS type). Inverting the condition (when checkpoint > 90% RAM, enable prefetch by default with lazy as opt-out) is the right shape, and the warning-vs-action mismatch you described is a clean argument. Please do file it as a separate vllm issue — happy to support the discussion, but it's structurally orthogonal to this PR (#40899) and to ours (#40923, #40925), so it should land as its own change.

On our V4-Flash baseline. For reference, dual DGX Spark (TP=2, sm_121) running V4-Flash with --moe-backend marlin --kv-cache-dtype fp8_ds_mla --enforce-eager --gpu-memory-utilization 0.80 --load-format instanttensor + VLLM_USE_RAY_COMPILED_DAG=0 (workaround for #36237 with V4-Flash's MoE) is currently sustaining a 24h soak with MMLU=90.0% (50q sample, +1.3pp vs published) and GSM8K=98.0% (50q, +7.2pp vs published). Stability under sustained load on this sm_121 stack appears solid once the Compiled DAG bypass is in place — would be useful to know whether your 8-node setup hits the same CDG-level instability (ours surfaced as silent worker death after a few prompts).

Thanks again — this is the kind of report that moves a PR forward.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jasl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 27, 2026
tonyliu312 pushed a commit to tonyliu312/vllm that referenced this pull request Apr 27, 2026
`w8a8_triton_block_scaled_mm` falls back to a hardcoded default config
when no pre-tuned `configs/N=*,K=*,device_name=*.json` file matches the
GPU. The default uses `BLOCK_SIZE_M=64`, which wastes 98% of the M
dimension in single-request decode (M=1). GPUs without a pre-tuned JSON
file for their (N, K, device) tuple pay this cost.

Narrow the change: only specialize the M<=8 case (single-request decode
and short MTP-style draft batches). Larger M keeps the previous default
unchanged so non-decode paths and tuned configs are not perturbed.

  M <= 8 (CUDA)   -> BLOCK_SIZE_M=16, num_stages=3   (new)
  M <= 8 (ROCm)   -> BLOCK_SIZE_M=16, num_stages=2   (new)
  else            -> BLOCK_SIZE_M=64, num_stages=2   (previous default)

num_stages=3 is gated to non-ROCm because MI300/MI250X LDS (64 KB) is
borderline for 3-stage Triton pipelining at typical [128, 128] block
sizes; on ROCm we keep num_stages=2 so the M<=8 branch still gets the
BLOCK_SIZE_M=16 wave-quantisation win without LDS pressure.

Pre-tuned JSON configs are unaffected (they short-circuit before this
branch). Workloads that already have a JSON for their (N, K, device)
get the same kernel as before.

Verified on dual DGX Spark (GB10, sm_121, TP=2) running V4-Flash:
median single-request decode goes from 5.45 t/s to 6.73 t/s (+23%) with
no other changes. Output remains coherent. The win is expected to
generalize to other architectures lacking a pre-tuned JSON for the
target (N, K) pair, but only the GB10 case is verified here; reviewers
on Hopper/Ampere are welcome to confirm or push back.

Refs vllm-project#40860 (V4 rebase), vllm-project#40899 (jasl SM12x scope is orthogonal)

Signed-off-by: Tony Liu <tonyliu0512@gmail.com>
@tonyliu312 tonyliu312 mentioned this pull request Apr 27, 2026
4 tasks
@sniper35
Copy link
Copy Markdown
Contributor

Hey @jasl thanks for your contribution! I came across this PR when I searched RTX PRo 6000 and found a small latent issue on scale parameter initialization and made a PR at jasl#1, does it look reasonable to you? Thanks!

Copy link
Copy Markdown
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for such a big change, we will not be able to review and accept it. you can keep it in a fork. thanks for your interest.

@WoosukKwon
Copy link
Copy Markdown
Collaborator

@jasl Can you actually rebase with the current main branch, so that we can see the diffs more clearly?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

@WoosukKwon I'm re-organizing a new PR based on the latest main

@idonati
Copy link
Copy Markdown

idonati commented Apr 27, 2026

@tonyliu312 Thanks again for the detailed reply. I ran your three suggestions on the 8× DGX Spark cluster and have results to share.

1. VLLM_USE_RAY_COMPILED_DAG=0 workaround — investigation result

Quick heads up: the env var name doesn't exist in current vLLM main. When I exported it on the launch script, vLLM emitted:

WARNING [envs.py:1846] Unknown vLLM environment variable detected: VLLM_USE_RAY_COMPILED_DAG

vllm/envs.py only carries these Ray-related toggles now:

  • VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE (auto / nccl / shm)
  • VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM (bool)
  • VLLM_USE_RAY_WRAPPED_PP_COMM (bool)
  • VLLM_USE_RAY_V2_EXECUTOR_BACKEND (bool)

So the master toggle for compiled DAG appears to have been removed in a refactor — the v1 engine + Ray multi-node path is hard-wired through compiled DAG now. But the underlying problem you flagged is very real on our cluster, and I hit it in a way I can characterize precisely. Details below.

2. Soak with your other flag combo (fp8_ds_mla + marlin + gpu_mem=0.80)

Full 30-minute soak, 50 prompts (Q&A / math / code / multi-language) against an 8-Spark TP=8 V4-Flash deployment using fp8_ds_mla + --moe-backend marlin + gpu_mem=0.80:

=== Summary ===
Total: 50 (ok=50 err=0 empty_content=0)
Wall time: 1812s
Avg latency per prompt: 2.81s
Avg decode rate: 10.62 tok/s (across 9624 total tokens)

Stable, no quality regression vs our fp8 baseline. Decode dropped slightly from 11.4 tok/s to 10.62 tok/s — I attribute that to the lower gpu_memory_utilization (0.80 vs 0.85) tightening the KV budget.

3. --load-format instanttensor

Built the instanttensor loader into the container. Weight load time dropped from ~290s (prefetch) to ~24s — about 12× faster on our EXT4-on-NVMe layout. Inference quality unchanged across a 10-prompt mixed verification (math, science, multi-language, code).

This is now our default for cold-start sensitive workflows.

4. MMLU + GSM8K via lm-evaluation-harness

Running 5-shot defaults, 50 questions/task limit, against the InstantTensor-loaded V4-Flash endpoint:

Task Metric Score Notes
GSM8K flexible-extract 96.0% ± 2.8% 5-shot, no chat template applied
GSM8K strict-match 94.0% ± 3.4% 5-shot, no chat template applied
MMLU aggregate (acc) 30.7% ± 0.84% 0-shot, no chat template, gpt2 tokenizer used for length budgeting only — see caveat below
MMLU stem 37.4%
MMLU humanities 28.3%
MMLU social sciences 25.7%
MMLU other 28.0%

MMLU caveat: aggregate is suppressed by 0-shot + no chat template + tokenizer fallback (lm-eval can't load the deepseek_v4 tokenizer config in transformers 5.6 because of a model_type=deepseek_v4 registration gap; I used the gpt2 tokenizer for context-length budgeting since tokenized_requests=False sends raw text to vLLM and the server tokenizes correctly). The score should not be read as the model's true MMLU number — it confirms the serving stack handled 11,400 sequential loglikelihood requests over 80 minutes without crashing or returning empty content. The stem-vs-humanities gradient looks directionally plausible.

GSM8K is the meaningful number — 96.0% is in line with reasoning-focused SOTA and confirms inference quality is intact end-to-end on the 8-Spark TP=8 InstantTensor stack.

5. The compiled-DAG hang — characterized precisely

This is what I think you were really pointing at, and I can now reproduce it deterministically.

I first ran GSM8K with num_concurrent=8. At request ~24, vLLM threw:

ray.exceptions.RayChannelTimeoutError: System error: If the execution is expected
to take a long time, increase RAY_CGRAPH_get_timeout which is currently 300 seconds.
Otherwise, this may indicate that the execution is hanging.

Engine died, all subsequent requests returned HTTP 500. Full Ray actor cancellation + compiled DAG teardown in the logs.

Re-running the same task with num_concurrent=4 against the same engine: clean run, 50/50 success, the 96% / 94% scores above. Per-request latency 5–10s steady state, individual ranks never block the channel longer than 300s.

So under concurrent generative load with longer chain-of-thought outputs and 8 parallel client streams against an 8-rank TP=8 cluster, the compiled DAG channel timeout (default 300s) gets hit before the slowest rank completes its allocation. Two practical workarounds:

  • Client-side: cap num_concurrent ≤ 4 in lm-evaluation-harness or any benchmark client.
  • Server-side: export RAY_CGRAPH_get_timeout=900 (or higher) at launch — Ray-level env var, propagates fine through the vLLM → Ray executor stack.

I've gone with the server-side env in our recipe so production isn't sensitive to client tuning. Single-stream production traffic at TP=8 has never reproduced this regardless of duration (the original 30-min soak ran clean — it never had >1 in-flight request).

If a master toggle to fall back to the legacy non-compiled Ray executor is meant to still exist, I haven't found the right name in vllm/envs.py — happy to re-test if you point me at the current flag.

6. Separate finding worth a follow-up

--safetensors-load-strategy prefetch is load-bearing on EXT4 for V4-Pro on this cluster. Without it, post-shard-load weight materialization random-reads from NVMe per-tensor and 3 of 8 workers straggle past 60 minutes (effective hang). With prefetch, V4-Pro fires up in ~12 min. Want me to file a separate issue with the worker-stall trace so the heuristic can decide more aggressively when to default-on prefetch (probably "if model size > 500 GB AND filesystem != tmpfs")? Happy to tag you on it.


Cluster details for reference: 8× NVIDIA DGX Spark (GB10 / sm_121, 128 GiB unified memory, ARM64), dual-rail 200G RoCE multi-switch fabric (4× MikroTik CRS804-4DDQ uplinks). Image built on nvcr.io/nvidia/pytorch:25.11-py3 + your sm_12x Marlin patch (PR #40923) + TORCH_CUDA_ARCH_LIST="12.0;12.1" + the InstantTensor loader.

@BehindTheCartan
Copy link
Copy Markdown

Re-test on ds4-sm120 head (8d0ebb76c) — Crash 1 looks fixed, Crash 2 still reproduces

Quick follow-up to the earlier report per @jasl's request to re-test on the new branch.

Setup

Same 2× RTX PRO 6000 Blackwell Max-Q (SM 12.0), same agent workload (~2-15k input-token prompts, ~1 in flight). Now on jasl/vllm@ds4-sm120 head (8d0ebb76c) + DeepGEMM SM120 prototype, with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True added per the OOM message hint. All other args unchanged: --tensor-parallel-size 2 --enable-expert-parallel --kv-cache-dtype fp8 --block-size 256 --max-model-len 65536 --gpu-memory-utilization 0.92 --enforce-eager.

Crash 1 — looks fixed ✅

Sustained a 30-min sync-mode (CUDA_LAUNCH_BLOCKING=1) run with default sparse chunks (256/128) — the exact config that fired the device-side assert in _accumulate_reference_attention_chunk on the prototype branch. Did not reproduce. 15 requests completed, all finish_reason=stop. Likely fixed by b7a70b9e5 and/or 6652949c2 (haven't tried to bisect).

Crash 2 — still reproduces ❌

RuntimeError: CUDA out of memory. Tried to allocate 1.38 GiB.
GPU 0 has a total capacity of 94.97 GiB of which 1.22 GiB is free.
This process has 93.73 GiB memory in use.
Of the allocated memory 90.04 GiB is allocated by PyTorch,
and 478.44 MiB is reserved by PyTorch but unallocated.
  • Triggers ~43s into the very first long-prompt planner call (iter 0 of the agent run).
  • Same shape as before: ~1.4 GiB indexer-forward transient request.
  • expandable_segments:True does not appear to mitigate. Verified the env reached the engine subprocess (tr '\0' '\n' < /proc/$(pgrep -f api_server)/environ shows PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True).
  • b7a70b9e5 Reduce SM12x long-prefill sparse MLA memory reduced steady-state somewhat — pre-crash GPU0 peak was ~96428 MiB vs ~96186 MiB on the prototype, basically unchanged — but the transient still blows the budget on long prompts.
  • EngineCore raises EngineDeadError after the OOM. The API server stays up but every subsequent request returns 500; the engine subprocess needs an external restart.

Notes

  • Saw the comment that you're re-organizing into a fresh PR against current main (in response to @WoosukKwon). Happy to re-test there as soon as it's up; this comment is just to close the loop on the prior report and confirm one of the two paths is cleared.
  • If the long-prefill stability work mentioned yesterday lands in the new PR, the Crash 2 path above is the one to validate.

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

I created a new PR #40991 against the latest main.

@BehindTheCartan
I don't have your configuration, but I'll look into it.
I have mentioned Johnny about my work, and he may help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector needs-rebase new-model Requests to new models nvidia speculative-decoding tool-calling v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

10 participants