From a4316882f7253c7672528557649ae7f320c9256f Mon Sep 17 00:00:00 2001 From: alichen Date: Fri, 8 May 2026 03:30:42 -0700 Subject: [PATCH 1/5] feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference Adds full SM120 (RTX PRO 6000 / RTX 5090 / DGX Spark) support for DeepSeek-V4 on SGLang, rebased onto main branch. Key changes: - Triton MXFP4 MoE kernel for SM120 (no MARLIN/tcgen05 on desktop Blackwell) - Triton FlashMLA sparse decode kernel for SM120 - MQA wq-precompute with vectorized batch for CUDA graph compatibility - DeepGEMM/PDL guards for SM120 (no TMEM/tcgen05) - NSA backend SM120 dispatch (tilelang default, skip DeepGEMM metadata) - FlashMLA SM120 adapter for deepseek_v4_backend - 3 CUDA-graph-breaking paths fixed (MoE .unique/.item, NSA/Compressed MQA) Results (8x RTX PRO 6000, TP=8): - Decode: 10.26 tok/s BS=1 with CUDA graph (2.4x vs without) - GSM8K 5-shot: 98.0% accuracy (200 questions) Co-Authored-By: Claude Opus 4.6 --- python/sglang/jit_kernel/utils.py | 6 +- python/sglang/srt/environ.py | 1 + .../layers/attention/deepseek_v4_backend.py | 15 +- .../srt/layers/attention/dsv4/indexer.py | 98 ++-- .../srt/layers/attention/dsv4/metadata.py | 8 +- .../attention/flash_mla_sm120_fallback.py | 261 ++++++++++ .../attention/flash_mla_sm120_triton.py | 366 +++++++++++++++ .../srt/layers/attention/nsa/nsa_indexer.py | 80 +++- .../attention/nsa/sm120_mqa_fallback.py | 215 +++++++++ .../layers/attention/nsa/sm120_mqa_triton.py | 177 +++++++ .../srt/layers/attention/nsa_backend.py | 40 +- .../layers/deep_gemm_wrapper/configurer.py | 6 +- .../fused_moe_triton/mxfp4_moe_fallback.py | 185 ++++++++ .../mxfp4_moe_sm120_triton.py | 444 ++++++++++++++++++ .../layers/quantization/mxfp4_marlin_moe.py | 74 ++- python/sglang/srt/server_args.py | 25 +- python/sglang/test/test_sm120_mqa_fallback.py | 279 +++++++++++ 17 files changed, 2221 insertions(+), 59 deletions(-) create mode 100644 python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py create mode 100644 python/sglang/srt/layers/attention/flash_mla_sm120_triton.py create mode 100644 python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py create mode 100644 python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py create mode 100644 python/sglang/test/test_sm120_mqa_fallback.py diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index bcd42e5ce349..d42c57262000 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -309,7 +309,11 @@ def get_jit_cuda_arch() -> ArchInfo: def is_arch_support_pdl() -> bool: if is_hip_runtime(): return False - return get_jit_cuda_arch().major >= 9 + arch = get_jit_cuda_arch() + # PDL requires SM100+ datacenter (tcgen05/TMEM); SM120 (desktop Blackwell) lacks these + if arch.major == 12: + return False + return arch.major >= 9 def _find_package_root(package: str) -> Optional[pathlib.Path]: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 57edcdd80ec4..5852a0ff73cb 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -578,6 +578,7 @@ class Envs: SGLANG_OPT_USE_ONLINE_COMPRESS = EnvBool(False) SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) + SGLANG_HACK_FLASHMLA_BACKEND = EnvStr("kernel") # SWA radix cache SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index 93e4507656c1..48ddd6999c08 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -37,6 +37,10 @@ from sglang.srt.layers.attention.dsv4.quant_k_cache import ( quant_to_nope_fp8_rope_bf16_pack_triton, ) +from sglang.srt.layers.attention.flash_mla_sm120_fallback import ( + _is_sm120, + flash_mla_with_kvcache_entrypoint, +) from sglang.srt.layers.dp_attention import ( get_attention_cp_rank, get_attention_cp_size, @@ -71,6 +75,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: def _create_flashmla_metadata(): + if _is_sm120: + return None import flash_mla return flash_mla.get_mla_metadata()[0] @@ -1021,9 +1027,7 @@ def forward( extra_indices.shape[-1] % 64 == 0 ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" - import flash_mla - - o = flash_mla.flash_mla_with_kvcache( + input_dict = dict( q=q, k_cache=swa_k_cache, head_dim_v=self.head_dim_v, @@ -1038,7 +1042,10 @@ def forward( extra_k_cache=extra_k_cache, extra_indices_in_kvcache=extra_indices, extra_topk_length=extra_topk_lengths, - )[0] + ) + + backend = envs.SGLANG_HACK_FLASHMLA_BACKEND.get() + o = flash_mla_with_kvcache_entrypoint(**input_dict, backend=backend)[0] o = o.squeeze(1) return o diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index 3bc982446681..7eb751cf5f7c 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -16,7 +16,10 @@ from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config from sglang.srt.environ import envs from sglang.srt.layers.attention.dsv4.compressor import Compressor -from sglang.srt.layers.attention.dsv4.metadata import PagedIndexerMetadata +from sglang.srt.layers.attention.dsv4.metadata import ( + PagedIndexerMetadata, + _is_sm120, +) from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation from sglang.srt.layers.attention.nsa.triton_kernel import act_quant from sglang.srt.layers.linear import ReplicatedLinear @@ -51,44 +54,81 @@ def fp8_paged_mqa_logits_torch( max_seq_len: int, clean_logits: bool = True, ) -> torch.Tensor: + """CUDA-graph-compatible FP8 paged MQA logits (vectorized, no .item()). + + Vectorized across batches using batched gather + bmm instead of + per-batch Python loop with .item() calls. + """ _ = deep_gemm_metadata batch_size, _, num_heads, head_dim = q_fp8.shape block_size = kvcache_fp8.shape[1] + device = q_fp8.device - assert head_dim == 128, "torch reference impl hardcodes DSV4 indexer head_dim=128" - assert block_size == 64, "torch reference impl hardcodes block_size=64 cache layout" + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) assert weight.shape == (batch_size, num_heads) + if seq_lens.dim() > 1: + seq_lens = seq_lens.squeeze(-1) assert seq_lens.shape == (batch_size,) assert page_table.shape[0] == batch_size assert clean_logits == False - logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) - for i in range(batch_size): - q = q_fp8[i, 0] - q = q.to(torch.float32) - q_scale = weight[i] - seq_len = int(seq_lens[i].item()) - assert seq_len <= max_seq_len - num_pages = (seq_len + block_size - 1) // block_size - padded_seq_len = num_pages * block_size - pages = page_table[i, :num_pages] - kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) - kvcache = kvcache_fp8[pages] - SCALE_OFFSET = block_size * head_dim - kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE) - kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32) - kvcache_value = kvcache_value.to(torch.float32) - kvcache_scale = kvcache_scale.contiguous() - kvcache_value = kvcache_value.view(padded_seq_len, head_dim) - kvcache_scale = kvcache_scale.view(padded_seq_len) - score = F.linear(kvcache_value, q) - score = F.relu(score) - score *= q_scale[None, :] - score = score.sum(dim=1) - score *= kvcache_scale - logits[i, :seq_len] = score[:seq_len] + # ── Vectorized: no .item(), no per-batch loop ── + max_pages = (max_seq_len + block_size - 1) // block_size + max_padded_seq = max_pages * block_size + + # Flatten KV cache for indexing: [total_pages, block_size * (head_dim + 4)] + kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + SCALE_OFFSET = block_size * head_dim + + # Gather pages for all batches: [batch, max_pages] + page_ids = page_table[:, :max_pages] + # Gather KV data: [batch, max_pages, block_size * (head_dim + 4)] + kvcache_gathered = kvcache_flat[page_ids] + + # Split value and scale + kv_value_raw = kvcache_gathered[ + ..., :SCALE_OFFSET + ] # [batch, max_pages, block_size * head_dim] + kv_scale_raw = kvcache_gathered[ + ..., SCALE_OFFSET: + ] # [batch, max_pages, block_size * 4] + + # Dequant value: view as FP8, convert to float32 + kv_value = kv_value_raw.contiguous().view(dtype=FP8_DTYPE).to(torch.float32) + kv_value = kv_value.view(batch_size, max_padded_seq, head_dim) + + # Dequant scale + kv_scale = kv_scale_raw.contiguous().view(dtype=torch.float32) + kv_scale = kv_scale.view(batch_size, max_padded_seq) + + # Q: [batch, num_heads, head_dim] + q = q_fp8[:, 0].to(torch.float32) + + # Batched matmul: [batch, max_padded_seq, head_dim] @ [batch, head_dim, num_heads] + score = torch.bmm(kv_value, q.transpose(1, 2)) # [batch, max_padded_seq, num_heads] + + # ReLU + scale by weight + sum across heads + score = F.relu(score) + score = score * weight.unsqueeze(1) # [batch, max_padded_seq, num_heads] + score = score.sum(dim=2) # [batch, max_padded_seq] + + # Apply KV scale + score = score * kv_scale # [batch, max_padded_seq] + + # Create validity mask and write output — graph-safe (no torch.tensor() calls) + out_width = min(max_padded_seq, max_seq_len) + logits = score.new_full((batch_size, max_seq_len), float("-inf")) + logits[:, :out_width] = score[:, :out_width] + + # Mask invalid positions to -inf + positions = torch.arange(max_seq_len, device=device) + invalid_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze( + 1 + ) # [batch, max_seq_len] + logits.masked_fill_(invalid_mask, float("-inf")) return logits @@ -381,7 +421,7 @@ def forward_c4_indexer( from sglang.srt.layers.attention.dsv4.tilelang_kernel import ( tilelang_fp8_paged_mqa_logits as fn, ) - elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or _is_sm120: fn = fp8_paged_mqa_logits_torch else: from deep_gemm import fp8_paged_mqa_logits as fn diff --git a/python/sglang/srt/layers/attention/dsv4/metadata.py b/python/sglang/srt/layers/attention/dsv4/metadata.py index 7995dbd959cb..6bef63619883 100644 --- a/python/sglang/srt/layers/attention/dsv4/metadata.py +++ b/python/sglang/srt/layers/attention/dsv4/metadata.py @@ -8,6 +8,10 @@ from sglang.srt.environ import envs from sglang.srt.utils import is_hip +from sglang.srt.utils.common import get_device_sm + +_is_cuda = torch.cuda.is_available() and not is_hip() +_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 if TYPE_CHECKING: pass @@ -103,7 +107,9 @@ class PagedIndexerMetadata: topk_metadata: torch.Tensor = field(init=False, repr=False) def __post_init__(self): - if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or _is_sm120: + # SM120: DeepGEMM get_paged_mqa_logits_metadata asserts + # "Unsupported architecture" on SM120. Use None (torch fallback path). self.deep_gemm_metadata = None else: import deep_gemm diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py new file mode 100644 index 000000000000..76cf3683df96 --- /dev/null +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py @@ -0,0 +1,261 @@ +"""FlashMLA adapter with SM120 fallback. + +The FP8 KV cache uses a page-internal layout where NOPE+ROPE data has +stride (nope_dim + rope_dim*2) per token, and scales are stored in a +separate region at the end of each page. The tensor shape +``(num_pages, page_size, 1, bytes_per_token)`` is just metadata for the +FlashMLA CUDA kernel -- it does NOT mean each token occupies +*bytes_per_token* contiguous bytes. + +On SM120 (Blackwell Desktop / RTX PRO 6000) the flash_mla CUDA kernel +is not available, so this module provides a pure-PyTorch fallback that +reads the raw paged buffer with the correct addressing. + +When SGLANG_SM120_TRITON_FLASHMLA=1 (default), a fused Triton kernel is +used instead of the PyTorch fallback for significantly better performance. +Set to 0 to fall back to the pure-PyTorch path. +""" + +import logging +import os + +import torch + +from sglang.srt.utils import is_hip +from sglang.srt.utils.common import get_device_sm + +logger = logging.getLogger(__name__) + +_is_cuda = torch.cuda.is_available() and not is_hip() +_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 + +# Page layout constants for DSv4-Flash (MODEL1): +# nope_dim = 448, rope_dim = 64, quantize_block_size = 64 +# nope_rope_stride = 448 + 64*2 = 576 bytes per token +# scale_stride = ceil(448/64) + 1 = 8 bytes per token (7 scales + 1 pad) +# bytes_per_token = 448 + 128 + 8 = 584 +# page_bytes = ceil_div(page_size * 584, 576) * 576 + +_NOPE_DIM = 448 +_ROPE_DIM = 64 +_NOPE_ROPE_STRIDE = _NOPE_DIM + _ROPE_DIM * 2 # 576 +_TILE_SIZE = 64 +_NUM_TILES = _NOPE_DIM // _TILE_SIZE # 7 +_SCALE_STRIDE = _NUM_TILES + 1 # 8 (7 scales + 1 pad) +_D = _NOPE_DIM + _ROPE_DIM # 512 + + +def _gather_and_dequant(k_cache, indices, page_size): + """Gather KV entries from the paged buffer using correct page-internal addressing. + + Args: + k_cache: (num_pages, page_size, 1, bytes_per_token) float8_e4m3fn + Non-contiguous view of the raw page buffer. + indices: (...) int32/int64, token-level indices. -1 = invalid. + page_size: tokens per page (256) + + Returns: + kv: (..., _D) bfloat16, dequantized KV vectors + """ + idx_shape = indices.shape + flat_idx = indices.reshape(-1) # (N,) + N = flat_idx.shape[0] + device = k_cache.device + + # Page-level addressing + page_bytes = k_cache.stride(0) # actual byte stride between pages + pages = flat_idx // page_size + offsets = flat_idx % page_size + + # Clamp invalid indices + safe_pages = pages.clamp(min=0) + safe_offsets = offsets.clamp(min=0) + + # Access raw buffer as uint8 — use as_strided to get full page view + num_pages = k_cache.shape[0] + raw_pages = k_cache.as_strided( + (num_pages, page_bytes), + (page_bytes, 1), + ).view( + torch.uint8 + ) # (num_pages, page_bytes) uint8 + # Note: float8_e4m3fn and uint8 are both 1 byte, view is safe + + # Compute byte offsets within each page + # NOPE: page[safe_page, safe_offset * 576 + 0:448] + # ROPE: page[safe_page, safe_offset * 576 + 448:576] + # SCALES: page[safe_page, page_size * 576 + safe_offset * 8 + 0:7] + + nope_base = safe_offsets * _NOPE_ROPE_STRIDE # (N,) + nope_offsets = nope_base.unsqueeze(-1) + torch.arange( + _NOPE_DIM, device=device, dtype=torch.long + ) # (N, 448) + + rope_base = nope_base + _NOPE_DIM # (N,) + rope_offsets = rope_base.unsqueeze(-1) + torch.arange( + _ROPE_DIM * 2, device=device, dtype=torch.long + ) # (N, 128) + + scale_section_offset = page_size * _NOPE_ROPE_STRIDE # 147456 + scale_base = scale_section_offset + safe_offsets * _SCALE_STRIDE # (N,) + scale_offsets = scale_base.unsqueeze(-1) + torch.arange( + _NUM_TILES, device=device, dtype=torch.long + ) # (N, 7) + + # Gather bytes per page — use advanced indexing + # raw_pages[safe_pages, nope_offsets] → (N, 448) + page_idx_nope = safe_pages.unsqueeze(-1).expand_as(nope_offsets) + nope_bytes = raw_pages[page_idx_nope, nope_offsets] # (N, 448) uint8 + + page_idx_rope = safe_pages.unsqueeze(-1).expand_as(rope_offsets) + rope_bytes = raw_pages[page_idx_rope, rope_offsets] # (N, 128) uint8 + + page_idx_scale = safe_pages.unsqueeze(-1).expand_as(scale_offsets) + scale_bytes = raw_pages[page_idx_scale, scale_offsets] # (N, 7) uint8 + + # Reinterpret dtypes + nope_fp8 = nope_bytes.view(torch.float8_e4m3fn) # (N, 448) + rope_bf16 = rope_bytes.contiguous().view(torch.bfloat16) # (N, 64) + scale_e8m0 = scale_bytes.view(torch.float8_e8m0fnu) # (N, 7) + + # Dequantize: nope_tile * scale_tile → bf16 (vectorized) + result = torch.empty(N, _D, dtype=torch.bfloat16, device=device) + result[:, :_NOPE_DIM] = ( + ( + nope_fp8.view(N, _NUM_TILES, _TILE_SIZE).float() + * scale_e8m0.view(N, _NUM_TILES, 1).float() + ) + .view(N, _NOPE_DIM) + .to(torch.bfloat16) + ) + result[:, _NOPE_DIM:] = rope_bf16 + + return result.reshape(*idx_shape, _D) + + +def _sm120_sparse_decode_fwd( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache=None, + extra_indices=None, + extra_topk_length=None, +): + B, s_q, H_q, D_qk = q.shape + num_pages, page_size, H_k, bpt = k_cache.shape + topk = indices.shape[-1] + + invalid_mask = indices < 0 + safe_indices = indices.clamp(min=0) + + if topk_length is not None: + topk_range = torch.arange(topk, device=topk_length.device).view(1, 1, topk) + invalid_mask = invalid_mask | (topk_range >= topk_length.view(B, 1, 1)) + + # Gather and dequantize using page-aware addressing + gathered_kv = _gather_and_dequant(k_cache, safe_indices, page_size) + + if extra_k_cache is not None and extra_indices is not None: + extra_topk = extra_indices.shape[-1] + extra_page_size = extra_k_cache.shape[1] + extra_invalid = extra_indices < 0 + extra_safe = extra_indices.clamp(min=0) + if extra_topk_length is not None: + extra_range = torch.arange( + extra_topk, device=extra_topk_length.device + ).view(1, 1, extra_topk) + extra_invalid = extra_invalid | ( + extra_range >= extra_topk_length.view(B, 1, 1) + ) + extra_kv = _gather_and_dequant(extra_k_cache, extra_safe, extra_page_size) + gathered_kv = torch.cat([gathered_kv, extra_kv], dim=2) + invalid_mask = torch.cat([invalid_mask, extra_invalid], dim=2) + + gathered_kv[invalid_mask] = 0.0 + + q_f = q.float() + kv_f = gathered_kv.float() + kv_d = kv_f.shape[-1] + if D_qk != kv_d: + q_f = q_f[..., :kv_d] + + scores = torch.einsum("bshd,bstd->bsht", q_f, kv_f) * softmax_scale + scores.masked_fill_(invalid_mask.unsqueeze(2).expand_as(scores), float("-inf")) + + lse = torch.logsumexp(scores, dim=-1) + + if attn_sink is not None: + lse_for_out = torch.logsumexp( + torch.stack([lse, attn_sink.view(1, 1, H_q).expand_as(lse)], dim=0), dim=0 + ) + else: + lse_for_out = lse.clone() + + lonely = lse == float("-inf") + lse_for_out[lonely] = float("inf") + weights = torch.exp(scores - lse_for_out.unsqueeze(-1)) + out = torch.einsum("bsht,bstv->bshv", weights, kv_f[..., :head_dim_v]) + out[lonely.unsqueeze(-1).expand_as(out)] = 0.0 + + return out.to(torch.bfloat16), lse.permute(0, 2, 1) + + +_use_triton_flashmla = os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" + + +def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): + if _is_sm120: + q = kwargs["q"] + k_cache = kwargs["k_cache"] + indices = kwargs["indices"] + topk_length = kwargs.get("topk_length") + attn_sink = kwargs.get("attn_sink") + head_dim_v = kwargs["head_dim_v"] + softmax_scale = kwargs.get("softmax_scale") + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + extra_k_cache = kwargs.get("extra_k_cache") + extra_indices = kwargs.get("extra_indices_in_kvcache") + extra_topk_length = kwargs.get("extra_topk_length") + + if _use_triton_flashmla: + from sglang.srt.layers.attention.flash_mla_sm120_triton import ( + flash_mla_sparse_decode_triton, + ) + + out, lse = flash_mla_sparse_decode_triton( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache, + extra_indices, + extra_topk_length, + ) + return (out, lse) + + out, lse = _sm120_sparse_decode_fwd( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache, + extra_indices, + extra_topk_length, + ) + return (out, lse) + + assert backend == "kernel", f"unsupported backend {backend!r}" + import flash_mla + + return flash_mla.flash_mla_with_kvcache(**kwargs) diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py new file mode 100644 index 000000000000..0eae57d3a350 --- /dev/null +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py @@ -0,0 +1,366 @@ +"""SM120-optimized Triton FlashMLA sparse decode kernel — Tiled V2. + +Replaces V1's serial token loop with a tiled vectorized approach: + 1. BLOCK_T tokens loaded simultaneously via 2D gather (vs 1-at-a-time) + 2. All BLOCK_T QK scores computed at once via vectorized mul-reduce + 3. V accumulation via vectorized weighted sum across BLOCK_T tokens + 4. Online softmax operates on tile-level maxima (fewer rescales) + +Three typed views of the same paged buffer handle FP8/uint8/BF16 regions: +- float8_e4m3fn view -> nope FP8 values (direct load + dequant) +- uint8 view -> UE8M0 scale bytes (raw integer -> exp2 conversion) +- bfloat16 view -> rope BF16 values (direct load) + +DSv4 page layout (per token, 576 bytes data + 8 bytes scales): + Data section: [0:448] FP8 nope | [448:576] BF16 rope (64 values = 128 bytes) + Scale section: [page_size*576 + offset*8 : +7] UE8M0 scales (7 groups of 64) + +Target: RTX PRO 6000 (SM120, 188 SMs, 99KB SMEM, ~1.5 TB/s GDDR7, 96MB L2) +""" + +import logging +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +LOG2E = tl.constexpr(1.4426950408889634) + +# DSv4 KV cache layout constants +_NOPE_DIM = 448 +_ROPE_DIM = 64 +_D = _NOPE_DIM + _ROPE_DIM # 512 +_TOKEN_DATA_STRIDE = 576 # bytes per token in data section +_SCALE_STRIDE = 8 # bytes per token in scale section + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_T": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 16}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_T": 32}, num_warps=8, num_stages=2), + ], + key=["topk_rounded"], +) +@triton.jit +def _tiled_sparse_decode_kernel( + # Q: [B, H, D] bf16 + Q_ptr, + # Paged KV cache — three typed views of same underlying memory + cache_fp8_ptr, # float8_e4m3fn flat (1 byte/elem) — for nope + cache_uint8_ptr, # uint8 flat (1 byte/elem) — for scales + cache_bf16_ptr, # bfloat16 flat (2 bytes/elem) — for rope + # Indices: [B, topk] int32 + indices_ptr, + # Valid lengths: [B] int32 + topk_len_ptr, + # Output: [B, H, D] bf16 and LSE: [B, H] float32 + O_ptr, + LSE_ptr, + # Scalars + sm_scale: tl.float32, + page_size: tl.int32, + page_bytes: tl.int64, + scale_section_off: tl.int64, # page_size * 576 + H: tl.int32, + topk: tl.int32, + topk_rounded: tl.int32, # for autotune key + has_topk_len: tl.constexpr, + # Strides + stride_qb: tl.int32, + stride_qh: tl.int32, + stride_ob: tl.int32, + stride_oh: tl.int32, + stride_ib: tl.int32, # indices batch stride + # Constexprs + NOPE_PAD: tl.constexpr, # 512 (padded from 448) + ROPE_DIM: tl.constexpr, # 64 + NOPE_DIM_RT: tl.int32, # 448 (runtime, for masking) + BLOCK_T: tl.constexpr, # tokens per tile (16 or 32) +): + """Tiled sparse decode: vectorized gather + QK + softmax + V accumulation. + + Grid: (B, H) — one block per (batch, head) pair. + Each block processes all topk tokens in tiles of BLOCK_T. + """ + bid = tl.program_id(0) + hid = tl.program_id(1) + + # ---- Load Q for this (batch, head) ---- + q_base = bid * stride_qb + hid * stride_qh + nope_offs = tl.arange(0, NOPE_PAD) # [512] + nope_mask = nope_offs < NOPE_DIM_RT # [512], True for [0:448] + rope_offs = tl.arange(0, ROPE_DIM) # [64] + + q_nope = tl.load(Q_ptr + q_base + nope_offs, mask=nope_mask, other=0.0) + q_nope = q_nope.to(tl.float32) * sm_scale + q_rope = tl.load(Q_ptr + q_base + NOPE_DIM_RT + rope_offs) + q_rope = q_rope.to(tl.float32) * sm_scale + + # ---- Valid token count ---- + valid_topk = topk + if has_topk_len: + valid_topk = tl.load(topk_len_ptr + bid).to(tl.int32) + valid_topk = tl.minimum(valid_topk, topk) + + # ---- Online softmax state (base-2 math for SM120 efficiency) ---- + m_i: tl.float32 = -1e30 + l_i: tl.float32 = 0.0 + acc_nope = tl.zeros([NOPE_PAD], dtype=tl.float32) + acc_rope = tl.zeros([ROPE_DIM], dtype=tl.float32) + + # ---- Precompute constant index vectors ---- + group_ids = (nope_offs // 64).to(tl.int64) # [NOPE_PAD], scale group for each dim + t_offs = tl.arange(0, BLOCK_T) # [BLOCK_T], token offsets within tile + + # ---- Process tokens in tiles of BLOCK_T ---- + for tile_start in range(0, topk, BLOCK_T): + t_idx = tile_start + t_offs # [BLOCK_T], global token indices + t_in_bounds = t_idx < topk # bounds for index load + t_valid = t_idx < valid_topk # bounds for actual processing + + # Load indices for this tile: [BLOCK_T] + raw_indices = tl.load( + indices_ptr + bid * stride_ib + t_idx, + mask=t_in_bounds, + other=-1, + ) + idx_valid = t_valid & (raw_indices >= 0) # [BLOCK_T] mask + + # Page addressing: [BLOCK_T] (clamp for safe addressing of invalid tokens) + safe_indices = tl.where(idx_valid, raw_indices, tl.zeros_like(raw_indices)) + page_ids = (safe_indices // page_size).to(tl.int64) + page_offs_t = (safe_indices % page_size).to(tl.int64) + token_data_bases = page_ids * page_bytes + page_offs_t * 576 # [BLOCK_T] int64 + + # ---- Vectorized NOPE FP8 gather: [BLOCK_T, NOPE_PAD] ---- + nope_addrs = token_data_bases[:, None] + nope_offs[None, :].to(tl.int64) + nope_2d_mask = idx_valid[:, None] & nope_mask[None, :] + kv_nope_fp8 = tl.load( + cache_fp8_ptr + nope_addrs, + mask=nope_2d_mask, + other=0.0, + ) + + # ---- Vectorized scale gather + dequant: [BLOCK_T, NOPE_PAD] ---- + scale_bases = page_ids * page_bytes + scale_section_off + page_offs_t * 8 + scale_addrs = scale_bases[:, None] + group_ids[None, :] + scale_raw = tl.load( + cache_uint8_ptr + scale_addrs, + mask=nope_2d_mask, + other=127, + ) + scale_f32 = tl.math.exp2(scale_raw.to(tl.float32) - 127.0) + kv_nope = tl.where(nope_2d_mask, kv_nope_fp8.to(tl.float32) * scale_f32, 0.0) + + # ---- Vectorized ROPE BF16 gather: [BLOCK_T, ROPE_DIM] ---- + rope_byte_bases = token_data_bases + 448 + rope_elem_bases = (rope_byte_bases // 2).to(tl.int64) + rope_addrs = rope_elem_bases[:, None] + rope_offs[None, :].to(tl.int64) + kv_rope = tl.load( + cache_bf16_ptr + rope_addrs, + mask=idx_valid[:, None], + other=0.0, + ).to(tl.float32) + + # ---- Vectorized QK scores: [BLOCK_T] ---- + # scores[t] = dot(q_nope, kv_nope[t]) + dot(q_rope, kv_rope[t]) + scores = tl.sum(q_nope[None, :] * kv_nope, axis=1) + tl.sum( + q_rope[None, :] * kv_rope, axis=1 + ) + scores = tl.where(idx_valid, scores, -1e30) + + # ---- Online softmax update (base-2, tile-level) ---- + scores_log2 = scores * LOG2E # [BLOCK_T] + tile_max = tl.max(scores_log2) # scalar + m_new = tl.maximum(m_i, tile_max) + + alpha = tl.math.exp2(m_i - m_new) # rescale factor + p = tl.math.exp2(scores_log2 - m_new) # [BLOCK_T] attention weights + p = tl.where(idx_valid, p, 0.0) # zero out invalid + + l_i = l_i * alpha + tl.sum(p) + + # ---- Vectorized V accumulation (K=V in MLA) ---- + # acc += sum_t(p[t] * kv[t, :]) for both nope and rope + acc_nope = acc_nope * alpha + tl.sum(p[:, None] * kv_nope, axis=0) + acc_rope = acc_rope * alpha + tl.sum(p[:, None] * kv_rope, axis=0) + m_i = m_new + + # ---- Normalize output ---- + safe_l = tl.where(l_i > 0.0, l_i, 1.0) + acc_nope = acc_nope / safe_l + acc_rope = acc_rope / safe_l + + # LSE: convert from log2 back to natural log + lse = tl.where(l_i > 0.0, m_i / LOG2E + tl.math.log(safe_l), float("-inf")) + + # ---- Store output ---- + o_base = bid * stride_ob + hid * stride_oh + tl.store(O_ptr + o_base + nope_offs, acc_nope.to(tl.bfloat16), mask=nope_mask) + tl.store(O_ptr + o_base + NOPE_DIM_RT + rope_offs, acc_rope.to(tl.bfloat16)) + tl.store(LSE_ptr + bid * H + hid, lse) + + +def _run_triton_sparse_decode( + q: torch.Tensor, # [B, 1, H, D] bf16 + k_cache: torch.Tensor, # [num_pages, page_size, 1, bpt] float8 + indices: torch.Tensor, # [B, ...] int32 + topk_length: Optional[torch.Tensor], + softmax_scale: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Run the tiled Triton sparse decode kernel on one paged KV cache.""" + B, _, H, D = q.shape + num_pages = k_cache.shape[0] + page_size = k_cache.shape[1] + page_bytes = k_cache.stride(0) # elements = bytes for float8 + + # Flatten indices to [B, topk] + flat_indices = indices.reshape(B, -1).contiguous() + topk = flat_indices.shape[1] + + # Create three typed views of the flat cache memory + total_elems = num_pages * page_bytes + raw_fp8 = k_cache.as_strided((total_elems,), (1,)) + raw_uint8 = raw_fp8.view(torch.uint8) + raw_bf16 = raw_uint8.view(torch.bfloat16) + + # Squeeze Q: [B, H, D] + q3 = q.squeeze(1) + if not q3.is_contiguous(): + q3 = q3.contiguous() + + out = torch.zeros(B, H, D, dtype=torch.bfloat16, device=q.device) + lse = torch.full((B, H), float("-inf"), dtype=torch.float32, device=q.device) + + # Round topk for autotune key stability + topk_rounded = triton.next_power_of_2(topk) + + grid = (B, H) + _tiled_sparse_decode_kernel[grid]( + q3, + raw_fp8, + raw_uint8, + raw_bf16, + flat_indices, + ( + topk_length + if topk_length is not None + else torch.empty(0, device=q.device, dtype=torch.int32) + ), + out, + lse, + softmax_scale, + page_size, + int(page_bytes), # page_bytes (int64) + int(page_size * _TOKEN_DATA_STRIDE), # scale_section_off (int64) + H, + topk, + topk_rounded, + topk_length is not None, + q3.stride(0), + q3.stride(1), + out.stride(0), + out.stride(1), + flat_indices.stride(0), + NOPE_PAD=512, + ROPE_DIM=_ROPE_DIM, + NOPE_DIM_RT=_NOPE_DIM, + ) + + # Return [B, 1, H, D] and [B, 1, H] + return out.unsqueeze(1), lse.unsqueeze(1) + + +def _merge_partial_attn( + out1: torch.Tensor, + lse1: torch.Tensor, + out2: torch.Tensor, + lse2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Merge two attention outputs using LSE-weighted combination. + + out: [B, 1, H, D] bf16, lse: [B, 1, H] float32 + """ + max_lse = torch.maximum(lse1, lse2) + w1 = torch.where(lse1 > -1e20, torch.exp(lse1 - max_lse), torch.zeros_like(lse1)) + w2 = torch.where(lse2 > -1e20, torch.exp(lse2 - max_lse), torch.zeros_like(lse2)) + total = (w1 + w2).clamp(min=1e-20) + merged = ( + w1.unsqueeze(-1) * out1.float() + w2.unsqueeze(-1) * out2.float() + ) / total.unsqueeze(-1) + merged_lse = max_lse + torch.log(total) + return merged.to(torch.bfloat16), merged_lse + + +def _apply_attn_sink( + out: torch.Tensor, + lse: torch.Tensor, + attn_sink: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply attention sink normalization. + + The sink adds to the softmax denominator without contributing output, + effectively down-weighting all attention scores. + + out: [B, 1, H, D] bf16, lse: [B, 1, H] f32, attn_sink: [H] f32 + """ + sink_lse = attn_sink.view(1, 1, -1).expand_as(lse) + combined_lse = torch.logaddexp(lse, sink_lse) + w = torch.where( + lse > -1e20, + torch.exp(lse - combined_lse), + torch.zeros_like(lse), + ) + return (out.float() * w.unsqueeze(-1)).to(torch.bfloat16), combined_lse + + +def flash_mla_sparse_decode_triton( + q: torch.Tensor, + k_cache: torch.Tensor, + indices: torch.Tensor, + topk_length: Optional[torch.Tensor], + attn_sink: Optional[torch.Tensor], + head_dim_v: int, + softmax_scale: float, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """SM120-optimized sparse MLA decode using tiled Triton kernel. + + Processes SWA and extra (c4/c128) caches separately via the same + Triton kernel, then merges results using LSE-weighted combination. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + # Process main cache (SWA) + out, lse = _run_triton_sparse_decode( + q, + k_cache, + indices, + topk_length, + softmax_scale, + ) + + # Process extra cache (c4 / c128) if present + if extra_k_cache is not None and extra_indices is not None: + out_extra, lse_extra = _run_triton_sparse_decode( + q, + extra_k_cache, + extra_indices, + extra_topk_length, + softmax_scale, + ) + out, lse = _merge_partial_attn(out, lse, out_extra, lse_extra) + + # Apply attention sink + if attn_sink is not None: + out, lse = _apply_attn_sink(out, lse, attn_sink) + + # Return format matching PyTorch fallback: (out, lse.permute(0,2,1)) + return out, lse.permute(0, 2, 1) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 854c76eccb38..e9ee2a8527ce 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -23,6 +23,7 @@ add_prefix, ceil_align, get_bool_env_var, + get_device_sm, is_cuda, is_gfx95_supported, is_hip, @@ -32,6 +33,7 @@ global _use_multi_stream _is_cuda = is_cuda() _is_hip = is_hip() +_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 # SM120/SM121 _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_fp8_fnuz = is_fp8_fnuz() @@ -39,9 +41,34 @@ if _is_cuda: try: import deep_gemm - except ImportError as e: + except (ImportError, AssertionError) as e: + # AssertionError: deep_gemm init fails on SM120 (no CUDA_HOME / unsupported arch) deep_gemm = e +if _is_sm120: + import os as _os + + if _os.environ.get("SGLANG_SM120_MQA_FALLBACK", "0") == "1": + from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + compute_paged_mqa_schedule_metadata as _sm120_compute_paged_mqa_schedule_metadata, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + sm120_fp8_mqa_logits as _sm120_fp8_mqa_logits, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + sm120_fp8_paged_mqa_logits as _sm120_fp8_paged_mqa_logits, + ) + else: + from sglang.srt.layers.attention.nsa.sm120_mqa_triton import ( + compute_paged_mqa_schedule_metadata as _sm120_compute_paged_mqa_schedule_metadata, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_triton import ( + sm120_fp8_mqa_logits as _sm120_fp8_mqa_logits, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_triton import ( + sm120_fp8_paged_mqa_logits as _sm120_fp8_paged_mqa_logits, + ) + if _use_aiter: from aiter.ops.cache import indexer_k_quant_and_cache @@ -198,7 +225,12 @@ def __init__( self.cp_size = None self.cp_rank = None if _is_cuda: - self.sm_count = deep_gemm.get_num_sms() + if _is_sm120: + # SM120: deep_gemm.get_num_sms() crashes; use torch native API + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.sm_count = props.multi_processor_count + else: + self.sm_count = deep_gemm.get_num_sms() self.half_device_sm_count = ceil_align(self.sm_count // 2, 8) pp_size = get_global_server_args().pp_size self.logits_with_pp_recv = pp_size > 1 and not get_pp_group().is_last_rank @@ -249,7 +281,7 @@ def _with_real_sm_count(self): # request to receive the PP proxy tensor or output from the previous stage, occupying one SM resource. # Model execution runs in parallel with the recv operation, so the SMs available to the indexer must be reduced # by 1. Currently, the last rank starts the send result + recv request only after waiting for execution results. - if self.logits_with_pp_recv: + if self.logits_with_pp_recv and not _is_sm120: pp_recv_sm_count = 1 with deep_gemm_wrapper.configure_deep_gemm_num_sms( self.sm_count - pp_recv_sm_count @@ -464,9 +496,16 @@ def _get_topk_paged( seqlens_32_2d = seqlens_32.unsqueeze(-1) if _is_cuda: if schedule_metadata is None: - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( - seqlens_32_2d, blocksize, self.sm_count - ) + if _is_sm120: + schedule_metadata = _sm120_compute_paged_mqa_schedule_metadata( + seqlens_32_2d, + blocksize, + self.sm_count, + ) + else: + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( + seqlens_32_2d, blocksize, self.sm_count + ) assert len(q_fp8.shape) == 3 q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now @@ -509,6 +548,17 @@ def _get_topk_paged( Preshuffle=False, KVBlockSize=block_kv, ) + elif _is_sm120: + logits = _sm120_fp8_paged_mqa_logits( + q_fp8[:q_offset], + kv_cache_fp8, + weights[:q_offset], + seqlens_32_2d, + block_tables, + schedule_metadata, + max_seq_len, + clean_logits=False, + ) else: logits = deep_gemm.fp8_paged_mqa_logits( q_fp8[:q_offset], @@ -638,6 +688,15 @@ def _get_topk_ragged( logits = fp8_mqa_logits( q_fp8[:q_offset], kv, scale, weights[:q_offset], ks, ke ) + elif _is_sm120: + logits = _sm120_fp8_mqa_logits( + q_fp8[:q_offset], + kv_fp8, + weights[:q_offset], + ks, + ke, + clean_logits=False, + ) else: logits = deep_gemm.fp8_mqa_logits( q_fp8[:q_offset], @@ -688,6 +747,15 @@ def _get_topk_ragged( ks[start:end], ke[start:end], ) + elif _is_sm120: + logits_chunk = _sm120_fp8_mqa_logits( + q_fp8[start:end], + kv_fp8, + weights[start:end], + ks[start:end], + ke[start:end], + clean_logits=False, + ) else: logits_chunk = deep_gemm.fp8_mqa_logits( q_fp8[start:end], diff --git a/python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py b/python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py new file mode 100644 index 000000000000..e696b5b9ccf8 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py @@ -0,0 +1,215 @@ +""" +SM120 fallback kernels for DeepGEMM FP8 MQA logits operations. + +On SM120 (RTX 5090, RTX PRO 6000, DGX Spark), DeepGEMM's fp8_paged_mqa_logits +and fp8_mqa_logits crash with 'Unsupported architecture'. This module provides +PyTorch-native fallback implementations that match the DeepGEMM API contract. + +Reference: vLLM PR#40991 (Triton sparse MLA fallback approach for SM120) +""" + +from __future__ import annotations + +import logging +from typing import Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def compute_paged_mqa_schedule_metadata( + seqlens: torch.Tensor, + block_size: int, + num_sms: int, +) -> None: + """SM120 fallback: scheduling is handled internally, return None.""" + return None + + +def _dequant_fp8_with_scale_suffix( + data_fp8: torch.Tensor, head_dim_qk: int +) -> torch.Tensor: + """ + Dequantize FP8 tensor that has per-row scale factors appended. + + DeepGEMM packs KV cache as [data_fp8 (head_dim_qk bytes) | scale (4 bytes)] + in a tensor of shape [..., head_dim_with_sf] where head_dim_with_sf = head_dim_qk + 4. + The scale is stored as a float32 value in the last 4 bytes. + """ + # Split data and scale + data_bytes = data_fp8[..., :head_dim_qk] + # Scale is stored in the last 4 bytes, reinterpret as float32 + scale_bytes = data_fp8[..., head_dim_qk:] + scale = scale_bytes.contiguous().view(torch.float32) # [..., 1] + + # Dequantize: cast FP8 to float32, multiply by scale + data_f32 = data_bytes.to(torch.float32) * scale + return data_f32 + + +def sm120_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + seqlens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata, + max_seq_len: int, + clean_logits: bool = False, +) -> torch.Tensor: + """ + SM120 fallback for deep_gemm.fp8_paged_mqa_logits(). + + Computes weighted multi-head dot-product logits over paged KV cache. + + Args: + q_fp8: [batch, next_n, n_heads, head_dim_with_sf] FP8 queries with appended scale + kv_cache_fp8: [num_blocks, block_kv, 1, head_dim_with_sf] FP8 paged KV cache + weights: [batch, n_heads] float32 head weights + seqlens: [batch, 1] or [batch] int32 sequence lengths + block_tables: [batch, max_blocks] int32 block table indices + schedule_metadata: ignored on SM120 (None) + max_seq_len: maximum sequence length for output + clean_logits: if True, fill unused positions with -inf + + Returns: + logits: [batch * next_n, max_seq_len] float32 + """ + batch, next_n, n_heads, head_dim_with_sf = q_fp8.shape + head_dim_qk = head_dim_with_sf - 4 # 128 typically + block_kv = kv_cache_fp8.shape[1] # typically 64 + device = q_fp8.device + + # Flatten seqlens + if seqlens.dim() == 2: + seqlens = seqlens.squeeze(-1) + + # Output logits + out = torch.full( + (batch * next_n, max_seq_len), + float("-inf"), + device=device, + dtype=torch.float32, + ) + + # Dequantize queries: [batch, next_n, n_heads, head_dim_qk] + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, head_dim_qk) + + for b in range(batch): + seq_len = seqlens[b].item() + if seq_len <= 0: + continue + + num_blocks_needed = (seq_len + block_kv - 1) // block_kv + + # Gather KV blocks for this batch element + block_ids = block_tables[b, :num_blocks_needed] + # [num_blocks_needed, block_kv, 1, head_dim_with_sf] + kv_blocks = kv_cache_fp8[block_ids] + # Flatten to [num_blocks_needed * block_kv, head_dim_with_sf] + kv_flat = kv_blocks.view(-1, head_dim_with_sf) + # Trim to actual sequence length + kv_flat = kv_flat[:seq_len] + + # Dequantize KV: [seq_len, head_dim_qk] + k_f32 = _dequant_fp8_with_scale_suffix(kv_flat.unsqueeze(-2), head_dim_qk) + k_f32 = k_f32.squeeze(-2) # [seq_len, head_dim_qk] + + # Vectorized over next_n: + # q_b: [next_n, n_heads, head_dim_qk] + q_b = q_f32[b] + # dots: [next_n, n_heads, seq_len] + dots = torch.einsum("tnd,sd->tns", q_b, k_f32) + # Apply head weights: [n_heads] -> weighted sum -> [next_n, seq_len] + w = weights[b] # [n_heads] + logits_b = torch.einsum("tns,n->ts", dots, w) # [next_n, seq_len] + out_start = b * next_n + out[out_start : out_start + next_n, :seq_len] = logits_b + + return out + + +def sm120_fp8_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: Tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + ks: torch.Tensor, + ke: torch.Tensor, + clean_logits: bool = False, +) -> torch.Tensor: + """ + SM120 fallback for deep_gemm.fp8_mqa_logits() (contiguous/ragged variant). + + Computes weighted multi-head dot-product logits over contiguous KV. + + Args: + q_fp8: [num_q, n_heads, head_dim_with_sf] FP8 queries with appended scale + kv_fp8: tuple of (k_data_fp8 [num_k, head_dim_with_sf], k_scale [num_k]) or + (k_data_fp8 [num_k, D], k_scale [num_k, scale_dim]) + weights: [num_q, n_heads] float32 head weights + ks: [num_q] int32 start indices into KV + ke: [num_q] int32 end indices into KV + + Returns: + logits: [num_q, num_k] float32 where num_k = max(ke) - min(ks) (or ke.max()) + """ + num_q, n_heads, head_dim_with_sf = q_fp8.shape + head_dim_qk = head_dim_with_sf - 4 + device = q_fp8.device + + k_data, k_scale = kv_fp8 + num_k = k_data.shape[0] + + # Determine output width + k_max = ke.max().item() if ke.numel() > 0 else 0 + out_width = max(k_max, num_k) + + # Output logits + out = torch.full( + (num_q, out_width), + float("-inf"), + device=device, + dtype=torch.float32, + ) + + if num_q == 0 or num_k == 0: + return out + + # Dequantize queries: [num_q, n_heads, head_dim_qk] + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, head_dim_qk) + + # Dequantize KV keys + if k_data.shape[-1] == head_dim_with_sf: + # Keys have appended scale suffix + k_f32 = _dequant_fp8_with_scale_suffix(k_data.unsqueeze(-2), head_dim_qk) + k_f32 = k_f32.squeeze(-2) # [num_k, head_dim_qk] + else: + # Keys and scales are separate + k_f32 = k_data.to(torch.float32) + if k_scale.dim() == 1: + k_f32 = k_f32 * k_scale.unsqueeze(-1) + else: + k_f32 = k_f32 * k_scale + + # Vectorized: compute all dot products at once + # q_f32: [num_q, n_heads, head_dim_qk], k_f32: [num_k, head_dim_qk] + # dots: [num_q, n_heads, num_k] + dots = torch.einsum("qhd,kd->qhk", q_f32, k_f32) + + # Apply head weights: [num_q, n_heads] -> [num_q, n_heads, 1] + w = weights.unsqueeze(-1) + # Weighted sum across heads: [num_q, num_k] + logits_all = (dots * w).sum(dim=1) + + # Mask to [ks, ke) ranges + k_indices = torch.arange(out_width, device=device).unsqueeze(0) # [1, out_width] + ks_expanded = ks.unsqueeze(1) # [num_q, 1] + ke_expanded = ke.unsqueeze(1) # [num_q, 1] + mask = (k_indices >= ks_expanded) & (k_indices < ke_expanded) # [num_q, out_width] + + # Place logits into output at valid positions + # logits_all is [num_q, num_k], but output is [num_q, out_width] + out[:, :num_k] = torch.where(mask[:, :num_k], logits_all, out[:, :num_k]) + + return out diff --git a/python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py b/python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py new file mode 100644 index 000000000000..4a19e8ea6392 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py @@ -0,0 +1,177 @@ +"""SM120-optimized MQA logits — CUDA graph compatible. + +Replaces the PyTorch fallback in sm120_mqa_fallback.py with an optimized +implementation that precomputes the head-weighted query vector before +scanning the KV cache, reducing per-position work from O(n_heads) to O(1). + +Key insight: logit[s] = sum_h(w[h] * dot(q[h], kv[s])) + = dot(sum_h(w[h] * q[h]), kv[s]) + = dot(wq, kv[s]) + +CUDA graph compatibility: +- No .item() calls — all computation stays on GPU tensors +- No per-batch Python loops — vectorized with torch.bmm +- Fixed tensor shapes derived from known parameters (max_seq_len, num_k) + +Target: RTX PRO 6000 (SM120, 188 SMs, 99KB SMEM, ~1.5 TB/s GDDR7) +""" + +import logging +from typing import Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def _dequant_fp8_with_scale_suffix( + data_fp8: torch.Tensor, + head_dim_qk: int, +) -> torch.Tensor: + """Dequantize FP8 tensor with appended float32 scale suffix.""" + data_bytes = data_fp8[..., :head_dim_qk] + scale_bytes = data_fp8[..., head_dim_qk:] + scale = scale_bytes.contiguous().view(torch.float32) + return data_bytes.to(torch.float32) * scale + + +def compute_paged_mqa_schedule_metadata( + seqlens: torch.Tensor, + block_size: int, + num_sms: int, +) -> None: + """SM120 fallback: scheduling is handled internally, return None.""" + return None + + +def sm120_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + seqlens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata, + max_seq_len: int, + clean_logits: bool = False, +) -> torch.Tensor: + """CUDA-graph-compatible paged MQA logits for SM120. + + Key optimizations vs fallback: + 1. Precompute wq = sum_h(w[h] * dequant(q[h])) — eliminates per-position head loop + 2. Batched matmul across all batch elements — no per-batch Python loop + 3. No .item() calls — all shapes derived from known parameters + """ + batch, next_n, n_heads, hd_with_sf = q_fp8.shape + hd = hd_with_sf - 4 + block_kv = kv_cache_fp8.shape[1] + device = q_fp8.device + + seqlens_flat = seqlens.view(-1).to(torch.int64) + + # Dequant Q: [batch, next_n, n_heads, hd] + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, hd) + + # Precompute wq = sum_h(w[b,h] * q[b,t,h,:]) → [batch, next_n, hd] + w = weights.view(batch, 1, n_heads, 1) + wq = (q_f32 * w).sum(dim=2) # [batch, next_n, hd] + + # Batch-dequant all KV blocks: [num_blocks, block_kv, hd] + kv_data = kv_cache_fp8[..., :hd].squeeze(2) + kv_scale_raw = kv_cache_fp8[..., hd:].squeeze(2) + kv_scale = kv_scale_raw.contiguous().view(torch.float32) + kv_f32 = kv_data.float() * kv_scale # [num_blocks_total, block_kv, hd] + + # ── Vectorized batch gather (no per-batch loop, no .item()) ── + max_blocks = (max_seq_len + block_kv - 1) // block_kv + # Gather block IDs for all batches: [batch, max_blocks] + block_ids = block_tables[:, :max_blocks] + + # Gather KV for all batches: [batch, max_blocks, block_kv, hd] + kv_batched = kv_f32[block_ids] + max_padded = max_blocks * block_kv + kv_flat = kv_batched.reshape(batch, max_padded, hd) + + # Batched matmul: [batch, next_n, hd] @ [batch, hd, max_padded] + logits_batched = torch.bmm( + wq, kv_flat.transpose(1, 2) + ) # [batch, next_n, max_padded] + + # Create validity mask: [batch, max_padded] + positions = torch.arange(max_padded, device=device) + valid = positions.unsqueeze(0) < seqlens_flat.unsqueeze(1) # [batch, max_padded] + + # Apply mask (broadcast over next_n) + logits_batched = logits_batched.masked_fill(~valid.unsqueeze(1), float("-inf")) + + # Write to output: [batch * next_n, max_seq_len] + out_width = min(max_padded, max_seq_len) + out = torch.full( + (batch * next_n, max_seq_len), + float("-inf"), + device=device, + dtype=torch.float32, + ) + out[:, :out_width] = logits_batched[:, :, :out_width].reshape( + batch * next_n, out_width + ) + + return out + + +def sm120_fp8_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: Tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + ks: torch.Tensor, + ke: torch.Tensor, + clean_logits: bool = False, +) -> torch.Tensor: + """CUDA-graph-compatible ragged MQA logits for SM120. + + Key optimization: precompute wq = sum_h(w[h] * q[h]), then single matmul. + No .item() calls — uses num_k for output width. + """ + num_q, n_heads, hd_with_sf = q_fp8.shape + hd = hd_with_sf - 4 + device = q_fp8.device + + k_data, k_scale = kv_fp8 + num_k = k_data.shape[0] + + # Use num_k as output width — avoids ke.max().item() GPU-CPU sync + out_width = num_k + + out = torch.full( + (num_q, out_width), + float("-inf"), + device=device, + dtype=torch.float32, + ) + + if num_q == 0 or num_k == 0: + return out + + # Dequant Q and precompute weighted query + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, hd) + w = weights.unsqueeze(-1) + wq = (q_f32 * w).sum(dim=1) # [num_q, hd] + + # Dequant KV + if k_data.shape[-1] == hd_with_sf: + k_f32 = _dequant_fp8_with_scale_suffix(k_data.unsqueeze(-2), hd).squeeze(-2) + else: + k_f32 = k_data.float() + if k_scale.dim() == 1: + k_f32 = k_f32 * k_scale.unsqueeze(-1) + else: + k_f32 = k_f32 * k_scale + + # Single matmul: [num_q, hd] @ [hd, num_k] → [num_q, num_k] + logits_all = wq @ k_f32.T + + # Apply ragged [ks, ke) masking + k_indices = torch.arange(out_width, device=device).unsqueeze(0) + mask = (k_indices >= ks.unsqueeze(1)) & (k_indices < ke.unsqueeze(1)) + out[:, :num_k] = torch.where(mask[:, :num_k], logits_all, out[:, :num_k]) + + return out diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index a938f0b01e22..9b5c2b21a8e6 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -36,7 +36,9 @@ ) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import get_device_sm, is_cuda, is_hip + +_is_sm120 = is_cuda() and get_device_sm() // 10 == 12 if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -641,10 +643,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): paged_mqa_schedule_metadata = None # DeepGEMM paged MQA logits path needs a schedule metadata tensor. # Compute it once per forward batch and reuse it across layers. - if is_cuda() and ( - forward_batch.forward_mode.is_decode_or_idle() - or forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend(include_v2=True) + if ( + is_cuda() + and not _is_sm120 + and ( + forward_batch.forward_mode.is_decode_or_idle() + or forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) + ) ): try: import deep_gemm @@ -930,10 +936,14 @@ def init_forward_metadata_capture_cuda_graph( real_page_table = self._transform_table_1_to_real(page_table_1) paged_mqa_schedule_metadata = None - if is_cuda() and ( - forward_mode.is_decode_or_idle() - or forward_mode.is_target_verify() - or forward_mode.is_draft_extend(include_v2=True) + if ( + is_cuda() + and not _is_sm120 + and ( + forward_mode.is_decode_or_idle() + or forward_mode.is_target_verify() + or forward_mode.is_draft_extend(include_v2=True) + ) ): try: import deep_gemm @@ -1081,10 +1091,14 @@ def init_forward_metadata_replay_cuda_graph( ) # Update DeepGEMM paged MQA schedule metadata outside the captured graph. - if is_cuda() and ( - forward_mode.is_decode_or_idle() - or forward_mode.is_target_verify() - or forward_mode.is_draft_extend(include_v2=True) + if ( + is_cuda() + and not _is_sm120 + and ( + forward_mode.is_decode_or_idle() + or forward_mode.is_target_verify() + or forward_mode.is_draft_extend(include_v2=True) + ) ): try: import deep_gemm diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py index de433fe2d505..9933b85e3dba 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py @@ -18,12 +18,16 @@ def _compute_enable_deep_gemm(): sm_version = get_device_sm() if (_is_cuda and sm_version < 90) or (_is_musa and sm_version < 31): return False + # DeepGEMM requires TMEM/tcgen05 (SM100+datacenter), not available on SM120 + if sm_version // 10 == 12: + return False if not (_is_cuda or _is_musa): return False try: import deep_gemm # noqa: F401 - except ImportError: + except (ImportError, AssertionError): + # AssertionError: deep_gemm init may fail on unsupported architectures return False return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py new file mode 100644 index 000000000000..31ccff929ea7 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py @@ -0,0 +1,185 @@ +"""PyTorch fallback for MXFP4 MoE GEMM on SM120. + +The Marlin MXFP4 kernel produces NaN on SM120 (Blackwell Desktop). +This module provides a pure-PyTorch implementation that dequantizes +MXFP4 weights (packed int8 + float8_e8m0fnu scales) to BF16 and uses +torch.matmul for the GEMM, per active expert. + +Slow but functionally correct — matches the FlashMLA fallback pattern. +""" + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# ── FP4 E2M1 lookup table ────────────────────────────────────────── +# Nibble encoding: bit3=sign, bit2-1=exponent (bias=1), bit0=mantissa +# 16 possible values for 4-bit float +_FP4_E2M1_LUT = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, # positive (0x0-0x7) + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, # negative (0x8-0xF) + ], + dtype=torch.float32, +) + + +def _dequant_mxfp4_weight( + packed: torch.Tensor, + scales: torch.Tensor, + unpacked_k: int, +) -> torch.Tensor: + """Dequantize one expert's MXFP4 weight from packed int8 to bfloat16. + + Args: + packed: [N, K//2] int8 — 2 FP4 values per byte (low nibble=even, high=odd) + scales: [N, K//32] float32 — dequantization scale per group of 32 elements + unpacked_k: K, the full unpacked dimension + + Returns: + [N, K] bfloat16 weight matrix + """ + device = packed.device + lut = _FP4_E2M1_LUT.to(device=device) + + # View as unsigned bytes for bit manipulation + u8 = packed.view(torch.uint8).to(torch.int32) + low = u8 & 0x0F # even-index elements + high = (u8 >> 4) & 0x0F # odd-index elements + + # Lookup FP4 → float32 + vals_low = lut[low.long()] # [N, K//2] + vals_high = lut[high.long()] # [N, K//2] + + # Interleave: [low0, high0, low1, high1, ...] + unpacked = torch.stack([vals_low, vals_high], dim=-1) # [N, K//2, 2] + unpacked = unpacked.reshape(packed.shape[0], -1) # [N, K] + unpacked = unpacked[:, :unpacked_k] # trim if needed + + # Apply group scales (group_size=32) + # scales: [N, K//32] — each scale covers 32 consecutive elements along K + if scales.dtype == torch.float8_e8m0fnu: + scales_f32 = scales.to(torch.float32) + else: + scales_f32 = scales.float() + scales_expanded = scales_f32.repeat_interleave(32, dim=-1)[:, :unpacked_k] + + result = unpacked * scales_expanded + return result.to(torch.bfloat16) + + +def mxfp4_moe_forward_fallback( + hidden_states: torch.Tensor, + w13_packed: torch.Tensor, + w2_packed: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_size: int, + intermediate_size: int, + inplace: bool = False, + routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, +) -> torch.Tensor: + """Pure-PyTorch MXFP4 MoE forward pass. + + Args: + hidden_states: [M, K] bfloat16 input activations + w13_packed: [E, 2*I, K//2] int8 packed gate_up_proj weights + w2_packed: [E, K, I//2] int8 packed down_proj weights + w13_scale: [E, 2*I, K//32] scales for gate_up_proj + w2_scale: [E, K, I//32] scales for down_proj + topk_ids: [M, topk] int32 expert assignments + topk_weights: [M, topk] float32 routing weights + hidden_size: K + intermediate_size: I (per partition) + inplace: whether to write output in-place + routed_scaling_factor: optional global scaling factor + clamp_limit: optional SwiGLU clamp limit (2604B submode) + + Returns: + [M, K] bfloat16 output tensor + """ + M, K = hidden_states.shape + topk = topk_ids.shape[1] + device = hidden_states.device + dtype = hidden_states.dtype + I = intermediate_size + + output = hidden_states if inplace else torch.zeros(M, K, dtype=dtype, device=device) + if not inplace: + output.zero_() + + # Find all active experts + active_experts = topk_ids.unique() + + for eid in active_experts: + eid_val = eid.item() + if eid_val < 0: + continue + + # Find (token_idx, slot_idx) pairs assigned to this expert + mask = topk_ids == eid_val # [M, topk] + token_mask = mask.any(dim=1) # [M] + token_indices = token_mask.nonzero(as_tuple=True)[0] + + if len(token_indices) == 0: + continue + + # ── GEMM1: gate_up_proj ── + # w13: [2*I, K//2] int8 → dequant → [2*I, K] bf16 + w13_dq = _dequant_mxfp4_weight( + w13_packed[eid_val], w13_scale[eid_val], K + ) # [2*I, K] + + h = hidden_states[token_indices] # [n, K] + # y = h @ W13^T → [n, K] @ [K, 2*I] = [n, 2*I] + intermediate = torch.matmul(h.float(), w13_dq.float().T).to(dtype) + + # ── SiLU + Mul (with optional clamp) ── + gate = intermediate[:, :I] + up = intermediate[:, I:] + if clamp_limit is not None and clamp_limit > 0: + gate = torch.clamp(gate, max=clamp_limit) + up = torch.clamp(up, min=-clamp_limit, max=clamp_limit) + intermediate2 = F.silu(gate) * up # [n, I] + + # ── GEMM2: down_proj ── + # w2: [K, I//2] int8 → dequant → [K, I] bf16 + w2_dq = _dequant_mxfp4_weight( + w2_packed[eid_val], w2_scale[eid_val], I + ) # [K, I] + + # y = intermediate2 @ W2^T → [n, I] @ [I, K] = [n, K] + down = torch.matmul(intermediate2.float(), w2_dq.float().T).to(dtype) + + # ── Accumulate with topk weights (vectorized over topk slots) ── + expert_mask = (topk_ids[token_indices] == eid_val).to(dtype) # [n, topk] + combined_weights = (expert_mask * topk_weights[token_indices].to(dtype)).sum( + dim=1, keepdim=True + ) # [n, 1] + output[token_indices] += down * combined_weights + + if routed_scaling_factor is not None and routed_scaling_factor != 1.0: + output.mul_(routed_scaling_factor) + + return output diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py new file mode 100644 index 000000000000..c48faeeb2f22 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py @@ -0,0 +1,444 @@ +"""SM120-optimized Triton MXFP4 MoE kernel — CUDA graph compatible. + +Replaces the PyTorch fallback (per-expert for-loop + full dequant + matmul) +with fused Triton kernels that: +1. Fuse FP4 dequant + GEMV (no intermediate BF16 weight materialization) +2. Process each (token, expert) slot independently — no data-dependent routing +3. Respect SM120 shared memory constraint (99 KB/block) + +CUDA graph compatibility: +- No .unique(), .item(), .nonzero() — all routing is tensor-level +- Fixed grid dimensions (M*topk, N_blocks) per captured batch size +- All control flow is static or within Triton kernels + +SM120 constraints: +- SMEM: 99 KB/block (vs SM100 228 KB) +- No TMEM/tcgen05 — uses mma.sync.aligned via Triton +- Max warps: 48/SM +- Registers: ~128/thread practical limit +""" + +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def _dequant_fp4_lut(nibble): + """Decode a 4-bit FP4 E2M1 nibble to float32 using arithmetic.""" + sign_bit = (nibble >> 3) & 1 + exp_bits = (nibble >> 1) & 3 + man_bit = nibble & 1 + + is_subnormal = exp_bits == 0 + mantissa = 1.0 + man_bit.to(tl.float32) * 0.5 + exponent = tl.math.exp2((exp_bits - 1).to(tl.float32)) + val = tl.where(is_subnormal, man_bit.to(tl.float32) * 0.5, mantissa * exponent) + val = tl.where(sign_bit != 0, -val, val) + return val + + +# ── Per-slot GEMV kernel: processes one (token, expert) pair ── + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2), + ], + key=["N", "K"], +) +@triton.jit +def _mxfp4_slot_gemv_kernel( + # Pointers + A_ptr, # [M_total, K] bf16 — source rows + B_packed_ptr, # [E, N, K//2] uint8 — packed FP4 expert weights + B_scale_ptr, # [E, N, K//32] float32 — weight scales + C_ptr, # [num_slots, N] bf16 — output + token_ids_ptr, # [num_slots] int32 — which A row for each slot + expert_ids_ptr, # [num_slots] int32 — which expert's B for each slot + # Dimensions + N: tl.int32, + K: tl.int32, + # A strides + stride_am: tl.int32, + # B strides (within an expert) + stride_bn: tl.int32, + stride_bk2: tl.int32, + # B_scale strides (within an expert) + stride_bsn: tl.int32, + stride_bsk32: tl.int32, + # Expert strides (between experts) + expert_b_stride: tl.int64, + expert_s_stride: tl.int64, + # C strides + stride_cm: tl.int32, + # Block sizes + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Per-slot fused MXFP4 dequant + GEMV. + + Grid: (num_slots, cdiv(N, BLOCK_N)) + Each program computes one (token, expert) pair for a BLOCK_N slice of output. + """ + slot_id = tl.program_id(0) + n_block = tl.program_id(1) + + token_id = tl.load(token_ids_ptr + slot_id).to(tl.int64) + expert_id = tl.load(expert_ids_ptr + slot_id).to(tl.int64) + + offs_n = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + # Expert weight base pointers + b_base = expert_id * expert_b_stride + s_base = expert_id * expert_s_stride + a_base = token_id * stride_am + + for k_start in range(0, K, BLOCK_K): + # ── Load packed B: [BLOCK_N, BLOCK_K//2] ── + offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2) + b_mask = n_mask[:, None] & (offs_k2[None, :] < K // 2) + b_packed = tl.load( + B_packed_ptr + + b_base + + offs_n[:, None] * stride_bn + + offs_k2[None, :] * stride_bk2, + mask=b_mask, + other=0, + ) + + # ── FP4 dequant ── + b_u8 = b_packed.to(tl.int32) + val_lo = _dequant_fp4_lut(b_u8 & 0x0F) # even K indices + val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) # odd K indices + + # ── Load and apply scales: [BLOCK_N, BLOCK_K//2] ── + group_ids = tl.arange(0, BLOCK_K // 2) // 16 # 32 values per group, 2 per byte + s_mask = n_mask[:, None] & ((k_start // 32 + group_ids[None, :]) < K // 32) + scales = tl.load( + B_scale_ptr + + s_base + + offs_n[:, None] * stride_bsn + + (k_start // 32 + group_ids[None, :]) * stride_bsk32, + mask=s_mask, + other=1.0, + ) + val_lo = val_lo * scales + val_hi = val_hi * scales + + # ── Load A even/odd: [BLOCK_K//2] each ── + offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2 + offs_k_odd = offs_k_even + 1 + + a_even = tl.load( + A_ptr + a_base + offs_k_even, + mask=offs_k_even < K, + other=0.0, + ).to(tl.float32) + a_odd = tl.load( + A_ptr + a_base + offs_k_odd, + mask=offs_k_odd < K, + other=0.0, + ).to(tl.float32) + + # ── Dot product: acc[n] += sum_k(a_even[k]*B_lo[n,k] + a_odd[k]*B_hi[n,k]) ── + acc += tl.sum(a_even[None, :] * val_lo, axis=1) + acc += tl.sum(a_odd[None, :] * val_hi, axis=1) + + # ── Store output ── + tl.store( + C_ptr + slot_id * stride_cm + offs_n, + acc.to(tl.bfloat16), + mask=n_mask, + ) + + +# ── Legacy per-expert GEMM kernel (kept for benchmarking) ── + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _mxfp4_gemm_kernel( + # Pointers + A_ptr, # [M, K] bf16 activation + B_packed_ptr, # [N, K//2] uint8 packed FP4 + B_scale_ptr, # [N, K//32] float32 scales + C_ptr, # [M, N] bf16 output + # Dimensions + M, + N, + K, + # Strides + stride_am, + stride_ak, + stride_bn, + stride_bk2, + stride_bsn, + stride_bsk32, + stride_cm, + stride_cn, + # Constexprs + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Fused MXFP4 dequant + GEMM: C = A @ dequant(B_packed, B_scale).T""" + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2) + b_mask = (offs_n[:, None] < N) & (offs_k2[None, :] < K // 2) + b_packed = tl.load( + B_packed_ptr + offs_n[:, None] * stride_bn + offs_k2[None, :] * stride_bk2, + mask=b_mask, + other=0, + ) + + b_u8 = b_packed.to(tl.int32) + val_lo = _dequant_fp4_lut(b_u8 & 0x0F) + val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) + + group_ids = tl.arange(0, BLOCK_K // 2) // 16 + scales_per_byte = tl.load( + B_scale_ptr + + offs_n[:, None] * stride_bsn + + (k_start // 32 + group_ids[None, :]) * stride_bsk32, + mask=(offs_n[:, None] < N) + & ((k_start // 32 + group_ids[None, :]) < K // 32), + other=1.0, + ) + val_lo = val_lo * scales_per_byte + val_hi = val_hi * scales_per_byte + + offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2 + offs_k_odd = offs_k_even + 1 + + a_even_mask = (offs_m[:, None] < M) & (offs_k_even[None, :] < K) + a_even = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_k_even[None, :] * stride_ak, + mask=a_even_mask, + other=0.0, + ).to(tl.float32) + + a_odd_mask = (offs_m[:, None] < M) & (offs_k_odd[None, :] < K) + a_odd = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_k_odd[None, :] * stride_ak, + mask=a_odd_mask, + other=0.0, + ).to(tl.float32) + + acc += tl.dot(a_even, tl.trans(val_lo)) + acc += tl.dot(a_odd, tl.trans(val_hi)) + + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store( + C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, + acc.to(tl.bfloat16), + mask=c_mask, + ) + + +def mxfp4_gemm_triton( + A: torch.Tensor, + B_packed: torch.Tensor, + B_scale: torch.Tensor, + K_full: int, +) -> torch.Tensor: + """Triton fused MXFP4 dequant + GEMM: output = A @ dequant(B).T + + Kept for standalone benchmarking. The MoE forward uses the slot kernel. + """ + M = A.shape[0] + N = B_packed.shape[0] + K = K_full + + if B_scale.dtype == torch.float8_e8m0fnu: + B_scale = B_scale.to(torch.float32) + elif B_scale.dtype != torch.float32: + B_scale = B_scale.float() + + C = torch.empty(M, N, dtype=torch.bfloat16, device=A.device) + A = A.contiguous() + B_packed = B_packed.contiguous() + B_scale = B_scale.contiguous() + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + B_u8 = B_packed.view(torch.uint8) + + _mxfp4_gemm_kernel[grid]( + A, + B_u8, + B_scale, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B_u8.stride(0), + B_u8.stride(1), + B_scale.stride(0), + B_scale.stride(1), + C.stride(0), + C.stride(1), + ) + return C + + +def mxfp4_moe_forward_triton( + hidden_states: torch.Tensor, + w13_packed: torch.Tensor, + w2_packed: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_size: int, + intermediate_size: int, + inplace: bool = False, + routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, +) -> torch.Tensor: + """SM120-optimized MXFP4 MoE forward — CUDA graph compatible. + + Uses per-slot GEMV kernels instead of per-expert Python loops. + Each (token, expert) slot is processed independently with a fixed grid, + eliminating .unique()/.item()/.nonzero() that break CUDA graph capture. + """ + import torch.nn.functional as F + + M, K = hidden_states.shape + topk = topk_ids.shape[1] + I = intermediate_size + num_slots = M * topk + device = hidden_states.device + dtype = hidden_states.dtype + + # ── Graph-safe routing: flatten topk assignments ── + # token_ids[slot] = which row of A (original token index) + # expert_ids[slot] = which expert's weights to use + flat_expert_ids = topk_ids.reshape(-1).contiguous() # [M*topk] + token_ids = ( + torch.arange(M, device=device, dtype=torch.int32) + .unsqueeze(1) + .expand(M, topk) + .reshape(-1) + .contiguous() + ) # [M*topk] + + # ── Ensure scales are float32 ── + if w13_scale.dtype != torch.float32: + w13_scale = w13_scale.to(torch.float32) + if w2_scale.dtype != torch.float32: + w2_scale = w2_scale.to(torch.float32) + + # ── GEMM1: gate_up projection ── + # hidden_states[token] @ w13[expert].T → [num_slots, 2*I] + intermediate = torch.empty(num_slots, 2 * I, dtype=dtype, device=device) + + w13_u8 = w13_packed.view(torch.uint8) # [E, 2*I, K//2] + grid1 = lambda meta: (num_slots, triton.cdiv(2 * I, meta["BLOCK_N"])) + + _mxfp4_slot_gemv_kernel[grid1]( + hidden_states, + w13_u8, + w13_scale, + intermediate, + token_ids, + flat_expert_ids, + 2 * I, + K, + hidden_states.stride(0), + w13_u8.stride(1), + w13_u8.stride(2), + w13_scale.stride(1), + w13_scale.stride(2), + w13_u8.stride(0), + w13_scale.stride(0), + intermediate.stride(0), + ) + + # ── SiLU activation (graph-safe vectorized ops) ── + gate = intermediate[:, :I].float() + up = intermediate[:, I:].float() + if clamp_limit is not None and clamp_limit > 0: + gate = torch.clamp(gate, max=clamp_limit) + up = torch.clamp(up, min=-clamp_limit, max=clamp_limit) + activated = (F.silu(gate) * up).to(dtype) + + # ── GEMM2: down projection ── + # activated[slot] @ w2[expert].T → [num_slots, K] + down = torch.empty(num_slots, K, dtype=dtype, device=device) + + # For GEMM2, A is the activated buffer — each slot reads its own row + slot_ids = torch.arange(num_slots, device=device, dtype=torch.int32) + + w2_u8 = w2_packed.view(torch.uint8) # [E, K, I//2] + grid2 = lambda meta: (num_slots, triton.cdiv(K, meta["BLOCK_N"])) + + _mxfp4_slot_gemv_kernel[grid2]( + activated, + w2_u8, + w2_scale, + down, + slot_ids, + flat_expert_ids, + K, + I, + activated.stride(0), + w2_u8.stride(1), + w2_u8.stride(2), + w2_scale.stride(1), + w2_scale.stride(2), + w2_u8.stride(0), + w2_scale.stride(0), + down.stride(0), + ) + + # ── Weighted sum across topk slots (graph-safe) ── + flat_weights = topk_weights.reshape(-1).unsqueeze(1).to(dtype) # [M*topk, 1] + output = (down * flat_weights).view(M, topk, K).sum(dim=1) + + if routed_scaling_factor is not None and routed_scaling_factor != 1.0: + output.mul_(routed_scaling_factor) + + return output diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py index 90a3de66f4aa..2f0d32d89ae9 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py +++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.utils import log_info_on_rank0 -from sglang.srt.utils.common import is_sm90_supported +from sglang.srt.utils.common import get_device_sm, is_sm90_supported if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput @@ -63,10 +63,43 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_mega_moe_weights_built", False): return - if not is_sm90_supported(): + _sm = get_device_sm() + if not is_sm90_supported() and _sm // 10 != 12: raise RuntimeError( "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." ) + + # SM120: Skip Marlin repacking, keep original weight format + # for PyTorch dequant fallback (Marlin kernel produces NaN on SM120) + if _sm // 10 == 12: + from torch.nn import Parameter + + log_info_on_rank0( + logger, + f"SM120 detected: using PyTorch MXFP4 MoE fallback " + f"(layer: {self.prefix})...", + ) + # Keep weights in original packed int8 format + # Normalize scales to float32 for direct use in dequant + w13_s = layer.w13_weight_scale_inv.data + w2_s = layer.w2_weight_scale_inv.data + if w13_s.dtype == torch.float8_e8m0fnu: + pass # already in e8m0 format, will convert at runtime + elif w13_s.dtype in (torch.uint8, torch.int8): + layer.w13_weight_scale_inv = Parameter( + w13_s.view(torch.uint8) + .view(torch.float8_e8m0fnu) + .to(torch.float32), + requires_grad=False, + ) + layer.w2_weight_scale_inv = Parameter( + w2_s.view(torch.uint8).view(torch.float8_e8m0fnu).to(torch.float32), + requires_grad=False, + ) + # else: float32 scales are already usable directly + layer._dsv4_mxfp4_backend = "sm120_fallback" + return + if not check_moe_marlin_supports_layer(layer, 32): raise RuntimeError( "Current DeepSeekV4 MoE layer does not satisfy Marlin constraints." @@ -99,6 +132,43 @@ def apply( if not TopKOutputChecker.format_is_standard(topk_output): raise ValueError(f"Unsupported topk output format: {topk_output.format}") + # SM120 fallback: use Triton fused dequant+GEMM (or PyTorch fallback) + if getattr(layer, "_dsv4_mxfp4_backend", None) == "sm120_fallback": + from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( + mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback, + ) + + hidden_states = dispatch_output.hidden_states + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] * 2 + + output = mxfp4_moe_forward_fallback( + hidden_states=hidden_states, + w13_packed=w13, + w2_packed=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + topk_ids=topk_output.topk_ids, + topk_weights=topk_output.topk_weights, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + routed_scaling_factor=( + self.runner.config.routed_scaling_factor + if hasattr(self.runner, "config") + else None + ), + clamp_limit=( + self.runner.config.swiglu_limit + if hasattr(self.runner, "config") + else None + ), + ) + return StandardCombineInput(hidden_states=output) + quant_info = MarlinMoeQuantInfo( w13_qweight=layer.w13_weight, w2_qweight=layer.w2_weight, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f5f266fee1d9..f74f6bc51861 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1686,6 +1686,7 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: from sglang.srt.arg_groups.hisparse_hook import ( apply_hisparse_nsa_backend_defaults, ) + from sglang.srt.utils import is_sm120_supported user_set_prefill = self.nsa_prefill_backend is not None user_set_decode = self.nsa_decode_backend is not None @@ -1699,7 +1700,13 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: self.nsa_prefill_backend = "tilelang" self.nsa_decode_backend = "tilelang" elif kv_cache_dtype == "fp8_e4m3": - if major >= 10: + if is_sm120_supported(): + # SM120: trtllm does not support SM120; use tilelang for both paths. + if not user_set_prefill: + self.nsa_prefill_backend = "tilelang" + if not user_set_decode: + self.nsa_decode_backend = "tilelang" + elif major >= 10: if not user_set_prefill: self.nsa_prefill_backend = "trtllm" if not user_set_decode: @@ -1712,7 +1719,13 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: self.nsa_decode_backend = "flashmla_kv" else: # set prefill/decode backends based on hardware architecture. - if major >= 10: + if is_sm120_supported(): + # SM120: trtllm does not support SM120; use tilelang (portable) + if not user_set_prefill: + self.nsa_prefill_backend = "tilelang" + if not user_set_decode: + self.nsa_decode_backend = "tilelang" + elif major >= 10: if not user_set_prefill: self.nsa_prefill_backend = "flashmla_sparse" if not user_set_decode: @@ -1929,6 +1942,14 @@ def _handle_model_specific_adjustments(self): logger.info( "Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM" ) + elif is_sm120_supported(): + # SM120: DSv4-Flash uses MXFP4 experts; marlin backend dispatches + # to our SM120 Triton fallback in mxfp4_marlin_moe.py + if self.moe_runner_backend == "auto": + self.moe_runner_backend = "marlin" + logger.info( + "Use marlin as MoE runner backend on SM120 for DeepseekV3/V4" + ) elif is_hip(): if not self.enable_dp_attention and self.nnodes == 1: # TODO (Hubert): Put this back later diff --git a/python/sglang/test/test_sm120_mqa_fallback.py b/python/sglang/test/test_sm120_mqa_fallback.py new file mode 100644 index 000000000000..3e3a87c9fd10 --- /dev/null +++ b/python/sglang/test/test_sm120_mqa_fallback.py @@ -0,0 +1,279 @@ +""" +Unit tests for SM120 MQA fallback kernels. + +These tests verify correctness of the PyTorch-native fallback implementations +that replace DeepGEMM's fp8_paged_mqa_logits and fp8_mqa_logits on SM120. + +Run: python -m pytest python/sglang/test/test_sm120_mqa_fallback.py -v +""" + +import pytest +import torch + +from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + _dequant_fp8_with_scale_suffix, + compute_paged_mqa_schedule_metadata, + sm120_fp8_mqa_logits, + sm120_fp8_paged_mqa_logits, +) + + +def _make_fp8_with_scale(data_f32: torch.Tensor) -> torch.Tensor: + """Helper: pack float32 data into FP8 + appended scale suffix format. + + For testing, we use a scale of 1.0 so the FP8 values are the raw values. + The last 4 bytes of each row store the float32 scale. + """ + device = data_f32.device + shape = data_f32.shape + head_dim = shape[-1] + + # Clamp to FP8 E4M3 range + fp8_max = torch.finfo(torch.float8_e4m3fn).max + data_clamped = data_f32.clamp(-fp8_max, fp8_max) + data_fp8 = data_clamped.to(torch.float8_e4m3fn) + + # Scale = 1.0 as float32 -> 4 bytes + scale_val = torch.ones((*shape[:-1], 1), dtype=torch.float32, device=device) + scale_bytes = scale_val.view(torch.float8_e4m3fn) # reinterpret as 4 fp8 bytes + + # Concatenate: [data_fp8 | scale_bytes] + result = torch.cat([data_fp8, scale_bytes], dim=-1) + return result + + +class TestDequantFP8: + def test_roundtrip(self): + """Dequantized values should approximately match original float32.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + data = torch.randn(4, 128, device=device) + packed = _make_fp8_with_scale(data) + recovered = _dequant_fp8_with_scale_suffix(packed.unsqueeze(-2), 128) + recovered = recovered.squeeze(-2) + # FP8 E4M3 has limited precision, allow some tolerance + torch.testing.assert_close(recovered, data, atol=0.2, rtol=0.1) + + def test_scale_applied(self): + """Non-unity scale should be applied correctly.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + head_dim = 128 + data = torch.ones(2, head_dim, device=device) * 0.5 + data_fp8 = data.to(torch.float8_e4m3fn) + + # Scale = 2.0 + scale = torch.full((2, 1), 2.0, dtype=torch.float32, device=device) + scale_bytes = scale.view(torch.float8_e4m3fn) + packed = torch.cat([data_fp8, scale_bytes], dim=-1) + + result = _dequant_fp8_with_scale_suffix(packed.unsqueeze(-2), head_dim) + result = result.squeeze(-2) + expected = data.float() * 2.0 + torch.testing.assert_close(result, expected, atol=0.1, rtol=0.05) + + +class TestScheduleMetadata: + def test_returns_none(self): + """SM120 schedule metadata is always None (scheduling handled internally).""" + result = compute_paged_mqa_schedule_metadata( + torch.tensor([10, 20]), block_size=64, num_sms=84 + ) + assert result is None + + +class TestPagedMQALogits: + @pytest.fixture + def setup(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = "cuda" + batch = 2 + next_n = 1 + n_heads = 4 + head_dim = 128 + head_dim_with_sf = head_dim + 4 + block_kv = 64 + num_blocks = 8 + max_seq_len = 256 + + # Create random FP8 queries + q_raw = torch.randn(batch, next_n, n_heads, head_dim, device=device) * 0.1 + q_fp8 = _make_fp8_with_scale(q_raw) + + # Create random FP8 KV cache blocks + kv_raw = torch.randn(num_blocks, block_kv, 1, head_dim, device=device) * 0.1 + kv_fp8 = _make_fp8_with_scale(kv_raw) + + # Head weights + weights = torch.randn(batch, n_heads, device=device) + + # Sequence lengths + seqlens = torch.tensor([[100], [64]], dtype=torch.int32, device=device) + + # Block tables: batch 0 uses blocks [0,1], batch 1 uses blocks [2] + block_tables = torch.zeros(batch, 4, dtype=torch.int32, device=device) + block_tables[0, :2] = torch.tensor([0, 1]) + block_tables[1, :1] = torch.tensor([2]) + + return { + "q_fp8": q_fp8, + "kv_fp8": kv_fp8, + "weights": weights, + "seqlens": seqlens, + "block_tables": block_tables, + "max_seq_len": max_seq_len, + "batch": batch, + "next_n": next_n, + } + + def test_output_shape(self, setup): + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + expected_shape = ( + setup["batch"] * setup["next_n"], + setup["max_seq_len"], + ) + assert logits.shape == expected_shape + + def test_masked_positions_are_neginf(self, setup): + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + # Positions beyond seq_len should be -inf + seq_len_0 = setup["seqlens"][0, 0].item() + assert torch.all(logits[0, seq_len_0:] == float("-inf")) + + def test_valid_positions_are_finite(self, setup): + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + seq_len_0 = setup["seqlens"][0, 0].item() + assert torch.all(torch.isfinite(logits[0, :seq_len_0])) + + def test_zero_seqlen(self, setup): + """Batch element with zero seqlen should produce all -inf.""" + setup["seqlens"][1, 0] = 0 + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + assert torch.all(logits[1] == float("-inf")) + + +class TestContiguousMQALogits: + @pytest.fixture + def setup(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = "cuda" + num_q = 4 + n_heads = 4 + head_dim = 128 + head_dim_with_sf = head_dim + 4 + num_k = 200 + + # Queries with scale suffix + q_raw = torch.randn(num_q, n_heads, head_dim, device=device) * 0.1 + q_fp8 = _make_fp8_with_scale(q_raw) + + # KV with scale suffix + k_raw = torch.randn(num_k, head_dim, device=device) * 0.1 + k_fp8 = _make_fp8_with_scale(k_raw.unsqueeze(-2)).squeeze(-2) + k_scale = torch.ones(num_k, device=device) + + # Weights + weights = torch.randn(num_q, n_heads, device=device) + + # Ragged ranges + ks = torch.tensor([0, 50, 100, 150], dtype=torch.int32, device=device) + ke = torch.tensor([50, 100, 150, 200], dtype=torch.int32, device=device) + + return { + "q_fp8": q_fp8, + "kv_fp8": (k_fp8, k_scale), + "weights": weights, + "ks": ks, + "ke": ke, + "num_q": num_q, + "num_k": num_k, + } + + def test_output_shape(self, setup): + logits = sm120_fp8_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["ks"], + setup["ke"], + ) + assert logits.shape[0] == setup["num_q"] + assert logits.shape[1] >= setup["num_k"] + + def test_masked_outside_range(self, setup): + logits = sm120_fp8_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["ks"], + setup["ke"], + ) + # For q=0: valid range [0, 50), positions [50, num_k) should be -inf + assert torch.all(logits[0, 50 : setup["num_k"]] == float("-inf")) + + def test_valid_inside_range(self, setup): + logits = sm120_fp8_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["ks"], + setup["ke"], + ) + # For q=0: valid range [0, 50), should be finite + assert torch.all(torch.isfinite(logits[0, :50])) + + def test_empty_input(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = "cuda" + q_fp8 = torch.zeros(0, 4, 132, dtype=torch.float8_e4m3fn, device=device) + k_fp8 = torch.zeros(10, 132, dtype=torch.float8_e4m3fn, device=device) + k_scale = torch.ones(10, device=device) + weights = torch.zeros(0, 4, device=device) + ks = torch.zeros(0, dtype=torch.int32, device=device) + ke = torch.zeros(0, dtype=torch.int32, device=device) + + logits = sm120_fp8_mqa_logits(q_fp8, (k_fp8, k_scale), weights, ks, ke) + assert logits.shape[0] == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 61a71a1b81722c4f4ed5d6dc66a7e14861a73428 Mon Sep 17 00:00:00 2001 From: alichen Date: Wed, 13 May 2026 06:06:59 -0700 Subject: [PATCH 2/5] fix: address PR review comments and handle KV cache uint8 dtype on SM120 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address all reviewer feedback from PR #24692: - Use is_sm120_supported() helper instead of raw sm_version checks - Guard SGLANG_OPT_DEEPGEMM_HC_PRENORM and SGLANG_OPT_USE_TILELANG_MHC_PRE with `not is_sm120_supported()` in deepseek_v4.py - Auto-select marlin MoE backend on SM120 in deepseek_v4_hook.py - Minor cleanups in indexer, metadata, nsa_backend, mxfp4_marlin_moe Fix FlashMLA Triton kernel garbled output on latest sglang:dev image: - Root cause: upstream changed KV cache dtype from float8_e4m3fn to uint8. The Triton kernel's as_strided() preserved the input dtype, so tl.load interpreted FP8 bit patterns as raw integers, corrupting attention scores. - Fix: explicitly view through uint8 → float8_e4m3fn before passing to Triton. Verified on sglang:dev-cu13 (sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130): - GSM8K 5-shot 200q: 99.0% - Decode BS=1: 11.40 tok/s, TPOT 87.7ms Co-Authored-By: Claude Opus 4.6 (1M context) --- python/sglang/jit_kernel/utils.py | 6 +++--- .../sglang/srt/arg_groups/deepseek_v4_hook.py | 9 +++++++++ python/sglang/srt/environ.py | 1 - .../layers/attention/deepseek_v4_backend.py | 3 +-- .../srt/layers/attention/dsv4/indexer.py | 4 ++-- .../srt/layers/attention/dsv4/metadata.py | 4 ++-- .../attention/flash_mla_sm120_fallback.py | 19 +++++++++++++++---- .../attention/flash_mla_sm120_triton.py | 10 +++++++--- .../srt/layers/attention/nsa/nsa_indexer.py | 9 +++++++-- .../srt/layers/attention/nsa_backend.py | 7 ++++--- .../layers/deep_gemm_wrapper/configurer.py | 6 +++--- .../mxfp4_moe_sm120_triton.py | 12 +++++++++++- .../layers/quantization/mxfp4_marlin_moe.py | 19 +++++++++---------- python/sglang/srt/models/deepseek_v4.py | 5 +++-- 14 files changed, 76 insertions(+), 38 deletions(-) diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index d42c57262000..96f4519ad004 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -310,9 +310,9 @@ def is_arch_support_pdl() -> bool: if is_hip_runtime(): return False arch = get_jit_cuda_arch() - # PDL requires SM100+ datacenter (tcgen05/TMEM); SM120 (desktop Blackwell) lacks these - if arch.major == 12: - return False + # PDL (griddepcontrol) instruction is supported on SM90+ (Hopper, Blackwell). + # SM120 (desktop Blackwell) supports PDL despite lacking TMEM/tcgen05 — + # PDL uses griddepcontrol for kernel scheduling, independent of TMEM. return arch.major >= 9 diff --git a/python/sglang/srt/arg_groups/deepseek_v4_hook.py b/python/sglang/srt/arg_groups/deepseek_v4_hook.py index b3af8e95f82c..3c49038f6d74 100644 --- a/python/sglang/srt/arg_groups/deepseek_v4_hook.py +++ b/python/sglang/srt/arg_groups/deepseek_v4_hook.py @@ -51,6 +51,15 @@ def apply_deepseek_v4_defaults(server_args: "ServerArgs", model_arch: str) -> No f"Setting swa_full_tokens_ratio to {server_args.swa_full_tokens_ratio} for {model_arch}." ) + # SM120: auto-select marlin MoE backend (dispatches to SM120 Triton kernel) + from sglang.srt.utils.common import is_sm120_supported + + if is_sm120_supported() and server_args.moe_runner_backend == "auto": + server_args.moe_runner_backend = "marlin" + logger.info( + "Use marlin as MoE runner backend on SM120 for DeepSeekV4" + ) + if server_args.disaggregation_mode != "null" and server_args.pp_size > 1: # get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized # flat KV ptrs by PP layer range. diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 5852a0ff73cb..57edcdd80ec4 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -578,7 +578,6 @@ class Envs: SGLANG_OPT_USE_ONLINE_COMPRESS = EnvBool(False) SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) - SGLANG_HACK_FLASHMLA_BACKEND = EnvStr("kernel") # SWA radix cache SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index 48ddd6999c08..ca4c02b3a3a5 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -1044,8 +1044,7 @@ def forward( extra_topk_length=extra_topk_lengths, ) - backend = envs.SGLANG_HACK_FLASHMLA_BACKEND.get() - o = flash_mla_with_kvcache_entrypoint(**input_dict, backend=backend)[0] + o = flash_mla_with_kvcache_entrypoint(**input_dict, backend="kernel")[0] o = o.squeeze(1) return o diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index 7eb751cf5f7c..bd08629bfa83 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -64,8 +64,8 @@ def fp8_paged_mqa_logits_torch( block_size = kvcache_fp8.shape[1] device = q_fp8.device - assert head_dim == 128, "TODO" - assert block_size == 64, "TODO" + assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128" + assert block_size == 64, "Vectorized torch impl hardcodes block_size=64 cache layout" assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) assert weight.shape == (batch_size, num_heads) diff --git a/python/sglang/srt/layers/attention/dsv4/metadata.py b/python/sglang/srt/layers/attention/dsv4/metadata.py index 6bef63619883..4c33b02e63b4 100644 --- a/python/sglang/srt/layers/attention/dsv4/metadata.py +++ b/python/sglang/srt/layers/attention/dsv4/metadata.py @@ -8,10 +8,10 @@ from sglang.srt.environ import envs from sglang.srt.utils import is_hip -from sglang.srt.utils.common import get_device_sm +from sglang.srt.utils.common import is_sm120_supported _is_cuda = torch.cuda.is_available() and not is_hip() -_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 +_is_sm120 = _is_cuda and is_sm120_supported() if TYPE_CHECKING: pass diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py index 76cf3683df96..886f521aef43 100644 --- a/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py @@ -22,12 +22,12 @@ import torch from sglang.srt.utils import is_hip -from sglang.srt.utils.common import get_device_sm +from sglang.srt.utils.common import is_sm120_supported logger = logging.getLogger(__name__) _is_cuda = torch.cuda.is_available() and not is_hip() -_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 +_is_sm120 = _is_cuda and is_sm120_supported() # Page layout constants for DSv4-Flash (MODEL1): # nope_dim = 448, rope_dim = 64, quantize_block_size = 64 @@ -204,11 +204,22 @@ def _sm120_sparse_decode_fwd( return out.to(torch.bfloat16), lse.permute(0, 2, 1) -_use_triton_flashmla = os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" +# Default SM120 FlashMLA backend: "triton" (optimized) or "torch" (pure-PyTorch fallback). +# Controlled by SGLANG_SM120_TRITON_FLASHMLA env var for backward compat (1=triton, 0=torch). +_sm120_default_backend = ( + "triton" + if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" + else "torch" +) def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): if _is_sm120: + # On SM120, the `backend` parameter selects between "triton" (default, + # optimized Triton kernel) and "torch" (pure-PyTorch fallback). + # The original flash_mla CUDA "kernel" backend is unavailable on SM120. + sm120_backend = _sm120_default_backend if backend == "kernel" else backend + q = kwargs["q"] k_cache = kwargs["k_cache"] indices = kwargs["indices"] @@ -222,7 +233,7 @@ def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): extra_indices = kwargs.get("extra_indices_in_kvcache") extra_topk_length = kwargs.get("extra_topk_length") - if _use_triton_flashmla: + if sm120_backend == "triton": from sglang.srt.layers.attention.flash_mla_sm120_triton import ( flash_mla_sparse_decode_triton, ) diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py index 0eae57d3a350..20ffb4d3cbab 100644 --- a/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py @@ -222,10 +222,14 @@ def _run_triton_sparse_decode( flat_indices = indices.reshape(B, -1).contiguous() topk = flat_indices.shape[1] - # Create three typed views of the flat cache memory + # Create three typed views of the flat cache memory. + # The KV cache may arrive as uint8 or float8_e4m3fn depending on the + # sglang version. Ensure each view has the correct dtype so Triton + # interprets the loaded values correctly (FP8 dequant vs raw integer). total_elems = num_pages * page_bytes - raw_fp8 = k_cache.as_strided((total_elems,), (1,)) - raw_uint8 = raw_fp8.view(torch.uint8) + raw_flat = k_cache.as_strided((total_elems,), (1,)) + raw_uint8 = raw_flat.view(torch.uint8) + raw_fp8 = raw_uint8.view(torch.float8_e4m3fn) raw_bf16 = raw_uint8.view(torch.bfloat16) # Squeeze Q: [B, H, D] diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index e9ee2a8527ce..182a759c9dfd 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -23,17 +23,17 @@ add_prefix, ceil_align, get_bool_env_var, - get_device_sm, is_cuda, is_gfx95_supported, is_hip, is_npu, ) +from sglang.srt.utils.common import is_sm120_supported global _use_multi_stream _is_cuda = is_cuda() _is_hip = is_hip() -_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 # SM120/SM121 +_is_sm120 = _is_cuda and is_sm120_supported() _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_fp8_fnuz = is_fp8_fnuz() @@ -850,6 +850,11 @@ def _get_topk_ragged_with_cp( actual_seq_q: int, cp_index: List[Tuple[int, int, int]] = None, ) -> torch.Tensor: + if _is_sm120: + raise NotImplementedError( + "Ragged CP path requires DeepGEMM fp8_mqa_logits which is not " + "supported on SM120. Use paged topk_transform instead." + ) if TYPE_CHECKING: assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 9b5c2b21a8e6..e271464d8a34 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -36,9 +36,10 @@ ) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import get_device_sm, is_cuda, is_hip +from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils.common import is_sm120_supported -_is_sm120 = is_cuda() and get_device_sm() // 10 == 12 +_is_sm120 = is_cuda() and is_sm120_supported() if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -1313,7 +1314,7 @@ def init_forward_metadata_replay_cuda_graph_from_precomputed( # this replay (the captured graph holds stale data otherwise, which can # deadlock the kernel when the runtime work decomposition diverges from # the captured one). - if is_cuda(): + if is_cuda() and not _is_sm120: try: import deep_gemm diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py index 9933b85e3dba..380e3e6ba523 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py @@ -7,6 +7,7 @@ is_cuda, is_musa, ) +from sglang.srt.utils.common import is_sm120_supported logger = logging.getLogger(__name__) @@ -19,15 +20,14 @@ def _compute_enable_deep_gemm(): if (_is_cuda and sm_version < 90) or (_is_musa and sm_version < 31): return False # DeepGEMM requires TMEM/tcgen05 (SM100+datacenter), not available on SM120 - if sm_version // 10 == 12: + if is_sm120_supported(): return False if not (_is_cuda or _is_musa): return False try: import deep_gemm # noqa: F401 - except (ImportError, AssertionError): - # AssertionError: deep_gemm init may fail on unsupported architectures + except ImportError: return False return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py index c48faeeb2f22..95d303f0a8fb 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py @@ -356,7 +356,11 @@ def mxfp4_moe_forward_triton( # ── Graph-safe routing: flatten topk assignments ── # token_ids[slot] = which row of A (original token index) # expert_ids[slot] = which expert's weights to use - flat_expert_ids = topk_ids.reshape(-1).contiguous() # [M*topk] + # topk_ids may contain -1 for padded/filtered tokens; clamp to 0 for safe + # Triton loads, then zero out invalid slots' output after GEMM. + flat_expert_ids_raw = topk_ids.reshape(-1).contiguous() # [M*topk] + invalid_slot_mask = flat_expert_ids_raw < 0 # [M*topk] + flat_expert_ids = flat_expert_ids_raw.clamp(min=0) # safe for indexing token_ids = ( torch.arange(M, device=device, dtype=torch.int32) .unsqueeze(1) @@ -434,6 +438,12 @@ def mxfp4_moe_forward_triton( down.stride(0), ) + # ── Zero out invalid slots (padded/filtered tokens with topk_ids == -1) ── + # Use multiplication instead of boolean indexing to stay CUDA-graph-safe + # (no GPU→CPU sync). valid_mask is 1.0 for valid slots, 0.0 for invalid. + valid_mask = (~invalid_slot_mask).unsqueeze(1).to(dtype) # [M*topk, 1] + down = down * valid_mask + # ── Weighted sum across topk slots (graph-safe) ── flat_weights = topk_weights.reshape(-1).unsqueeze(1).to(dtype) # [M*topk, 1] output = (down * flat_weights).view(M, topk, K).sum(dim=1) diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py index 2f0d32d89ae9..cf7b337e3bb1 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py +++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.utils import log_info_on_rank0 -from sglang.srt.utils.common import get_device_sm, is_sm90_supported +from sglang.srt.utils.common import is_sm120_supported, is_sm90_supported if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput @@ -63,15 +63,14 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_mega_moe_weights_built", False): return - _sm = get_device_sm() - if not is_sm90_supported() and _sm // 10 != 12: + if not is_sm90_supported() and not is_sm120_supported(): raise RuntimeError( "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." ) # SM120: Skip Marlin repacking, keep original weight format - # for PyTorch dequant fallback (Marlin kernel produces NaN on SM120) - if _sm // 10 == 12: + # for Triton dequant kernel (Marlin kernel produces NaN on SM120) + if is_sm120_supported(): from torch.nn import Parameter log_info_on_rank0( @@ -97,7 +96,7 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False, ) # else: float32 scales are already usable directly - layer._dsv4_mxfp4_backend = "sm120_fallback" + layer._dsv4_mxfp4_backend = "sm120_triton" return if not check_moe_marlin_supports_layer(layer, 32): @@ -132,10 +131,10 @@ def apply( if not TopKOutputChecker.format_is_standard(topk_output): raise ValueError(f"Unsupported topk output format: {topk_output.format}") - # SM120 fallback: use Triton fused dequant+GEMM (or PyTorch fallback) - if getattr(layer, "_dsv4_mxfp4_backend", None) == "sm120_fallback": + # SM120: use Triton fused dequant+GEMM (Marlin kernel produces NaN on SM120) + if getattr(layer, "_dsv4_mxfp4_backend", None) == "sm120_triton": from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( - mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback, + mxfp4_moe_forward_triton, ) hidden_states = dispatch_output.hidden_states @@ -146,7 +145,7 @@ def apply( intermediate_size = w13.shape[1] // 2 hidden_size = w13.shape[2] * 2 - output = mxfp4_moe_forward_fallback( + output = mxfp4_moe_forward_triton( hidden_states=hidden_states, w13_packed=w13, w2_packed=w2, diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 4461a0a279a9..650969c56df2 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -70,6 +70,7 @@ log_info_on_rank0, make_layers, ) +from sglang.srt.utils.common import is_sm120_supported from sglang.srt.utils.hf_transformers_utils import get_rope_config logger = logging.getLogger(__name__) @@ -677,7 +678,7 @@ def hc_pre_torch_impl(x, hc_fn): ) return y, post, comb, False - if envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get(): + if envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get() and not is_sm120_supported(): from sglang.srt.layers.mhc import mhc_pre norm_kwargs = {} @@ -699,7 +700,7 @@ def hc_pre_torch_impl(x, hc_fn): ) return y, post.squeeze(-1), comb, norm is not None - if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): + if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get() and not is_sm120_supported(): import deep_gemm x_flat = x.flatten(1).bfloat16() From beac7a091392cc20f7b1df7012a0fad9ec5f62a5 Mon Sep 17 00:00:00 2001 From: alichen Date: Wed, 13 May 2026 20:17:39 -0700 Subject: [PATCH 3/5] style: fix pre-commit lint issues (isort, ruff, black) Co-Authored-By: Claude Opus 4.6 (1M context) --- python/sglang/srt/arg_groups/deepseek_v4_hook.py | 4 +--- python/sglang/srt/layers/attention/dsv4/indexer.py | 6 +++--- .../sglang/srt/layers/attention/flash_mla_sm120_fallback.py | 4 +--- python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/arg_groups/deepseek_v4_hook.py b/python/sglang/srt/arg_groups/deepseek_v4_hook.py index 3c49038f6d74..06b323fb3233 100644 --- a/python/sglang/srt/arg_groups/deepseek_v4_hook.py +++ b/python/sglang/srt/arg_groups/deepseek_v4_hook.py @@ -56,9 +56,7 @@ def apply_deepseek_v4_defaults(server_args: "ServerArgs", model_arch: str) -> No if is_sm120_supported() and server_args.moe_runner_backend == "auto": server_args.moe_runner_backend = "marlin" - logger.info( - "Use marlin as MoE runner backend on SM120 for DeepSeekV4" - ) + logger.info("Use marlin as MoE runner backend on SM120 for DeepSeekV4") if server_args.disaggregation_mode != "null" and server_args.pp_size > 1: # get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index df2c03e88b4c..87653b186ced 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -20,8 +20,6 @@ PagedIndexerMetadata, _is_sm120, ) -from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation -from sglang.srt.layers.attention.nsa.triton_kernel import act_quant from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.state_capturer.indexer_topk import get_global_indexer_capturer from sglang.srt.utils import add_prefix, is_hip @@ -65,7 +63,9 @@ def fp8_paged_mqa_logits_torch( device = q_fp8.device assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128" - assert block_size == 64, "Vectorized torch impl hardcodes block_size=64 cache layout" + assert ( + block_size == 64 + ), "Vectorized torch impl hardcodes block_size=64 cache layout" assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) assert weight.shape == (batch_size, num_heads) diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py index 886f521aef43..16a10108f28e 100644 --- a/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py @@ -207,9 +207,7 @@ def _sm120_sparse_decode_fwd( # Default SM120 FlashMLA backend: "triton" (optimized) or "torch" (pure-PyTorch fallback). # Controlled by SGLANG_SM120_TRITON_FLASHMLA env var for backward compat (1=triton, 0=torch). _sm120_default_backend = ( - "triton" - if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" - else "torch" + "triton" if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" else "torch" ) diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py index 9b88e4e038da..618e9cec5bb2 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py +++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.utils import log_info_on_rank0, set_weight_attrs -from sglang.srt.utils.common import is_sm120_supported, is_sm90_supported +from sglang.srt.utils.common import is_sm90_supported, is_sm120_supported if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput From 385a8d4ed9111f250c7bdf2936ba7bbd799e69c9 Mon Sep 17 00:00:00 2001 From: alichen Date: Wed, 13 May 2026 23:38:43 -0700 Subject: [PATCH 4/5] fix: wrap pytest.main in sys.exit for CI exit code propagation Co-Authored-By: Claude Opus 4.6 (1M context) --- python/sglang/test/test_sm120_mqa_fallback.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/test/test_sm120_mqa_fallback.py b/python/sglang/test/test_sm120_mqa_fallback.py index 3e3a87c9fd10..77f247918ca5 100644 --- a/python/sglang/test/test_sm120_mqa_fallback.py +++ b/python/sglang/test/test_sm120_mqa_fallback.py @@ -7,6 +7,8 @@ Run: python -m pytest python/sglang/test/test_sm120_mqa_fallback.py -v """ +import sys + import pytest import torch @@ -276,4 +278,4 @@ def test_empty_input(self): if __name__ == "__main__": - pytest.main([__file__, "-v"]) + sys.exit(pytest.main([__file__, "-v"])) From 1f2b7ef5a39e5b07f4da7375ecb54d7421c439be Mon Sep 17 00:00:00 2001 From: AdamPlatin Date: Mon, 18 May 2026 17:27:16 +0800 Subject: [PATCH 5/5] perf(sm120): add FP8 W8A8 Block GEMM autotune configs for RTX PRO 6000 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port 23 autotune configuration JSON files from the SM120 fork branch, eliminating all "Performance might be sub-optimal" warnings on RTX PRO 6000 Blackwell. These configs provide pre-tuned tile sizes for the Triton FP8 block-scaled GEMM kernel across all DeepSeek-V4 layer dimensions (N×K = 1280×5120 through 7168×18432). Verified: EAGLE 1/1/2 decode avg 44.4→45.0 tok/s, peak 52.4 tok/s on 2× RTX PRO 6000 (TP=2, ctx=8192). --- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 20 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 20 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 20 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...type=fp8_w8a8,block_shape=[128, 128].json} | 132 ++++++++-------- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 20 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 20 +++ ...type=fp8_w8a8,block_shape=[128, 128].json} | 148 +++++++++--------- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 +++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 20 +++ 23 files changed, 1970 insertions(+), 140 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json rename python/sglang/srt/layers/quantization/configs/{N=2048,K=4096,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json => N=3072,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json} (88%) create mode 100644 python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json rename python/sglang/srt/layers/quantization/configs/{N=5120,K=2048,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json => N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json} (78%) create mode 100644 python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..5c0c8d76195f --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..ca97131cb8ef --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "4": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3}, + "16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, + "32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "48": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2} +} diff --git a/python/sglang/srt/layers/quantization/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..ca97131cb8ef --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "4": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3}, + "16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, + "32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "48": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2} +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..ca97131cb8ef --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "4": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3}, + "16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, + "32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "48": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2} +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..77ba0d7477bd --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..0a5d7bfdba48 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json similarity index 88% rename from python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json rename to python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json index 15bf8b23f353..cb91a279d423 100644 --- a/python/sglang/srt/layers/quantization/configs/N=2048,K=4096,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -1,59 +1,27 @@ { - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, + "1": { + "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, @@ -61,30 +29,30 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 64, + "num_warps": 8, "num_stages": 5 }, "48": { @@ -97,35 +65,67 @@ }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, "num_stages": 3 }, - "1024": { + "256": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, - "1536": { + "512": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3 }, - "2048": { + "1536": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4 + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 }, "3072": { "BLOCK_SIZE_M": 128, @@ -133,14 +133,14 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 4 + "num_stages": 2 }, "4096": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..7febe3d272b4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..ca97131cb8ef --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "4": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3}, + "16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, + "32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "48": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2} +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..ca97131cb8ef --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "4": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3}, + "16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, + "32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "48": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2} +} diff --git a/python/sglang/srt/layers/quantization/configs/N=5120,K=2048,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json similarity index 78% rename from python/sglang/srt/layers/quantization/configs/N=5120,K=2048,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json rename to python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json index 1e23bc0b3ec8..9d7658bfc41b 100644 --- a/python/sglang/srt/layers/quantization/configs/N=5120,K=2048,device_name=NVIDIA_L40,dtype=fp8_w8a8,block_shape=[128, 128].json +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -1,55 +1,23 @@ { - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, "1": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 5 + "num_stages": 3 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, @@ -57,90 +25,122 @@ }, "8": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 5 + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 + "num_warps": 8, + "num_stages": 2 }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 4 }, "32": { - "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4 }, "64": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 4 + "num_stages": 2 }, - "1024": { - "BLOCK_SIZE_M": 128, + "96": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 }, - "1536": { - "BLOCK_SIZE_M": 128, + "128": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, - "2048": { + "256": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 }, - "3072": { + "512": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 2 }, - "4096": { + "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 8, - "num_stages": 4 + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 } } diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..03dba5ad15ba --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..9a5ff48b8942 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..15e91cde59a3 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..c714b7f1928c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..386928de139c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..f33809b0ad05 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..9c908e804065 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..f78e7060e684 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..1d3ce5c94c2d --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..3ab5796ee15b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..3cb7eaa07c74 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..ca97131cb8ef --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,20 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 64, "num_warps": 4, "num_stages": 3}, + "2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "4": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "8": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 3}, + "16": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "24": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, + "32": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3}, + "48": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2}, + "1536": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2}, + "3072": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, + "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2} +}