Skip to content

Support SM120#318

Open
jasl wants to merge 14 commits intodeepseek-ai:mainfrom
jasl:sm120
Open

Support SM120#318
jasl wants to merge 14 commits intodeepseek-ai:mainfrom
jasl:sm120

Conversation

@jasl
Copy link
Copy Markdown
Contributor

@jasl jasl commented Apr 25, 2026

Disclaimer: I'm not working in the AI domain, but I wish my RTX Pro 6000 could run DeepSeek v4, so I made this.
GPT is the co-worker.

With this PoC, with some other dirty hacks, I've made it run on my 2 x RTX Pro 6000 workstation.

See my vLLM PR vllm-project/vllm#40991

If anyone can help, please take it.
I shall try my best to make it fully work, but it's a challenge to me.

jasl added 4 commits April 25, 2026 04:47
DeepGEMM currently rejects SM12x devices in the DeepSeek V4 forward path because the HC prenorm GEMM and paged MQA logits families only dispatch SM90/SM100 implementations. This change adds an explicit compatibility path for compute capability 12.x while keeping the existing high-performance SM90/SM100 paths unchanged.

Normalize CUDA major 12 devices, including SM120 and SM121, to the sm_120f JIT target and sm120 include suffix so a GB10/DGX Spark can build a shared SM12x kernel family. Add an SM12x tf32_hc_prenorm_gemm reference fallback using ATen CUDA matmul plus the required square-sum output, preserving the existing split-output ABI by storing the full reference result in split zero and zeroing the remaining splits.

Allow SM12x paged MQA metadata and add an FP8-only SM120 paged MQA logits reference runtime. The new CUDA kernel avoids SM90 WGMMA and SM100 tcgen05/TMA assumptions; it directly dequantizes FP8 q/kv values, applies the KV scale, accumulates in FP32, applies ReLU and per-head weights, and stores invalid token positions as -inf. FP4 paged MQA, non-paged MQA, MegaMoE, and general FP8/FP4 GEMM remain gated until they have separate SM12x implementations.

Add focused SM120 regression coverage for HC prenorm GEMM and FP8 paged MQA logits against PyTorch references, including the fused paged KV-cache layout used by the API. Also install the requested project-scoped C++ readability skills under .codex/skills so future CUDA/C++ work in this checkout can use the same local guidance.

Verification on 10.0.0.116 (/home/jasl/tmp/DeepGEMM): ./develop.sh completed with CUDA 13.1, then PYTHONPATH=. pytest -q tests/test_sm120_reference.py -q returned '.. [100%]'. A separate JIT command check confirmed the SM12x path compiles with --gpu-architecture=sm_120f.
DeepSeek V4 on SM120 currently reaches an einsum path where the activation side is laid out as bhr and the output projection is logically hdr, but the weight and scale tensors can arrive from the vLLM Cutlass path as flattened 2D tensors reshaped into that logical form.

Add an SM120-specific reference implementation for the bhr,hdr->bhd FP8 einsum recipe using FP32 block scales and scalar FP32 accumulation. This is intentionally a correctness-first CUDA fallback rather than an optimized Tensor Core kernel, and is only dispatched for arch major 12 with recipe (1, 128, 128).

Verified on 10.0.0.110 with the vLLM editable build against /home/jasl/tmp/DeepGEMM. The generated JIT compile used --gpu-architecture=sm_120f, and direct DeepGEMM plus vLLM wrapper checks matched the PyTorch reference with max_abs 0.0. The full DeepSeek-V4-Flash serve path now gets past the previous DeepGEMM einsum failure and reaches the separate FlashMLA backend blocker.
@wuwenthink
Copy link
Copy Markdown

thanks for it,i hope it can be use in vllm & sglang

ergodic-flow and others added 6 commits April 26, 2026 03:38
This change introduces a new kernel for SM120 architecture (RTX 5090 and 6000 Pro)
specifically for tf32_hc_prenorm_gemm .
Add SM120 coverage for the varlen FP8 paged MQA logits path and the FP8 bhr,hdr->bhd einsum fallback used by the DeepSeek V4 bring-up. These tests exercise the existing fallback kernels against PyTorch references with physical block tables, repeated varlen rows, and FP8 scaling-factor layouts.

