Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
829ae2c
feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference
AliceChenyy May 8, 2026
0b46b07
fix: address PR review comments and handle KV cache uint8 dtype on SM120
AliceChenyy May 13, 2026
aa7000c
style: fix pre-commit lint issues (isort, ruff, black)
AliceChenyy May 14, 2026
3e33539
fix: wrap pytest.main in sys.exit for CI exit code propagation
AliceChenyy May 14, 2026
9db9a66
fix: add SM120 guards for sgl-kernel 0.4.2.post2 compatibility
AliceChenyy May 19, 2026
ef24964
fix: remove _is_sm120 module var from metadata.py, use is_sm120_suppo…
AliceChenyy May 19, 2026
1763009
refactor: simplify _is_sm120 by using is_sm120_supported() directly
AliceChenyy May 19, 2026
88a2c05
fix: replace broken _is_sm120 import in indexer.py with is_sm120_supp…
AliceChenyy May 20, 2026
a3d06b9
style: remove unused imports flagged by ruff (rotate_activation, act_…
AliceChenyy May 20, 2026
4491e60
Merge remote-tracking branch 'origin/main' into sm120-dsv4-rebase
AliceChenyy May 20, 2026
5df66d2
fix: rename nsa_ to dsa_ in SM120 backend defaults after NSA→DSA refa…
AliceChenyy May 20, 2026
03be8ad
test: register SM120 MQA fallback tests in CI (base-b, 1-gpu-small)
AliceChenyy May 20, 2026
d9c4bfd
fix: address Fridge003 review — isolate SM120 code, auto-set envs, re…
AliceChenyy May 21, 2026
21c00bc
docs: add SM120 (RTX PRO 6000) to DeepSeek-V4 cookbook
AliceChenyy May 21, 2026
5f526de
Merge remote-tracking branch 'origin/main' into sm120-dsv4-rebase
AliceChenyy May 21, 2026
cd44380
fix: revert unnecessary change to is_arch_support_pdl in jit_kernel/u…
AliceChenyy May 21, 2026
fe39f16
Merge remote-tracking branch 'origin/main' into sm120-dsv4-rebase
AliceChenyy May 25, 2026
6a19b79
test(sm120): add P0 unit tests for paged MQA logits + FlashMLA fallbacks
AliceChenyy May 25, 2026
e1368c6
style(sm120): black/isort + drop unused import in P0 tests
AliceChenyy May 25, 2026
f8f982b
ci: retrigger CI after 120-min per-user cooldown expiry
AliceChenyy May 25, 2026
fb0b6a9
Merge remote-tracking branch 'origin/main' into sm120-dsv4-rebase
AliceChenyy May 28, 2026
d495c59
Merge branch 'main' into sm120-dsv4-rebase
b8zhong May 29, 2026
5387b1c
docs(cookbook): shrink SM120 note to selectable recipe
AliceChenyy May 29, 2026
6b5a290
Merge branch 'sm120-dsv4-rebase' of github.com:AliceChenyy/sglang int…
AliceChenyy May 29, 2026
787a47a
Merge branch 'main' into sm120-dsv4-rebase
Fridge003 May 29, 2026
c789f2b
fix: address Fridge003 review — JSX toggle, revert unnecessary changes
AliceChenyy May 29, 2026
579b359
Merge branch 'sm120-dsv4-rebase' of github.com:AliceChenyy/sglang int…
AliceChenyy May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/sglang/jit_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ def get_jit_cuda_arch() -> ArchInfo:
def is_arch_support_pdl() -> bool:
if is_hip_runtime():
return False
return get_jit_cuda_arch().major >= 9
arch = get_jit_cuda_arch()
Comment thread
Fridge003 marked this conversation as resolved.
Outdated
# PDL (griddepcontrol) instruction is supported on SM90+ (Hopper, Blackwell).
# SM120 (desktop Blackwell) supports PDL despite lacking TMEM/tcgen05 —
# PDL uses griddepcontrol for kernel scheduling, independent of TMEM.
return arch.major >= 9


def _find_package_root(package: str) -> Optional[pathlib.Path]:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/arg_groups/deepseek_v4_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def apply_deepseek_v4_defaults(server_args: "ServerArgs", model_arch: str) -> No
f"Setting swa_full_tokens_ratio to {server_args.swa_full_tokens_ratio} for {model_arch}."
)

# SM120: auto-select marlin MoE backend (dispatches to SM120 Triton kernel)
from sglang.srt.utils.common import is_sm120_supported

if is_sm120_supported() and server_args.moe_runner_backend == "auto":
Comment thread
Fridge003 marked this conversation as resolved.
Outdated
server_args.moe_runner_backend = "marlin"
logger.info("Use marlin as MoE runner backend on SM120 for DeepSeekV4")


def validate_deepseek_v4_cp(server_args: "ServerArgs") -> None:
"""Validate DeepSeek V4 context-parallel configuration."""
Expand Down
14 changes: 10 additions & 4 deletions python/sglang/srt/layers/attention/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
from sglang.srt.layers.attention.dsv4.quant_k_cache import (
quant_to_nope_fp8_rope_bf16_pack_triton,
)
from sglang.srt.layers.attention.flash_mla_sm120_fallback import (
_is_sm120,
flash_mla_with_kvcache_entrypoint,
)
from sglang.srt.layers.dp_attention import (
get_attention_cp_rank,
get_attention_cp_size,
Expand Down Expand Up @@ -81,6 +85,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T:


def _create_flashmla_metadata():
if _is_sm120:
return None
import flash_mla

return flash_mla.get_mla_metadata()[0]
Expand Down Expand Up @@ -1031,9 +1037,7 @@ def forward(
extra_indices.shape[-1] % 64 == 0
), f"{extra_indices.shape=}'s last dimension is not aligned to 64"

import flash_mla

o = flash_mla.flash_mla_with_kvcache(
input_dict = dict(
q=q,
k_cache=swa_k_cache,
head_dim_v=self.head_dim_v,
Expand All @@ -1048,7 +1052,9 @@ def forward(
extra_k_cache=extra_k_cache,
extra_indices_in_kvcache=extra_indices,
extra_topk_length=extra_topk_lengths,
)[0]
)

o = flash_mla_with_kvcache_entrypoint(**input_dict, backend="kernel")[0]
Comment thread
Fridge003 marked this conversation as resolved.
Outdated

o = o.squeeze(1)
return o
Expand Down
84 changes: 78 additions & 6 deletions python/sglang/srt/layers/attention/dsa/dsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_hip,
is_npu,
)
from sglang.srt.utils.common import is_sm120_supported
Comment thread
Fridge003 marked this conversation as resolved.
Outdated

logger = logging.getLogger(__name__)

Expand All @@ -57,9 +58,34 @@
if _is_cuda:
try:
import deep_gemm
except ImportError as e:
except (ImportError, AssertionError) as e:
# AssertionError: deep_gemm init fails on SM120 (no CUDA_HOME / unsupported arch)
deep_gemm = e

if is_sm120_supported():
import os as _os

if _os.environ.get("SGLANG_SM120_MQA_FALLBACK", "0") == "1":
from sglang.srt.layers.attention.dsa.sm120_mqa_fallback import (
compute_paged_mqa_schedule_metadata as _sm120_compute_paged_mqa_schedule_metadata,
)
from sglang.srt.layers.attention.dsa.sm120_mqa_fallback import (
sm120_fp8_mqa_logits as _sm120_fp8_mqa_logits,
)
from sglang.srt.layers.attention.dsa.sm120_mqa_fallback import (
sm120_fp8_paged_mqa_logits as _sm120_fp8_paged_mqa_logits,
)
else:
from sglang.srt.layers.attention.dsa.sm120_mqa_triton import (
compute_paged_mqa_schedule_metadata as _sm120_compute_paged_mqa_schedule_metadata,
)
from sglang.srt.layers.attention.dsa.sm120_mqa_triton import (
sm120_fp8_mqa_logits as _sm120_fp8_mqa_logits,
)
from sglang.srt.layers.attention.dsa.sm120_mqa_triton import (
sm120_fp8_paged_mqa_logits as _sm120_fp8_paged_mqa_logits,
)

if _use_aiter:
from aiter.ops.cache import indexer_k_quant_and_cache

Expand Down Expand Up @@ -218,7 +244,12 @@ def __init__(
self.cp_size = None
self.cp_rank = None
if _is_cuda:
self.sm_count = deep_gemm.get_num_sms()
if is_sm120_supported():
# SM120: deep_gemm.get_num_sms() crashes; use torch native API
props = torch.cuda.get_device_properties(torch.cuda.current_device())
self.sm_count = props.multi_processor_count
else:
self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = ceil_align(self.sm_count // 2, 8)
pp_size = get_global_server_args().pp_size
self.logits_with_pp_recv = pp_size > 1 and not get_pp_group().is_last_rank
Expand Down Expand Up @@ -269,7 +300,7 @@ def _with_real_sm_count(self):
# request to receive the PP proxy tensor or output from the previous stage, occupying one SM resource.
# Model execution runs in parallel with the recv operation, so the SMs available to the indexer must be reduced
# by 1. Currently, the last rank starts the send result + recv request only after waiting for execution results.
if self.logits_with_pp_recv:
if self.logits_with_pp_recv and not is_sm120_supported():
pp_recv_sm_count = 1
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
self.sm_count - pp_recv_sm_count
Expand Down Expand Up @@ -493,9 +524,16 @@ def _get_topk_paged(
seqlens_32_2d = seqlens_32.unsqueeze(-1)
if _is_cuda:
if schedule_metadata is None:
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
seqlens_32_2d, blocksize, self.sm_count
)
if is_sm120_supported():
schedule_metadata = _sm120_compute_paged_mqa_schedule_metadata(
seqlens_32_2d,
blocksize,
self.sm_count,
)
else:
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
seqlens_32_2d, blocksize, self.sm_count
)

assert len(q_fp8.shape) == 3
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
Expand Down Expand Up @@ -532,6 +570,17 @@ def _get_topk_paged(
Preshuffle=_use_aiter_preshuffle,
KVBlockSize=block_kv,
)
elif is_sm120_supported():
logits = _sm120_fp8_paged_mqa_logits(
q_fp8[:q_offset],
kv_cache_fp8,
weights[:q_offset],
seqlens_32_2d,
block_tables,
schedule_metadata,
max_seq_len,
clean_logits=False,
)
else:
logits = deep_gemm.fp8_paged_mqa_logits(
q_fp8[:q_offset],
Expand Down Expand Up @@ -707,6 +756,15 @@ def _get_topk_ragged(
logits = fp8_mqa_logits(
q_fp8[:q_offset], kv, scale, weights[:q_offset], ks, ke
)
elif is_sm120_supported():
logits = _sm120_fp8_mqa_logits(
q_fp8[:q_offset],
kv_fp8,
weights[:q_offset],
ks,
ke,
clean_logits=False,
)
else:
logits = deep_gemm.fp8_mqa_logits(
q_fp8[:q_offset],
Expand Down Expand Up @@ -757,6 +815,15 @@ def _get_topk_ragged(
ks[start:end],
ke[start:end],
)
elif is_sm120_supported():
logits_chunk = _sm120_fp8_mqa_logits(
q_fp8[start:end],
kv_fp8,
weights[start:end],
ks[start:end],
ke[start:end],
clean_logits=False,
)
else:
logits_chunk = deep_gemm.fp8_mqa_logits(
q_fp8[start:end],
Expand Down Expand Up @@ -848,6 +915,11 @@ def _get_topk_ragged_with_cp(
actual_seq_q: int,
cp_index: List[Tuple[int, int, int]] = None,
) -> torch.Tensor:
if is_sm120_supported():
raise NotImplementedError(
"Ragged CP path requires DeepGEMM fp8_mqa_logits which is not "
"supported on SM120. Use paged topk_transform instead."
)
if TYPE_CHECKING:
assert isinstance(forward_batch.token_to_kv_pool, DSATokenToKVPool)

Expand Down
Loading
Loading