Skip to content

[New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes#41834

Open
jasl wants to merge 7 commits intovllm-project:mainfrom
jasl:codex/ds4-sm120-min-enable
Open

[New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes#41834
jasl wants to merge 7 commits intovllm-project:mainfrom
jasl:codex/ds4-sm120-min-enable

Conversation

@jasl
Copy link
Copy Markdown
Contributor

@jasl jasl commented May 6, 2026

Purpose

This PR adds support for DeepSeek V4 Flash for SM12x (DGX Spark and RTX Pro 6000)

Note 1: This supersedes #40991 with a smaller branch and a cleaner file layout.
Note 2: SM12x hardware is hard to infer DeepSeek V4 Pro, so I only test and focus on DeepSeek V4 Flash
Note 3: GB10 requires a patch, or otherwise, loading the DeepSeek model will cause the device to hang and require a hard reboot.
Note 4: Another essential patch is about usability as well, without it, a 2 * GB10 or 2 * RTX Pro 6000 configuration will crash on agentic use cases (such as OpenClaw, Open Code) in just one turn. The issue was reported by a tester in jasl#2

Test Plan

Local/static checks:

BASE=$(git merge-base upstream/main HEAD)

uv run --no-project --with ruff ruff check \
  $(git diff --name-only "$BASE"..HEAD -- '*.py')

uv run --no-project python -m py_compile \
  $(git diff --name-only "$BASE"..HEAD -- '*.py')

git diff --check "$BASE"..HEAD

Serve test, SM120 / 2x RTX PRO 6000, TP=2:

export PATH="/home/jasl/tmp/vllm/.venv/bin:/usr/local/cuda/bin:$PATH"
export CUDA_HOME=/usr/local/cuda
export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
export CUDA_ARCH_LIST=120a
export TORCH_CUDA_ARCH_LIST=12.0a
export VLLM_RPC_TIMEOUT=100000
unset PYTORCH_CUDA_ALLOC_CONF

vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --host 127.0.0.1 \
  --port 8000 \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --block-size 256 \
  --tensor-parallel-size 2 \
  --enable-expert-parallel \
  --gpu-memory-utilization 0.95 \
  --max-model-len 65536 \
  --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}' \
  --tokenizer-mode deepseek_v4 \
  --tool-call-parser deepseek_v4 \
  --enable-auto-tool-choice \
  --reasoning-parser deepseek_v4 \
  --no-enable-flashinfer-autotune

Full GSM8K accuracy:

lm_eval \
  --model local-completions \
  --model_args model=deepseek-ai/DeepSeek-V4-Flash,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=4,max_retries=10,tokenized_requests=False,tokenizer_backend=none,max_gen_toks=2048,timeout=60000 \
  --tasks gsm8k \
  --num_fewshot 8 \
  --batch_size auto \
  --output_path ./eval_gsm8k

Test Result

Static checks:

ruff: All checks passed!
py_compile: passed
git diff --check: clean

GB10 / DGX Spark startup smoke:

SM121 / GB10
TP=2, PP=1
server reached /health=200
quick chat smoke passed

SM120 serve / max-context:

/health: 200
/v1/models max_model_len: 393216
KV cache capacity: 1,936,534 tokens
max concurrency for 393216 tokens: 4.92x
max-context reliability: passed

Long-context reliability smoke, SM120 / TP=2:

max_model_len=393216
tokenized prompt ~= 392K tokens
3 repeated max-context runs: passed
1 salted cold max-context run: passed
request_success error/abort/length/repetition: 0

Full GSM8K, SM120 / TP=2 / EP enabled / FlashInfer autotune disabled:

Mode exact_match, flexible exact_match, strict
no-MTP 0.9514783927 ± 0.0059184686 0.9514783927 ± 0.0059184686
MTP, 2 speculative tokens 0.9499620925 ± 0.0060054424 0.9507202426 ± 0.0059621507

Both runs completed with return code 0 and no server-side errors.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added deepseek Related to DeepSeek models nvidia v1 labels May 6, 2026
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 6, 2026

@zyongye
I've cleaned up the old PR, could you help review this one?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for DeepSeek V4 on SM12x (Blackwell) architectures by providing Triton-based fallbacks for DeepGEMM-dependent operations. Key enhancements include the introduction of specialized Triton kernels for sparse MLA, FP8 einsum, and MQA logits, as well as memory optimizations in the sparse attention indexer to compute top-k indices without materializing full logits. Additionally, the PR updates the model loader to support weight name filtering for skipping MTP weights and handles Blackwell-specific FP8 quantization scales. I have no feedback to provide.

@chatgpt-codex-connector
Copy link
Copy Markdown

💡 Codex Review

def _sparse_indexer_requires_deep_gemm() -> bool:
return current_platform.is_cuda() and not (
current_platform.is_device_capability_family(120)
)

P1 Badge Keep DeepGEMM requirement for SM120 FP4 indexer path

This helper now disables the DeepGEMM requirement for every SM120 run, but the FP4 indexer cache path still depends on DeepGEMM kernels (fp8_fp4_*) because the new SM120 fallback only handles q_scale is None (FP8 Q). With use_fp4_cache=True on SM120 and no DeepGEMM installed, construction succeeds and the first prefill/decode call fails at runtime with the DeepGEMM _missing() error instead of being rejected up front.


if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)