Also tighten the SM120 HyperConnection JIT wrapper so zero-stage split-K launches are rejected on the host. The regression now uses a valid split shape with enough K blocks, which keeps the test aligned with the kernel's split-K contract.
Replace the scalar per-dimension FP8 dot-product loop in the SM120 paged MQA fallback with fp8x4 loads and conversion. The launch geometry stays unchanged, which avoids the CTA-count regression observed with the warp-per-logit experiment while reducing conversion/load overhead inside each logit calculation.

Add a V4-shaped SM120 regression with 64 heads and 128-dimensional FP8 Q/KV to keep the vectorized path covered alongside the existing small-shape and varlen cases.
Use fp8x4 loads and conversion in the SM120 bhr,hdr->bhd fallback while keeping the existing launch geometry. The vectorized path is gated on contiguous rank strides and a rank size divisible by four, with the original scalar path retained for irregular shapes.

Add a V4-shaped regression covering 64 heads, rank 128, and output dim 128. On GB10 this improves representative event-timed shapes from roughly 0.99 TFLOPS to about 4.0 TFLOPS.
Add regression coverage for the SM120 tf32_hc_prenorm_gemm kernel beyond the original small m/n shape. The new cases cover n=16 and n=32, multiple 64-row M tiles, and split-K accumulation so future HC tuning has a stronger correctness guard.
@jasl jasl changed the title [PoC] Support SM120 Support SM120 Apr 25, 2026
@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 25, 2026

With GPT's help, I've made low-risk changes to get DeepGEMM working on SM12x.
And it seems DeepGEMM isn't the bottleneck for inference DeepSeek V4.

So I think the PR is ready to review.

jasl added 4 commits April 26, 2026 06:55
Use explicit NVCC gencode targets for JIT builds so SM12x family-specific MMA instructions do not get rejected while compiling generic PTX.

Add an env-gated SM120 FP8 paged MQA tiled kernel for the V4 shape: 64 heads, head_dim 128, block_kv 64, and float logits. The tiled path uses SM120 E4M3 QMMA, supports token group sizes 1/2/4/8, and defaults to a Q-cached 4-group schedule under DG_SM120_PAGED_MQA_TILED=1.

Keep the scalar reference fallback as the default path for now and add CUDA parity coverage for all tiled variants. On DGX Spark/GB10, the V4-shaped component benchmark improved from 1.918 ms to 0.677 ms for batch=512 ctx=512, and from 1.906 ms to 0.678 ms for batch=128 ctx=2048.
Extend the explicit Q-cache schedule to token group sizes 2 and 8, keeping env overrides for all tested group sizes. The runtime default now uses eight token groups with Q caching enabled, which reduces repeated Q fragment loads while keeping the scalar fallback unchanged unless DG_SM120_PAGED_MQA_TILED=1 is set.

DGX Spark benchmark after this change: batch=512 ctx=512 improves from 1.909 ms reference to 0.506 ms default tiled, batch=128 ctx=2048 improves from 1.877 ms to 0.504 ms, and batch=4 ctx=32768 improves from 0.970 ms to 0.251 ms.
Rename the SM120 paged MQA host dispatcher so it no longer advertises itself as reference-only now that it can select the tiled kernel via DG_SM120_PAGED_MQA_TILED.

Also tighten a temporary HC fallback comment and avoid binding the generated runtime args to a temporary reference.
Use scalar naming for the SM120 compatibility kernels that are not pure test references: FP8 einsum and FP8 paged MQA logits now expose scalar runtime/kernel names.

Keep reference naming only for the PyTorch oracle in tests, and rename the SM120 regression file to test_sm120_kernels.py so the PR reads as kernel coverage rather than temporary reference scaffolding.
@wuwenthink
Copy link
Copy Markdown

wuwenthink commented Apr 26, 2026

Thanks for your hard work, I was indeed able to run through, but the prefill is only less than 500, the decode is only 25-35 speed, and the subsequent optimization is still a long way off, and it is still far from reaching the speed of agent use.
if i need test it ,what branches in https://github.com/jasl/vllm should be switch?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 26, 2026

Thanks for your hard work, I was indeed able to run through, but the prefill is only less than 500, the decode is only 25-35 speed, and the subsequent optimization is still a long way off, and it is still far from reaching the speed of agent use. if i need test it ,what branches in https://github.com/jasl/vllm should be switch?

My vLLM PR vllm-project/vllm#40899 has already switched to this branch.
This PR provides an unoptimized implementation, aiming to provide a correctness baseline
There is still much work to do. My plan is to make it work for now and find high-value hot spots to improve.

