diff --git a/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx b/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx
index 392282aaa9f0..d59a613ee139 100644
--- a/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx
+++ b/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx
@@ -29,7 +29,7 @@ tag: NEW
DeepSeek-V4-Flash |
284B |
13B |
- single-node serving: B200 / GB200 / GB300 / H200 on 4 GPUs |
+ single-node serving: B200 / GB200 / GB300 / H200 on 4 GPUs; RTX PRO 6000 (SM120) on 4-8 GPUs |
| DeepSeek-V4-Pro |
@@ -128,6 +128,20 @@ PD-Disagg recipes on H200 may require `docker run --privileged --ulimit memlock=
can discover the IB HCAs; without IB exposure mooncake silently falls back to
TCP, which can lead to garbled KV transfer on large checkpoints.
+
+
+**SM120 (RTX PRO 6000 Blackwell Server Edition) note**
+
+DeepSeek-V4-Flash can run on RTX PRO 6000 Blackwell Server Edition (SM120, 96 GB GDDR7) with Tensor Parallelism only. We support two TP configurations:
+- **TP=8 (recommended)**: `--tp 8 --mem-fraction-static 0.70 --cuda-graph-max-bs 32`. Leaves ~20 GB per GPU for KV cache.
+- **TP=4 (memory-constrained)**: `--tp 4 --mem-fraction-static 0.90 --cuda-graph-max-bs 4`. Runs near the 96 GB memory limit (~98.8% usage) with minimal KV cache headroom.
+
+SM120 uses Triton-based MoE and FlashMLA fallback kernels instead of CUTLASS/DeepGEMM (auto-detected, no manual flags needed). Use Docker image `lmsysorg/sglang:dev-cu13` (CUDA 13.0, required for SM120 / CC 12.0).
+
+Performance is memory-bandwidth bound (~1.5 TB/s GDDR7 vs ~8 TB/s HBM3e on B200); expect ~15-17 tok/s at BS=1. Accuracy matches reference (GSM8K 10/10, GPQA Diamond 72.0% vs 71.2% published).
+
+V4-Pro is not supported on SM120 (model does not fit in 8x 96 GB).
+
**MegaMoE**
MegaMoE fuses expert dispatch + GEMM into a single kernel for higher throughput
diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py
index f9f396428557..53741c1b8fd3 100644
--- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py
+++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py
@@ -55,6 +55,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ceil_align
+from sglang.srt.utils.common import is_sm120_supported
if TYPE_CHECKING:
from flash_mla.flash_mla_interface import FlashMLASchedMeta
@@ -62,6 +63,8 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
+_is_sm120 = is_sm120_supported()
+
logger = logging.getLogger(__name__)
SWA_WINDOW = 128
@@ -81,6 +84,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T:
def _create_flashmla_metadata():
+ if _is_sm120:
+ return None
import flash_mla
return flash_mla.get_mla_metadata()[0]
@@ -1031,24 +1036,42 @@ def forward(
extra_indices.shape[-1] % 64 == 0
), f"{extra_indices.shape=}'s last dimension is not aligned to 64"
- import flash_mla
-
- o = flash_mla.flash_mla_with_kvcache(
- q=q,
- k_cache=swa_k_cache,
- head_dim_v=self.head_dim_v,
- block_table=None,
- cache_seqlens=None,
- tile_scheduler_metadata=flashmla_metadata,
- softmax_scale=self.softmax_scale,
- is_fp8_kvcache=True,
- indices=swa_page_indices,
- topk_length=swa_topk_lengths,
- attn_sink=attn_sink,
- extra_k_cache=extra_k_cache,
- extra_indices_in_kvcache=extra_indices,
- extra_topk_length=extra_topk_lengths,
- )[0]
+ if _is_sm120:
+ from sglang.srt.layers.attention.flash_mla_sm120 import (
+ flash_mla_with_kvcache_sm120,
+ )
+
+ o = flash_mla_with_kvcache_sm120(
+ q=q,
+ k_cache=swa_k_cache,
+ head_dim_v=self.head_dim_v,
+ softmax_scale=self.softmax_scale,
+ indices=swa_page_indices,
+ topk_length=swa_topk_lengths,
+ attn_sink=attn_sink,
+ extra_k_cache=extra_k_cache,
+ extra_indices_in_kvcache=extra_indices,
+ extra_topk_length=extra_topk_lengths,
+ )[0]
+ else:
+ import flash_mla
+
+ o = flash_mla.flash_mla_with_kvcache(
+ q=q,
+ k_cache=swa_k_cache,
+ head_dim_v=self.head_dim_v,
+ block_table=None,
+ cache_seqlens=None,
+ tile_scheduler_metadata=flashmla_metadata,
+ softmax_scale=self.softmax_scale,
+ is_fp8_kvcache=True,
+ indices=swa_page_indices,
+ topk_length=swa_topk_lengths,
+ attn_sink=attn_sink,
+ extra_k_cache=extra_k_cache,
+ extra_indices_in_kvcache=extra_indices,
+ extra_topk_length=extra_topk_lengths,
+ )[0]
o = o.squeeze(1)
return o
diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py
index f8899264d6d0..ec1b6d9665bc 100644
--- a/python/sglang/srt/layers/attention/dsv4/indexer.py
+++ b/python/sglang/srt/layers/attention/dsv4/indexer.py
@@ -20,6 +20,7 @@
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.state_capturer.indexer_topk import get_global_indexer_capturer
from sglang.srt.utils import add_prefix, is_hip
+from sglang.srt.utils.common import is_sm120_supported
if TYPE_CHECKING:
from sglang.srt.layers.attention.dsv4.compressor import (
@@ -90,6 +91,74 @@ def fp8_paged_mqa_logits_torch(
return logits
+def fp8_paged_mqa_logits_torch_sm120(
+ q_fp8: torch.Tensor,
+ kvcache_fp8: torch.Tensor,
+ weight: torch.Tensor,
+ seq_lens: torch.Tensor,
+ page_table: torch.Tensor,
+ deep_gemm_metadata: Any,
+ max_seq_len: int,
+ clean_logits: bool = True,
+) -> torch.Tensor:
+ """CUDA-graph-compatible FP8 paged MQA logits for SM120 (vectorized, no .item())."""
+ _ = deep_gemm_metadata
+ batch_size, _, num_heads, head_dim = q_fp8.shape
+ block_size = kvcache_fp8.shape[1]
+ device = q_fp8.device
+
+ assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128"
+ assert (
+ block_size == 64
+ ), "Vectorized torch impl hardcodes block_size=64 cache layout"
+ assert q_fp8.shape == (batch_size, 1, num_heads, head_dim)
+ assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4)
+ assert weight.shape == (batch_size, num_heads)
+ if seq_lens.dim() > 1:
+ seq_lens = seq_lens.squeeze(-1)
+ assert seq_lens.shape == (batch_size,)
+ assert page_table.shape[0] == batch_size
+ assert clean_logits == False
+
+ max_pages = (max_seq_len + block_size - 1) // block_size
+ max_padded_seq = max_pages * block_size
+
+ kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4))
+ SCALE_OFFSET = block_size * head_dim
+
+ page_ids = page_table[:, :max_pages]
+ kvcache_gathered = kvcache_flat[page_ids]
+
+ kv_value_raw = kvcache_gathered[..., :SCALE_OFFSET]
+ kv_scale_raw = kvcache_gathered[..., SCALE_OFFSET:]
+
+ kv_value = kv_value_raw.contiguous().view(dtype=FP8_DTYPE).to(torch.float32)
+ kv_value = kv_value.view(batch_size, max_padded_seq, head_dim)
+
+ kv_scale = kv_scale_raw.contiguous().view(dtype=torch.float32)
+ kv_scale = kv_scale.view(batch_size, max_padded_seq)
+
+ q = q_fp8[:, 0].to(torch.float32)
+
+ score = torch.bmm(kv_value, q.transpose(1, 2))
+
+ score = F.relu(score)
+ score = score * weight.unsqueeze(1)
+ score = score.sum(dim=2)
+
+ score = score * kv_scale
+
+ out_width = min(max_padded_seq, max_seq_len)
+ logits = score.new_full((batch_size, max_seq_len), float("-inf"))
+ logits[:, :out_width] = score[:, :out_width]
+
+ positions = torch.arange(max_seq_len, device=device)
+ invalid_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(1)
+ logits.masked_fill_(invalid_mask, float("-inf"))
+
+ return logits
+
+
def topk_transform_512_pytorch_vectorized(
scores: torch.Tensor,
seq_lens: torch.Tensor,
@@ -372,7 +441,10 @@ def forward_c4_indexer(
tilelang_fp8_paged_mqa_logits as fn,
)
elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get():
- fn = fp8_paged_mqa_logits_torch
+ if is_sm120_supported():
+ fn = fp8_paged_mqa_logits_torch_sm120
+ else:
+ fn = fp8_paged_mqa_logits_torch
else:
from deep_gemm import fp8_paged_mqa_logits as fn
diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120.py b/python/sglang/srt/layers/attention/flash_mla_sm120.py
new file mode 100644
index 000000000000..8d49d2a6e461
--- /dev/null
+++ b/python/sglang/srt/layers/attention/flash_mla_sm120.py
@@ -0,0 +1,252 @@
+"""SM120 FlashMLA sparse decode implementation.
+
+On SM120 (Blackwell Desktop / RTX PRO 6000) the flash_mla CUDA kernel
+is not available, so this module provides alternative implementations:
+
+- A fused Triton kernel (default, ``SGLANG_SM120_TRITON_FLASHMLA=1``)
+- A pure-PyTorch fallback (``SGLANG_SM120_TRITON_FLASHMLA=0``)
+
+The FP8 KV cache uses a page-internal layout where NOPE+ROPE data has
+stride (nope_dim + rope_dim*2) per token, and scales are stored in a
+separate region at the end of each page.
+"""
+
+import logging
+import os
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+# Page layout constants for DSv4-Flash (MODEL1):
+# nope_dim = 448, rope_dim = 64, quantize_block_size = 64
+# nope_rope_stride = 448 + 64*2 = 576 bytes per token
+# scale_stride = ceil(448/64) + 1 = 8 bytes per token (7 scales + 1 pad)
+# bytes_per_token = 448 + 128 + 8 = 584
+# page_bytes = ceil_div(page_size * 584, 576) * 576
+
+_NOPE_DIM = 448
+_ROPE_DIM = 64
+_NOPE_ROPE_STRIDE = _NOPE_DIM + _ROPE_DIM * 2 # 576
+_TILE_SIZE = 64
+_NUM_TILES = _NOPE_DIM // _TILE_SIZE # 7
+_SCALE_STRIDE = _NUM_TILES + 1 # 8 (7 scales + 1 pad)
+_D = _NOPE_DIM + _ROPE_DIM # 512
+
+
+def _gather_and_dequant(k_cache, indices, page_size):
+ """Gather KV entries from the paged buffer using correct page-internal addressing.
+
+ Args:
+ k_cache: (num_pages, page_size, 1, bytes_per_token) float8_e4m3fn
+ Non-contiguous view of the raw page buffer.
+ indices: (...) int32/int64, token-level indices. -1 = invalid.
+ page_size: tokens per page (256)
+
+ Returns:
+ kv: (..., _D) bfloat16, dequantized KV vectors
+ """
+ idx_shape = indices.shape
+ flat_idx = indices.reshape(-1) # (N,)
+ N = flat_idx.shape[0]
+ device = k_cache.device
+
+ # Page-level addressing
+ page_bytes = k_cache.stride(0) # actual byte stride between pages
+ pages = flat_idx // page_size
+ offsets = flat_idx % page_size
+
+ # Clamp invalid indices
+ safe_pages = pages.clamp(min=0)
+ safe_offsets = offsets.clamp(min=0)
+
+ # Access raw buffer as uint8 — use as_strided to get full page view
+ num_pages = k_cache.shape[0]
+ raw_pages = k_cache.as_strided(
+ (num_pages, page_bytes),
+ (page_bytes, 1),
+ ).view(
+ torch.uint8
+ ) # (num_pages, page_bytes) uint8
+ # Note: float8_e4m3fn and uint8 are both 1 byte, view is safe
+
+ # Compute byte offsets within each page
+ # NOPE: page[safe_page, safe_offset * 576 + 0:448]
+ # ROPE: page[safe_page, safe_offset * 576 + 448:576]
+ # SCALES: page[safe_page, page_size * 576 + safe_offset * 8 + 0:7]
+
+ nope_base = safe_offsets * _NOPE_ROPE_STRIDE # (N,)
+ nope_offsets = nope_base.unsqueeze(-1) + torch.arange(
+ _NOPE_DIM, device=device, dtype=torch.long
+ ) # (N, 448)
+
+ rope_base = nope_base + _NOPE_DIM # (N,)
+ rope_offsets = rope_base.unsqueeze(-1) + torch.arange(
+ _ROPE_DIM * 2, device=device, dtype=torch.long
+ ) # (N, 128)
+
+ scale_section_offset = page_size * _NOPE_ROPE_STRIDE # 147456
+ scale_base = scale_section_offset + safe_offsets * _SCALE_STRIDE # (N,)
+ scale_offsets = scale_base.unsqueeze(-1) + torch.arange(
+ _NUM_TILES, device=device, dtype=torch.long
+ ) # (N, 7)
+
+ # Gather bytes per page — use advanced indexing
+ # raw_pages[safe_pages, nope_offsets] → (N, 448)
+ page_idx_nope = safe_pages.unsqueeze(-1).expand_as(nope_offsets)
+ nope_bytes = raw_pages[page_idx_nope, nope_offsets] # (N, 448) uint8
+
+ page_idx_rope = safe_pages.unsqueeze(-1).expand_as(rope_offsets)
+ rope_bytes = raw_pages[page_idx_rope, rope_offsets] # (N, 128) uint8
+
+ page_idx_scale = safe_pages.unsqueeze(-1).expand_as(scale_offsets)
+ scale_bytes = raw_pages[page_idx_scale, scale_offsets] # (N, 7) uint8
+
+ # Reinterpret dtypes
+ nope_fp8 = nope_bytes.view(torch.float8_e4m3fn) # (N, 448)
+ rope_bf16 = rope_bytes.contiguous().view(torch.bfloat16) # (N, 64)
+ scale_e8m0 = scale_bytes.view(torch.float8_e8m0fnu) # (N, 7)
+
+ # Dequantize: nope_tile * scale_tile → bf16 (vectorized)
+ result = torch.empty(N, _D, dtype=torch.bfloat16, device=device)
+ result[:, :_NOPE_DIM] = (
+ (
+ nope_fp8.view(N, _NUM_TILES, _TILE_SIZE).float()
+ * scale_e8m0.view(N, _NUM_TILES, 1).float()
+ )
+ .view(N, _NOPE_DIM)
+ .to(torch.bfloat16)
+ )
+ result[:, _NOPE_DIM:] = rope_bf16
+
+ return result.reshape(*idx_shape, _D)
+
+
+def _sm120_sparse_decode_fwd(
+ q,
+ k_cache,
+ indices,
+ topk_length,
+ attn_sink,
+ head_dim_v,
+ softmax_scale,
+ extra_k_cache=None,
+ extra_indices=None,
+ extra_topk_length=None,
+):
+ B, s_q, H_q, D_qk = q.shape
+ num_pages, page_size, H_k, bpt = k_cache.shape
+ topk = indices.shape[-1]
+
+ invalid_mask = indices < 0
+ safe_indices = indices.clamp(min=0)
+
+ if topk_length is not None:
+ topk_range = torch.arange(topk, device=topk_length.device).view(1, 1, topk)
+ invalid_mask = invalid_mask | (topk_range >= topk_length.view(B, 1, 1))
+
+ # Gather and dequantize using page-aware addressing
+ gathered_kv = _gather_and_dequant(k_cache, safe_indices, page_size)
+
+ if extra_k_cache is not None and extra_indices is not None:
+ extra_topk = extra_indices.shape[-1]
+ extra_page_size = extra_k_cache.shape[1]
+ extra_invalid = extra_indices < 0
+ extra_safe = extra_indices.clamp(min=0)
+ if extra_topk_length is not None:
+ extra_range = torch.arange(
+ extra_topk, device=extra_topk_length.device
+ ).view(1, 1, extra_topk)
+ extra_invalid = extra_invalid | (
+ extra_range >= extra_topk_length.view(B, 1, 1)
+ )
+ extra_kv = _gather_and_dequant(extra_k_cache, extra_safe, extra_page_size)
+ gathered_kv = torch.cat([gathered_kv, extra_kv], dim=2)
+ invalid_mask = torch.cat([invalid_mask, extra_invalid], dim=2)
+
+ gathered_kv[invalid_mask] = 0.0
+
+ q_f = q.float()
+ kv_f = gathered_kv.float()
+ kv_d = kv_f.shape[-1]
+ if D_qk != kv_d:
+ q_f = q_f[..., :kv_d]
+
+ scores = torch.einsum("bshd,bstd->bsht", q_f, kv_f) * softmax_scale
+ scores.masked_fill_(invalid_mask.unsqueeze(2).expand_as(scores), float("-inf"))
+
+ lse = torch.logsumexp(scores, dim=-1)
+
+ if attn_sink is not None:
+ lse_for_out = torch.logsumexp(
+ torch.stack([lse, attn_sink.view(1, 1, H_q).expand_as(lse)], dim=0), dim=0
+ )
+ else:
+ lse_for_out = lse.clone()
+
+ lonely = lse == float("-inf")
+ lse_for_out[lonely] = float("inf")
+ weights = torch.exp(scores - lse_for_out.unsqueeze(-1))
+ out = torch.einsum("bsht,bstv->bshv", weights, kv_f[..., :head_dim_v])
+ out[lonely.unsqueeze(-1).expand_as(out)] = 0.0
+
+ return out.to(torch.bfloat16), lse.permute(0, 2, 1)
+
+
+# Default SM120 FlashMLA backend: "triton" (optimized) or "torch" (pure-PyTorch fallback).
+# Controlled by SGLANG_SM120_TRITON_FLASHMLA env var (1=triton, 0=torch).
+_sm120_default_backend = (
+ "triton" if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" else "torch"
+)
+
+
+def flash_mla_with_kvcache_sm120(**kwargs):
+ """SM120 FlashMLA sparse decode entry point.
+
+ Dispatches to the Triton kernel (default) or PyTorch fallback.
+ """
+ q = kwargs["q"]
+ k_cache = kwargs["k_cache"]
+ indices = kwargs["indices"]
+ topk_length = kwargs.get("topk_length")
+ attn_sink = kwargs.get("attn_sink")
+ head_dim_v = kwargs["head_dim_v"]
+ softmax_scale = kwargs.get("softmax_scale")
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+ extra_k_cache = kwargs.get("extra_k_cache")
+ extra_indices = kwargs.get("extra_indices_in_kvcache")
+ extra_topk_length = kwargs.get("extra_topk_length")
+
+ if _sm120_default_backend == "triton":
+ from sglang.srt.layers.attention.flash_mla_sm120_triton import (
+ flash_mla_sparse_decode_triton,
+ )
+
+ out, lse = flash_mla_sparse_decode_triton(
+ q,
+ k_cache,
+ indices,
+ topk_length,
+ attn_sink,
+ head_dim_v,
+ softmax_scale,
+ extra_k_cache,
+ extra_indices,
+ extra_topk_length,
+ )
+ return (out, lse)
+
+ out, lse = _sm120_sparse_decode_fwd(
+ q,
+ k_cache,
+ indices,
+ topk_length,
+ attn_sink,
+ head_dim_v,
+ softmax_scale,
+ extra_k_cache,
+ extra_indices,
+ extra_topk_length,
+ )
+ return (out, lse)
diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py
new file mode 100644
index 000000000000..20ffb4d3cbab
--- /dev/null
+++ b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py
@@ -0,0 +1,370 @@
+"""SM120-optimized Triton FlashMLA sparse decode kernel — Tiled V2.
+
+Replaces V1's serial token loop with a tiled vectorized approach:
+ 1. BLOCK_T tokens loaded simultaneously via 2D gather (vs 1-at-a-time)
+ 2. All BLOCK_T QK scores computed at once via vectorized mul-reduce
+ 3. V accumulation via vectorized weighted sum across BLOCK_T tokens
+ 4. Online softmax operates on tile-level maxima (fewer rescales)
+
+Three typed views of the same paged buffer handle FP8/uint8/BF16 regions:
+- float8_e4m3fn view -> nope FP8 values (direct load + dequant)
+- uint8 view -> UE8M0 scale bytes (raw integer -> exp2 conversion)
+- bfloat16 view -> rope BF16 values (direct load)
+
+DSv4 page layout (per token, 576 bytes data + 8 bytes scales):
+ Data section: [0:448] FP8 nope | [448:576] BF16 rope (64 values = 128 bytes)
+ Scale section: [page_size*576 + offset*8 : +7] UE8M0 scales (7 groups of 64)
+
+Target: RTX PRO 6000 (SM120, 188 SMs, 99KB SMEM, ~1.5 TB/s GDDR7, 96MB L2)
+"""
+
+import logging
+from typing import Optional, Tuple
+
+import torch
+import triton
+import triton.language as tl
+
+logger = logging.getLogger(__name__)
+
+LOG2E = tl.constexpr(1.4426950408889634)
+
+# DSv4 KV cache layout constants
+_NOPE_DIM = 448
+_ROPE_DIM = 64
+_D = _NOPE_DIM + _ROPE_DIM # 512
+_TOKEN_DATA_STRIDE = 576 # bytes per token in data section
+_SCALE_STRIDE = 8 # bytes per token in scale section
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_T": 16}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_T": 16}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_T": 32}, num_warps=8, num_stages=2),
+ ],
+ key=["topk_rounded"],
+)
+@triton.jit
+def _tiled_sparse_decode_kernel(
+ # Q: [B, H, D] bf16
+ Q_ptr,
+ # Paged KV cache — three typed views of same underlying memory
+ cache_fp8_ptr, # float8_e4m3fn flat (1 byte/elem) — for nope
+ cache_uint8_ptr, # uint8 flat (1 byte/elem) — for scales
+ cache_bf16_ptr, # bfloat16 flat (2 bytes/elem) — for rope
+ # Indices: [B, topk] int32
+ indices_ptr,
+ # Valid lengths: [B] int32
+ topk_len_ptr,
+ # Output: [B, H, D] bf16 and LSE: [B, H] float32
+ O_ptr,
+ LSE_ptr,
+ # Scalars
+ sm_scale: tl.float32,
+ page_size: tl.int32,
+ page_bytes: tl.int64,
+ scale_section_off: tl.int64, # page_size * 576
+ H: tl.int32,
+ topk: tl.int32,
+ topk_rounded: tl.int32, # for autotune key
+ has_topk_len: tl.constexpr,
+ # Strides
+ stride_qb: tl.int32,
+ stride_qh: tl.int32,
+ stride_ob: tl.int32,
+ stride_oh: tl.int32,
+ stride_ib: tl.int32, # indices batch stride
+ # Constexprs
+ NOPE_PAD: tl.constexpr, # 512 (padded from 448)
+ ROPE_DIM: tl.constexpr, # 64
+ NOPE_DIM_RT: tl.int32, # 448 (runtime, for masking)
+ BLOCK_T: tl.constexpr, # tokens per tile (16 or 32)
+):
+ """Tiled sparse decode: vectorized gather + QK + softmax + V accumulation.
+
+ Grid: (B, H) — one block per (batch, head) pair.
+ Each block processes all topk tokens in tiles of BLOCK_T.
+ """
+ bid = tl.program_id(0)
+ hid = tl.program_id(1)
+
+ # ---- Load Q for this (batch, head) ----
+ q_base = bid * stride_qb + hid * stride_qh
+ nope_offs = tl.arange(0, NOPE_PAD) # [512]
+ nope_mask = nope_offs < NOPE_DIM_RT # [512], True for [0:448]
+ rope_offs = tl.arange(0, ROPE_DIM) # [64]
+
+ q_nope = tl.load(Q_ptr + q_base + nope_offs, mask=nope_mask, other=0.0)
+ q_nope = q_nope.to(tl.float32) * sm_scale
+ q_rope = tl.load(Q_ptr + q_base + NOPE_DIM_RT + rope_offs)
+ q_rope = q_rope.to(tl.float32) * sm_scale
+
+ # ---- Valid token count ----
+ valid_topk = topk
+ if has_topk_len:
+ valid_topk = tl.load(topk_len_ptr + bid).to(tl.int32)
+ valid_topk = tl.minimum(valid_topk, topk)
+
+ # ---- Online softmax state (base-2 math for SM120 efficiency) ----
+ m_i: tl.float32 = -1e30
+ l_i: tl.float32 = 0.0
+ acc_nope = tl.zeros([NOPE_PAD], dtype=tl.float32)
+ acc_rope = tl.zeros([ROPE_DIM], dtype=tl.float32)
+
+ # ---- Precompute constant index vectors ----
+ group_ids = (nope_offs // 64).to(tl.int64) # [NOPE_PAD], scale group for each dim
+ t_offs = tl.arange(0, BLOCK_T) # [BLOCK_T], token offsets within tile
+
+ # ---- Process tokens in tiles of BLOCK_T ----
+ for tile_start in range(0, topk, BLOCK_T):
+ t_idx = tile_start + t_offs # [BLOCK_T], global token indices
+ t_in_bounds = t_idx < topk # bounds for index load
+ t_valid = t_idx < valid_topk # bounds for actual processing
+
+ # Load indices for this tile: [BLOCK_T]
+ raw_indices = tl.load(
+ indices_ptr + bid * stride_ib + t_idx,
+ mask=t_in_bounds,
+ other=-1,
+ )
+ idx_valid = t_valid & (raw_indices >= 0) # [BLOCK_T] mask
+
+ # Page addressing: [BLOCK_T] (clamp for safe addressing of invalid tokens)
+ safe_indices = tl.where(idx_valid, raw_indices, tl.zeros_like(raw_indices))
+ page_ids = (safe_indices // page_size).to(tl.int64)
+ page_offs_t = (safe_indices % page_size).to(tl.int64)
+ token_data_bases = page_ids * page_bytes + page_offs_t * 576 # [BLOCK_T] int64
+
+ # ---- Vectorized NOPE FP8 gather: [BLOCK_T, NOPE_PAD] ----
+ nope_addrs = token_data_bases[:, None] + nope_offs[None, :].to(tl.int64)
+ nope_2d_mask = idx_valid[:, None] & nope_mask[None, :]
+ kv_nope_fp8 = tl.load(
+ cache_fp8_ptr + nope_addrs,
+ mask=nope_2d_mask,
+ other=0.0,
+ )
+
+ # ---- Vectorized scale gather + dequant: [BLOCK_T, NOPE_PAD] ----
+ scale_bases = page_ids * page_bytes + scale_section_off + page_offs_t * 8
+ scale_addrs = scale_bases[:, None] + group_ids[None, :]
+ scale_raw = tl.load(
+ cache_uint8_ptr + scale_addrs,
+ mask=nope_2d_mask,
+ other=127,
+ )
+ scale_f32 = tl.math.exp2(scale_raw.to(tl.float32) - 127.0)
+ kv_nope = tl.where(nope_2d_mask, kv_nope_fp8.to(tl.float32) * scale_f32, 0.0)
+
+ # ---- Vectorized ROPE BF16 gather: [BLOCK_T, ROPE_DIM] ----
+ rope_byte_bases = token_data_bases + 448
+ rope_elem_bases = (rope_byte_bases // 2).to(tl.int64)
+ rope_addrs = rope_elem_bases[:, None] + rope_offs[None, :].to(tl.int64)
+ kv_rope = tl.load(
+ cache_bf16_ptr + rope_addrs,
+ mask=idx_valid[:, None],
+ other=0.0,
+ ).to(tl.float32)
+
+ # ---- Vectorized QK scores: [BLOCK_T] ----
+ # scores[t] = dot(q_nope, kv_nope[t]) + dot(q_rope, kv_rope[t])
+ scores = tl.sum(q_nope[None, :] * kv_nope, axis=1) + tl.sum(
+ q_rope[None, :] * kv_rope, axis=1
+ )
+ scores = tl.where(idx_valid, scores, -1e30)
+
+ # ---- Online softmax update (base-2, tile-level) ----
+ scores_log2 = scores * LOG2E # [BLOCK_T]
+ tile_max = tl.max(scores_log2) # scalar
+ m_new = tl.maximum(m_i, tile_max)
+
+ alpha = tl.math.exp2(m_i - m_new) # rescale factor
+ p = tl.math.exp2(scores_log2 - m_new) # [BLOCK_T] attention weights
+ p = tl.where(idx_valid, p, 0.0) # zero out invalid
+
+ l_i = l_i * alpha + tl.sum(p)
+
+ # ---- Vectorized V accumulation (K=V in MLA) ----
+ # acc += sum_t(p[t] * kv[t, :]) for both nope and rope
+ acc_nope = acc_nope * alpha + tl.sum(p[:, None] * kv_nope, axis=0)
+ acc_rope = acc_rope * alpha + tl.sum(p[:, None] * kv_rope, axis=0)
+ m_i = m_new
+
+ # ---- Normalize output ----
+ safe_l = tl.where(l_i > 0.0, l_i, 1.0)
+ acc_nope = acc_nope / safe_l
+ acc_rope = acc_rope / safe_l
+
+ # LSE: convert from log2 back to natural log
+ lse = tl.where(l_i > 0.0, m_i / LOG2E + tl.math.log(safe_l), float("-inf"))
+
+ # ---- Store output ----
+ o_base = bid * stride_ob + hid * stride_oh
+ tl.store(O_ptr + o_base + nope_offs, acc_nope.to(tl.bfloat16), mask=nope_mask)
+ tl.store(O_ptr + o_base + NOPE_DIM_RT + rope_offs, acc_rope.to(tl.bfloat16))
+ tl.store(LSE_ptr + bid * H + hid, lse)
+
+
+def _run_triton_sparse_decode(
+ q: torch.Tensor, # [B, 1, H, D] bf16
+ k_cache: torch.Tensor, # [num_pages, page_size, 1, bpt] float8
+ indices: torch.Tensor, # [B, ...] int32
+ topk_length: Optional[torch.Tensor],
+ softmax_scale: float,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Run the tiled Triton sparse decode kernel on one paged KV cache."""
+ B, _, H, D = q.shape
+ num_pages = k_cache.shape[0]
+ page_size = k_cache.shape[1]
+ page_bytes = k_cache.stride(0) # elements = bytes for float8
+
+ # Flatten indices to [B, topk]
+ flat_indices = indices.reshape(B, -1).contiguous()
+ topk = flat_indices.shape[1]
+
+ # Create three typed views of the flat cache memory.
+ # The KV cache may arrive as uint8 or float8_e4m3fn depending on the
+ # sglang version. Ensure each view has the correct dtype so Triton
+ # interprets the loaded values correctly (FP8 dequant vs raw integer).
+ total_elems = num_pages * page_bytes
+ raw_flat = k_cache.as_strided((total_elems,), (1,))
+ raw_uint8 = raw_flat.view(torch.uint8)
+ raw_fp8 = raw_uint8.view(torch.float8_e4m3fn)
+ raw_bf16 = raw_uint8.view(torch.bfloat16)
+
+ # Squeeze Q: [B, H, D]
+ q3 = q.squeeze(1)
+ if not q3.is_contiguous():
+ q3 = q3.contiguous()
+
+ out = torch.zeros(B, H, D, dtype=torch.bfloat16, device=q.device)
+ lse = torch.full((B, H), float("-inf"), dtype=torch.float32, device=q.device)
+
+ # Round topk for autotune key stability
+ topk_rounded = triton.next_power_of_2(topk)
+
+ grid = (B, H)
+ _tiled_sparse_decode_kernel[grid](
+ q3,
+ raw_fp8,
+ raw_uint8,
+ raw_bf16,
+ flat_indices,
+ (
+ topk_length
+ if topk_length is not None
+ else torch.empty(0, device=q.device, dtype=torch.int32)
+ ),
+ out,
+ lse,
+ softmax_scale,
+ page_size,
+ int(page_bytes), # page_bytes (int64)
+ int(page_size * _TOKEN_DATA_STRIDE), # scale_section_off (int64)
+ H,
+ topk,
+ topk_rounded,
+ topk_length is not None,
+ q3.stride(0),
+ q3.stride(1),
+ out.stride(0),
+ out.stride(1),
+ flat_indices.stride(0),
+ NOPE_PAD=512,
+ ROPE_DIM=_ROPE_DIM,
+ NOPE_DIM_RT=_NOPE_DIM,
+ )
+
+ # Return [B, 1, H, D] and [B, 1, H]
+ return out.unsqueeze(1), lse.unsqueeze(1)
+
+
+def _merge_partial_attn(
+ out1: torch.Tensor,
+ lse1: torch.Tensor,
+ out2: torch.Tensor,
+ lse2: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Merge two attention outputs using LSE-weighted combination.
+
+ out: [B, 1, H, D] bf16, lse: [B, 1, H] float32
+ """
+ max_lse = torch.maximum(lse1, lse2)
+ w1 = torch.where(lse1 > -1e20, torch.exp(lse1 - max_lse), torch.zeros_like(lse1))
+ w2 = torch.where(lse2 > -1e20, torch.exp(lse2 - max_lse), torch.zeros_like(lse2))
+ total = (w1 + w2).clamp(min=1e-20)
+ merged = (
+ w1.unsqueeze(-1) * out1.float() + w2.unsqueeze(-1) * out2.float()
+ ) / total.unsqueeze(-1)
+ merged_lse = max_lse + torch.log(total)
+ return merged.to(torch.bfloat16), merged_lse
+
+
+def _apply_attn_sink(
+ out: torch.Tensor,
+ lse: torch.Tensor,
+ attn_sink: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Apply attention sink normalization.
+
+ The sink adds to the softmax denominator without contributing output,
+ effectively down-weighting all attention scores.
+
+ out: [B, 1, H, D] bf16, lse: [B, 1, H] f32, attn_sink: [H] f32
+ """
+ sink_lse = attn_sink.view(1, 1, -1).expand_as(lse)
+ combined_lse = torch.logaddexp(lse, sink_lse)
+ w = torch.where(
+ lse > -1e20,
+ torch.exp(lse - combined_lse),
+ torch.zeros_like(lse),
+ )
+ return (out.float() * w.unsqueeze(-1)).to(torch.bfloat16), combined_lse
+
+
+def flash_mla_sparse_decode_triton(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ indices: torch.Tensor,
+ topk_length: Optional[torch.Tensor],
+ attn_sink: Optional[torch.Tensor],
+ head_dim_v: int,
+ softmax_scale: float,
+ extra_k_cache: Optional[torch.Tensor] = None,
+ extra_indices: Optional[torch.Tensor] = None,
+ extra_topk_length: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """SM120-optimized sparse MLA decode using tiled Triton kernel.
+
+ Processes SWA and extra (c4/c128) caches separately via the same
+ Triton kernel, then merges results using LSE-weighted combination.
+ """
+ if softmax_scale is None:
+ softmax_scale = q.shape[-1] ** (-0.5)
+
+ # Process main cache (SWA)
+ out, lse = _run_triton_sparse_decode(
+ q,
+ k_cache,
+ indices,
+ topk_length,
+ softmax_scale,
+ )
+
+ # Process extra cache (c4 / c128) if present
+ if extra_k_cache is not None and extra_indices is not None:
+ out_extra, lse_extra = _run_triton_sparse_decode(
+ q,
+ extra_k_cache,
+ extra_indices,
+ extra_topk_length,
+ softmax_scale,
+ )
+ out, lse = _merge_partial_attn(out, lse, out_extra, lse_extra)
+
+ # Apply attention sink
+ if attn_sink is not None:
+ out, lse = _apply_attn_sink(out, lse, attn_sink)
+
+ # Return format matching PyTorch fallback: (out, lse.permute(0,2,1))
+ return out, lse.permute(0, 2, 1)
diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py
index de433fe2d505..e27e96fff326 100644
--- a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py
+++ b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py
@@ -18,6 +18,9 @@ def _compute_enable_deep_gemm():
sm_version = get_device_sm()
if (_is_cuda and sm_version < 90) or (_is_musa and sm_version < 31):
return False
+ # DeepGEMM requires TMEM/tcgen05 (SM100+datacenter), not available on SM120
+ if sm_version == 120:
+ return False
if not (_is_cuda or _is_musa):
return False
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py
new file mode 100644
index 000000000000..95d303f0a8fb
--- /dev/null
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py
@@ -0,0 +1,454 @@
+"""SM120-optimized Triton MXFP4 MoE kernel — CUDA graph compatible.
+
+Replaces the PyTorch fallback (per-expert for-loop + full dequant + matmul)
+with fused Triton kernels that:
+1. Fuse FP4 dequant + GEMV (no intermediate BF16 weight materialization)
+2. Process each (token, expert) slot independently — no data-dependent routing
+3. Respect SM120 shared memory constraint (99 KB/block)
+
+CUDA graph compatibility:
+- No .unique(), .item(), .nonzero() — all routing is tensor-level
+- Fixed grid dimensions (M*topk, N_blocks) per captured batch size
+- All control flow is static or within Triton kernels
+
+SM120 constraints:
+- SMEM: 99 KB/block (vs SM100 228 KB)
+- No TMEM/tcgen05 — uses mma.sync.aligned via Triton
+- Max warps: 48/SM
+- Registers: ~128/thread practical limit
+"""
+
+import logging
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+logger = logging.getLogger(__name__)
+
+
+@triton.jit
+def _dequant_fp4_lut(nibble):
+ """Decode a 4-bit FP4 E2M1 nibble to float32 using arithmetic."""
+ sign_bit = (nibble >> 3) & 1
+ exp_bits = (nibble >> 1) & 3
+ man_bit = nibble & 1
+
+ is_subnormal = exp_bits == 0
+ mantissa = 1.0 + man_bit.to(tl.float32) * 0.5
+ exponent = tl.math.exp2((exp_bits - 1).to(tl.float32))
+ val = tl.where(is_subnormal, man_bit.to(tl.float32) * 0.5, mantissa * exponent)
+ val = tl.where(sign_bit != 0, -val, val)
+ return val
+
+
+# ── Per-slot GEMV kernel: processes one (token, expert) pair ──
+
+
+@triton.autotune(
+ configs=[
+ triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2),
+ ],
+ key=["N", "K"],
+)
+@triton.jit
+def _mxfp4_slot_gemv_kernel(
+ # Pointers
+ A_ptr, # [M_total, K] bf16 — source rows
+ B_packed_ptr, # [E, N, K//2] uint8 — packed FP4 expert weights
+ B_scale_ptr, # [E, N, K//32] float32 — weight scales
+ C_ptr, # [num_slots, N] bf16 — output
+ token_ids_ptr, # [num_slots] int32 — which A row for each slot
+ expert_ids_ptr, # [num_slots] int32 — which expert's B for each slot
+ # Dimensions
+ N: tl.int32,
+ K: tl.int32,
+ # A strides
+ stride_am: tl.int32,
+ # B strides (within an expert)
+ stride_bn: tl.int32,
+ stride_bk2: tl.int32,
+ # B_scale strides (within an expert)
+ stride_bsn: tl.int32,
+ stride_bsk32: tl.int32,
+ # Expert strides (between experts)
+ expert_b_stride: tl.int64,
+ expert_s_stride: tl.int64,
+ # C strides
+ stride_cm: tl.int32,
+ # Block sizes
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ """Per-slot fused MXFP4 dequant + GEMV.
+
+ Grid: (num_slots, cdiv(N, BLOCK_N))
+ Each program computes one (token, expert) pair for a BLOCK_N slice of output.
+ """
+ slot_id = tl.program_id(0)
+ n_block = tl.program_id(1)
+
+ token_id = tl.load(token_ids_ptr + slot_id).to(tl.int64)
+ expert_id = tl.load(expert_ids_ptr + slot_id).to(tl.int64)
+
+ offs_n = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
+ n_mask = offs_n < N
+
+ acc = tl.zeros([BLOCK_N], dtype=tl.float32)
+
+ # Expert weight base pointers
+ b_base = expert_id * expert_b_stride
+ s_base = expert_id * expert_s_stride
+ a_base = token_id * stride_am
+
+ for k_start in range(0, K, BLOCK_K):
+ # ── Load packed B: [BLOCK_N, BLOCK_K//2] ──
+ offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2)
+ b_mask = n_mask[:, None] & (offs_k2[None, :] < K // 2)
+ b_packed = tl.load(
+ B_packed_ptr
+ + b_base
+ + offs_n[:, None] * stride_bn
+ + offs_k2[None, :] * stride_bk2,
+ mask=b_mask,
+ other=0,
+ )
+
+ # ── FP4 dequant ──
+ b_u8 = b_packed.to(tl.int32)
+ val_lo = _dequant_fp4_lut(b_u8 & 0x0F) # even K indices
+ val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) # odd K indices
+
+ # ── Load and apply scales: [BLOCK_N, BLOCK_K//2] ──
+ group_ids = tl.arange(0, BLOCK_K // 2) // 16 # 32 values per group, 2 per byte
+ s_mask = n_mask[:, None] & ((k_start // 32 + group_ids[None, :]) < K // 32)
+ scales = tl.load(
+ B_scale_ptr
+ + s_base
+ + offs_n[:, None] * stride_bsn
+ + (k_start // 32 + group_ids[None, :]) * stride_bsk32,
+ mask=s_mask,
+ other=1.0,
+ )
+ val_lo = val_lo * scales
+ val_hi = val_hi * scales
+
+ # ── Load A even/odd: [BLOCK_K//2] each ──
+ offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2
+ offs_k_odd = offs_k_even + 1
+
+ a_even = tl.load(
+ A_ptr + a_base + offs_k_even,
+ mask=offs_k_even < K,
+ other=0.0,
+ ).to(tl.float32)
+ a_odd = tl.load(
+ A_ptr + a_base + offs_k_odd,
+ mask=offs_k_odd < K,
+ other=0.0,
+ ).to(tl.float32)
+
+ # ── Dot product: acc[n] += sum_k(a_even[k]*B_lo[n,k] + a_odd[k]*B_hi[n,k]) ──
+ acc += tl.sum(a_even[None, :] * val_lo, axis=1)
+ acc += tl.sum(a_odd[None, :] * val_hi, axis=1)
+
+ # ── Store output ──
+ tl.store(
+ C_ptr + slot_id * stride_cm + offs_n,
+ acc.to(tl.bfloat16),
+ mask=n_mask,
+ )
+
+
+# ── Legacy per-expert GEMM kernel (kept for benchmarking) ──
+
+
+@triton.autotune(
+ configs=[
+ triton.Config(
+ {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2
+ ),
+ triton.Config(
+ {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2
+ ),
+ triton.Config(
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8, num_stages=2
+ ),
+ triton.Config(
+ {"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2
+ ),
+ ],
+ key=["M", "N", "K"],
+)
+@triton.jit
+def _mxfp4_gemm_kernel(
+ # Pointers
+ A_ptr, # [M, K] bf16 activation
+ B_packed_ptr, # [N, K//2] uint8 packed FP4
+ B_scale_ptr, # [N, K//32] float32 scales
+ C_ptr, # [M, N] bf16 output
+ # Dimensions
+ M,
+ N,
+ K,
+ # Strides
+ stride_am,
+ stride_ak,
+ stride_bn,
+ stride_bk2,
+ stride_bsn,
+ stride_bsk32,
+ stride_cm,
+ stride_cn,
+ # Constexprs
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ """Fused MXFP4 dequant + GEMM: C = A @ dequant(B_packed, B_scale).T"""
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ for k_start in range(0, K, BLOCK_K):
+ offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2)
+ b_mask = (offs_n[:, None] < N) & (offs_k2[None, :] < K // 2)
+ b_packed = tl.load(
+ B_packed_ptr + offs_n[:, None] * stride_bn + offs_k2[None, :] * stride_bk2,
+ mask=b_mask,
+ other=0,
+ )
+
+ b_u8 = b_packed.to(tl.int32)
+ val_lo = _dequant_fp4_lut(b_u8 & 0x0F)
+ val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F)
+
+ group_ids = tl.arange(0, BLOCK_K // 2) // 16
+ scales_per_byte = tl.load(
+ B_scale_ptr
+ + offs_n[:, None] * stride_bsn
+ + (k_start // 32 + group_ids[None, :]) * stride_bsk32,
+ mask=(offs_n[:, None] < N)
+ & ((k_start // 32 + group_ids[None, :]) < K // 32),
+ other=1.0,
+ )
+ val_lo = val_lo * scales_per_byte
+ val_hi = val_hi * scales_per_byte
+
+ offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2
+ offs_k_odd = offs_k_even + 1
+
+ a_even_mask = (offs_m[:, None] < M) & (offs_k_even[None, :] < K)
+ a_even = tl.load(
+ A_ptr + offs_m[:, None] * stride_am + offs_k_even[None, :] * stride_ak,
+ mask=a_even_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ a_odd_mask = (offs_m[:, None] < M) & (offs_k_odd[None, :] < K)
+ a_odd = tl.load(
+ A_ptr + offs_m[:, None] * stride_am + offs_k_odd[None, :] * stride_ak,
+ mask=a_odd_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ acc += tl.dot(a_even, tl.trans(val_lo))
+ acc += tl.dot(a_odd, tl.trans(val_hi))
+
+ c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+ tl.store(
+ C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
+ acc.to(tl.bfloat16),
+ mask=c_mask,
+ )
+
+
+def mxfp4_gemm_triton(
+ A: torch.Tensor,
+ B_packed: torch.Tensor,
+ B_scale: torch.Tensor,
+ K_full: int,
+) -> torch.Tensor:
+ """Triton fused MXFP4 dequant + GEMM: output = A @ dequant(B).T
+
+ Kept for standalone benchmarking. The MoE forward uses the slot kernel.
+ """
+ M = A.shape[0]
+ N = B_packed.shape[0]
+ K = K_full
+
+ if B_scale.dtype == torch.float8_e8m0fnu:
+ B_scale = B_scale.to(torch.float32)
+ elif B_scale.dtype != torch.float32:
+ B_scale = B_scale.float()
+
+ C = torch.empty(M, N, dtype=torch.bfloat16, device=A.device)
+ A = A.contiguous()
+ B_packed = B_packed.contiguous()
+ B_scale = B_scale.contiguous()
+
+ grid = lambda meta: (
+ triton.cdiv(M, meta["BLOCK_M"]),
+ triton.cdiv(N, meta["BLOCK_N"]),
+ )
+ B_u8 = B_packed.view(torch.uint8)
+
+ _mxfp4_gemm_kernel[grid](
+ A,
+ B_u8,
+ B_scale,
+ C,
+ M,
+ N,
+ K,
+ A.stride(0),
+ A.stride(1),
+ B_u8.stride(0),
+ B_u8.stride(1),
+ B_scale.stride(0),
+ B_scale.stride(1),
+ C.stride(0),
+ C.stride(1),
+ )
+ return C
+
+
+def mxfp4_moe_forward_triton(
+ hidden_states: torch.Tensor,
+ w13_packed: torch.Tensor,
+ w2_packed: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk_weights: torch.Tensor,
+ hidden_size: int,
+ intermediate_size: int,
+ inplace: bool = False,
+ routed_scaling_factor: Optional[float] = None,
+ clamp_limit: Optional[float] = None,
+) -> torch.Tensor:
+ """SM120-optimized MXFP4 MoE forward — CUDA graph compatible.
+
+ Uses per-slot GEMV kernels instead of per-expert Python loops.
+ Each (token, expert) slot is processed independently with a fixed grid,
+ eliminating .unique()/.item()/.nonzero() that break CUDA graph capture.
+ """
+ import torch.nn.functional as F
+
+ M, K = hidden_states.shape
+ topk = topk_ids.shape[1]
+ I = intermediate_size
+ num_slots = M * topk
+ device = hidden_states.device
+ dtype = hidden_states.dtype
+
+ # ── Graph-safe routing: flatten topk assignments ──
+ # token_ids[slot] = which row of A (original token index)
+ # expert_ids[slot] = which expert's weights to use
+ # topk_ids may contain -1 for padded/filtered tokens; clamp to 0 for safe
+ # Triton loads, then zero out invalid slots' output after GEMM.
+ flat_expert_ids_raw = topk_ids.reshape(-1).contiguous() # [M*topk]
+ invalid_slot_mask = flat_expert_ids_raw < 0 # [M*topk]
+ flat_expert_ids = flat_expert_ids_raw.clamp(min=0) # safe for indexing
+ token_ids = (
+ torch.arange(M, device=device, dtype=torch.int32)
+ .unsqueeze(1)
+ .expand(M, topk)
+ .reshape(-1)
+ .contiguous()
+ ) # [M*topk]
+
+ # ── Ensure scales are float32 ──
+ if w13_scale.dtype != torch.float32:
+ w13_scale = w13_scale.to(torch.float32)
+ if w2_scale.dtype != torch.float32:
+ w2_scale = w2_scale.to(torch.float32)
+
+ # ── GEMM1: gate_up projection ──
+ # hidden_states[token] @ w13[expert].T → [num_slots, 2*I]
+ intermediate = torch.empty(num_slots, 2 * I, dtype=dtype, device=device)
+
+ w13_u8 = w13_packed.view(torch.uint8) # [E, 2*I, K//2]
+ grid1 = lambda meta: (num_slots, triton.cdiv(2 * I, meta["BLOCK_N"]))
+
+ _mxfp4_slot_gemv_kernel[grid1](
+ hidden_states,
+ w13_u8,
+ w13_scale,
+ intermediate,
+ token_ids,
+ flat_expert_ids,
+ 2 * I,
+ K,
+ hidden_states.stride(0),
+ w13_u8.stride(1),
+ w13_u8.stride(2),
+ w13_scale.stride(1),
+ w13_scale.stride(2),
+ w13_u8.stride(0),
+ w13_scale.stride(0),
+ intermediate.stride(0),
+ )
+
+ # ── SiLU activation (graph-safe vectorized ops) ──
+ gate = intermediate[:, :I].float()
+ up = intermediate[:, I:].float()
+ if clamp_limit is not None and clamp_limit > 0:
+ gate = torch.clamp(gate, max=clamp_limit)
+ up = torch.clamp(up, min=-clamp_limit, max=clamp_limit)
+ activated = (F.silu(gate) * up).to(dtype)
+
+ # ── GEMM2: down projection ──
+ # activated[slot] @ w2[expert].T → [num_slots, K]
+ down = torch.empty(num_slots, K, dtype=dtype, device=device)
+
+ # For GEMM2, A is the activated buffer — each slot reads its own row
+ slot_ids = torch.arange(num_slots, device=device, dtype=torch.int32)
+
+ w2_u8 = w2_packed.view(torch.uint8) # [E, K, I//2]
+ grid2 = lambda meta: (num_slots, triton.cdiv(K, meta["BLOCK_N"]))
+
+ _mxfp4_slot_gemv_kernel[grid2](
+ activated,
+ w2_u8,
+ w2_scale,
+ down,
+ slot_ids,
+ flat_expert_ids,
+ K,
+ I,
+ activated.stride(0),
+ w2_u8.stride(1),
+ w2_u8.stride(2),
+ w2_scale.stride(1),
+ w2_scale.stride(2),
+ w2_u8.stride(0),
+ w2_scale.stride(0),
+ down.stride(0),
+ )
+
+ # ── Zero out invalid slots (padded/filtered tokens with topk_ids == -1) ──
+ # Use multiplication instead of boolean indexing to stay CUDA-graph-safe
+ # (no GPU→CPU sync). valid_mask is 1.0 for valid slots, 0.0 for invalid.
+ valid_mask = (~invalid_slot_mask).unsqueeze(1).to(dtype) # [M*topk, 1]
+ down = down * valid_mask
+
+ # ── Weighted sum across topk slots (graph-safe) ──
+ flat_weights = topk_weights.reshape(-1).unsqueeze(1).to(dtype) # [M*topk, 1]
+ output = (down * flat_weights).view(M, topk, K).sum(dim=1)
+
+ if routed_scaling_factor is not None and routed_scaling_factor != 1.0:
+ output.mul_(routed_scaling_factor)
+
+ return output
diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py
index f7fc76dfc96a..618e9cec5bb2 100644
--- a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py
+++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py
@@ -9,7 +9,7 @@
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import log_info_on_rank0, set_weight_attrs
-from sglang.srt.utils.common import is_sm90_supported
+from sglang.srt.utils.common import is_sm90_supported, is_sm120_supported
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
@@ -108,10 +108,42 @@ def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_mega_moe_weights_built", False):
return
- if not is_sm90_supported():
+ if not is_sm90_supported() and not is_sm120_supported():
raise RuntimeError(
"DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above."
)
+
+ # SM120: Skip Marlin repacking, keep original weight format
+ # for Triton dequant kernel (Marlin kernel produces NaN on SM120)
+ if is_sm120_supported():
+ from torch.nn import Parameter
+
+ log_info_on_rank0(
+ logger,
+ f"SM120 detected: using PyTorch MXFP4 MoE fallback "
+ f"(layer: {self.prefix})...",
+ )
+ # Keep weights in original packed int8 format
+ # Normalize scales to float32 for direct use in dequant
+ w13_s = layer.w13_weight_scale_inv.data
+ w2_s = layer.w2_weight_scale_inv.data
+ if w13_s.dtype == torch.float8_e8m0fnu:
+ pass # already in e8m0 format, will convert at runtime
+ elif w13_s.dtype in (torch.uint8, torch.int8):
+ layer.w13_weight_scale_inv = Parameter(
+ w13_s.view(torch.uint8)
+ .view(torch.float8_e8m0fnu)
+ .to(torch.float32),
+ requires_grad=False,
+ )
+ layer.w2_weight_scale_inv = Parameter(
+ w2_s.view(torch.uint8).view(torch.float8_e8m0fnu).to(torch.float32),
+ requires_grad=False,
+ )
+ # else: float32 scales are already usable directly
+ layer._dsv4_mxfp4_backend = "sm120_triton"
+ return
+
if not check_moe_marlin_supports_layer(layer, 32):
raise RuntimeError(
"Current DeepSeekV4 MoE layer does not satisfy Marlin constraints."
@@ -144,6 +176,43 @@ def apply(
if not TopKOutputChecker.format_is_standard(topk_output):
raise ValueError(f"Unsupported topk output format: {topk_output.format}")
+ # SM120: use Triton fused dequant+GEMM (Marlin kernel produces NaN on SM120)
+ if getattr(layer, "_dsv4_mxfp4_backend", None) == "sm120_triton":
+ from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import (
+ mxfp4_moe_forward_triton,
+ )
+
+ hidden_states = dispatch_output.hidden_states
+ w13 = layer.w13_weight.data
+ w2 = layer.w2_weight.data
+ w13_scale = layer.w13_weight_scale_inv.data
+ w2_scale = layer.w2_weight_scale_inv.data
+ intermediate_size = w13.shape[1] // 2
+ hidden_size = w13.shape[2] * 2
+
+ output = mxfp4_moe_forward_triton(
+ hidden_states=hidden_states,
+ w13_packed=w13,
+ w2_packed=w2,
+ w13_scale=w13_scale,
+ w2_scale=w2_scale,
+ topk_ids=topk_output.topk_ids,
+ topk_weights=topk_output.topk_weights,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ routed_scaling_factor=(
+ self.runner.config.routed_scaling_factor
+ if hasattr(self.runner, "config")
+ else None
+ ),
+ clamp_limit=(
+ self.runner.config.swiglu_limit
+ if hasattr(self.runner, "config")
+ else None
+ ),
+ )
+ return StandardCombineInput(hidden_states=output)
+
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_weight,
w2_qweight=layer.w2_weight,
diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py
index fdf6d557ca45..e307b640f90b 100644
--- a/python/sglang/srt/models/deepseek_v4.py
+++ b/python/sglang/srt/models/deepseek_v4.py
@@ -1336,7 +1336,7 @@ def _setup_fp8_wo_a_scales(self, is_nextn: bool) -> None:
)
def post_load_weights(self, is_nextn=False, weight_names=None):
- if _FP8_WO_A_GEMM:
+ if envs.SGLANG_OPT_FP8_WO_A_GEMM.get():
self._setup_fp8_wo_a_scales(is_nextn)
if is_nextn:
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 1d1b8d29959d..dda65b62c64e 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -1940,6 +1940,21 @@ def _handle_model_specific_adjustments(self):
logger.info(
"Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM"
)
+ elif is_sm120_supported():
+ # SM120: DSv4-Flash uses MXFP4 experts; marlin backend dispatches
+ # to our SM120 Triton fallback in mxfp4_marlin_moe.py
+ if self.moe_runner_backend == "auto":
+ self.moe_runner_backend = "marlin"
+ logger.info(
+ "Use marlin as MoE runner backend on SM120 for DeepseekV3/V4"
+ )
+ # SM120 lacks tcgen05/TMEM: disable features that depend on
+ # DeepGEMM or require >99KB SMEM (topk_v2).
+ envs.SGLANG_OPT_FP8_WO_A_GEMM.set(False)
+ envs.SGLANG_OPT_USE_TOPK_V2.set(False)
+ envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.set(False)
+ envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.set(False)
+ envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.set(True)
elif is_hip():
if not self.enable_dp_attention and self.nnodes == 1:
# TODO (Hubert): Put this back later
@@ -1985,6 +2000,20 @@ def _handle_model_specific_adjustments(self):
validate_deepseek_v4_cp(self)
+ if is_sm120_supported():
+ if self.moe_runner_backend == "auto":
+ self.moe_runner_backend = "marlin"
+ logger.info(
+ "Use marlin as MoE runner backend on SM120 for DeepseekV4"
+ )
+ # SM120 lacks tcgen05/TMEM: disable features that depend on
+ # DeepGEMM or require >99KB SMEM (topk_v2).
+ envs.SGLANG_OPT_FP8_WO_A_GEMM.set(False)
+ envs.SGLANG_OPT_USE_TOPK_V2.set(False)
+ envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.set(False)
+ envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.set(False)
+ envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.set(True)
+
elif model_arch in ["GptOssForCausalLM"]:
# Set attention backend for GPT-OSS
if self.is_attention_backend_not_set():