P2 Badge Propagate weight_name_filter to fast safetensor loaders

The new pre-load weight_name_filter is only wired into safetensors_weights_iterator; this branch still loads all tensors for fastsafetensors (and similarly other non-default safetensor iterators), so skipped tensors are still materialized. For DeepSeek V4 this defeats the intended early skip of MTP weights and can reintroduce high transient memory use/OOM when these load formats are enabled.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@jasl jasl changed the title [New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash [New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes May 6, 2026
@jasl jasl force-pushed the codex/ds4-sm120-min-enable branch from 042e366 to df2e6f8 Compare May 6, 2026 16:26
@jasl jasl force-pushed the codex/ds4-sm120-min-enable branch from df2e6f8 to 5a774bb Compare May 6, 2026 16:50
@aidendle94
Copy link
Copy Markdown

Do you have a performance result somewhere compared to the last PR?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 6, 2026

Do you have a performance result somewhere compared to the last PR?

This is the minimum support PR (as the reviewer requested), so it only enables support with two essential patches to prevent the vLLM crash or device hang.

I'm porting all the following changes to a new preview branch, which should be close to production-ready.

@jasl jasl force-pushed the codex/ds4-sm120-min-enable branch from 5a774bb to 4128d11 Compare May 6, 2026 22:14
jasl and others added 7 commits May 7, 2026 07:57
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Fix the SM12x fp8 einsum custom-op registration import, skip unused DeepSeek V4 MTP checkpoint tensors before safetensors materialization, and release MXFP4 setup temporaries after kernel setup.

Signed-off-by: jasl <jasl9187@hotmail.com>
Protect hybrid-aligned DeepSeek V4 MLA prompt cache blocks so they survive decode and unrelated cache churn. Release those protected references under admission pressure and before prefix-cache reset so they do not starve the block pool.

Add regression coverage for reuse after decode pressure, admission under protected refs, and reset cleanup.

Signed-off-by: jasl <jasl9187@hotmail.com>
Forward model skip_weight_name_before_load filters into the fastsafetensors iterator and skip filtered keys before materializing tensors. This keeps DeepSeek V4 non-MTP loads from reading MTP-only weights when users select --load-format fastsafetensors.

Keep the regression coverage at behavior level by checking the DefaultModelLoader path and pruning private implementation-field assertions from the adjacent DeepSeek V4 prefix-cache tests.

Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
@jasl jasl force-pushed the codex/ds4-sm120-min-enable branch from 4128d11 to 1942bad Compare May 6, 2026 23:58
@qym-ll
Copy link
Copy Markdown

qym-ll commented May 7, 2026

May I ask which is the final branch? How do we set it up and use it? Has the crash issue been resolved?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 7, 2026

May I ask which is the final branch? How do we set it up and use it? Has the crash issue been resolved?

This PR's performance isn't great, because this is the minimum support PR (as the reviewer requested), but the key reliability patches have been included.

I don't see any crash on my 2 * GB10 and 2 * RTX Pro 6000.

It needs the following PRs to restore performance.

I'm preparing the preview branch now. To avoid the mass I made earlier, I'll test it locally first, especially for GB10.

@qym-ll
Copy link
Copy Markdown

qym-ll commented May 7, 2026

2* RTX Pro 6000

Do you have any recommended high-performance branches? Is there a deployment tutorial? I am using 4* RTX Pro 6000, and I want to try setting them up and using them.

@leavelet
Copy link
Copy Markdown

leavelet commented May 7, 2026

Hi, I opened a PR adding SM120 support for DeepGEMM: deepseek-ai/DeepGEMM#324 — might be helpful here

@johnnynunez
Copy link
Copy Markdown
Contributor

cc @pavanimajety @askliar

@ehfd
Copy link
Copy Markdown
Contributor

ehfd commented May 7, 2026

@jasl I have a suggestion, after understanding how everything is.

First, we can try our best to merge this (#38476). This has only DeepSeek v3.2 and GLM-5.x in scope for now, and DeepSeek v4 is not within scope yet. However, DeepSeek v4 still needs Sparse MLA.

Then, the PR can be substantially shortened without the TRITON_SPARSE_MLA components.

Does this sound good?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 7, 2026

May I ask which is the final branch? How do we set it up and use it? Has the crash issue been resolved?

https://github.com/jasl/vllm/tree/ds4-sm120-preview

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented May 7, 2026

@jasl I have a suggestion, after understanding how everything is.

First, we can try our best to merge this (#38476). This has only DeepSeek v3.2 and GLM-5.x in scope for now, and DeepSeek v4 is not within scope yet. However, DeepSeek v4 still needs Sparse MLA.

Then, the PR can be substantially shortened without the TRITON_SPARSE_MLA components.

Does this sound good?

I tested the branch a few days ago; however, it's not the fast path for SM12x, and its correctness isn't as good as my implementation.

@ehfd
Copy link
Copy Markdown
Contributor

ehfd commented May 7, 2026

I tested the branch a few days ago; however, it's not the fast path for SM12x, and its correctness isn't as good as my implementation.

Yes, I understand totally. So the PRs are both required for different purposes. Thank you for the response.

@ehfd
Copy link
Copy Markdown
Contributor

ehfd commented May 7, 2026

haosdent/vllm-nightly:fix-38006-4

Note that a breakthrough regarding TTFT was recently found in this Docker image.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models nvidia v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

6 participants