The original DeepSeek V4 support PR has issues, which are slow even on GH200 and 8*H100 cluster (verified by my friends), so not sure if the slow is caused by this PR.

@wuwenthink
Copy link
Copy Markdown

wuwenthink commented Apr 26, 2026

Thanks, there are still many friends in the community who are paying attention to your implementation, can I understand that there is no need to switch to other branches in the future(ex ds4-sm120-prototype or ds4-sm120-next), just update on this one branch of jasl:ds4-sm120?

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 26, 2026

Thanks, there are still many friends in the community who are paying attention to your implementation, can I understand that there is no need to switch to other branches in the future, just update on this one branch of jasl:ds4-sm120?

Yes, for now.
The PR is ready to review, I wish it could be merged, so people no longer have to use my fork.

@johnnynunez
Copy link
Copy Markdown

Do you know the maintainers? @jasl

@jasl
Copy link
Copy Markdown
Contributor Author

jasl commented Apr 27, 2026

Do you know the maintainers? @jasl

Nope, but it seems @LyricZhao is a NVIDIAN. Could you reach him internally?

UPDATE: Sorry, he is ex-NVIDIAN

@myshytf
Copy link
Copy Markdown

myshytf commented Apr 28, 2026

after Applying your patch and finetuned some using dsv4 pro api, I got 1500tok/s for prefill and 90tok/s for decode with 2x rtx pro 6000. thank you for your hard work!

@zheanxu
Copy link
Copy Markdown
Collaborator

zheanxu commented Apr 30, 2026

Hi @jasl, thank you for this wonderful contribution. The effort you and the community have put in is truly appreciated.

I'm a maintainer of this repo. Unfortunately, we aren't able to merge SM120 support, as we don't have the hardware to test it and lack the capacity to maintain an extra architecture.

We'd love to see a community-maintained fork for SM120, and we'd be happy to link to it from our README so other users can find it.

Thanks again for your hard work!

@armondhonore
Copy link
Copy Markdown

@jasl

⏺ 14/14 PR #318 tests PASS on RTX Pro 6000.

Your SM120 implementation is solid.

I have a 4 Card rig I am testing this on!

@linjiapro
Copy link
Copy Markdown

@zheanxu

This is Lin Jia, an Nvidia employee, lots of folks in the world use workstation GPUs for LLM local hosting, being able to support SM120 would be great! For this PR, it is mostly added codes, instead of logic tangled with what you have in other files.

I know you mentioned you don't have the hardware. One way to go would be merge this in, even the SM120 breaks later, community can keep on patching and testing, and it will not damage your other codes given the clean cut between files.

Thoughts?

@xiao-zaiyi
Copy link
Copy Markdown

Hi, I would like to ask if L20 is supported?

@leavelet
Copy link
Copy Markdown

leavelet commented May 1, 2026

Hi, I have a carefully tuned and performance-optimized version here. You can compare against it: #324

@johnnynunez
Copy link
Copy Markdown

@zheanxu

This is Lin Jia, an Nvidia employee, lots of folks in the world use workstation GPUs for LLM local hosting, being able to support SM120 would be great! For this PR, it is mostly added codes, instead of logic tangled with what you have in other files.

I know you mentioned you don't have the hardware. One way to go would be merge this in, even the SM120 breaks later, community can keep on patching and testing, and it will not damage your other codes given the clean cut between files.

Thoughts?

Another employee here… i think the same.
Community and nvidia can support sm12x…

@leavelet
Copy link
Copy Markdown

leavelet commented May 1, 2026

@zheanxu
This is Lin Jia, an Nvidia employee, lots of folks in the world use workstation GPUs for LLM local hosting, being able to support SM120 would be great! For this PR, it is mostly added codes, instead of logic tangled with what you have in other files.
I know you mentioned you don't have the hardware. One way to go would be merge this in, even the SM120 breaks later, community can keep on patching and testing, and it will not damage your other codes given the clean cut between files.
Thoughts?

Another employee here… i think the same. Community and nvidia can support sm12x…

@johnnynunez @linjiapro
We have an nv_dev branch, and my PR targets merging into nv_dev. DevTech will coordinate the support for the nv_dev branch, so this avoids adding extra burden on DeepSeek

@linjiapro
Copy link
Copy Markdown

@leavelet I really appreciated your PR, and the nv_dev branch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants