Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
829ae2c
feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference
AliceChenyy May 8, 2026
0b46b07
fix: address PR review comments and handle KV cache uint8 dtype on SM120
AliceChenyy May 13, 2026
aa7000c
style: fix pre-commit lint issues (isort, ruff, black)
AliceChenyy May 14, 2026
3e33539
fix: wrap pytest.main in sys.exit for CI exit code propagation
AliceChenyy May 14, 2026
9db9a66
fix: add SM120 guards for sgl-kernel 0.4.2.post2 compatibility
AliceChenyy May 19, 2026
ef24964
fix: remove _is_sm120 module var from metadata.py, use is_sm120_suppo…
AliceChenyy May 19, 2026
1763009
refactor: simplify _is_sm120 by using is_sm120_supported() directly
AliceChenyy May 19, 2026
88a2c05
fix: replace broken _is_sm120 import in indexer.py with is_sm120_supp…
AliceChenyy May 20, 2026
a3d06b9
style: remove unused imports flagged by ruff (rotate_activation, act_…
AliceChenyy May 20, 2026
4491e60
Merge remote-tracking branch 'origin/main' into sm120-dsv4-rebase
AliceChenyy May 20, 2026
5df66d2
fix: rename nsa_ to dsa_ in SM120 backend defaults after NSA→DSA refa…
AliceChenyy May 20, 2026
03be8ad
test: register SM120 MQA fallback tests in CI (base-b, 1-gpu-small)
AliceChenyy May 20, 2026
d9c4bfd
fix: address Fridge003 review — isolate SM120 code, auto-set envs, re…
AliceChenyy May 21, 2026
21c00bc
docs: add SM120 (RTX PRO 6000) to DeepSeek-V4 cookbook
AliceChenyy May 21, 2026
5f526de
Merge remote-tracking branch 'origin/main' into sm120-dsv4-rebase
AliceChenyy May 21, 2026
cd44380
fix: revert unnecessary change to is_arch_support_pdl in jit_kernel/u…
AliceChenyy May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tag: NEW
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}><strong><a href="https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash">DeepSeek-V4-Flash</a></strong></td>
<td style={{padding: "9px 12px", textAlign: "right", backgroundColor: "rgba(255,255,255,0.05)"}}><strong>284B</strong></td>
<td style={{padding: "9px 12px", textAlign: "right", backgroundColor: "rgba(255,255,255,0.02)"}}>13B</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>single-node serving: B200 / GB200 / GB300 / H200 on 4 GPUs</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>single-node serving: B200 / GB200 / GB300 / H200 on 4 GPUs; RTX PRO 6000 (SM120) on 4-8 GPUs</td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}><strong><a href="https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro">DeepSeek-V4-Pro</a></strong></td>
Expand Down Expand Up @@ -128,6 +128,20 @@ PD-Disagg recipes on H200 may require `docker run --privileged --ulimit memlock=
can discover the IB HCAs; without IB exposure mooncake silently falls back to
TCP, which can lead to garbled KV transfer on large checkpoints.

<a id="sm120-note" />

**SM120 (RTX PRO 6000 Blackwell Server Edition) note**

DeepSeek-V4-Flash can run on RTX PRO 6000 Blackwell Server Edition (SM120, 96 GB GDDR7) with Tensor Parallelism only. We support two TP configurations:
- **TP=8 (recommended)**: `--tp 8 --mem-fraction-static 0.70 --cuda-graph-max-bs 32`. Leaves ~20 GB per GPU for KV cache.
- **TP=4 (memory-constrained)**: `--tp 4 --mem-fraction-static 0.90 --cuda-graph-max-bs 4`. Runs near the 96 GB memory limit (~98.8% usage) with minimal KV cache headroom.

SM120 uses Triton-based MoE and FlashMLA fallback kernels instead of CUTLASS/DeepGEMM (auto-detected, no manual flags needed). Use Docker image `lmsysorg/sglang:dev-cu13` (CUDA 13.0, required for SM120 / CC 12.0).

Performance is memory-bandwidth bound (~1.5 TB/s GDDR7 vs ~8 TB/s HBM3e on B200); expect ~15-17 tok/s at BS=1. Accuracy matches reference (GSM8K 10/10, GPQA Diamond 72.0% vs 71.2% published).

V4-Pro is not supported on SM120 (model does not fit in 8x 96 GB).

**MegaMoE**

MegaMoE fuses expert dispatch + GEMM into a single kernel for higher throughput
Expand Down
59 changes: 41 additions & 18 deletions python/sglang/srt/layers/attention/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,16 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ceil_align
from sglang.srt.utils.common import is_sm120_supported

if TYPE_CHECKING:
from flash_mla.flash_mla_interface import FlashMLASchedMeta

from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner

_is_sm120 = is_sm120_supported()

logger = logging.getLogger(__name__)

SWA_WINDOW = 128
Expand All @@ -81,6 +84,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T:


def _create_flashmla_metadata():
if _is_sm120:
return None
import flash_mla

return flash_mla.get_mla_metadata()[0]
Expand Down Expand Up @@ -1031,24 +1036,42 @@ def forward(
extra_indices.shape[-1] % 64 == 0
), f"{extra_indices.shape=}'s last dimension is not aligned to 64"

import flash_mla

o = flash_mla.flash_mla_with_kvcache(
q=q,
k_cache=swa_k_cache,
head_dim_v=self.head_dim_v,
block_table=None,
cache_seqlens=None,
tile_scheduler_metadata=flashmla_metadata,
softmax_scale=self.softmax_scale,
is_fp8_kvcache=True,
indices=swa_page_indices,
topk_length=swa_topk_lengths,
attn_sink=attn_sink,
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_lengths,
)[0]
if _is_sm120:
from sglang.srt.layers.attention.flash_mla_sm120 import (
flash_mla_with_kvcache_sm120,
)

o = flash_mla_with_kvcache_sm120(
q=q,
k_cache=swa_k_cache,
head_dim_v=self.head_dim_v,
softmax_scale=self.softmax_scale,
indices=swa_page_indices,
topk_length=swa_topk_lengths,
attn_sink=attn_sink,
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_lengths,
)[0]
else:
import flash_mla

o = flash_mla.flash_mla_with_kvcache(
q=q,
k_cache=swa_k_cache,
head_dim_v=self.head_dim_v,
block_table=None,
cache_seqlens=None,
tile_scheduler_metadata=flashmla_metadata,
softmax_scale=self.softmax_scale,
is_fp8_kvcache=True,
indices=swa_page_indices,
topk_length=swa_topk_lengths,
attn_sink=attn_sink,
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_lengths,
)[0]

o = o.squeeze(1)
return o
Expand Down
74 changes: 73 additions & 1 deletion python/sglang/srt/layers/attention/dsv4/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.state_capturer.indexer_topk import get_global_indexer_capturer
from sglang.srt.utils import add_prefix, is_hip
from sglang.srt.utils.common import is_sm120_supported

if TYPE_CHECKING:
from sglang.srt.layers.attention.dsv4.compressor import (
Expand Down Expand Up @@ -90,6 +91,74 @@ def fp8_paged_mqa_logits_torch(
return logits


def fp8_paged_mqa_logits_torch_sm120(
q_fp8: torch.Tensor,
kvcache_fp8: torch.Tensor,
weight: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
deep_gemm_metadata: Any,
max_seq_len: int,
clean_logits: bool = True,
) -> torch.Tensor:
"""CUDA-graph-compatible FP8 paged MQA logits for SM120 (vectorized, no .item())."""
_ = deep_gemm_metadata
batch_size, _, num_heads, head_dim = q_fp8.shape
block_size = kvcache_fp8.shape[1]
device = q_fp8.device

assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128"
assert (
block_size == 64
), "Vectorized torch impl hardcodes block_size=64 cache layout"
assert q_fp8.shape == (batch_size, 1, num_heads, head_dim)
assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4)
assert weight.shape == (batch_size, num_heads)
if seq_lens.dim() > 1:
seq_lens = seq_lens.squeeze(-1)
assert seq_lens.shape == (batch_size,)
assert page_table.shape[0] == batch_size
assert clean_logits == False

max_pages = (max_seq_len + block_size - 1) // block_size
max_padded_seq = max_pages * block_size

kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4))
SCALE_OFFSET = block_size * head_dim

page_ids = page_table[:, :max_pages]
kvcache_gathered = kvcache_flat[page_ids]

kv_value_raw = kvcache_gathered[..., :SCALE_OFFSET]
kv_scale_raw = kvcache_gathered[..., SCALE_OFFSET:]

kv_value = kv_value_raw.contiguous().view(dtype=FP8_DTYPE).to(torch.float32)
kv_value = kv_value.view(batch_size, max_padded_seq, head_dim)

kv_scale = kv_scale_raw.contiguous().view(dtype=torch.float32)
kv_scale = kv_scale.view(batch_size, max_padded_seq)

q = q_fp8[:, 0].to(torch.float32)

score = torch.bmm(kv_value, q.transpose(1, 2))

score = F.relu(score)
score = score * weight.unsqueeze(1)
score = score.sum(dim=2)

score = score * kv_scale

out_width = min(max_padded_seq, max_seq_len)
logits = score.new_full((batch_size, max_seq_len), float("-inf"))
logits[:, :out_width] = score[:, :out_width]

positions = torch.arange(max_seq_len, device=device)
invalid_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(1)
logits.masked_fill_(invalid_mask, float("-inf"))

return logits


def topk_transform_512_pytorch_vectorized(
scores: torch.Tensor,
seq_lens: torch.Tensor,
Expand Down Expand Up @@ -372,7 +441,10 @@ def forward_c4_indexer(
tilelang_fp8_paged_mqa_logits as fn,
)
elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get():
fn = fp8_paged_mqa_logits_torch
if is_sm120_supported():
fn = fp8_paged_mqa_logits_torch_sm120
else:
fn = fp8_paged_mqa_logits_torch
else:
from deep_gemm import fp8_paged_mqa_logits as fn

Expand Down
Loading
Loading