diff --git a/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx b/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx index 392282aaa9f0..d59a613ee139 100644 --- a/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx +++ b/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx @@ -29,7 +29,7 @@ tag: NEW DeepSeek-V4-Flash 284B 13B - single-node serving: B200 / GB200 / GB300 / H200 on 4 GPUs + single-node serving: B200 / GB200 / GB300 / H200 on 4 GPUs; RTX PRO 6000 (SM120) on 4-8 GPUs DeepSeek-V4-Pro @@ -128,6 +128,20 @@ PD-Disagg recipes on H200 may require `docker run --privileged --ulimit memlock= can discover the IB HCAs; without IB exposure mooncake silently falls back to TCP, which can lead to garbled KV transfer on large checkpoints. + + +**SM120 (RTX PRO 6000 Blackwell Server Edition) note** + +DeepSeek-V4-Flash can run on RTX PRO 6000 Blackwell Server Edition (SM120, 96 GB GDDR7) with Tensor Parallelism only. We support two TP configurations: +- **TP=8 (recommended)**: `--tp 8 --mem-fraction-static 0.70 --cuda-graph-max-bs 32`. Leaves ~20 GB per GPU for KV cache. +- **TP=4 (memory-constrained)**: `--tp 4 --mem-fraction-static 0.90 --cuda-graph-max-bs 4`. Runs near the 96 GB memory limit (~98.8% usage) with minimal KV cache headroom. + +SM120 uses Triton-based MoE and FlashMLA fallback kernels instead of CUTLASS/DeepGEMM (auto-detected, no manual flags needed). Use Docker image `lmsysorg/sglang:dev-cu13` (CUDA 13.0, required for SM120 / CC 12.0). + +Performance is memory-bandwidth bound (~1.5 TB/s GDDR7 vs ~8 TB/s HBM3e on B200); expect ~15-17 tok/s at BS=1. Accuracy matches reference (GSM8K 10/10, GPQA Diamond 72.0% vs 71.2% published). + +V4-Pro is not supported on SM120 (model does not fit in 8x 96 GB). + **MegaMoE** MegaMoE fuses expert dispatch + GEMM into a single kernel for higher throughput diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index f9f396428557..53741c1b8fd3 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -55,6 +55,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import ceil_align +from sglang.srt.utils.common import is_sm120_supported if TYPE_CHECKING: from flash_mla.flash_mla_interface import FlashMLASchedMeta @@ -62,6 +63,8 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner +_is_sm120 = is_sm120_supported() + logger = logging.getLogger(__name__) SWA_WINDOW = 128 @@ -81,6 +84,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: def _create_flashmla_metadata(): + if _is_sm120: + return None import flash_mla return flash_mla.get_mla_metadata()[0] @@ -1031,24 +1036,42 @@ def forward( extra_indices.shape[-1] % 64 == 0 ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" - import flash_mla - - o = flash_mla.flash_mla_with_kvcache( - q=q, - k_cache=swa_k_cache, - head_dim_v=self.head_dim_v, - block_table=None, - cache_seqlens=None, - tile_scheduler_metadata=flashmla_metadata, - softmax_scale=self.softmax_scale, - is_fp8_kvcache=True, - indices=swa_page_indices, - topk_length=swa_topk_lengths, - attn_sink=attn_sink, - extra_k_cache=extra_k_cache, - extra_indices_in_kvcache=extra_indices, - extra_topk_length=extra_topk_lengths, - )[0] + if _is_sm120: + from sglang.srt.layers.attention.flash_mla_sm120 import ( + flash_mla_with_kvcache_sm120, + ) + + o = flash_mla_with_kvcache_sm120( + q=q, + k_cache=swa_k_cache, + head_dim_v=self.head_dim_v, + softmax_scale=self.softmax_scale, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + )[0] + else: + import flash_mla + + o = flash_mla.flash_mla_with_kvcache( + q=q, + k_cache=swa_k_cache, + head_dim_v=self.head_dim_v, + block_table=None, + cache_seqlens=None, + tile_scheduler_metadata=flashmla_metadata, + softmax_scale=self.softmax_scale, + is_fp8_kvcache=True, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + )[0] o = o.squeeze(1) return o diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index f8899264d6d0..ec1b6d9665bc 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -20,6 +20,7 @@ from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.state_capturer.indexer_topk import get_global_indexer_capturer from sglang.srt.utils import add_prefix, is_hip +from sglang.srt.utils.common import is_sm120_supported if TYPE_CHECKING: from sglang.srt.layers.attention.dsv4.compressor import ( @@ -90,6 +91,74 @@ def fp8_paged_mqa_logits_torch( return logits +def fp8_paged_mqa_logits_torch_sm120( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + """CUDA-graph-compatible FP8 paged MQA logits for SM120 (vectorized, no .item()).""" + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + device = q_fp8.device + + assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128" + assert ( + block_size == 64 + ), "Vectorized torch impl hardcodes block_size=64 cache layout" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + if seq_lens.dim() > 1: + seq_lens = seq_lens.squeeze(-1) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + max_pages = (max_seq_len + block_size - 1) // block_size + max_padded_seq = max_pages * block_size + + kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + SCALE_OFFSET = block_size * head_dim + + page_ids = page_table[:, :max_pages] + kvcache_gathered = kvcache_flat[page_ids] + + kv_value_raw = kvcache_gathered[..., :SCALE_OFFSET] + kv_scale_raw = kvcache_gathered[..., SCALE_OFFSET:] + + kv_value = kv_value_raw.contiguous().view(dtype=FP8_DTYPE).to(torch.float32) + kv_value = kv_value.view(batch_size, max_padded_seq, head_dim) + + kv_scale = kv_scale_raw.contiguous().view(dtype=torch.float32) + kv_scale = kv_scale.view(batch_size, max_padded_seq) + + q = q_fp8[:, 0].to(torch.float32) + + score = torch.bmm(kv_value, q.transpose(1, 2)) + + score = F.relu(score) + score = score * weight.unsqueeze(1) + score = score.sum(dim=2) + + score = score * kv_scale + + out_width = min(max_padded_seq, max_seq_len) + logits = score.new_full((batch_size, max_seq_len), float("-inf")) + logits[:, :out_width] = score[:, :out_width] + + positions = torch.arange(max_seq_len, device=device) + invalid_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(1) + logits.masked_fill_(invalid_mask, float("-inf")) + + return logits + + def topk_transform_512_pytorch_vectorized( scores: torch.Tensor, seq_lens: torch.Tensor, @@ -372,7 +441,10 @@ def forward_c4_indexer( tilelang_fp8_paged_mqa_logits as fn, ) elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): - fn = fp8_paged_mqa_logits_torch + if is_sm120_supported(): + fn = fp8_paged_mqa_logits_torch_sm120 + else: + fn = fp8_paged_mqa_logits_torch else: from deep_gemm import fp8_paged_mqa_logits as fn diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120.py b/python/sglang/srt/layers/attention/flash_mla_sm120.py new file mode 100644 index 000000000000..8d49d2a6e461 --- /dev/null +++ b/python/sglang/srt/layers/attention/flash_mla_sm120.py @@ -0,0 +1,252 @@ +"""SM120 FlashMLA sparse decode implementation. + +On SM120 (Blackwell Desktop / RTX PRO 6000) the flash_mla CUDA kernel +is not available, so this module provides alternative implementations: + +- A fused Triton kernel (default, ``SGLANG_SM120_TRITON_FLASHMLA=1``) +- A pure-PyTorch fallback (``SGLANG_SM120_TRITON_FLASHMLA=0``) + +The FP8 KV cache uses a page-internal layout where NOPE+ROPE data has +stride (nope_dim + rope_dim*2) per token, and scales are stored in a +separate region at the end of each page. +""" + +import logging +import os + +import torch + +logger = logging.getLogger(__name__) + +# Page layout constants for DSv4-Flash (MODEL1): +# nope_dim = 448, rope_dim = 64, quantize_block_size = 64 +# nope_rope_stride = 448 + 64*2 = 576 bytes per token +# scale_stride = ceil(448/64) + 1 = 8 bytes per token (7 scales + 1 pad) +# bytes_per_token = 448 + 128 + 8 = 584 +# page_bytes = ceil_div(page_size * 584, 576) * 576 + +_NOPE_DIM = 448 +_ROPE_DIM = 64 +_NOPE_ROPE_STRIDE = _NOPE_DIM + _ROPE_DIM * 2 # 576 +_TILE_SIZE = 64 +_NUM_TILES = _NOPE_DIM // _TILE_SIZE # 7 +_SCALE_STRIDE = _NUM_TILES + 1 # 8 (7 scales + 1 pad) +_D = _NOPE_DIM + _ROPE_DIM # 512 + + +def _gather_and_dequant(k_cache, indices, page_size): + """Gather KV entries from the paged buffer using correct page-internal addressing. + + Args: + k_cache: (num_pages, page_size, 1, bytes_per_token) float8_e4m3fn + Non-contiguous view of the raw page buffer. + indices: (...) int32/int64, token-level indices. -1 = invalid. + page_size: tokens per page (256) + + Returns: + kv: (..., _D) bfloat16, dequantized KV vectors + """ + idx_shape = indices.shape + flat_idx = indices.reshape(-1) # (N,) + N = flat_idx.shape[0] + device = k_cache.device + + # Page-level addressing + page_bytes = k_cache.stride(0) # actual byte stride between pages + pages = flat_idx // page_size + offsets = flat_idx % page_size + + # Clamp invalid indices + safe_pages = pages.clamp(min=0) + safe_offsets = offsets.clamp(min=0) + + # Access raw buffer as uint8 — use as_strided to get full page view + num_pages = k_cache.shape[0] + raw_pages = k_cache.as_strided( + (num_pages, page_bytes), + (page_bytes, 1), + ).view( + torch.uint8 + ) # (num_pages, page_bytes) uint8 + # Note: float8_e4m3fn and uint8 are both 1 byte, view is safe + + # Compute byte offsets within each page + # NOPE: page[safe_page, safe_offset * 576 + 0:448] + # ROPE: page[safe_page, safe_offset * 576 + 448:576] + # SCALES: page[safe_page, page_size * 576 + safe_offset * 8 + 0:7] + + nope_base = safe_offsets * _NOPE_ROPE_STRIDE # (N,) + nope_offsets = nope_base.unsqueeze(-1) + torch.arange( + _NOPE_DIM, device=device, dtype=torch.long + ) # (N, 448) + + rope_base = nope_base + _NOPE_DIM # (N,) + rope_offsets = rope_base.unsqueeze(-1) + torch.arange( + _ROPE_DIM * 2, device=device, dtype=torch.long + ) # (N, 128) + + scale_section_offset = page_size * _NOPE_ROPE_STRIDE # 147456 + scale_base = scale_section_offset + safe_offsets * _SCALE_STRIDE # (N,) + scale_offsets = scale_base.unsqueeze(-1) + torch.arange( + _NUM_TILES, device=device, dtype=torch.long + ) # (N, 7) + + # Gather bytes per page — use advanced indexing + # raw_pages[safe_pages, nope_offsets] → (N, 448) + page_idx_nope = safe_pages.unsqueeze(-1).expand_as(nope_offsets) + nope_bytes = raw_pages[page_idx_nope, nope_offsets] # (N, 448) uint8 + + page_idx_rope = safe_pages.unsqueeze(-1).expand_as(rope_offsets) + rope_bytes = raw_pages[page_idx_rope, rope_offsets] # (N, 128) uint8 + + page_idx_scale = safe_pages.unsqueeze(-1).expand_as(scale_offsets) + scale_bytes = raw_pages[page_idx_scale, scale_offsets] # (N, 7) uint8 + + # Reinterpret dtypes + nope_fp8 = nope_bytes.view(torch.float8_e4m3fn) # (N, 448) + rope_bf16 = rope_bytes.contiguous().view(torch.bfloat16) # (N, 64) + scale_e8m0 = scale_bytes.view(torch.float8_e8m0fnu) # (N, 7) + + # Dequantize: nope_tile * scale_tile → bf16 (vectorized) + result = torch.empty(N, _D, dtype=torch.bfloat16, device=device) + result[:, :_NOPE_DIM] = ( + ( + nope_fp8.view(N, _NUM_TILES, _TILE_SIZE).float() + * scale_e8m0.view(N, _NUM_TILES, 1).float() + ) + .view(N, _NOPE_DIM) + .to(torch.bfloat16) + ) + result[:, _NOPE_DIM:] = rope_bf16 + + return result.reshape(*idx_shape, _D) + + +def _sm120_sparse_decode_fwd( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache=None, + extra_indices=None, + extra_topk_length=None, +): + B, s_q, H_q, D_qk = q.shape + num_pages, page_size, H_k, bpt = k_cache.shape + topk = indices.shape[-1] + + invalid_mask = indices < 0 + safe_indices = indices.clamp(min=0) + + if topk_length is not None: + topk_range = torch.arange(topk, device=topk_length.device).view(1, 1, topk) + invalid_mask = invalid_mask | (topk_range >= topk_length.view(B, 1, 1)) + + # Gather and dequantize using page-aware addressing + gathered_kv = _gather_and_dequant(k_cache, safe_indices, page_size) + + if extra_k_cache is not None and extra_indices is not None: + extra_topk = extra_indices.shape[-1] + extra_page_size = extra_k_cache.shape[1] + extra_invalid = extra_indices < 0 + extra_safe = extra_indices.clamp(min=0) + if extra_topk_length is not None: + extra_range = torch.arange( + extra_topk, device=extra_topk_length.device + ).view(1, 1, extra_topk) + extra_invalid = extra_invalid | ( + extra_range >= extra_topk_length.view(B, 1, 1) + ) + extra_kv = _gather_and_dequant(extra_k_cache, extra_safe, extra_page_size) + gathered_kv = torch.cat([gathered_kv, extra_kv], dim=2) + invalid_mask = torch.cat([invalid_mask, extra_invalid], dim=2) + + gathered_kv[invalid_mask] = 0.0 + + q_f = q.float() + kv_f = gathered_kv.float() + kv_d = kv_f.shape[-1] + if D_qk != kv_d: + q_f = q_f[..., :kv_d] + + scores = torch.einsum("bshd,bstd->bsht", q_f, kv_f) * softmax_scale + scores.masked_fill_(invalid_mask.unsqueeze(2).expand_as(scores), float("-inf")) + + lse = torch.logsumexp(scores, dim=-1) + + if attn_sink is not None: + lse_for_out = torch.logsumexp( + torch.stack([lse, attn_sink.view(1, 1, H_q).expand_as(lse)], dim=0), dim=0 + ) + else: + lse_for_out = lse.clone() + + lonely = lse == float("-inf") + lse_for_out[lonely] = float("inf") + weights = torch.exp(scores - lse_for_out.unsqueeze(-1)) + out = torch.einsum("bsht,bstv->bshv", weights, kv_f[..., :head_dim_v]) + out[lonely.unsqueeze(-1).expand_as(out)] = 0.0 + + return out.to(torch.bfloat16), lse.permute(0, 2, 1) + + +# Default SM120 FlashMLA backend: "triton" (optimized) or "torch" (pure-PyTorch fallback). +# Controlled by SGLANG_SM120_TRITON_FLASHMLA env var (1=triton, 0=torch). +_sm120_default_backend = ( + "triton" if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" else "torch" +) + + +def flash_mla_with_kvcache_sm120(**kwargs): + """SM120 FlashMLA sparse decode entry point. + + Dispatches to the Triton kernel (default) or PyTorch fallback. + """ + q = kwargs["q"] + k_cache = kwargs["k_cache"] + indices = kwargs["indices"] + topk_length = kwargs.get("topk_length") + attn_sink = kwargs.get("attn_sink") + head_dim_v = kwargs["head_dim_v"] + softmax_scale = kwargs.get("softmax_scale") + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + extra_k_cache = kwargs.get("extra_k_cache") + extra_indices = kwargs.get("extra_indices_in_kvcache") + extra_topk_length = kwargs.get("extra_topk_length") + + if _sm120_default_backend == "triton": + from sglang.srt.layers.attention.flash_mla_sm120_triton import ( + flash_mla_sparse_decode_triton, + ) + + out, lse = flash_mla_sparse_decode_triton( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache, + extra_indices, + extra_topk_length, + ) + return (out, lse) + + out, lse = _sm120_sparse_decode_fwd( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache, + extra_indices, + extra_topk_length, + ) + return (out, lse) diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py new file mode 100644 index 000000000000..20ffb4d3cbab --- /dev/null +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py @@ -0,0 +1,370 @@ +"""SM120-optimized Triton FlashMLA sparse decode kernel — Tiled V2. + +Replaces V1's serial token loop with a tiled vectorized approach: + 1. BLOCK_T tokens loaded simultaneously via 2D gather (vs 1-at-a-time) + 2. All BLOCK_T QK scores computed at once via vectorized mul-reduce + 3. V accumulation via vectorized weighted sum across BLOCK_T tokens + 4. Online softmax operates on tile-level maxima (fewer rescales) + +Three typed views of the same paged buffer handle FP8/uint8/BF16 regions: +- float8_e4m3fn view -> nope FP8 values (direct load + dequant) +- uint8 view -> UE8M0 scale bytes (raw integer -> exp2 conversion) +- bfloat16 view -> rope BF16 values (direct load) + +DSv4 page layout (per token, 576 bytes data + 8 bytes scales): + Data section: [0:448] FP8 nope | [448:576] BF16 rope (64 values = 128 bytes) + Scale section: [page_size*576 + offset*8 : +7] UE8M0 scales (7 groups of 64) + +Target: RTX PRO 6000 (SM120, 188 SMs, 99KB SMEM, ~1.5 TB/s GDDR7, 96MB L2) +""" + +import logging +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +LOG2E = tl.constexpr(1.4426950408889634) + +# DSv4 KV cache layout constants +_NOPE_DIM = 448 +_ROPE_DIM = 64 +_D = _NOPE_DIM + _ROPE_DIM # 512 +_TOKEN_DATA_STRIDE = 576 # bytes per token in data section +_SCALE_STRIDE = 8 # bytes per token in scale section + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_T": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 16}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_T": 32}, num_warps=8, num_stages=2), + ], + key=["topk_rounded"], +) +@triton.jit +def _tiled_sparse_decode_kernel( + # Q: [B, H, D] bf16 + Q_ptr, + # Paged KV cache — three typed views of same underlying memory + cache_fp8_ptr, # float8_e4m3fn flat (1 byte/elem) — for nope + cache_uint8_ptr, # uint8 flat (1 byte/elem) — for scales + cache_bf16_ptr, # bfloat16 flat (2 bytes/elem) — for rope + # Indices: [B, topk] int32 + indices_ptr, + # Valid lengths: [B] int32 + topk_len_ptr, + # Output: [B, H, D] bf16 and LSE: [B, H] float32 + O_ptr, + LSE_ptr, + # Scalars + sm_scale: tl.float32, + page_size: tl.int32, + page_bytes: tl.int64, + scale_section_off: tl.int64, # page_size * 576 + H: tl.int32, + topk: tl.int32, + topk_rounded: tl.int32, # for autotune key + has_topk_len: tl.constexpr, + # Strides + stride_qb: tl.int32, + stride_qh: tl.int32, + stride_ob: tl.int32, + stride_oh: tl.int32, + stride_ib: tl.int32, # indices batch stride + # Constexprs + NOPE_PAD: tl.constexpr, # 512 (padded from 448) + ROPE_DIM: tl.constexpr, # 64 + NOPE_DIM_RT: tl.int32, # 448 (runtime, for masking) + BLOCK_T: tl.constexpr, # tokens per tile (16 or 32) +): + """Tiled sparse decode: vectorized gather + QK + softmax + V accumulation. + + Grid: (B, H) — one block per (batch, head) pair. + Each block processes all topk tokens in tiles of BLOCK_T. + """ + bid = tl.program_id(0) + hid = tl.program_id(1) + + # ---- Load Q for this (batch, head) ---- + q_base = bid * stride_qb + hid * stride_qh + nope_offs = tl.arange(0, NOPE_PAD) # [512] + nope_mask = nope_offs < NOPE_DIM_RT # [512], True for [0:448] + rope_offs = tl.arange(0, ROPE_DIM) # [64] + + q_nope = tl.load(Q_ptr + q_base + nope_offs, mask=nope_mask, other=0.0) + q_nope = q_nope.to(tl.float32) * sm_scale + q_rope = tl.load(Q_ptr + q_base + NOPE_DIM_RT + rope_offs) + q_rope = q_rope.to(tl.float32) * sm_scale + + # ---- Valid token count ---- + valid_topk = topk + if has_topk_len: + valid_topk = tl.load(topk_len_ptr + bid).to(tl.int32) + valid_topk = tl.minimum(valid_topk, topk) + + # ---- Online softmax state (base-2 math for SM120 efficiency) ---- + m_i: tl.float32 = -1e30 + l_i: tl.float32 = 0.0 + acc_nope = tl.zeros([NOPE_PAD], dtype=tl.float32) + acc_rope = tl.zeros([ROPE_DIM], dtype=tl.float32) + + # ---- Precompute constant index vectors ---- + group_ids = (nope_offs // 64).to(tl.int64) # [NOPE_PAD], scale group for each dim + t_offs = tl.arange(0, BLOCK_T) # [BLOCK_T], token offsets within tile + + # ---- Process tokens in tiles of BLOCK_T ---- + for tile_start in range(0, topk, BLOCK_T): + t_idx = tile_start + t_offs # [BLOCK_T], global token indices + t_in_bounds = t_idx < topk # bounds for index load + t_valid = t_idx < valid_topk # bounds for actual processing + + # Load indices for this tile: [BLOCK_T] + raw_indices = tl.load( + indices_ptr + bid * stride_ib + t_idx, + mask=t_in_bounds, + other=-1, + ) + idx_valid = t_valid & (raw_indices >= 0) # [BLOCK_T] mask + + # Page addressing: [BLOCK_T] (clamp for safe addressing of invalid tokens) + safe_indices = tl.where(idx_valid, raw_indices, tl.zeros_like(raw_indices)) + page_ids = (safe_indices // page_size).to(tl.int64) + page_offs_t = (safe_indices % page_size).to(tl.int64) + token_data_bases = page_ids * page_bytes + page_offs_t * 576 # [BLOCK_T] int64 + + # ---- Vectorized NOPE FP8 gather: [BLOCK_T, NOPE_PAD] ---- + nope_addrs = token_data_bases[:, None] + nope_offs[None, :].to(tl.int64) + nope_2d_mask = idx_valid[:, None] & nope_mask[None, :] + kv_nope_fp8 = tl.load( + cache_fp8_ptr + nope_addrs, + mask=nope_2d_mask, + other=0.0, + ) + + # ---- Vectorized scale gather + dequant: [BLOCK_T, NOPE_PAD] ---- + scale_bases = page_ids * page_bytes + scale_section_off + page_offs_t * 8 + scale_addrs = scale_bases[:, None] + group_ids[None, :] + scale_raw = tl.load( + cache_uint8_ptr + scale_addrs, + mask=nope_2d_mask, + other=127, + ) + scale_f32 = tl.math.exp2(scale_raw.to(tl.float32) - 127.0) + kv_nope = tl.where(nope_2d_mask, kv_nope_fp8.to(tl.float32) * scale_f32, 0.0) + + # ---- Vectorized ROPE BF16 gather: [BLOCK_T, ROPE_DIM] ---- + rope_byte_bases = token_data_bases + 448 + rope_elem_bases = (rope_byte_bases // 2).to(tl.int64) + rope_addrs = rope_elem_bases[:, None] + rope_offs[None, :].to(tl.int64) + kv_rope = tl.load( + cache_bf16_ptr + rope_addrs, + mask=idx_valid[:, None], + other=0.0, + ).to(tl.float32) + + # ---- Vectorized QK scores: [BLOCK_T] ---- + # scores[t] = dot(q_nope, kv_nope[t]) + dot(q_rope, kv_rope[t]) + scores = tl.sum(q_nope[None, :] * kv_nope, axis=1) + tl.sum( + q_rope[None, :] * kv_rope, axis=1 + ) + scores = tl.where(idx_valid, scores, -1e30) + + # ---- Online softmax update (base-2, tile-level) ---- + scores_log2 = scores * LOG2E # [BLOCK_T] + tile_max = tl.max(scores_log2) # scalar + m_new = tl.maximum(m_i, tile_max) + + alpha = tl.math.exp2(m_i - m_new) # rescale factor + p = tl.math.exp2(scores_log2 - m_new) # [BLOCK_T] attention weights + p = tl.where(idx_valid, p, 0.0) # zero out invalid + + l_i = l_i * alpha + tl.sum(p) + + # ---- Vectorized V accumulation (K=V in MLA) ---- + # acc += sum_t(p[t] * kv[t, :]) for both nope and rope + acc_nope = acc_nope * alpha + tl.sum(p[:, None] * kv_nope, axis=0) + acc_rope = acc_rope * alpha + tl.sum(p[:, None] * kv_rope, axis=0) + m_i = m_new + + # ---- Normalize output ---- + safe_l = tl.where(l_i > 0.0, l_i, 1.0) + acc_nope = acc_nope / safe_l + acc_rope = acc_rope / safe_l + + # LSE: convert from log2 back to natural log + lse = tl.where(l_i > 0.0, m_i / LOG2E + tl.math.log(safe_l), float("-inf")) + + # ---- Store output ---- + o_base = bid * stride_ob + hid * stride_oh + tl.store(O_ptr + o_base + nope_offs, acc_nope.to(tl.bfloat16), mask=nope_mask) + tl.store(O_ptr + o_base + NOPE_DIM_RT + rope_offs, acc_rope.to(tl.bfloat16)) + tl.store(LSE_ptr + bid * H + hid, lse) + + +def _run_triton_sparse_decode( + q: torch.Tensor, # [B, 1, H, D] bf16 + k_cache: torch.Tensor, # [num_pages, page_size, 1, bpt] float8 + indices: torch.Tensor, # [B, ...] int32 + topk_length: Optional[torch.Tensor], + softmax_scale: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Run the tiled Triton sparse decode kernel on one paged KV cache.""" + B, _, H, D = q.shape + num_pages = k_cache.shape[0] + page_size = k_cache.shape[1] + page_bytes = k_cache.stride(0) # elements = bytes for float8 + + # Flatten indices to [B, topk] + flat_indices = indices.reshape(B, -1).contiguous() + topk = flat_indices.shape[1] + + # Create three typed views of the flat cache memory. + # The KV cache may arrive as uint8 or float8_e4m3fn depending on the + # sglang version. Ensure each view has the correct dtype so Triton + # interprets the loaded values correctly (FP8 dequant vs raw integer). + total_elems = num_pages * page_bytes + raw_flat = k_cache.as_strided((total_elems,), (1,)) + raw_uint8 = raw_flat.view(torch.uint8) + raw_fp8 = raw_uint8.view(torch.float8_e4m3fn) + raw_bf16 = raw_uint8.view(torch.bfloat16) + + # Squeeze Q: [B, H, D] + q3 = q.squeeze(1) + if not q3.is_contiguous(): + q3 = q3.contiguous() + + out = torch.zeros(B, H, D, dtype=torch.bfloat16, device=q.device) + lse = torch.full((B, H), float("-inf"), dtype=torch.float32, device=q.device) + + # Round topk for autotune key stability + topk_rounded = triton.next_power_of_2(topk) + + grid = (B, H) + _tiled_sparse_decode_kernel[grid]( + q3, + raw_fp8, + raw_uint8, + raw_bf16, + flat_indices, + ( + topk_length + if topk_length is not None + else torch.empty(0, device=q.device, dtype=torch.int32) + ), + out, + lse, + softmax_scale, + page_size, + int(page_bytes), # page_bytes (int64) + int(page_size * _TOKEN_DATA_STRIDE), # scale_section_off (int64) + H, + topk, + topk_rounded, + topk_length is not None, + q3.stride(0), + q3.stride(1), + out.stride(0), + out.stride(1), + flat_indices.stride(0), + NOPE_PAD=512, + ROPE_DIM=_ROPE_DIM, + NOPE_DIM_RT=_NOPE_DIM, + ) + + # Return [B, 1, H, D] and [B, 1, H] + return out.unsqueeze(1), lse.unsqueeze(1) + + +def _merge_partial_attn( + out1: torch.Tensor, + lse1: torch.Tensor, + out2: torch.Tensor, + lse2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Merge two attention outputs using LSE-weighted combination. + + out: [B, 1, H, D] bf16, lse: [B, 1, H] float32 + """ + max_lse = torch.maximum(lse1, lse2) + w1 = torch.where(lse1 > -1e20, torch.exp(lse1 - max_lse), torch.zeros_like(lse1)) + w2 = torch.where(lse2 > -1e20, torch.exp(lse2 - max_lse), torch.zeros_like(lse2)) + total = (w1 + w2).clamp(min=1e-20) + merged = ( + w1.unsqueeze(-1) * out1.float() + w2.unsqueeze(-1) * out2.float() + ) / total.unsqueeze(-1) + merged_lse = max_lse + torch.log(total) + return merged.to(torch.bfloat16), merged_lse + + +def _apply_attn_sink( + out: torch.Tensor, + lse: torch.Tensor, + attn_sink: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply attention sink normalization. + + The sink adds to the softmax denominator without contributing output, + effectively down-weighting all attention scores. + + out: [B, 1, H, D] bf16, lse: [B, 1, H] f32, attn_sink: [H] f32 + """ + sink_lse = attn_sink.view(1, 1, -1).expand_as(lse) + combined_lse = torch.logaddexp(lse, sink_lse) + w = torch.where( + lse > -1e20, + torch.exp(lse - combined_lse), + torch.zeros_like(lse), + ) + return (out.float() * w.unsqueeze(-1)).to(torch.bfloat16), combined_lse + + +def flash_mla_sparse_decode_triton( + q: torch.Tensor, + k_cache: torch.Tensor, + indices: torch.Tensor, + topk_length: Optional[torch.Tensor], + attn_sink: Optional[torch.Tensor], + head_dim_v: int, + softmax_scale: float, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """SM120-optimized sparse MLA decode using tiled Triton kernel. + + Processes SWA and extra (c4/c128) caches separately via the same + Triton kernel, then merges results using LSE-weighted combination. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + # Process main cache (SWA) + out, lse = _run_triton_sparse_decode( + q, + k_cache, + indices, + topk_length, + softmax_scale, + ) + + # Process extra cache (c4 / c128) if present + if extra_k_cache is not None and extra_indices is not None: + out_extra, lse_extra = _run_triton_sparse_decode( + q, + extra_k_cache, + extra_indices, + extra_topk_length, + softmax_scale, + ) + out, lse = _merge_partial_attn(out, lse, out_extra, lse_extra) + + # Apply attention sink + if attn_sink is not None: + out, lse = _apply_attn_sink(out, lse, attn_sink) + + # Return format matching PyTorch fallback: (out, lse.permute(0,2,1)) + return out, lse.permute(0, 2, 1) diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py index de433fe2d505..e27e96fff326 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py @@ -18,6 +18,9 @@ def _compute_enable_deep_gemm(): sm_version = get_device_sm() if (_is_cuda and sm_version < 90) or (_is_musa and sm_version < 31): return False + # DeepGEMM requires TMEM/tcgen05 (SM100+datacenter), not available on SM120 + if sm_version == 120: + return False if not (_is_cuda or _is_musa): return False diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py new file mode 100644 index 000000000000..95d303f0a8fb --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py @@ -0,0 +1,454 @@ +"""SM120-optimized Triton MXFP4 MoE kernel — CUDA graph compatible. + +Replaces the PyTorch fallback (per-expert for-loop + full dequant + matmul) +with fused Triton kernels that: +1. Fuse FP4 dequant + GEMV (no intermediate BF16 weight materialization) +2. Process each (token, expert) slot independently — no data-dependent routing +3. Respect SM120 shared memory constraint (99 KB/block) + +CUDA graph compatibility: +- No .unique(), .item(), .nonzero() — all routing is tensor-level +- Fixed grid dimensions (M*topk, N_blocks) per captured batch size +- All control flow is static or within Triton kernels + +SM120 constraints: +- SMEM: 99 KB/block (vs SM100 228 KB) +- No TMEM/tcgen05 — uses mma.sync.aligned via Triton +- Max warps: 48/SM +- Registers: ~128/thread practical limit +""" + +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def _dequant_fp4_lut(nibble): + """Decode a 4-bit FP4 E2M1 nibble to float32 using arithmetic.""" + sign_bit = (nibble >> 3) & 1 + exp_bits = (nibble >> 1) & 3 + man_bit = nibble & 1 + + is_subnormal = exp_bits == 0 + mantissa = 1.0 + man_bit.to(tl.float32) * 0.5 + exponent = tl.math.exp2((exp_bits - 1).to(tl.float32)) + val = tl.where(is_subnormal, man_bit.to(tl.float32) * 0.5, mantissa * exponent) + val = tl.where(sign_bit != 0, -val, val) + return val + + +# ── Per-slot GEMV kernel: processes one (token, expert) pair ── + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2), + ], + key=["N", "K"], +) +@triton.jit +def _mxfp4_slot_gemv_kernel( + # Pointers + A_ptr, # [M_total, K] bf16 — source rows + B_packed_ptr, # [E, N, K//2] uint8 — packed FP4 expert weights + B_scale_ptr, # [E, N, K//32] float32 — weight scales + C_ptr, # [num_slots, N] bf16 — output + token_ids_ptr, # [num_slots] int32 — which A row for each slot + expert_ids_ptr, # [num_slots] int32 — which expert's B for each slot + # Dimensions + N: tl.int32, + K: tl.int32, + # A strides + stride_am: tl.int32, + # B strides (within an expert) + stride_bn: tl.int32, + stride_bk2: tl.int32, + # B_scale strides (within an expert) + stride_bsn: tl.int32, + stride_bsk32: tl.int32, + # Expert strides (between experts) + expert_b_stride: tl.int64, + expert_s_stride: tl.int64, + # C strides + stride_cm: tl.int32, + # Block sizes + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Per-slot fused MXFP4 dequant + GEMV. + + Grid: (num_slots, cdiv(N, BLOCK_N)) + Each program computes one (token, expert) pair for a BLOCK_N slice of output. + """ + slot_id = tl.program_id(0) + n_block = tl.program_id(1) + + token_id = tl.load(token_ids_ptr + slot_id).to(tl.int64) + expert_id = tl.load(expert_ids_ptr + slot_id).to(tl.int64) + + offs_n = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + # Expert weight base pointers + b_base = expert_id * expert_b_stride + s_base = expert_id * expert_s_stride + a_base = token_id * stride_am + + for k_start in range(0, K, BLOCK_K): + # ── Load packed B: [BLOCK_N, BLOCK_K//2] ── + offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2) + b_mask = n_mask[:, None] & (offs_k2[None, :] < K // 2) + b_packed = tl.load( + B_packed_ptr + + b_base + + offs_n[:, None] * stride_bn + + offs_k2[None, :] * stride_bk2, + mask=b_mask, + other=0, + ) + + # ── FP4 dequant ── + b_u8 = b_packed.to(tl.int32) + val_lo = _dequant_fp4_lut(b_u8 & 0x0F) # even K indices + val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) # odd K indices + + # ── Load and apply scales: [BLOCK_N, BLOCK_K//2] ── + group_ids = tl.arange(0, BLOCK_K // 2) // 16 # 32 values per group, 2 per byte + s_mask = n_mask[:, None] & ((k_start // 32 + group_ids[None, :]) < K // 32) + scales = tl.load( + B_scale_ptr + + s_base + + offs_n[:, None] * stride_bsn + + (k_start // 32 + group_ids[None, :]) * stride_bsk32, + mask=s_mask, + other=1.0, + ) + val_lo = val_lo * scales + val_hi = val_hi * scales + + # ── Load A even/odd: [BLOCK_K//2] each ── + offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2 + offs_k_odd = offs_k_even + 1 + + a_even = tl.load( + A_ptr + a_base + offs_k_even, + mask=offs_k_even < K, + other=0.0, + ).to(tl.float32) + a_odd = tl.load( + A_ptr + a_base + offs_k_odd, + mask=offs_k_odd < K, + other=0.0, + ).to(tl.float32) + + # ── Dot product: acc[n] += sum_k(a_even[k]*B_lo[n,k] + a_odd[k]*B_hi[n,k]) ── + acc += tl.sum(a_even[None, :] * val_lo, axis=1) + acc += tl.sum(a_odd[None, :] * val_hi, axis=1) + + # ── Store output ── + tl.store( + C_ptr + slot_id * stride_cm + offs_n, + acc.to(tl.bfloat16), + mask=n_mask, + ) + + +# ── Legacy per-expert GEMM kernel (kept for benchmarking) ── + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _mxfp4_gemm_kernel( + # Pointers + A_ptr, # [M, K] bf16 activation + B_packed_ptr, # [N, K//2] uint8 packed FP4 + B_scale_ptr, # [N, K//32] float32 scales + C_ptr, # [M, N] bf16 output + # Dimensions + M, + N, + K, + # Strides + stride_am, + stride_ak, + stride_bn, + stride_bk2, + stride_bsn, + stride_bsk32, + stride_cm, + stride_cn, + # Constexprs + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Fused MXFP4 dequant + GEMM: C = A @ dequant(B_packed, B_scale).T""" + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2) + b_mask = (offs_n[:, None] < N) & (offs_k2[None, :] < K // 2) + b_packed = tl.load( + B_packed_ptr + offs_n[:, None] * stride_bn + offs_k2[None, :] * stride_bk2, + mask=b_mask, + other=0, + ) + + b_u8 = b_packed.to(tl.int32) + val_lo = _dequant_fp4_lut(b_u8 & 0x0F) + val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) + + group_ids = tl.arange(0, BLOCK_K // 2) // 16 + scales_per_byte = tl.load( + B_scale_ptr + + offs_n[:, None] * stride_bsn + + (k_start // 32 + group_ids[None, :]) * stride_bsk32, + mask=(offs_n[:, None] < N) + & ((k_start // 32 + group_ids[None, :]) < K // 32), + other=1.0, + ) + val_lo = val_lo * scales_per_byte + val_hi = val_hi * scales_per_byte + + offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2 + offs_k_odd = offs_k_even + 1 + + a_even_mask = (offs_m[:, None] < M) & (offs_k_even[None, :] < K) + a_even = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_k_even[None, :] * stride_ak, + mask=a_even_mask, + other=0.0, + ).to(tl.float32) + + a_odd_mask = (offs_m[:, None] < M) & (offs_k_odd[None, :] < K) + a_odd = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_k_odd[None, :] * stride_ak, + mask=a_odd_mask, + other=0.0, + ).to(tl.float32) + + acc += tl.dot(a_even, tl.trans(val_lo)) + acc += tl.dot(a_odd, tl.trans(val_hi)) + + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store( + C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, + acc.to(tl.bfloat16), + mask=c_mask, + ) + + +def mxfp4_gemm_triton( + A: torch.Tensor, + B_packed: torch.Tensor, + B_scale: torch.Tensor, + K_full: int, +) -> torch.Tensor: + """Triton fused MXFP4 dequant + GEMM: output = A @ dequant(B).T + + Kept for standalone benchmarking. The MoE forward uses the slot kernel. + """ + M = A.shape[0] + N = B_packed.shape[0] + K = K_full + + if B_scale.dtype == torch.float8_e8m0fnu: + B_scale = B_scale.to(torch.float32) + elif B_scale.dtype != torch.float32: + B_scale = B_scale.float() + + C = torch.empty(M, N, dtype=torch.bfloat16, device=A.device) + A = A.contiguous() + B_packed = B_packed.contiguous() + B_scale = B_scale.contiguous() + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + B_u8 = B_packed.view(torch.uint8) + + _mxfp4_gemm_kernel[grid]( + A, + B_u8, + B_scale, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B_u8.stride(0), + B_u8.stride(1), + B_scale.stride(0), + B_scale.stride(1), + C.stride(0), + C.stride(1), + ) + return C + + +def mxfp4_moe_forward_triton( + hidden_states: torch.Tensor, + w13_packed: torch.Tensor, + w2_packed: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_size: int, + intermediate_size: int, + inplace: bool = False, + routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, +) -> torch.Tensor: + """SM120-optimized MXFP4 MoE forward — CUDA graph compatible. + + Uses per-slot GEMV kernels instead of per-expert Python loops. + Each (token, expert) slot is processed independently with a fixed grid, + eliminating .unique()/.item()/.nonzero() that break CUDA graph capture. + """ + import torch.nn.functional as F + + M, K = hidden_states.shape + topk = topk_ids.shape[1] + I = intermediate_size + num_slots = M * topk + device = hidden_states.device + dtype = hidden_states.dtype + + # ── Graph-safe routing: flatten topk assignments ── + # token_ids[slot] = which row of A (original token index) + # expert_ids[slot] = which expert's weights to use + # topk_ids may contain -1 for padded/filtered tokens; clamp to 0 for safe + # Triton loads, then zero out invalid slots' output after GEMM. + flat_expert_ids_raw = topk_ids.reshape(-1).contiguous() # [M*topk] + invalid_slot_mask = flat_expert_ids_raw < 0 # [M*topk] + flat_expert_ids = flat_expert_ids_raw.clamp(min=0) # safe for indexing + token_ids = ( + torch.arange(M, device=device, dtype=torch.int32) + .unsqueeze(1) + .expand(M, topk) + .reshape(-1) + .contiguous() + ) # [M*topk] + + # ── Ensure scales are float32 ── + if w13_scale.dtype != torch.float32: + w13_scale = w13_scale.to(torch.float32) + if w2_scale.dtype != torch.float32: + w2_scale = w2_scale.to(torch.float32) + + # ── GEMM1: gate_up projection ── + # hidden_states[token] @ w13[expert].T → [num_slots, 2*I] + intermediate = torch.empty(num_slots, 2 * I, dtype=dtype, device=device) + + w13_u8 = w13_packed.view(torch.uint8) # [E, 2*I, K//2] + grid1 = lambda meta: (num_slots, triton.cdiv(2 * I, meta["BLOCK_N"])) + + _mxfp4_slot_gemv_kernel[grid1]( + hidden_states, + w13_u8, + w13_scale, + intermediate, + token_ids, + flat_expert_ids, + 2 * I, + K, + hidden_states.stride(0), + w13_u8.stride(1), + w13_u8.stride(2), + w13_scale.stride(1), + w13_scale.stride(2), + w13_u8.stride(0), + w13_scale.stride(0), + intermediate.stride(0), + ) + + # ── SiLU activation (graph-safe vectorized ops) ── + gate = intermediate[:, :I].float() + up = intermediate[:, I:].float() + if clamp_limit is not None and clamp_limit > 0: + gate = torch.clamp(gate, max=clamp_limit) + up = torch.clamp(up, min=-clamp_limit, max=clamp_limit) + activated = (F.silu(gate) * up).to(dtype) + + # ── GEMM2: down projection ── + # activated[slot] @ w2[expert].T → [num_slots, K] + down = torch.empty(num_slots, K, dtype=dtype, device=device) + + # For GEMM2, A is the activated buffer — each slot reads its own row + slot_ids = torch.arange(num_slots, device=device, dtype=torch.int32) + + w2_u8 = w2_packed.view(torch.uint8) # [E, K, I//2] + grid2 = lambda meta: (num_slots, triton.cdiv(K, meta["BLOCK_N"])) + + _mxfp4_slot_gemv_kernel[grid2]( + activated, + w2_u8, + w2_scale, + down, + slot_ids, + flat_expert_ids, + K, + I, + activated.stride(0), + w2_u8.stride(1), + w2_u8.stride(2), + w2_scale.stride(1), + w2_scale.stride(2), + w2_u8.stride(0), + w2_scale.stride(0), + down.stride(0), + ) + + # ── Zero out invalid slots (padded/filtered tokens with topk_ids == -1) ── + # Use multiplication instead of boolean indexing to stay CUDA-graph-safe + # (no GPU→CPU sync). valid_mask is 1.0 for valid slots, 0.0 for invalid. + valid_mask = (~invalid_slot_mask).unsqueeze(1).to(dtype) # [M*topk, 1] + down = down * valid_mask + + # ── Weighted sum across topk slots (graph-safe) ── + flat_weights = topk_weights.reshape(-1).unsqueeze(1).to(dtype) # [M*topk, 1] + output = (down * flat_weights).view(M, topk, K).sum(dim=1) + + if routed_scaling_factor is not None and routed_scaling_factor != 1.0: + output.mul_(routed_scaling_factor) + + return output diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py index f7fc76dfc96a..618e9cec5bb2 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py +++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.utils import log_info_on_rank0, set_weight_attrs -from sglang.srt.utils.common import is_sm90_supported +from sglang.srt.utils.common import is_sm90_supported, is_sm120_supported if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput @@ -108,10 +108,42 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_mega_moe_weights_built", False): return - if not is_sm90_supported(): + if not is_sm90_supported() and not is_sm120_supported(): raise RuntimeError( "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." ) + + # SM120: Skip Marlin repacking, keep original weight format + # for Triton dequant kernel (Marlin kernel produces NaN on SM120) + if is_sm120_supported(): + from torch.nn import Parameter + + log_info_on_rank0( + logger, + f"SM120 detected: using PyTorch MXFP4 MoE fallback " + f"(layer: {self.prefix})...", + ) + # Keep weights in original packed int8 format + # Normalize scales to float32 for direct use in dequant + w13_s = layer.w13_weight_scale_inv.data + w2_s = layer.w2_weight_scale_inv.data + if w13_s.dtype == torch.float8_e8m0fnu: + pass # already in e8m0 format, will convert at runtime + elif w13_s.dtype in (torch.uint8, torch.int8): + layer.w13_weight_scale_inv = Parameter( + w13_s.view(torch.uint8) + .view(torch.float8_e8m0fnu) + .to(torch.float32), + requires_grad=False, + ) + layer.w2_weight_scale_inv = Parameter( + w2_s.view(torch.uint8).view(torch.float8_e8m0fnu).to(torch.float32), + requires_grad=False, + ) + # else: float32 scales are already usable directly + layer._dsv4_mxfp4_backend = "sm120_triton" + return + if not check_moe_marlin_supports_layer(layer, 32): raise RuntimeError( "Current DeepSeekV4 MoE layer does not satisfy Marlin constraints." @@ -144,6 +176,43 @@ def apply( if not TopKOutputChecker.format_is_standard(topk_output): raise ValueError(f"Unsupported topk output format: {topk_output.format}") + # SM120: use Triton fused dequant+GEMM (Marlin kernel produces NaN on SM120) + if getattr(layer, "_dsv4_mxfp4_backend", None) == "sm120_triton": + from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( + mxfp4_moe_forward_triton, + ) + + hidden_states = dispatch_output.hidden_states + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] * 2 + + output = mxfp4_moe_forward_triton( + hidden_states=hidden_states, + w13_packed=w13, + w2_packed=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + topk_ids=topk_output.topk_ids, + topk_weights=topk_output.topk_weights, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + routed_scaling_factor=( + self.runner.config.routed_scaling_factor + if hasattr(self.runner, "config") + else None + ), + clamp_limit=( + self.runner.config.swiglu_limit + if hasattr(self.runner, "config") + else None + ), + ) + return StandardCombineInput(hidden_states=output) + quant_info = MarlinMoeQuantInfo( w13_qweight=layer.w13_weight, w2_qweight=layer.w2_weight, diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index fdf6d557ca45..e307b640f90b 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -1336,7 +1336,7 @@ def _setup_fp8_wo_a_scales(self, is_nextn: bool) -> None: ) def post_load_weights(self, is_nextn=False, weight_names=None): - if _FP8_WO_A_GEMM: + if envs.SGLANG_OPT_FP8_WO_A_GEMM.get(): self._setup_fp8_wo_a_scales(is_nextn) if is_nextn: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1d1b8d29959d..dda65b62c64e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1940,6 +1940,21 @@ def _handle_model_specific_adjustments(self): logger.info( "Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM" ) + elif is_sm120_supported(): + # SM120: DSv4-Flash uses MXFP4 experts; marlin backend dispatches + # to our SM120 Triton fallback in mxfp4_marlin_moe.py + if self.moe_runner_backend == "auto": + self.moe_runner_backend = "marlin" + logger.info( + "Use marlin as MoE runner backend on SM120 for DeepseekV3/V4" + ) + # SM120 lacks tcgen05/TMEM: disable features that depend on + # DeepGEMM or require >99KB SMEM (topk_v2). + envs.SGLANG_OPT_FP8_WO_A_GEMM.set(False) + envs.SGLANG_OPT_USE_TOPK_V2.set(False) + envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.set(False) + envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.set(False) + envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.set(True) elif is_hip(): if not self.enable_dp_attention and self.nnodes == 1: # TODO (Hubert): Put this back later @@ -1985,6 +2000,20 @@ def _handle_model_specific_adjustments(self): validate_deepseek_v4_cp(self) + if is_sm120_supported(): + if self.moe_runner_backend == "auto": + self.moe_runner_backend = "marlin" + logger.info( + "Use marlin as MoE runner backend on SM120 for DeepseekV4" + ) + # SM120 lacks tcgen05/TMEM: disable features that depend on + # DeepGEMM or require >99KB SMEM (topk_v2). + envs.SGLANG_OPT_FP8_WO_A_GEMM.set(False) + envs.SGLANG_OPT_USE_TOPK_V2.set(False) + envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.set(False) + envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.set(False) + envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.set(True) + elif model_arch in ["GptOssForCausalLM"]: # Set attention backend for GPT-OSS if self.is_attention_backend_not_set():