diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 0943acd8dc54..ed62e5ffa518 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -19,7 +19,7 @@ import os import random import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import numpy as np @@ -53,6 +53,9 @@ class BenchArgs: extra_request_body: Optional[str] = None apply_chat_template: bool = False profile: bool = False + profile_activities: Tuple[str] = ("CPU", "GPU") + profile_start_step: Optional[int] = None + profile_steps: Optional[int] = None skip_warmup: bool = False do_not_exit: bool = False prompt_suffix: str = "" @@ -169,6 +172,26 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Use Torch Profiler. The endpoint must be launched with " "SGLANG_TORCH_PROFILER_DIR to enable profiler.", ) + parser.add_argument( + "--profile-activities", + type=str, + nargs="+", + default=["CPU", "GPU"], + choices=["CPU", "GPU", "CUDA_PROFILER", "XPU"], + help="Profiler activities: CPU, GPU, XPU, CUDA_PROFILER. If CPU/GPU/XPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.", + ) + parser.add_argument( + "--profile-start-step", + type=int, + default=None, + help="Decode step at which to start profiling (0-indexed). If not specified, defaults to output_len // 2.", + ) + parser.add_argument( + "--profile-steps", + type=int, + default=None, + help="Number of decode steps to profile starting from profile-start-step. If not specified, profiles only one step.", + ) parser.add_argument( "--skip-warmup", action="store_true", @@ -210,6 +233,9 @@ def throughput_test_once( ignore_eos: bool, extra_request_body: Dict, profile: bool, + profile_activities=None, + profile_start_step=None, + profile_steps=None, return_logprob: bool = False, logprob_start_len: int = -1, ): @@ -241,7 +267,7 @@ def throughput_test_once( "SGLANG_TORCH_PROFILER_DIR" in os.environ ), "Please set SGLANG_TORCH_PROFILER_DIR." os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True) - backend.start_profile() + backend.start_profile(start_step=profile_start_step, num_steps=profile_steps, activities=profile_activities) st = time.perf_counter() gen_out = backend.generate( @@ -255,8 +281,9 @@ def throughput_test_once( if profile: dir = os.getenv("SGLANG_TORCH_PROFILER_DIR") known_files = set(os.listdir(dir)) - backend.stop_profile() - monitor_trace_file(known_files, dir) + if not profile_steps: + backend.stop_profile() + monitor_trace_file(known_files, dir) if backend_name == "runtime": gen_out = json.loads(gen_out) @@ -455,6 +482,9 @@ def throughput_test( ignore_eos=not bench_args.disable_ignore_eos, extra_request_body=extra_request_body, profile=bench_args.profile, + profile_activities=bench_args.profile_activities, + profile_start_step=bench_args.profile_start_step, + profile_steps=bench_args.profile_steps, return_logprob=bench_args.return_logprob, logprob_start_len=bench_args.logprob_start_len, ) diff --git a/python/sglang/jit_kernel/deepseek_v4.py b/python/sglang/jit_kernel/deepseek_v4.py index 671dbf89141f..a929702b64c8 100644 --- a/python/sglang/jit_kernel/deepseek_v4.py +++ b/python/sglang/jit_kernel/deepseek_v4.py @@ -13,10 +13,16 @@ make_cpp_args, ) from sglang.srt.environ import envs -from sglang.srt.utils import cpu_has_amx_support, is_cpu +from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_xpu _is_cpu = is_cpu() +# JIT-compiled CUDA kernels in this module require tvm_ffi and a working CUDA +# toolchain. On non-CUDA backends (e.g. XPU) those entrypoints fall back to +# triton/torch implementations. +_is_cuda = is_cuda() +_is_xpu = is_xpu() _cpu_amx = cpu_has_amx_support() + if TYPE_CHECKING: from tvm_ffi.module import Module @@ -421,10 +427,10 @@ def generate( device=seq_lens.device, pin_memory=seq_lens.is_cpu if not _is_cpu else False, ) - is_overlap = compress_ratio == 4 - if _is_cpu: - plan_lens = _plan_compress_prefill_torch( + if _is_cuda: + module = _jit_common_module() + plan_lens = module.plan_compress_prefill( extend_lens, seq_lens, plan_tensor[0], @@ -434,8 +440,7 @@ def generate( use_cuda_graph, ) else: - module = _jit_common_module() - plan_lens = module.plan_compress_prefill( + plan_lens = _plan_compress_prefill_torch( extend_lens, seq_lens, plan_tensor[0], @@ -487,6 +492,488 @@ def is_decode(self) -> bool: return False +def _decode_prefill_plan(plan_bytes: torch.Tensor) -> torch.Tensor: + """Decode packed PrefillPlan tensor ([N, 16] uint8) into [N, 4] int64. + + Each plan slot is 4 little-endian uint32: + (ragged_id, batch_id, position, window_len). Invalid entries use + ``0xFFFFFFFF`` for every field. + + On-device bitcast: avoids a D->H sync that drains the L0 queue every + layer on XPU. uint8 [N, 16] is reinterpreted as int32 [N, 4] (LE on + XPU/x86), promoted to int64, then masked to keep the unsigned value + so that the kInvalid sentinel 0xFFFFFFFF (which would bitcast to -1 + in int32) compares correctly against ``_INVALID_PLAN``. + """ + t = plan_bytes.detach().contiguous() + as_i32 = t.view(torch.int32).reshape(-1, 4) + return as_i32.to(torch.int64) & 0xFFFFFFFF + + +def _describe(x: Any) -> str: + if isinstance(x, torch.Tensor): + return f"Tensor(shape={tuple(x.shape)}, dtype={x.dtype}, device={x.device})" + if isinstance(x, (CompressorDecodePlan, CompressorPrefillPlan)): + fields = ", ".join( + f"{name}={_describe(getattr(x, name))}" for name in x._fields + ) + return f"{type(x).__name__}({fields})" + if x is None: + return "None" + return repr(x) + + +def _torch_compress_forward( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + out: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan], + extra_data: Optional[torch.Tensor], + *, + head_dim: int, + compress_ratio: Literal[4, 128], +) -> None: + if compress_ratio == 4: + if plan.is_decode: + _torch_c4_decode( + kv_score_buffer, + kv_score_input, + out, + ape, + indices, + plan.seq_lens, + extra_data, + head_dim=head_dim, + ) + else: + _torch_c4_prefill( + kv_score_buffer, + kv_score_input, + out, + ape, + indices, + plan.compress_plan, + plan.write_plan, + extra_data, + head_dim=head_dim, + ) + else: + assert compress_ratio == 128 + if plan.is_decode: + _torch_c128_decode( + kv_score_buffer, + kv_score_input, + out, + ape, + indices, + plan.seq_lens, + head_dim=head_dim, + ) + else: + _torch_c128_prefill( + kv_score_buffer, + kv_score_input, + out, + ape, + indices, + plan.compress_plan, + plan.write_plan, + extra_data, + head_dim=head_dim, + ) + + +def _softmax_weighted_sum( + kv: torch.Tensor, + score: torch.Tensor, + bias: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + """Safe softmax over dim=-2 then weighted sum of ``kv``. + + Shapes: ``kv``/``score``/``bias`` all ``[..., S, head_dim]``. + Returns ``[..., head_dim]`` cast to ``out_dtype``. + """ + s = (score.float() + bias.float()) + m = s.amax(dim=-2, keepdim=True) + w = (s - m).exp() + num = (kv.float() * w).sum(dim=-2) + den = w.sum(dim=-2) + return (num / den).to(out_dtype) + + +# --------------------------------------------------------------------------- +# c4 fallback +# --------------------------------------------------------------------------- + + +def _c4_split_chunks(buf_or_input: torch.Tensor, head_dim: int) -> torch.Tensor: + """Split last dim ``head_dim*4`` into ``[..., 4, head_dim]``. + + Layout: ``| kv_overlap | kv | score_overlap | score |``. + """ + return buf_or_input.view(*buf_or_input.shape[:-1], 4, head_dim) + + +def _torch_c4_decode( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + out: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + seq_lens: torch.Tensor, + extra: Optional[torch.Tensor], + *, + head_dim: int, +) -> None: + # Determine page mode from buffer shape. + page_size = kv_score_buffer.shape[1] + paged = page_size == 4 + assert paged or page_size == 8 + HD = head_dim + B = indices.shape[0] + device = kv_score_input.device + out_dtype = out.dtype + + indices_i64 = indices.to(torch.int64) + seq_lens_i64 = seq_lens.to(torch.int64) + + # 1) write current step into the buffer + write_pos = (seq_lens_i64 + (page_size - 1)) % page_size # [B] + kv_score_buffer[indices_i64, write_pos] = kv_score_input.to(kv_score_buffer.dtype) + + # 2) forward only when seq_len % 4 == 0 + # NOTE: avoid host syncs on the per-step hot path -- compute over all + # B batches and mask the writeback. On XPU/L0, ``bool(t.any())`` and + # ``t.nonzero()`` drain the command queue every layer. + do_fwd = (seq_lens_i64 % 4) == 0 # [B] + + # Gather 8 slots from buffer for every batch. + buf4 = _c4_split_chunks(kv_score_buffer, HD) # [N_idx, page, 4, HD] + + if paged: + assert extra is not None, "Page4Align mode requires extra tensor" + index_prev = extra.view(-1).to(torch.int64) # [B] + # i in [0,8): first 4 use index_prev (overlap), last 4 use index + kv_chunks = [] + score_chunks = [] + for i in range(8): + k = i % 4 + page_idx = index_prev if i < 4 else indices_i64 + chunk_kv = 0 if i < 4 else 1 + chunk_score = 2 if i < 4 else 3 + kv_chunks.append(buf4[page_idx, k, chunk_kv]) + score_chunks.append(buf4[page_idx, k, chunk_score]) + else: + # Ring buffer of size 8. + kv_chunks = [] + score_chunks = [] + for i in range(8): + k = (seq_lens_i64 + i) % 8 + chunk_kv = 0 if i < 4 else 1 + chunk_score = 2 if i < 4 else 3 + kv_chunks.append(buf4[indices_i64, k, chunk_kv]) + score_chunks.append(buf4[indices_i64, k, chunk_score]) + + kv_stack = torch.stack(kv_chunks, dim=1) # [B, 8, HD] + score_stack = torch.stack(score_chunks, dim=1) + bias = ape.unsqueeze(0).expand(kv_stack.shape[0], -1, -1) # [B, 8, HD] + + # seq_len == 4 special case: zero overlap kv, -inf overlap score. + # Apply unconditionally via torch.where to avoid a host sync. + sl4 = (seq_lens_i64 == 4) + sl4_b = sl4.view(-1, 1, 1) + zero = torch.zeros((), dtype=kv_stack.dtype, device=device) + ninf = torch.full((), -1e9, dtype=score_stack.dtype, device=device) + head_mask = torch.zeros(8, dtype=torch.bool, device=device) + head_mask[:4] = True + full_mask = sl4_b & head_mask.view(1, 8, 1) + kv_stack = torch.where(full_mask, zero, kv_stack) + score_stack = torch.where(full_mask, ninf, score_stack) + + result = _softmax_weighted_sum(kv_stack, score_stack, bias, out_dtype) + mask = do_fwd.view(-1, *([1] * (out.ndim - 1))) + out.copy_(torch.where(mask, result.to(out.dtype), out)) + + +def _torch_c4_prefill( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + out: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + compress_plan: torch.Tensor, + write_plan: torch.Tensor, + extra: Optional[torch.Tensor], + *, + head_dim: int, +) -> None: + page_size = kv_score_buffer.shape[1] + paged = page_size == 4 + assert paged or page_size == 8 + HD = head_dim + device = kv_score_input.device + indices_i64 = indices.to(torch.int64) + + # On-device, fully-valid plans. ``_torch_plan_compress_prefill`` already + # slices the plan tensors to their exact valid length on the host before + # the H->D copy (see ``CompressorPrefillPlan.generate``), and prefill + # plans never use cuda-graph padding. So we can skip the per-layer + # ``bool((plan != INVALID).any())`` host sync and the data-dependent + # boolean-mask gather entirely; both are intrinsic L0-queue drains. + cplan = _decode_prefill_plan(compress_plan) # [Nc, 4] int64 + wplan = _decode_prefill_plan(write_plan) # [Nw, 4] int64 + + extra_i64: Optional[torch.Tensor] = None + if paged: + assert extra is not None, "Page4Align c4 prefill requires extra tensor" + extra_i64 = extra.to(torch.int64) + + # NOTE: order matches the CUDA kernel launches in c4.cuh: compress + # (reads buffer) runs BEFORE write (mutates buffer). Reversing the order + # would feed write-modified slots into compress and corrupt the output. + + # ---- compress plan ---------------------------------------------------- + if cplan.shape[0] > 0: + ragged_ids = cplan[:, 0] + batch_ids = cplan[:, 1] + positions = cplan[:, 2] + window_lens = cplan[:, 3] + seq_lens = positions + 1 # [N] + N = ragged_ids.shape[0] + + buf4 = _c4_split_chunks(kv_score_buffer, HD) # [N_idx, page, 4, HD] + inp4 = _c4_split_chunks(kv_score_input, HD) # [N_q, 4, HD] + + if paged: + assert extra_i64 is not None + load_first_page = extra_i64[batch_ids, 0] + load_second_page = extra_i64[batch_ids, 1] + # Choose per-i page: window_len <= 4 means both halves use second + # page; otherwise overlap (i<4) uses first, normal (i>=4) uses + # second. + wl_le_4 = window_lens <= 4 + + kv_chunks = [] + score_chunks = [] + for i in range(8): + chunk_kv = 0 if i < 4 else 1 + chunk_score = 2 if i < 4 else 3 + use_buf = (i < window_lens) # [N] bool + + # Buffer source. + if paged: + if i < 4: + page_idx = torch.where( + wl_le_4, load_second_page, load_first_page + ) + else: + page_idx = load_second_page + k_buf = torch.full_like(positions, i % 4) + buf_kv = buf4[page_idx, k_buf, chunk_kv] + buf_score = buf4[page_idx, k_buf, chunk_score] + else: + page_idx = indices_i64[batch_ids] + k_buf = (seq_lens + i) % 8 + buf_kv = buf4[page_idx, k_buf, chunk_kv] + buf_score = buf4[page_idx, k_buf, chunk_score] + + # Ragged tail source (k = i - 7 <= 0). + rag_off = (ragged_ids + (i - 7)).clamp(min=0) + rag_kv = inp4[rag_off, chunk_kv] + rag_score = inp4[rag_off, chunk_score] + + ub = use_buf.unsqueeze(-1) + kv_chunks.append(torch.where(ub, buf_kv, rag_kv)) + score_chunks.append(torch.where(ub, buf_score, rag_score)) + + kv_stack = torch.stack(kv_chunks, dim=1) # [N, 8, HD] + score_stack = torch.stack(score_chunks, dim=1) + bias = ape.unsqueeze(0).expand(N, -1, -1) + + # Apply the seq_len==4 special case unconditionally via torch.where + # to keep this sync-free; rows where seq_len != 4 see no change. + sl4 = (seq_lens == 4) + sl4_b = sl4.view(-1, 1, 1) + zero = torch.zeros((), dtype=kv_stack.dtype, device=device) + ninf = torch.full((), -1e9, dtype=score_stack.dtype, device=device) + head_mask = torch.zeros(8, dtype=torch.bool, device=device) + head_mask[:4] = True + full_mask = sl4_b & head_mask.view(1, 8, 1) + kv_stack = torch.where(full_mask, zero, kv_stack) + score_stack = torch.where(full_mask, ninf, score_stack) + + result = _softmax_weighted_sum(kv_stack, score_stack, bias, out.dtype) + out[ragged_ids] = result + + # ---- write plan (must run AFTER compress) ---------------------------- + _torch_c4_prefill_write( + kv_score_buffer, + kv_score_input, + indices_i64, + extra_i64, + wplan, + paged, + device, + ) + + +def _torch_c4_prefill_write( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + indices_i64: torch.Tensor, + extra_i64: Optional[torch.Tensor], + wplan: torch.Tensor, + paged: bool, + device: torch.device, +) -> None: + # See ``_torch_c4_prefill`` for why no validity check / boolean-mask + # gather is needed: prefill plans are already pre-sliced to valid + # length on the host before the H->D copy. + if wplan.shape[0] == 0: + return + ragged_ids = wplan[:, 0] + batch_ids = wplan[:, 1] + positions = wplan[:, 2] + if paged: + assert extra_i64 is not None + last_pos = extra_i64[batch_ids, 3] + write_first_page = extra_i64[batch_ids, 2] + write_second_page = indices_i64[batch_ids] + tgt_index = torch.where( + positions < last_pos, write_first_page, write_second_page + ) + tgt_pos = positions % 4 + else: + tgt_index = indices_i64[batch_ids] + tgt_pos = positions % 8 + kv_score_buffer[tgt_index, tgt_pos] = kv_score_input[ragged_ids].to( + kv_score_buffer.dtype + ) + + +# --------------------------------------------------------------------------- +# c128 fallback +# --------------------------------------------------------------------------- + + +def _c128_split_chunks(buf_or_input: torch.Tensor, head_dim: int) -> torch.Tensor: + """Split last dim ``head_dim*2`` into ``[..., 2, head_dim]``. + + Layout: ``| kv | score |``. + """ + return buf_or_input.view(*buf_or_input.shape[:-1], 2, head_dim) + + +def _torch_c128_decode( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + out: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + seq_lens: torch.Tensor, + *, + head_dim: int, +) -> None: + HD = head_dim + device = kv_score_input.device + indices_i64 = indices.to(torch.int64) + seq_lens_i64 = seq_lens.to(torch.int64) + + # 1) write current step at (seq_len + 127) % 128. + write_pos = (seq_lens_i64 + 127) % 128 + kv_score_buffer[indices_i64, write_pos] = kv_score_input.to(kv_score_buffer.dtype) + + # 2) forward only when seq_len % 128 == 0; window_len = 128 (all from buf). + # NOTE: avoid host syncs (e.g. ``bool(do_fwd.any())`` / + # ``do_fwd.nonzero()``) on the per-step hot path -- on XPU/L0 those + # force-drain the queue every layer and turn one decode step into + # minutes. Compute over all B batches, then mask the writeback. + do_fwd = (seq_lens_i64 % 128) == 0 # [B] + + buf2 = _c128_split_chunks(kv_score_buffer, HD) # [N_idx, 128, 2, HD] + gathered = buf2[indices_i64] # [B, 128, 2, HD] + kv = gathered[..., 0, :] # [B, 128, HD] + score = gathered[..., 1, :] + bias = ape.unsqueeze(0).expand(kv.shape[0], -1, -1) + result = _softmax_weighted_sum(kv, score, bias, out.dtype) # [B, ...] + mask = do_fwd.view(-1, *([1] * (out.ndim - 1))) + out.copy_(torch.where(mask, result.to(out.dtype), out)) + + +def _torch_c128_prefill( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + out: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + compress_plan: torch.Tensor, + write_plan: torch.Tensor, + extra: Optional[torch.Tensor], + *, + head_dim: int, +) -> None: + HD = head_dim + device = kv_score_input.device + indices_i64 = indices.to(torch.int64) + # extra is optional load_indices; falls back to indices when absent. + load_indices_i64 = ( + extra.to(torch.int64) if extra is not None else indices_i64 + ) + + # On-device, fully-valid plans (see ``_torch_c4_prefill``). + cplan = _decode_prefill_plan(compress_plan) + wplan = _decode_prefill_plan(write_plan) + + # NOTE: order matches the CUDA kernel launches in c128.cuh: compress + # (reads buffer) runs BEFORE write (mutates buffer). Reversing the order + # would feed write-modified slots into compress and corrupt the output. + + # ---- compress plan (uses `load_indices`) ----------------------------- + if cplan.shape[0] > 0: + ragged_ids = cplan[:, 0] + batch_ids = cplan[:, 1] + window_lens = cplan[:, 3] + N = ragged_ids.shape[0] + + buf2 = _c128_split_chunks(kv_score_buffer, HD) # [N_idx, 128, 2, HD] + inp2 = _c128_split_chunks(kv_score_input, HD) # [N_q, 2, HD] + + page_idx = load_indices_i64[batch_ids] # [N] + buf_slot = buf2[page_idx] # [N, 128, 2, HD] + buf_kv = buf_slot[..., 0, :] # [N, 128, HD] + buf_score = buf_slot[..., 1, :] + + j = torch.arange(128, device=device, dtype=torch.int64) + rag_off = (ragged_ids.unsqueeze(1) + (j.unsqueeze(0) - 127)).clamp( + min=0 + ) # [N, 128] + rag_kv = inp2[..., 0, :][rag_off] # [N, 128, HD] + rag_score = inp2[..., 1, :][rag_off] + + use_buf = (j.unsqueeze(0) < window_lens.unsqueeze(1)).unsqueeze( + -1 + ) # [N,128,1] + kv = torch.where(use_buf, buf_kv, rag_kv) + score = torch.where(use_buf, buf_score, rag_score) + bias = ape.unsqueeze(0).expand(N, -1, -1) # [N, 128, HD] + + out[ragged_ids] = _softmax_weighted_sum(kv, score, bias, out.dtype) + + # ---- write plan (uses `indices`, must run AFTER compress) ------------ + if wplan.shape[0] > 0: + ragged_ids = wplan[:, 0] + batch_ids = wplan[:, 1] + positions = wplan[:, 2] + tgt_index = indices_i64[batch_ids] + tgt_pos = positions % 128 + kv_score_buffer[tgt_index, tgt_pos] = kv_score_input[ragged_ids].to( + kv_score_buffer.dtype + ) + + class CompressorDecodePlan(NamedTuple): compress_ratio: int seq_lens: torch.Tensor @@ -555,6 +1042,21 @@ def compress_forward( F = online_module.decode if plan.is_decode else online_module.prefill F(kv_score_buffer, kv_score_input, out, ape, indices, *plan[1:], extra_data) return out + if _is_xpu: + # torch fallback for non-CUDA backends. Mirrors + # FlashCompress{4,128}Kernel in jit_kernel/csrc/deepseek_v4/c{4,128}.cuh. + _torch_compress_forward( + kv_score_buffer, + kv_score_input, + out, + ape, + indices, + plan, + extra_data, + head_dim=head_dim, + compress_ratio=compress_ratio, + ) + return out module = _jit_compress_module( head_dim, kv_score_input.dtype, @@ -574,6 +1076,18 @@ def compress_fused_norm_rope_inplace( plan: Union[CompressorDecodePlan, CompressorPrefillPlan], ) -> None: freq_cis = torch.view_as_real(freq_cis).flatten(-2) + mode = 1 if plan.is_decode else 0 + if _is_xpu: + _torch_fused_norm_rope( + kv, + weight, + plan[1], + freq_cis, + mode, + eps, + plan.compress_ratio, + ) + return module = _jit_norm_rope_module(kv.dtype, kv.shape[-1], freq_cis.shape[-1]) module.forward( kv, @@ -586,6 +1100,88 @@ def compress_fused_norm_rope_inplace( ) +def _torch_fused_norm_rope( + input_tensor: torch.Tensor, + weight: torch.Tensor, + handle: torch.Tensor, + freqs_cis: torch.Tensor, + mode: int, + eps: float, + compress_ratio: int, +) -> None: + """Pure-torch fallback for ``FusedNormRopeKernel::forward``. + + Mirrors ``fused_norm_rope`` in + ``jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh``: per-row RMSNorm + in-place, then RoPE on the trailing ``rope_dim`` of each selected row. + + ``mode``: + 0 = CompressExtend (handle = packed PrefillPlan, [N, 16] uint8) + 1 = CompressDecode (handle = seq_lens, [N]) + 2 = DefaultForward (handle = positions, [N]) + """ + head_dim = input_tensor.shape[-1] + rope_dim = freqs_cis.shape[-1] + device = input_tensor.device + in_dtype = input_tensor.dtype + + if mode == 0: + # Prefill plan is already pre-sliced to its exact valid length on + # the host (no cuda-graph padding on the prefill path), so all rows + # are valid. Avoid the per-layer ``bool((plan != INVALID).any())`` + # / boolean-mask gather host-syncs that drain the L0 queue. + plan = _decode_prefill_plan(handle) + if plan.shape[0] == 0: + return + rows = plan[:, 0] + positions = plan[:, 2] + 1 - compress_ratio + row_mask = None # writeback unconditional via index + elif mode == 1: + # Decode: avoid host syncs (``bool(valid.any())`` / + # ``valid.nonzero()``) that drain the L0 queue every layer. + # Compute over all rows and mask the writeback with torch.where. + seq_lens = handle.to(torch.int64) + valid = (seq_lens % compress_ratio) == 0 # [B] bool + num_works = seq_lens.shape[0] + rows = torch.arange(num_works, device=device, dtype=torch.int64) + # For invalid rows, ``positions`` would be negative; clamp to 0 + # so freqs_cis[positions] is always in-bounds. Result is masked + # out below before writeback. + positions = (seq_lens - compress_ratio).clamp_min(0) + row_mask = valid + elif mode == 2: + num_works = handle.shape[0] + positions = handle.to(torch.int64) + rows = torch.arange(num_works, device=device, dtype=torch.int64) + row_mask = None + else: + raise ValueError(f"unsupported fused_norm_rope mode: {mode}") + + # RMSNorm in-place on selected rows. + x = input_tensor[rows].float() # [N, head_dim] + var = (x * x).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(var + eps) * weight.float() + + # RoPE on the trailing rope_dim, viewed as (real, imag) interleaved pairs. + n = x.shape[0] + rope = x[..., -rope_dim:].view(n, rope_dim // 2, 2) + freq = freqs_cis.float()[positions].view(n, rope_dim // 2, 2) + xr, xi = rope[..., 0], rope[..., 1] + fr, fi = freq[..., 0], freq[..., 1] + out_r = xr * fr - xi * fi + out_i = xr * fi + xi * fr + rope_out = torch.stack([out_r, out_i], dim=-1).reshape(n, rope_dim) + x[..., -rope_dim:] = rope_out + if row_mask is None: + input_tensor[rows] = x.to(in_dtype) + else: + # Only update rows where ``valid`` is true; preserve others. + mask = row_mask.view(-1, *([1] * (x.ndim - 1))) + input_tensor[rows] = torch.where( + mask, x.to(in_dtype), input_tensor[rows] + ) + + def fused_rope( q: torch.Tensor, k: Optional[torch.Tensor], @@ -593,14 +1189,25 @@ def fused_rope( positions: torch.Tensor, inverse: bool = False, ) -> None: - if _is_cpu and _cpu_amx: + if _is_cuda: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() + module = _jit_fused_rope_module() + module.forward(q, k, freqs_real, positions, inverse) + elif _is_cpu and _cpu_amx: torch.ops.sgl_kernel.apply_rotary_emb_interleaved_cpu( q, freqs_cis, inverse, positions, k ) else: - freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() - module = _jit_fused_rope_module() - module.forward(q, k, freqs_real, positions, inverse) + # Triton fallback for non-CUDA backends (e.g. XPU). Mirrors + # FusedQKRopeKernel: apply rotary embedding in-place to q and + # (when provided) k. + from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton + + apply_rotary_emb_triton(q, freqs_cis, positions=positions, inverse=inverse) + if k is not None: + apply_rotary_emb_triton( + k, freqs_cis, positions=positions, inverse=inverse + ) @triton.jit diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index bcd42e5ce349..d42c57262000 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -309,7 +309,11 @@ def get_jit_cuda_arch() -> ArchInfo: def is_arch_support_pdl() -> bool: if is_hip_runtime(): return False - return get_jit_cuda_arch().major >= 9 + arch = get_jit_cuda_arch() + # PDL requires SM100+ datacenter (tcgen05/TMEM); SM120 (desktop Blackwell) lacks these + if arch.major == 12: + return False + return arch.major >= 9 def _find_package_root(package: str) -> Optional[pathlib.Path]: diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 57edcdd80ec4..18176cd41fdb 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -570,14 +570,16 @@ class Envs: SGLANG_DSV4_REASONING_EFFORT = EnvStr("") # CUDA kernels - SGLANG_OPT_DEEPGEMM_HC_PRENORM = EnvBool(True) - SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(True) - SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(True) + SGLANG_OPT_DEEPGEMM_HC_PRENORM = EnvBool(False) + SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(False) + SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(False) SGLANG_OPT_USE_TILELANG_INDEXER = EnvBool(False) SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(False) SGLANG_OPT_USE_ONLINE_COMPRESS = EnvBool(False) SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) - SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) + SGLANG_FP8_PAGED_MQA_LOGITS_TRITON = EnvBool(True) + SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(True) + SGLANG_HACK_FLASHMLA_BACKEND = EnvStr("kernel") # SWA radix cache SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) @@ -592,20 +594,20 @@ class Envs: SGLANG_OPT_FIX_MEGA_MOE_MEMORY = EnvBool(False) # TopK - SGLANG_OPT_USE_FUSED_HASH_TOPK = EnvBool(True) + SGLANG_OPT_USE_FUSED_HASH_TOPK = EnvBool(False) SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK = EnvBool(True) SGLANG_OPT_USE_TOPK_V2 = EnvBool(False) # GEMM / kernel fusion SGLANG_OPT_FP8_WO_A_GEMM = EnvBool(False) - SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("cublas") + SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("torch") SGLANG_OPT_USE_JIT_EP_ACTIVATION = EnvBool(True) SGLANG_OPT_USE_JIT_NORM = EnvBool(False) SGLANG_OPT_FUSE_WQA_WKV = EnvBool(True) - SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(True) + SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(False) # Cache / overlap - SGLANG_OPT_USE_FUSED_STORE_CACHE = EnvBool(True) + SGLANG_OPT_USE_FUSED_STORE_CACHE = EnvBool(False) SGLANG_OPT_USE_OVERLAP_STORE_CACHE = EnvBool(True) SGLANG_OPT_USE_MULTI_STREAM_OVERLAP = EnvBool(True) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py index 24153cdce0a7..7e5f054d373f 100644 --- a/python/sglang/srt/layers/attention/deepseek_v4_backend.py +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -37,6 +37,10 @@ from sglang.srt.layers.attention.dsv4.quant_k_cache import ( quant_to_nope_fp8_rope_bf16_pack_triton, ) +from sglang.srt.layers.attention.flash_mla_sm120_fallback import ( + _is_sm120, + flash_mla_with_kvcache_entrypoint, +) from sglang.srt.layers.dp_attention import ( get_attention_cp_rank, get_attention_cp_size, @@ -44,9 +48,10 @@ from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInput -from sglang.srt.utils import ceil_align, cpu_has_amx_support, is_cpu +from sglang.srt.utils import ceil_align, cpu_has_amx_support, is_cpu, is_xpu _is_cpu = is_cpu() +_is_xpu = is_xpu() _cpu_amx = cpu_has_amx_support() if _is_cpu and _cpu_amx: @@ -66,6 +71,7 @@ SWA_WINDOW = 128 C4_TOPK = 512 PAGE_INDEX_ALIGNED_SIZE = 64 +_is_xpu = is_xpu() T = TypeVar("T", bound=Optional[torch.Tensor]) @@ -80,6 +86,8 @@ def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: def _create_flashmla_metadata(): + if _is_sm120 or _is_xpu: + return None try: import flash_mla @@ -1067,9 +1075,7 @@ def forward( )[0] else: - import flash_mla - - o = flash_mla.flash_mla_with_kvcache( + input_dict = dict( q=q, k_cache=swa_k_cache, head_dim_v=self.head_dim_v, @@ -1084,7 +1090,10 @@ def forward( extra_k_cache=extra_k_cache, extra_indices_in_kvcache=extra_indices, extra_topk_length=extra_topk_lengths, - )[0] + ) + + backend = envs.SGLANG_HACK_FLASHMLA_BACKEND.get() + o = flash_mla_with_kvcache_entrypoint(**input_dict, backend=backend)[0] o = o.squeeze(1) return o diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py index f370248c141c..40c311118db4 100644 --- a/python/sglang/srt/layers/attention/dsv4/indexer.py +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -16,7 +16,10 @@ from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config from sglang.srt.environ import envs from sglang.srt.layers.attention.dsv4.compressor import Compressor -from sglang.srt.layers.attention.dsv4.metadata import PagedIndexerMetadata +from sglang.srt.layers.attention.dsv4.metadata import ( + PagedIndexerMetadata, + _is_sm120, +) from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation from sglang.srt.layers.attention.nsa.triton_kernel import act_quant from sglang.srt.layers.linear import ReplicatedLinear @@ -62,48 +65,109 @@ def fp8_paged_mqa_logits_torch( max_seq_len: int, clean_logits: bool = True, ) -> torch.Tensor: + """CUDA-graph-compatible FP8 paged MQA logits (vectorized, no .item()). + + Vectorized across batches using batched gather + bmm instead of + per-batch Python loop with .item() calls. + """ _ = deep_gemm_metadata batch_size, _, num_heads, head_dim = q_fp8.shape block_size = kvcache_fp8.shape[1] + device = q_fp8.device - assert head_dim == 128, "torch reference impl hardcodes DSV4 indexer head_dim=128" - assert block_size == 64, "torch reference impl hardcodes block_size=64 cache layout" + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) assert weight.shape == (batch_size, num_heads) + if seq_lens.dim() > 1: + seq_lens = seq_lens.squeeze(-1) assert seq_lens.shape == (batch_size,) assert page_table.shape[0] == batch_size assert clean_logits == False - logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) - for i in range(batch_size): - q = q_fp8[i, 0] - q = q.to(torch.float32) - q_scale = weight[i] - seq_len = int(seq_lens[i].item()) - assert seq_len <= max_seq_len - num_pages = (seq_len + block_size - 1) // block_size - padded_seq_len = num_pages * block_size - pages = page_table[i, :num_pages] - kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) - kvcache = kvcache_fp8[pages] - SCALE_OFFSET = block_size * head_dim - kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE) - kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32) - kvcache_value = kvcache_value.to(torch.float32) - kvcache_scale = kvcache_scale.contiguous() - kvcache_value = kvcache_value.view(padded_seq_len, head_dim) - kvcache_scale = kvcache_scale.view(padded_seq_len) - score = F.linear(kvcache_value, q) - score = F.relu(score) - score *= q_scale[None, :] - score = score.sum(dim=1) - score *= kvcache_scale - logits[i, :seq_len] = score[:seq_len] + # ── Vectorized: no .item(), no per-batch loop ── + max_pages = (max_seq_len + block_size - 1) // block_size + max_padded_seq = max_pages * block_size + + # Flatten KV cache for indexing: [total_pages, block_size * (head_dim + 4)] + kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + SCALE_OFFSET = block_size * head_dim + + # Gather pages for all batches: [batch, max_pages] + page_ids = page_table[:, :max_pages] + # Gather KV data: [batch, max_pages, block_size * (head_dim + 4)] + kvcache_gathered = kvcache_flat[page_ids] + + # Split value and scale + kv_value_raw = kvcache_gathered[ + ..., :SCALE_OFFSET + ] # [batch, max_pages, block_size * head_dim] + kv_scale_raw = kvcache_gathered[ + ..., SCALE_OFFSET: + ] # [batch, max_pages, block_size * 4] + + # Dequant value: view as FP8, convert to float32 + kv_value = kv_value_raw.contiguous().view(dtype=FP8_DTYPE).to(torch.float32) + kv_value = kv_value.view(batch_size, max_padded_seq, head_dim) + + # Dequant scale + kv_scale = kv_scale_raw.contiguous().view(dtype=torch.float32) + kv_scale = kv_scale.view(batch_size, max_padded_seq) + + # Q: [batch, num_heads, head_dim] + q = q_fp8[:, 0].to(torch.float32) + + # Batched matmul: [batch, max_padded_seq, head_dim] @ [batch, head_dim, num_heads] + score = torch.bmm(kv_value, q.transpose(1, 2)) # [batch, max_padded_seq, num_heads] + + # ReLU + scale by weight + sum across heads + score = F.relu(score) + score = score * weight.unsqueeze(1) # [batch, max_padded_seq, num_heads] + score = score.sum(dim=2) # [batch, max_padded_seq] + + # Apply KV scale + score = score * kv_scale # [batch, max_padded_seq] + + # Create validity mask and write output — graph-safe (no torch.tensor() calls) + out_width = min(max_padded_seq, max_seq_len) + logits = score.new_full((batch_size, max_seq_len), float("-inf")) + logits[:, :out_width] = score[:, :out_width] + + # Mask invalid positions to -inf + positions = torch.arange(max_seq_len, device=device) + invalid_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze( + 1 + ) # [batch, max_seq_len] + logits.masked_fill_(invalid_mask, float("-inf")) return logits +_NEG_INF_I32_CACHE: dict = {} +_NEG_ONE_I32_CACHE: dict = {} + + +def _neg_inf_scalar(device: torch.device) -> torch.Tensor: + """Cached scalar -inf fp32 tensor per device.""" + key = str(device) + t = _NEG_INF_I32_CACHE.get(key) + if t is None: + t = torch.tensor(float("-inf"), device=device, dtype=torch.float32) + _NEG_INF_I32_CACHE[key] = t + return t + + +def _neg_one_i32_scalar(device: torch.device) -> torch.Tensor: + """Cached scalar -1 int32 tensor per device.""" + key = str(device) + t = _NEG_ONE_I32_CACHE.get(key) + if t is None: + t = torch.tensor(-1, device=device, dtype=torch.int32) + _NEG_ONE_I32_CACHE[key] = t + return t + + def topk_transform_512_pytorch_vectorized( scores: torch.Tensor, seq_lens: torch.Tensor, @@ -121,13 +185,18 @@ def topk_transform_512_pytorch_vectorized( page_bits = (page_size - 1).bit_length() if page_size > 1 else 0 page_mask = page_size - 1 - positions = ( - torch.arange(max_seq_len, device=device).unsqueeze(0).expand(batch_size, -1) - ) - valid_mask = positions < seq_lens.unsqueeze(1) + # Per-device cached scalar constants — constructing torch.tensor(...) on + # the device path here would otherwise force an H2D copy + sync per call + # (this function runs per-layer per prefill chunk). + neg_inf = _neg_inf_scalar(device) + neg_one_i32 = _neg_one_i32_scalar(device) - masked_scores = scores.clone() - masked_scores[~valid_mask] = float("-inf") + positions = torch.arange(max_seq_len, device=device).unsqueeze(0) # (1, S) + valid_mask = positions < seq_lens.unsqueeze(1) # (B, S) + + # NOTE: avoid `masked_scores[~valid_mask] = float("-inf")` — boolean- + # indexed scatter writes are a known L0 sync hot spot on Intel XPU. + masked_scores = torch.where(valid_mask, scores, neg_inf) actual_k = min(TOPK, max_seq_len) _, raw_indices = torch.topk( @@ -141,39 +210,35 @@ def topk_transform_512_pytorch_vectorized( ) raw_indices = torch.cat([raw_indices, padding], dim=1) - batch_indices = ( - torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, TOPK) - ) - gathered_scores = scores[ - batch_indices.flatten(), raw_indices.clamp(min=0).flatten() - ].view(batch_size, TOPK) + # Gather along dim=1 with a single int64 gather kernel — significantly + # cheaper than the 2-D advanced index `scores[batch_idx.flatten(), + # raw.flatten()].view(...)` pattern, which materializes a flat int64 + # index tensor of size B*TOPK for every call. + gather_idx = raw_indices.clamp(min=0).to(torch.long) + gathered_scores = torch.gather(scores, dim=1, index=gather_idx) valid_topk = gathered_scores != float("-inf") if actual_k < TOPK: pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k valid_topk = valid_topk & ~pad_mask - needs_sequential = seq_lens <= TOPK - if needs_sequential.any(): - sequential_indices = ( - torch.arange(TOPK, device=device, dtype=torch.int32) - .unsqueeze(0) - .expand(batch_size, -1) - ) - sequential_valid = sequential_indices < seq_lens.unsqueeze(1) - - raw_indices = torch.where( - needs_sequential.unsqueeze(1).expand(-1, TOPK), - torch.where( - sequential_valid, - sequential_indices, - torch.tensor(-1, device=device, dtype=torch.int32), - ), - raw_indices, - ) - valid_topk = torch.where( - needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk - ) + # Always run the sequential override (cheap when no row hits it). + # Previously this was guarded by `if needs_sequential.any():` which + # forced a D2H sync per call on XPU. + needs_sequential = (seq_lens <= TOPK).unsqueeze(1) # (B, 1) bool + sequential_indices = torch.arange( + TOPK, device=device, dtype=torch.int32 + ).unsqueeze(0) # (1, TOPK) + sequential_valid = sequential_indices < seq_lens.unsqueeze(1).to( + sequential_indices.dtype + ) # (B, TOPK) + + raw_indices = torch.where( + needs_sequential, + torch.where(sequential_valid, sequential_indices, neg_one_i32), + raw_indices, + ) + valid_topk = torch.where(needs_sequential, sequential_valid, valid_topk) page_idx = raw_indices >> page_bits offset_in_page = raw_indices & page_mask @@ -184,16 +249,12 @@ def topk_transform_512_pytorch_vectorized( page_indices = (physical_pages << page_bits) | offset_in_page page_indices = page_indices.to(torch.int32) - page_indices = torch.where( - valid_topk, page_indices, torch.tensor(-1, device=device, dtype=torch.int32) - ) + page_indices = torch.where(valid_topk, page_indices, neg_one_i32) out_page_indices.copy_(page_indices) if out_raw_indices is not None: - raw_indices = torch.where( - valid_topk, raw_indices, torch.tensor(-1, device=device, dtype=torch.int32) - ) + raw_indices = torch.where(valid_topk, raw_indices, neg_one_i32) out_raw_indices.copy_(raw_indices) @@ -442,10 +503,13 @@ def forward_c4_indexer( from sglang.srt.layers.attention.dsv4.tilelang_kernel import ( tilelang_fp8_paged_mqa_logits as fn, ) - elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or _is_sm120: fn = fp8_paged_mqa_logits_torch elif _is_cpu and _cpu_amx: fn = fp8_paged_mqa_logits_cpu + elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TRITON.get(): + from .triton_fp8_paged_mqa_logits import fp8_paged_mqa_logits_triton + fn = fp8_paged_mqa_logits_triton else: from deep_gemm import fp8_paged_mqa_logits as fn diff --git a/python/sglang/srt/layers/attention/dsv4/metadata.py b/python/sglang/srt/layers/attention/dsv4/metadata.py index 275b0b2ccff1..ab10ccdef12c 100644 --- a/python/sglang/srt/layers/attention/dsv4/metadata.py +++ b/python/sglang/srt/layers/attention/dsv4/metadata.py @@ -8,6 +8,10 @@ from sglang.srt.environ import envs from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_hip +from sglang.srt.utils.common import get_device_sm + +_is_cuda = torch.cuda.is_available() and not is_hip() +_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 if TYPE_CHECKING: pass @@ -103,9 +107,10 @@ class PagedIndexerMetadata: topk_metadata: torch.Tensor = field(init=False, repr=False) def __post_init__(self): - if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or ( - is_cpu() and cpu_has_amx_support() - ): + if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get() or envs.SGLANG_FP8_PAGED_MQA_LOGITS_TRITON.get() or ( + is_cpu() and cpu_has_amx_support()) or _is_sm120: + # SM120: DeepGEMM get_paged_mqa_logits_metadata asserts + # "Unsupported architecture" on SM120. Use None (torch fallback path). self.deep_gemm_metadata = None else: import deep_gemm diff --git a/python/sglang/srt/layers/attention/dsv4/triton_fp8_paged_mqa_logits.py b/python/sglang/srt/layers/attention/dsv4/triton_fp8_paged_mqa_logits.py new file mode 100644 index 000000000000..536aae040646 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/triton_fp8_paged_mqa_logits.py @@ -0,0 +1,197 @@ +import torch +import triton +import triton.language as tl +from typing import Any, Optional +from sglang.srt.utils import is_hip + +if is_hip(): + FP8_DTYPE = torch.float8_e4m3fnuz +else: + FP8_DTYPE = torch.float8_e4m3fn + +_FP8_PAGED_MQA_LOGITS_CHUNK_BYTES = 256 * 1024 * 1024 + + +@triton.jit +def _score_relu_weight_scale_kernel( + # score: [cb, padded_seq_len, num_heads] fp32 (input/output fused) + # weight: [cb, num_heads] fp32 + # kv_scale: [cb, padded_seq_len] fp32 + # seq_lens: [cb] int + # logits_out: [cb, max_seq_len] fp32 + score_ptr, + weight_ptr, + kv_scale_ptr, + seq_lens_ptr, + logits_ptr, + cb, + padded_seq_len, + max_seq_len, + num_heads: tl.constexpr, + score_stride_b, + score_stride_s, + weight_stride_b, + scale_stride_b, + logits_stride_b, + BLOCK_S: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Fused: relu(score) * weight → sum over heads → * kv_scale → masked store. + + Grid: (cb, cdiv(padded_seq_len, BLOCK_S)) + Each program handles BLOCK_S seq positions for one batch row. + """ + pid_b = tl.program_id(0) + pid_s = tl.program_id(1) + + if pid_b >= cb: + return + + seq_len = tl.load(seq_lens_ptr + pid_b).to(tl.int32) + + s_offs = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) + s_mask = s_offs < padded_seq_len + + # Load weight[batch, :] — all heads + wt_base = weight_ptr + pid_b * weight_stride_b + h_offs = tl.arange(0, BLOCK_H) + h_mask = h_offs < num_heads + w = tl.load(wt_base + h_offs, mask=h_mask, other=0.0) # [BLOCK_H] + + # Accumulate: for each position, sum over heads of relu(score) * weight + score_base = score_ptr + pid_b * score_stride_b + acc = tl.zeros([BLOCK_S], dtype=tl.float32) + + for h in range(num_heads): + # score[batch, s, h] + s_ptrs = score_base + s_offs * score_stride_s + h + val = tl.load(s_ptrs, mask=s_mask, other=0.0) + # ReLU + val = tl.maximum(val, 0.0) + # Multiply by weight[h] + w_h = tl.load(wt_base + h).to(tl.float32) + acc += val * w_h + + # Multiply by kv_scale[batch, s] + scale_base = kv_scale_ptr + pid_b * scale_stride_b + kv_s = tl.load(scale_base + s_offs, mask=s_mask, other=0.0) + acc = acc * kv_s + + # Masked store to logits — zero out positions >= seq_len + valid = s_mask & (s_offs < seq_len) & (s_offs < max_seq_len) + write_mask = s_mask & (s_offs < max_seq_len) + out_base = logits_ptr + pid_b * logits_stride_b + tl.store(out_base + s_offs, tl.where(valid, acc, 0.0), mask=write_mask) + + +def fp8_paged_mqa_logits_triton( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + """Triton-accelerated fp8_paged_mqa_logits. + + Paged KV gather + fp8→fp32 upcast + einsum remain in PyTorch (already + vectorized, no host syncs). The post-matmul pipeline (ReLU, weight multiply, + head reduction, scale multiply, validity masking, store) is fused into a + single Triton kernel — replacing 5 separate PyTorch ops per chunk. + """ + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + + assert head_dim == 128 + assert block_size == 64 + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + if seq_lens.dim() > 1: + seq_lens = seq_lens.squeeze(-1) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + device = q_fp8.device + head_dim_with_sf = head_dim + 4 + SCALE_OFFSET = block_size * head_dim + + max_pages_eff = (max_seq_len + block_size - 1) // block_size + P = min(page_table.shape[1], max_pages_eff) + padded_seq_len = P * block_size + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + + kv_flat = kvcache_fp8.reshape(-1, block_size * head_dim_with_sf) + num_pages_total = kv_flat.shape[0] + + bytes_per_row = max(1, P * block_size * head_dim * 4) + chunk_size = max(1, _FP8_PAGED_MQA_LOGITS_CHUNK_BYTES // bytes_per_row) + + pt = page_table[:, :P] + if num_pages_total > 0: + pt = pt.clamp_(min=0, max=num_pages_total - 1) + + BLOCK_S = 64 + BLOCK_H = triton.next_power_of_2(num_heads) + + for s in range(0, batch_size, chunk_size): + e = min(s + chunk_size, batch_size) + cb = e - s + + # Paged gather (vectorized, no host sync) + kv = kv_flat[pt[s:e]] + kv_value_b = kv[..., :SCALE_OFFSET].contiguous() + kv_scale_b = kv[..., SCALE_OFFSET:].contiguous() + + # bytes -> fp8 -> fp32 + kv_value = ( + kv_value_b.view(dtype=FP8_DTYPE) + .view(cb, padded_seq_len, head_dim) + .to(torch.float32) + ) + # bytes -> fp32 scale per token + kv_scale = kv_scale_b.view(dtype=torch.float32).view(cb, padded_seq_len) + + # q: (cb, num_heads, head_dim) fp32 + q = q_fp8[s:e, 0].to(torch.float32) + + # Batched matmul: score[b,s,h] = sum_d(kv_value[b,s,d] * q[b,h,d]) + # shape: (cb, padded_seq_len, num_heads) + score = torch.einsum("bsd,bhd->bsh", kv_value, q) + + # Fused Triton kernel: relu -> weight -> sum_heads -> scale -> mask -> store + score = score.contiguous() + kv_scale = kv_scale.contiguous() + + write_len = min(padded_seq_len, max_seq_len) + grid = (cb, triton.cdiv(padded_seq_len, BLOCK_S)) + + _score_relu_weight_scale_kernel[grid]( + score, + weight[s:e], + kv_scale, + seq_lens[s:e], + logits[s:e], + cb, + padded_seq_len, + max_seq_len, + num_heads=num_heads, + score_stride_b=score.stride(0), + score_stride_s=score.stride(1), + weight_stride_b=weight.stride(0), + scale_stride_b=kv_scale.stride(0), + logits_stride_b=logits.stride(1) * 0 + max_seq_len, # logits is contiguous + BLOCK_S=BLOCK_S, + BLOCK_H=BLOCK_H, + ) + + # Zero-fill remaining columns if padded_seq_len < max_seq_len + if write_len < max_seq_len: + logits[s:e, write_len:] = 0 + + return logits diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py new file mode 100644 index 000000000000..d25b9ff853a2 --- /dev/null +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_fallback.py @@ -0,0 +1,310 @@ +"""FlashMLA adapter with SM120 fallback. + +The FP8 KV cache uses a page-internal layout where NOPE+ROPE data has +stride (nope_dim + rope_dim*2) per token, and scales are stored in a +separate region at the end of each page. The tensor shape +``(num_pages, page_size, 1, bytes_per_token)`` is just metadata for the +FlashMLA CUDA kernel -- it does NOT mean each token occupies +*bytes_per_token* contiguous bytes. + +On SM120 (Blackwell Desktop / RTX PRO 6000) the flash_mla CUDA kernel +is not available, so this module provides a pure-PyTorch fallback that +reads the raw paged buffer with the correct addressing. + +When SGLANG_SM120_TRITON_FLASHMLA=1 (default), a fused Triton kernel is +used instead of the PyTorch fallback for significantly better performance. +Set to 0 to fall back to the pure-PyTorch path. +""" + +import logging +import os + +import torch + +from sglang.srt.utils import is_hip, is_xpu +from sglang.srt.utils.common import get_device_sm + +logger = logging.getLogger(__name__) + +_is_cuda = torch.cuda.is_available() and not is_hip() +_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 +_is_xpu = is_xpu() + +# Page layout constants for DSv4-Flash (MODEL1): +# nope_dim = 448, rope_dim = 64, quantize_block_size = 64 +# nope_rope_stride = 448 + 64*2 = 576 bytes per token +# scale_stride = ceil(448/64) + 1 = 8 bytes per token (7 scales + 1 pad) +# bytes_per_token = 448 + 128 + 8 = 584 +# page_bytes = ceil_div(page_size * 584, 576) * 576 + +_NOPE_DIM = 448 +_ROPE_DIM = 64 +_NOPE_ROPE_STRIDE = _NOPE_DIM + _ROPE_DIM * 2 # 576 +_TILE_SIZE = 64 +_NUM_TILES = _NOPE_DIM // _TILE_SIZE # 7 +_SCALE_STRIDE = _NUM_TILES + 1 # 8 (7 scales + 1 pad) +_D = _NOPE_DIM + _ROPE_DIM # 512 + + +_GATHER_CHUNK = 16384 # tokens per chunk; ~16k * 1024 B ≈ 16 MiB output per chunk + +# Per-chunk peak-memory budget for the sparse decode fallback (MiB). Read +# once at import time so the forward path doesn't pay an os.environ lookup +# per layer per decode step. +_SM120_SPARSE_CHUNK_MIB = int(os.environ.get("SGLANG_SM120_SPARSE_CHUNK_MIB", "256")) + + +def _gather_and_dequant(k_cache, indices, page_size): + """Gather KV entries from the paged buffer using correct page-internal addressing. + + Args: + k_cache: (num_pages, page_size, 1, bytes_per_token) float8_e4m3fn + Non-contiguous view of the raw page buffer. + indices: (...) int32/int64, token-level indices. Invalid indices are + expected to already be clamped into [0, num_pages*page_size). + page_size: tokens per page (e.g. 256, 64, 2) + + Returns: + kv: (..., _D) bfloat16, dequantized KV vectors + """ + idx_shape = indices.shape + flat_idx = indices.reshape(-1) # (N,) + N = flat_idx.shape[0] + device = k_cache.device + + page_bytes = k_cache.stride(0) # actual byte stride between pages + num_pages = k_cache.shape[0] + + # Flatten the raw byte buffer so we can gather with a single int64 index + # per byte instead of paying for a full (N, 448) int64 index tensor up + # front. flat_buf has nelems = num_pages * page_bytes uint8. + raw_pages = k_cache.as_strided( + (num_pages, page_bytes), + (page_bytes, 1), + ).view(torch.uint8) + flat_buf = raw_pages.reshape(-1) + + scale_section_offset = page_size * _NOPE_ROPE_STRIDE + + nope_arange = torch.arange(_NOPE_DIM, device=device, dtype=torch.long) + rope_arange = torch.arange(_ROPE_DIM * 2, device=device, dtype=torch.long) + scale_arange = torch.arange(_NUM_TILES, device=device, dtype=torch.long) + + result = torch.empty(N, _D, dtype=torch.bfloat16, device=device) + + # Process in chunks to bound peak memory of the int64 advanced-index + # tensors (which would otherwise be N * 448 * 8 bytes — multiple GB on + # long-context prefills with large topk). + for start in range(0, N, _GATHER_CHUNK): + end = min(start + _GATHER_CHUNK, N) + chunk = flat_idx[start:end] + n = end - start + + pages = chunk // page_size + offsets = chunk % page_size + + # Per-token base byte offset into the flat raw buffer. + page_base = pages.to(torch.long) * page_bytes # (n,) + nope_base = page_base + offsets.to(torch.long) * _NOPE_ROPE_STRIDE # (n,) + + nope_idx = nope_base.unsqueeze(-1) + nope_arange # (n, 448) + rope_idx = nope_base.unsqueeze(-1) + (_NOPE_DIM + rope_arange) # (n, 128) + scale_idx = ( + page_base.unsqueeze(-1) + + scale_section_offset + + offsets.to(torch.long).unsqueeze(-1) * _SCALE_STRIDE + + scale_arange + ) # (n, 7) + + nope_bytes = flat_buf[nope_idx.reshape(-1)].view(n, _NOPE_DIM) + rope_bytes = flat_buf[rope_idx.reshape(-1)].view(n, _ROPE_DIM * 2) + scale_bytes = flat_buf[scale_idx.reshape(-1)].view(n, _NUM_TILES) + + nope_fp8 = nope_bytes.view(torch.float8_e4m3fn) # (n, 448) + rope_bf16 = rope_bytes.contiguous().view(torch.bfloat16) # (n, 64) + scale_e8m0 = scale_bytes.view(torch.float8_e8m0fnu) # (n, 7) + + result[start:end, :_NOPE_DIM] = ( + nope_fp8.view(n, _NUM_TILES, _TILE_SIZE).float() + * scale_e8m0.view(n, _NUM_TILES, 1).float() + ).view(n, _NOPE_DIM).to(torch.bfloat16) + result[start:end, _NOPE_DIM:] = rope_bf16 + + return result.reshape(*idx_shape, _D) + + +def _sm120_sparse_decode_fwd( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache=None, + extra_indices=None, + extra_topk_length=None, +): + B, s_q, H_q, D_qk = q.shape + num_pages, page_size, H_k, bpt = k_cache.shape + topk = indices.shape[-1] + device = q.device + + # FlashMLA kernel treats `index == -1` as invalid; we additionally treat + # any index outside [0, num_pages*page_size) as invalid because the CUDA + # tile scheduler would simply never visit those slots, whereas this + # PyTorch fallback gathers them eagerly. + max_valid = num_pages * page_size + invalid_mask = (indices < 0) | (indices >= max_valid) + safe_indices = indices.clamp(min=0, max=max_valid - 1) + if topk_length is not None: + topk_range = torch.arange(topk, device=topk_length.device).view(1, 1, topk) + invalid_mask = invalid_mask | (topk_range >= topk_length.view(B, 1, 1)) + + have_extra = extra_k_cache is not None and extra_indices is not None + if have_extra: + extra_topk = extra_indices.shape[-1] + extra_num_pages, extra_page_size = extra_k_cache.shape[0], extra_k_cache.shape[1] + extra_max_valid = extra_num_pages * extra_page_size + extra_invalid = (extra_indices < 0) | (extra_indices >= extra_max_valid) + extra_safe = extra_indices.clamp(min=0, max=extra_max_valid - 1) + if extra_topk_length is not None: + extra_range = torch.arange(extra_topk, device=extra_topk_length.device).view(1, 1, extra_topk) + extra_invalid = extra_invalid | (extra_range >= extra_topk_length.view(B, 1, 1)) + else: + extra_topk = 0 + + total_topk = topk + extra_topk + # Flatten the (B, s_q) row dimension so we can chunk easily. + R = B * s_q # number of query rows + q_rows = q.reshape(R, H_q, D_qk) + safe_indices_rows = safe_indices.reshape(R, topk) + invalid_rows = invalid_mask.reshape(R, topk) + if have_extra: + extra_safe_rows = extra_safe.reshape(R, extra_topk) + extra_invalid_rows = extra_invalid.reshape(R, extra_topk) + + out_rows = torch.empty(R, H_q, head_dim_v, dtype=torch.bfloat16, device=device) + lse_rows = torch.empty(R, H_q, dtype=torch.float32, device=device) + + # Bound per-chunk peak memory. Dominant bf16 tensor is gathered KV: + # chunk * total_topk * _D * 2 bytes; fp32 working set adds ~3x on top. + # On Intel L0, per-launch overhead is high (~hundreds of us), so prefer + # fewer/larger chunks. Target 256 MiB peak (override via + # SGLANG_SM120_SPARSE_CHUNK_MIB at import time). + bytes_per_row = total_topk * _D * 2 + chunk_rows = max( + 1, min(R, (_SM120_SPARSE_CHUNK_MIB * 1024 * 1024) // max(1, bytes_per_row)) + ) + + for start in range(0, R, chunk_rows): + end = min(start + chunk_rows, R) + n = end - start + + # Gather KV for this chunk only. + kv_chunk = _gather_and_dequant( + k_cache, safe_indices_rows[start:end], page_size + ) # (n, topk, _D) + inv_chunk = invalid_rows[start:end] # (n, topk) + if have_extra: + extra_kv_chunk = _gather_and_dequant( + extra_k_cache, extra_safe_rows[start:end], extra_page_size + ) # (n, extra_topk, _D) + kv_chunk = torch.cat([kv_chunk, extra_kv_chunk], dim=1) + inv_chunk = torch.cat([inv_chunk, extra_invalid_rows[start:end]], dim=1) + del extra_kv_chunk + + # Zero-out invalid KV rows so they contribute nothing (kernel parity). + kv_chunk[inv_chunk] = 0.0 + + q_chunk = q_rows[start:end].float() # (n, H_q, D_qk) + kv_f = kv_chunk.float() # (n, T, _D) + kv_d = kv_f.shape[-1] + if D_qk != kv_d: + q_chunk = q_chunk[..., :kv_d] + + # scores: (n, H_q, T) + scores = torch.einsum("nhd,ntd->nht", q_chunk, kv_f) * softmax_scale + scores.masked_fill_(inv_chunk.unsqueeze(1).expand_as(scores), float("-inf")) + + lse = torch.logsumexp(scores, dim=-1) # (n, H_q) + + if attn_sink is not None: + lse_for_out = torch.logsumexp( + torch.stack([lse, attn_sink.view(1, H_q).expand_as(lse)], dim=0), + dim=0, + ) + else: + lse_for_out = lse.clone() + + lonely = lse == float("-inf") + lse_for_out[lonely] = float("inf") + weights = torch.exp(scores - lse_for_out.unsqueeze(-1)) + out_chunk = torch.einsum("nht,ntv->nhv", weights, kv_f[..., :head_dim_v]) + out_chunk[lonely.unsqueeze(-1).expand_as(out_chunk)] = 0.0 + + out_rows[start:end] = out_chunk.to(torch.bfloat16) + lse_rows[start:end] = lse + + del kv_chunk, kv_f, q_chunk, scores, weights, out_chunk, lse, lse_for_out, lonely + + out = out_rows.reshape(B, s_q, H_q, head_dim_v) + lse = lse_rows.reshape(B, s_q, H_q).permute(0, 2, 1) + return out, lse + + +_use_triton_flashmla = os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "0") == "1" + + +def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): + if _is_sm120 or _is_xpu: + q = kwargs["q"] + k_cache = kwargs["k_cache"] + indices = kwargs["indices"] + topk_length = kwargs.get("topk_length") + attn_sink = kwargs.get("attn_sink") + head_dim_v = kwargs["head_dim_v"] + softmax_scale = kwargs.get("softmax_scale") + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + extra_k_cache = kwargs.get("extra_k_cache") + extra_indices = kwargs.get("extra_indices_in_kvcache") + extra_topk_length = kwargs.get("extra_topk_length") + + if _use_triton_flashmla: + from sglang.srt.layers.attention.flash_mla_sm120_triton import ( + flash_mla_sparse_decode_triton, + ) + + out, lse = flash_mla_sparse_decode_triton( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache, + extra_indices, + extra_topk_length, + ) + return (out, lse) + + out, lse = _sm120_sparse_decode_fwd( + q, + k_cache, + indices, + topk_length, + attn_sink, + head_dim_v, + softmax_scale, + extra_k_cache, + extra_indices, + extra_topk_length, + ) + return (out, lse) + + assert backend == "kernel", f"unsupported backend {backend!r}" + import flash_mla + + return flash_mla.flash_mla_with_kvcache(**kwargs) diff --git a/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py new file mode 100644 index 000000000000..0eae57d3a350 --- /dev/null +++ b/python/sglang/srt/layers/attention/flash_mla_sm120_triton.py @@ -0,0 +1,366 @@ +"""SM120-optimized Triton FlashMLA sparse decode kernel — Tiled V2. + +Replaces V1's serial token loop with a tiled vectorized approach: + 1. BLOCK_T tokens loaded simultaneously via 2D gather (vs 1-at-a-time) + 2. All BLOCK_T QK scores computed at once via vectorized mul-reduce + 3. V accumulation via vectorized weighted sum across BLOCK_T tokens + 4. Online softmax operates on tile-level maxima (fewer rescales) + +Three typed views of the same paged buffer handle FP8/uint8/BF16 regions: +- float8_e4m3fn view -> nope FP8 values (direct load + dequant) +- uint8 view -> UE8M0 scale bytes (raw integer -> exp2 conversion) +- bfloat16 view -> rope BF16 values (direct load) + +DSv4 page layout (per token, 576 bytes data + 8 bytes scales): + Data section: [0:448] FP8 nope | [448:576] BF16 rope (64 values = 128 bytes) + Scale section: [page_size*576 + offset*8 : +7] UE8M0 scales (7 groups of 64) + +Target: RTX PRO 6000 (SM120, 188 SMs, 99KB SMEM, ~1.5 TB/s GDDR7, 96MB L2) +""" + +import logging +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + +LOG2E = tl.constexpr(1.4426950408889634) + +# DSv4 KV cache layout constants +_NOPE_DIM = 448 +_ROPE_DIM = 64 +_D = _NOPE_DIM + _ROPE_DIM # 512 +_TOKEN_DATA_STRIDE = 576 # bytes per token in data section +_SCALE_STRIDE = 8 # bytes per token in scale section + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_T": 16}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_T": 16}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_T": 32}, num_warps=8, num_stages=2), + ], + key=["topk_rounded"], +) +@triton.jit +def _tiled_sparse_decode_kernel( + # Q: [B, H, D] bf16 + Q_ptr, + # Paged KV cache — three typed views of same underlying memory + cache_fp8_ptr, # float8_e4m3fn flat (1 byte/elem) — for nope + cache_uint8_ptr, # uint8 flat (1 byte/elem) — for scales + cache_bf16_ptr, # bfloat16 flat (2 bytes/elem) — for rope + # Indices: [B, topk] int32 + indices_ptr, + # Valid lengths: [B] int32 + topk_len_ptr, + # Output: [B, H, D] bf16 and LSE: [B, H] float32 + O_ptr, + LSE_ptr, + # Scalars + sm_scale: tl.float32, + page_size: tl.int32, + page_bytes: tl.int64, + scale_section_off: tl.int64, # page_size * 576 + H: tl.int32, + topk: tl.int32, + topk_rounded: tl.int32, # for autotune key + has_topk_len: tl.constexpr, + # Strides + stride_qb: tl.int32, + stride_qh: tl.int32, + stride_ob: tl.int32, + stride_oh: tl.int32, + stride_ib: tl.int32, # indices batch stride + # Constexprs + NOPE_PAD: tl.constexpr, # 512 (padded from 448) + ROPE_DIM: tl.constexpr, # 64 + NOPE_DIM_RT: tl.int32, # 448 (runtime, for masking) + BLOCK_T: tl.constexpr, # tokens per tile (16 or 32) +): + """Tiled sparse decode: vectorized gather + QK + softmax + V accumulation. + + Grid: (B, H) — one block per (batch, head) pair. + Each block processes all topk tokens in tiles of BLOCK_T. + """ + bid = tl.program_id(0) + hid = tl.program_id(1) + + # ---- Load Q for this (batch, head) ---- + q_base = bid * stride_qb + hid * stride_qh + nope_offs = tl.arange(0, NOPE_PAD) # [512] + nope_mask = nope_offs < NOPE_DIM_RT # [512], True for [0:448] + rope_offs = tl.arange(0, ROPE_DIM) # [64] + + q_nope = tl.load(Q_ptr + q_base + nope_offs, mask=nope_mask, other=0.0) + q_nope = q_nope.to(tl.float32) * sm_scale + q_rope = tl.load(Q_ptr + q_base + NOPE_DIM_RT + rope_offs) + q_rope = q_rope.to(tl.float32) * sm_scale + + # ---- Valid token count ---- + valid_topk = topk + if has_topk_len: + valid_topk = tl.load(topk_len_ptr + bid).to(tl.int32) + valid_topk = tl.minimum(valid_topk, topk) + + # ---- Online softmax state (base-2 math for SM120 efficiency) ---- + m_i: tl.float32 = -1e30 + l_i: tl.float32 = 0.0 + acc_nope = tl.zeros([NOPE_PAD], dtype=tl.float32) + acc_rope = tl.zeros([ROPE_DIM], dtype=tl.float32) + + # ---- Precompute constant index vectors ---- + group_ids = (nope_offs // 64).to(tl.int64) # [NOPE_PAD], scale group for each dim + t_offs = tl.arange(0, BLOCK_T) # [BLOCK_T], token offsets within tile + + # ---- Process tokens in tiles of BLOCK_T ---- + for tile_start in range(0, topk, BLOCK_T): + t_idx = tile_start + t_offs # [BLOCK_T], global token indices + t_in_bounds = t_idx < topk # bounds for index load + t_valid = t_idx < valid_topk # bounds for actual processing + + # Load indices for this tile: [BLOCK_T] + raw_indices = tl.load( + indices_ptr + bid * stride_ib + t_idx, + mask=t_in_bounds, + other=-1, + ) + idx_valid = t_valid & (raw_indices >= 0) # [BLOCK_T] mask + + # Page addressing: [BLOCK_T] (clamp for safe addressing of invalid tokens) + safe_indices = tl.where(idx_valid, raw_indices, tl.zeros_like(raw_indices)) + page_ids = (safe_indices // page_size).to(tl.int64) + page_offs_t = (safe_indices % page_size).to(tl.int64) + token_data_bases = page_ids * page_bytes + page_offs_t * 576 # [BLOCK_T] int64 + + # ---- Vectorized NOPE FP8 gather: [BLOCK_T, NOPE_PAD] ---- + nope_addrs = token_data_bases[:, None] + nope_offs[None, :].to(tl.int64) + nope_2d_mask = idx_valid[:, None] & nope_mask[None, :] + kv_nope_fp8 = tl.load( + cache_fp8_ptr + nope_addrs, + mask=nope_2d_mask, + other=0.0, + ) + + # ---- Vectorized scale gather + dequant: [BLOCK_T, NOPE_PAD] ---- + scale_bases = page_ids * page_bytes + scale_section_off + page_offs_t * 8 + scale_addrs = scale_bases[:, None] + group_ids[None, :] + scale_raw = tl.load( + cache_uint8_ptr + scale_addrs, + mask=nope_2d_mask, + other=127, + ) + scale_f32 = tl.math.exp2(scale_raw.to(tl.float32) - 127.0) + kv_nope = tl.where(nope_2d_mask, kv_nope_fp8.to(tl.float32) * scale_f32, 0.0) + + # ---- Vectorized ROPE BF16 gather: [BLOCK_T, ROPE_DIM] ---- + rope_byte_bases = token_data_bases + 448 + rope_elem_bases = (rope_byte_bases // 2).to(tl.int64) + rope_addrs = rope_elem_bases[:, None] + rope_offs[None, :].to(tl.int64) + kv_rope = tl.load( + cache_bf16_ptr + rope_addrs, + mask=idx_valid[:, None], + other=0.0, + ).to(tl.float32) + + # ---- Vectorized QK scores: [BLOCK_T] ---- + # scores[t] = dot(q_nope, kv_nope[t]) + dot(q_rope, kv_rope[t]) + scores = tl.sum(q_nope[None, :] * kv_nope, axis=1) + tl.sum( + q_rope[None, :] * kv_rope, axis=1 + ) + scores = tl.where(idx_valid, scores, -1e30) + + # ---- Online softmax update (base-2, tile-level) ---- + scores_log2 = scores * LOG2E # [BLOCK_T] + tile_max = tl.max(scores_log2) # scalar + m_new = tl.maximum(m_i, tile_max) + + alpha = tl.math.exp2(m_i - m_new) # rescale factor + p = tl.math.exp2(scores_log2 - m_new) # [BLOCK_T] attention weights + p = tl.where(idx_valid, p, 0.0) # zero out invalid + + l_i = l_i * alpha + tl.sum(p) + + # ---- Vectorized V accumulation (K=V in MLA) ---- + # acc += sum_t(p[t] * kv[t, :]) for both nope and rope + acc_nope = acc_nope * alpha + tl.sum(p[:, None] * kv_nope, axis=0) + acc_rope = acc_rope * alpha + tl.sum(p[:, None] * kv_rope, axis=0) + m_i = m_new + + # ---- Normalize output ---- + safe_l = tl.where(l_i > 0.0, l_i, 1.0) + acc_nope = acc_nope / safe_l + acc_rope = acc_rope / safe_l + + # LSE: convert from log2 back to natural log + lse = tl.where(l_i > 0.0, m_i / LOG2E + tl.math.log(safe_l), float("-inf")) + + # ---- Store output ---- + o_base = bid * stride_ob + hid * stride_oh + tl.store(O_ptr + o_base + nope_offs, acc_nope.to(tl.bfloat16), mask=nope_mask) + tl.store(O_ptr + o_base + NOPE_DIM_RT + rope_offs, acc_rope.to(tl.bfloat16)) + tl.store(LSE_ptr + bid * H + hid, lse) + + +def _run_triton_sparse_decode( + q: torch.Tensor, # [B, 1, H, D] bf16 + k_cache: torch.Tensor, # [num_pages, page_size, 1, bpt] float8 + indices: torch.Tensor, # [B, ...] int32 + topk_length: Optional[torch.Tensor], + softmax_scale: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Run the tiled Triton sparse decode kernel on one paged KV cache.""" + B, _, H, D = q.shape + num_pages = k_cache.shape[0] + page_size = k_cache.shape[1] + page_bytes = k_cache.stride(0) # elements = bytes for float8 + + # Flatten indices to [B, topk] + flat_indices = indices.reshape(B, -1).contiguous() + topk = flat_indices.shape[1] + + # Create three typed views of the flat cache memory + total_elems = num_pages * page_bytes + raw_fp8 = k_cache.as_strided((total_elems,), (1,)) + raw_uint8 = raw_fp8.view(torch.uint8) + raw_bf16 = raw_uint8.view(torch.bfloat16) + + # Squeeze Q: [B, H, D] + q3 = q.squeeze(1) + if not q3.is_contiguous(): + q3 = q3.contiguous() + + out = torch.zeros(B, H, D, dtype=torch.bfloat16, device=q.device) + lse = torch.full((B, H), float("-inf"), dtype=torch.float32, device=q.device) + + # Round topk for autotune key stability + topk_rounded = triton.next_power_of_2(topk) + + grid = (B, H) + _tiled_sparse_decode_kernel[grid]( + q3, + raw_fp8, + raw_uint8, + raw_bf16, + flat_indices, + ( + topk_length + if topk_length is not None + else torch.empty(0, device=q.device, dtype=torch.int32) + ), + out, + lse, + softmax_scale, + page_size, + int(page_bytes), # page_bytes (int64) + int(page_size * _TOKEN_DATA_STRIDE), # scale_section_off (int64) + H, + topk, + topk_rounded, + topk_length is not None, + q3.stride(0), + q3.stride(1), + out.stride(0), + out.stride(1), + flat_indices.stride(0), + NOPE_PAD=512, + ROPE_DIM=_ROPE_DIM, + NOPE_DIM_RT=_NOPE_DIM, + ) + + # Return [B, 1, H, D] and [B, 1, H] + return out.unsqueeze(1), lse.unsqueeze(1) + + +def _merge_partial_attn( + out1: torch.Tensor, + lse1: torch.Tensor, + out2: torch.Tensor, + lse2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Merge two attention outputs using LSE-weighted combination. + + out: [B, 1, H, D] bf16, lse: [B, 1, H] float32 + """ + max_lse = torch.maximum(lse1, lse2) + w1 = torch.where(lse1 > -1e20, torch.exp(lse1 - max_lse), torch.zeros_like(lse1)) + w2 = torch.where(lse2 > -1e20, torch.exp(lse2 - max_lse), torch.zeros_like(lse2)) + total = (w1 + w2).clamp(min=1e-20) + merged = ( + w1.unsqueeze(-1) * out1.float() + w2.unsqueeze(-1) * out2.float() + ) / total.unsqueeze(-1) + merged_lse = max_lse + torch.log(total) + return merged.to(torch.bfloat16), merged_lse + + +def _apply_attn_sink( + out: torch.Tensor, + lse: torch.Tensor, + attn_sink: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply attention sink normalization. + + The sink adds to the softmax denominator without contributing output, + effectively down-weighting all attention scores. + + out: [B, 1, H, D] bf16, lse: [B, 1, H] f32, attn_sink: [H] f32 + """ + sink_lse = attn_sink.view(1, 1, -1).expand_as(lse) + combined_lse = torch.logaddexp(lse, sink_lse) + w = torch.where( + lse > -1e20, + torch.exp(lse - combined_lse), + torch.zeros_like(lse), + ) + return (out.float() * w.unsqueeze(-1)).to(torch.bfloat16), combined_lse + + +def flash_mla_sparse_decode_triton( + q: torch.Tensor, + k_cache: torch.Tensor, + indices: torch.Tensor, + topk_length: Optional[torch.Tensor], + attn_sink: Optional[torch.Tensor], + head_dim_v: int, + softmax_scale: float, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """SM120-optimized sparse MLA decode using tiled Triton kernel. + + Processes SWA and extra (c4/c128) caches separately via the same + Triton kernel, then merges results using LSE-weighted combination. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + # Process main cache (SWA) + out, lse = _run_triton_sparse_decode( + q, + k_cache, + indices, + topk_length, + softmax_scale, + ) + + # Process extra cache (c4 / c128) if present + if extra_k_cache is not None and extra_indices is not None: + out_extra, lse_extra = _run_triton_sparse_decode( + q, + extra_k_cache, + extra_indices, + extra_topk_length, + softmax_scale, + ) + out, lse = _merge_partial_attn(out, lse, out_extra, lse_extra) + + # Apply attention sink + if attn_sink is not None: + out, lse = _apply_attn_sink(out, lse, attn_sink) + + # Return format matching PyTorch fallback: (out, lse.permute(0,2,1)) + return out, lse.permute(0, 2, 1) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 08434b06f992..ea83f63a6b57 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -24,30 +24,58 @@ ceil_align, cpu_has_amx_support, get_bool_env_var, + get_device_sm, is_cpu, is_cuda, is_gfx95_supported, is_hip, is_npu, + is_xpu, ) global _use_multi_stream _is_cuda = is_cuda() _is_hip = is_hip() +_is_sm120 = _is_cuda and get_device_sm() // 10 == 12 # SM120/SM121 _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_fp8_fnuz = is_fp8_fnuz() _is_gfx95_supported = is_gfx95_supported() -from sglang.srt.utils import cpu_has_amx_support, is_cpu - _is_cpu = is_cpu() _cpu_amx = cpu_has_amx_support() +_is_xpu = is_xpu() + if _is_cuda: try: import deep_gemm - except ImportError as e: + except (ImportError, AssertionError) as e: + # AssertionError: deep_gemm init fails on SM120 (no CUDA_HOME / unsupported arch) deep_gemm = e +if _is_sm120: + import os as _os + + if _os.environ.get("SGLANG_SM120_MQA_FALLBACK", "0") == "1": + from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + compute_paged_mqa_schedule_metadata as _sm120_compute_paged_mqa_schedule_metadata, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + sm120_fp8_mqa_logits as _sm120_fp8_mqa_logits, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + sm120_fp8_paged_mqa_logits as _sm120_fp8_paged_mqa_logits, + ) + else: + from sglang.srt.layers.attention.nsa.sm120_mqa_triton import ( + compute_paged_mqa_schedule_metadata as _sm120_compute_paged_mqa_schedule_metadata, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_triton import ( + sm120_fp8_mqa_logits as _sm120_fp8_mqa_logits, + ) + from sglang.srt.layers.attention.nsa.sm120_mqa_triton import ( + sm120_fp8_paged_mqa_logits as _sm120_fp8_paged_mqa_logits, + ) + if _use_aiter: from aiter.ops.cache import indexer_k_quant_and_cache @@ -153,12 +181,34 @@ def topk_transform( """ +def _torch_hadamard_transform(x: torch.Tensor, scale: float) -> torch.Tensor: + """Pure-torch FWHT fallback for backends without a fused kernel. + + Iterative Cooley-Tukey-style Walsh-Hadamard transform along the last + dim. Hidden size must be a power of two; same contract as the fused + ``hadamard_transform`` op. + """ + n = x.size(-1) + leading = x.shape[:-1] + out = x.reshape(-1, n).clone() + h = 1 + while h < n: + out = out.view(-1, n // (2 * h), 2, h) + a = out[:, :, 0, :] + b = out[:, :, 1, :] + out = torch.stack((a + b, a - b), dim=2).view(-1, n) + h *= 2 + return out.view(*leading, n) * scale + + def rotate_activation(x: torch.Tensor) -> torch.Tensor: # from sgl_kernel import hadamard_transform if _is_hip: from fast_hadamard_transform import hadamard_transform elif _is_cpu and _cpu_amx: hadamard_transform = torch.ops.sgl_kernel.fast_hadamard_transform_cpu + elif _is_xpu: + hadamard_transform = _torch_hadamard_transform else: from sglang.jit_kernel.hadamard import hadamard_transform @@ -206,7 +256,12 @@ def __init__( self.cp_size = None self.cp_rank = None if _is_cuda: - self.sm_count = deep_gemm.get_num_sms() + if _is_sm120: + # SM120: deep_gemm.get_num_sms() crashes; use torch native API + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + self.sm_count = props.multi_processor_count + else: + self.sm_count = deep_gemm.get_num_sms() self.half_device_sm_count = ceil_align(self.sm_count // 2, 8) pp_size = get_global_server_args().pp_size self.logits_with_pp_recv = pp_size > 1 and not get_pp_group().is_last_rank @@ -257,7 +312,7 @@ def _with_real_sm_count(self): # request to receive the PP proxy tensor or output from the previous stage, occupying one SM resource. # Model execution runs in parallel with the recv operation, so the SMs available to the indexer must be reduced # by 1. Currently, the last rank starts the send result + recv request only after waiting for execution results. - if self.logits_with_pp_recv: + if self.logits_with_pp_recv and not _is_sm120: pp_recv_sm_count = 1 with deep_gemm_wrapper.configure_deep_gemm_num_sms( self.sm_count - pp_recv_sm_count @@ -472,9 +527,16 @@ def _get_topk_paged( seqlens_32_2d = seqlens_32.unsqueeze(-1) if _is_cuda: if schedule_metadata is None: - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( - seqlens_32_2d, blocksize, self.sm_count - ) + if _is_sm120: + schedule_metadata = _sm120_compute_paged_mqa_schedule_metadata( + seqlens_32_2d, + blocksize, + self.sm_count, + ) + else: + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( + seqlens_32_2d, blocksize, self.sm_count + ) assert len(q_fp8.shape) == 3 q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now @@ -517,6 +579,17 @@ def _get_topk_paged( Preshuffle=False, KVBlockSize=block_kv, ) + elif _is_sm120: + logits = _sm120_fp8_paged_mqa_logits( + q_fp8[:q_offset], + kv_cache_fp8, + weights[:q_offset], + seqlens_32_2d, + block_tables, + schedule_metadata, + max_seq_len, + clean_logits=False, + ) else: logits = deep_gemm.fp8_paged_mqa_logits( q_fp8[:q_offset], @@ -646,6 +719,15 @@ def _get_topk_ragged( logits = fp8_mqa_logits( q_fp8[:q_offset], kv, scale, weights[:q_offset], ks, ke ) + elif _is_sm120: + logits = _sm120_fp8_mqa_logits( + q_fp8[:q_offset], + kv_fp8, + weights[:q_offset], + ks, + ke, + clean_logits=False, + ) else: logits = deep_gemm.fp8_mqa_logits( q_fp8[:q_offset], @@ -696,6 +778,15 @@ def _get_topk_ragged( ks[start:end], ke[start:end], ) + elif _is_sm120: + logits_chunk = _sm120_fp8_mqa_logits( + q_fp8[start:end], + kv_fp8, + weights[start:end], + ks[start:end], + ke[start:end], + clean_logits=False, + ) else: logits_chunk = deep_gemm.fp8_mqa_logits( q_fp8[start:end], diff --git a/python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py b/python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py new file mode 100644 index 000000000000..e696b5b9ccf8 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/sm120_mqa_fallback.py @@ -0,0 +1,215 @@ +""" +SM120 fallback kernels for DeepGEMM FP8 MQA logits operations. + +On SM120 (RTX 5090, RTX PRO 6000, DGX Spark), DeepGEMM's fp8_paged_mqa_logits +and fp8_mqa_logits crash with 'Unsupported architecture'. This module provides +PyTorch-native fallback implementations that match the DeepGEMM API contract. + +Reference: vLLM PR#40991 (Triton sparse MLA fallback approach for SM120) +""" + +from __future__ import annotations + +import logging +from typing import Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def compute_paged_mqa_schedule_metadata( + seqlens: torch.Tensor, + block_size: int, + num_sms: int, +) -> None: + """SM120 fallback: scheduling is handled internally, return None.""" + return None + + +def _dequant_fp8_with_scale_suffix( + data_fp8: torch.Tensor, head_dim_qk: int +) -> torch.Tensor: + """ + Dequantize FP8 tensor that has per-row scale factors appended. + + DeepGEMM packs KV cache as [data_fp8 (head_dim_qk bytes) | scale (4 bytes)] + in a tensor of shape [..., head_dim_with_sf] where head_dim_with_sf = head_dim_qk + 4. + The scale is stored as a float32 value in the last 4 bytes. + """ + # Split data and scale + data_bytes = data_fp8[..., :head_dim_qk] + # Scale is stored in the last 4 bytes, reinterpret as float32 + scale_bytes = data_fp8[..., head_dim_qk:] + scale = scale_bytes.contiguous().view(torch.float32) # [..., 1] + + # Dequantize: cast FP8 to float32, multiply by scale + data_f32 = data_bytes.to(torch.float32) * scale + return data_f32 + + +def sm120_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + seqlens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata, + max_seq_len: int, + clean_logits: bool = False, +) -> torch.Tensor: + """ + SM120 fallback for deep_gemm.fp8_paged_mqa_logits(). + + Computes weighted multi-head dot-product logits over paged KV cache. + + Args: + q_fp8: [batch, next_n, n_heads, head_dim_with_sf] FP8 queries with appended scale + kv_cache_fp8: [num_blocks, block_kv, 1, head_dim_with_sf] FP8 paged KV cache + weights: [batch, n_heads] float32 head weights + seqlens: [batch, 1] or [batch] int32 sequence lengths + block_tables: [batch, max_blocks] int32 block table indices + schedule_metadata: ignored on SM120 (None) + max_seq_len: maximum sequence length for output + clean_logits: if True, fill unused positions with -inf + + Returns: + logits: [batch * next_n, max_seq_len] float32 + """ + batch, next_n, n_heads, head_dim_with_sf = q_fp8.shape + head_dim_qk = head_dim_with_sf - 4 # 128 typically + block_kv = kv_cache_fp8.shape[1] # typically 64 + device = q_fp8.device + + # Flatten seqlens + if seqlens.dim() == 2: + seqlens = seqlens.squeeze(-1) + + # Output logits + out = torch.full( + (batch * next_n, max_seq_len), + float("-inf"), + device=device, + dtype=torch.float32, + ) + + # Dequantize queries: [batch, next_n, n_heads, head_dim_qk] + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, head_dim_qk) + + for b in range(batch): + seq_len = seqlens[b].item() + if seq_len <= 0: + continue + + num_blocks_needed = (seq_len + block_kv - 1) // block_kv + + # Gather KV blocks for this batch element + block_ids = block_tables[b, :num_blocks_needed] + # [num_blocks_needed, block_kv, 1, head_dim_with_sf] + kv_blocks = kv_cache_fp8[block_ids] + # Flatten to [num_blocks_needed * block_kv, head_dim_with_sf] + kv_flat = kv_blocks.view(-1, head_dim_with_sf) + # Trim to actual sequence length + kv_flat = kv_flat[:seq_len] + + # Dequantize KV: [seq_len, head_dim_qk] + k_f32 = _dequant_fp8_with_scale_suffix(kv_flat.unsqueeze(-2), head_dim_qk) + k_f32 = k_f32.squeeze(-2) # [seq_len, head_dim_qk] + + # Vectorized over next_n: + # q_b: [next_n, n_heads, head_dim_qk] + q_b = q_f32[b] + # dots: [next_n, n_heads, seq_len] + dots = torch.einsum("tnd,sd->tns", q_b, k_f32) + # Apply head weights: [n_heads] -> weighted sum -> [next_n, seq_len] + w = weights[b] # [n_heads] + logits_b = torch.einsum("tns,n->ts", dots, w) # [next_n, seq_len] + out_start = b * next_n + out[out_start : out_start + next_n, :seq_len] = logits_b + + return out + + +def sm120_fp8_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: Tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + ks: torch.Tensor, + ke: torch.Tensor, + clean_logits: bool = False, +) -> torch.Tensor: + """ + SM120 fallback for deep_gemm.fp8_mqa_logits() (contiguous/ragged variant). + + Computes weighted multi-head dot-product logits over contiguous KV. + + Args: + q_fp8: [num_q, n_heads, head_dim_with_sf] FP8 queries with appended scale + kv_fp8: tuple of (k_data_fp8 [num_k, head_dim_with_sf], k_scale [num_k]) or + (k_data_fp8 [num_k, D], k_scale [num_k, scale_dim]) + weights: [num_q, n_heads] float32 head weights + ks: [num_q] int32 start indices into KV + ke: [num_q] int32 end indices into KV + + Returns: + logits: [num_q, num_k] float32 where num_k = max(ke) - min(ks) (or ke.max()) + """ + num_q, n_heads, head_dim_with_sf = q_fp8.shape + head_dim_qk = head_dim_with_sf - 4 + device = q_fp8.device + + k_data, k_scale = kv_fp8 + num_k = k_data.shape[0] + + # Determine output width + k_max = ke.max().item() if ke.numel() > 0 else 0 + out_width = max(k_max, num_k) + + # Output logits + out = torch.full( + (num_q, out_width), + float("-inf"), + device=device, + dtype=torch.float32, + ) + + if num_q == 0 or num_k == 0: + return out + + # Dequantize queries: [num_q, n_heads, head_dim_qk] + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, head_dim_qk) + + # Dequantize KV keys + if k_data.shape[-1] == head_dim_with_sf: + # Keys have appended scale suffix + k_f32 = _dequant_fp8_with_scale_suffix(k_data.unsqueeze(-2), head_dim_qk) + k_f32 = k_f32.squeeze(-2) # [num_k, head_dim_qk] + else: + # Keys and scales are separate + k_f32 = k_data.to(torch.float32) + if k_scale.dim() == 1: + k_f32 = k_f32 * k_scale.unsqueeze(-1) + else: + k_f32 = k_f32 * k_scale + + # Vectorized: compute all dot products at once + # q_f32: [num_q, n_heads, head_dim_qk], k_f32: [num_k, head_dim_qk] + # dots: [num_q, n_heads, num_k] + dots = torch.einsum("qhd,kd->qhk", q_f32, k_f32) + + # Apply head weights: [num_q, n_heads] -> [num_q, n_heads, 1] + w = weights.unsqueeze(-1) + # Weighted sum across heads: [num_q, num_k] + logits_all = (dots * w).sum(dim=1) + + # Mask to [ks, ke) ranges + k_indices = torch.arange(out_width, device=device).unsqueeze(0) # [1, out_width] + ks_expanded = ks.unsqueeze(1) # [num_q, 1] + ke_expanded = ke.unsqueeze(1) # [num_q, 1] + mask = (k_indices >= ks_expanded) & (k_indices < ke_expanded) # [num_q, out_width] + + # Place logits into output at valid positions + # logits_all is [num_q, num_k], but output is [num_q, out_width] + out[:, :num_k] = torch.where(mask[:, :num_k], logits_all, out[:, :num_k]) + + return out diff --git a/python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py b/python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py new file mode 100644 index 000000000000..4a19e8ea6392 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/sm120_mqa_triton.py @@ -0,0 +1,177 @@ +"""SM120-optimized MQA logits — CUDA graph compatible. + +Replaces the PyTorch fallback in sm120_mqa_fallback.py with an optimized +implementation that precomputes the head-weighted query vector before +scanning the KV cache, reducing per-position work from O(n_heads) to O(1). + +Key insight: logit[s] = sum_h(w[h] * dot(q[h], kv[s])) + = dot(sum_h(w[h] * q[h]), kv[s]) + = dot(wq, kv[s]) + +CUDA graph compatibility: +- No .item() calls — all computation stays on GPU tensors +- No per-batch Python loops — vectorized with torch.bmm +- Fixed tensor shapes derived from known parameters (max_seq_len, num_k) + +Target: RTX PRO 6000 (SM120, 188 SMs, 99KB SMEM, ~1.5 TB/s GDDR7) +""" + +import logging +from typing import Tuple + +import torch + +logger = logging.getLogger(__name__) + + +def _dequant_fp8_with_scale_suffix( + data_fp8: torch.Tensor, + head_dim_qk: int, +) -> torch.Tensor: + """Dequantize FP8 tensor with appended float32 scale suffix.""" + data_bytes = data_fp8[..., :head_dim_qk] + scale_bytes = data_fp8[..., head_dim_qk:] + scale = scale_bytes.contiguous().view(torch.float32) + return data_bytes.to(torch.float32) * scale + + +def compute_paged_mqa_schedule_metadata( + seqlens: torch.Tensor, + block_size: int, + num_sms: int, +) -> None: + """SM120 fallback: scheduling is handled internally, return None.""" + return None + + +def sm120_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + seqlens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata, + max_seq_len: int, + clean_logits: bool = False, +) -> torch.Tensor: + """CUDA-graph-compatible paged MQA logits for SM120. + + Key optimizations vs fallback: + 1. Precompute wq = sum_h(w[h] * dequant(q[h])) — eliminates per-position head loop + 2. Batched matmul across all batch elements — no per-batch Python loop + 3. No .item() calls — all shapes derived from known parameters + """ + batch, next_n, n_heads, hd_with_sf = q_fp8.shape + hd = hd_with_sf - 4 + block_kv = kv_cache_fp8.shape[1] + device = q_fp8.device + + seqlens_flat = seqlens.view(-1).to(torch.int64) + + # Dequant Q: [batch, next_n, n_heads, hd] + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, hd) + + # Precompute wq = sum_h(w[b,h] * q[b,t,h,:]) → [batch, next_n, hd] + w = weights.view(batch, 1, n_heads, 1) + wq = (q_f32 * w).sum(dim=2) # [batch, next_n, hd] + + # Batch-dequant all KV blocks: [num_blocks, block_kv, hd] + kv_data = kv_cache_fp8[..., :hd].squeeze(2) + kv_scale_raw = kv_cache_fp8[..., hd:].squeeze(2) + kv_scale = kv_scale_raw.contiguous().view(torch.float32) + kv_f32 = kv_data.float() * kv_scale # [num_blocks_total, block_kv, hd] + + # ── Vectorized batch gather (no per-batch loop, no .item()) ── + max_blocks = (max_seq_len + block_kv - 1) // block_kv + # Gather block IDs for all batches: [batch, max_blocks] + block_ids = block_tables[:, :max_blocks] + + # Gather KV for all batches: [batch, max_blocks, block_kv, hd] + kv_batched = kv_f32[block_ids] + max_padded = max_blocks * block_kv + kv_flat = kv_batched.reshape(batch, max_padded, hd) + + # Batched matmul: [batch, next_n, hd] @ [batch, hd, max_padded] + logits_batched = torch.bmm( + wq, kv_flat.transpose(1, 2) + ) # [batch, next_n, max_padded] + + # Create validity mask: [batch, max_padded] + positions = torch.arange(max_padded, device=device) + valid = positions.unsqueeze(0) < seqlens_flat.unsqueeze(1) # [batch, max_padded] + + # Apply mask (broadcast over next_n) + logits_batched = logits_batched.masked_fill(~valid.unsqueeze(1), float("-inf")) + + # Write to output: [batch * next_n, max_seq_len] + out_width = min(max_padded, max_seq_len) + out = torch.full( + (batch * next_n, max_seq_len), + float("-inf"), + device=device, + dtype=torch.float32, + ) + out[:, :out_width] = logits_batched[:, :, :out_width].reshape( + batch * next_n, out_width + ) + + return out + + +def sm120_fp8_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: Tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + ks: torch.Tensor, + ke: torch.Tensor, + clean_logits: bool = False, +) -> torch.Tensor: + """CUDA-graph-compatible ragged MQA logits for SM120. + + Key optimization: precompute wq = sum_h(w[h] * q[h]), then single matmul. + No .item() calls — uses num_k for output width. + """ + num_q, n_heads, hd_with_sf = q_fp8.shape + hd = hd_with_sf - 4 + device = q_fp8.device + + k_data, k_scale = kv_fp8 + num_k = k_data.shape[0] + + # Use num_k as output width — avoids ke.max().item() GPU-CPU sync + out_width = num_k + + out = torch.full( + (num_q, out_width), + float("-inf"), + device=device, + dtype=torch.float32, + ) + + if num_q == 0 or num_k == 0: + return out + + # Dequant Q and precompute weighted query + q_f32 = _dequant_fp8_with_scale_suffix(q_fp8, hd) + w = weights.unsqueeze(-1) + wq = (q_f32 * w).sum(dim=1) # [num_q, hd] + + # Dequant KV + if k_data.shape[-1] == hd_with_sf: + k_f32 = _dequant_fp8_with_scale_suffix(k_data.unsqueeze(-2), hd).squeeze(-2) + else: + k_f32 = k_data.float() + if k_scale.dim() == 1: + k_f32 = k_f32 * k_scale.unsqueeze(-1) + else: + k_f32 = k_f32 * k_scale + + # Single matmul: [num_q, hd] @ [hd, num_k] → [num_q, num_k] + logits_all = wq @ k_f32.T + + # Apply ragged [ks, ke) masking + k_indices = torch.arange(out_width, device=device).unsqueeze(0) + mask = (k_indices >= ks.unsqueeze(1)) & (k_indices < ke.unsqueeze(1)) + out[:, :num_k] = torch.where(mask[:, :num_k], logits_all, out[:, :num_k]) + + return out diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index a938f0b01e22..9b5c2b21a8e6 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -36,7 +36,9 @@ ) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import get_device_sm, is_cuda, is_hip + +_is_sm120 = is_cuda() and get_device_sm() // 10 == 12 if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -641,10 +643,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): paged_mqa_schedule_metadata = None # DeepGEMM paged MQA logits path needs a schedule metadata tensor. # Compute it once per forward batch and reuse it across layers. - if is_cuda() and ( - forward_batch.forward_mode.is_decode_or_idle() - or forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend(include_v2=True) + if ( + is_cuda() + and not _is_sm120 + and ( + forward_batch.forward_mode.is_decode_or_idle() + or forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) + ) ): try: import deep_gemm @@ -930,10 +936,14 @@ def init_forward_metadata_capture_cuda_graph( real_page_table = self._transform_table_1_to_real(page_table_1) paged_mqa_schedule_metadata = None - if is_cuda() and ( - forward_mode.is_decode_or_idle() - or forward_mode.is_target_verify() - or forward_mode.is_draft_extend(include_v2=True) + if ( + is_cuda() + and not _is_sm120 + and ( + forward_mode.is_decode_or_idle() + or forward_mode.is_target_verify() + or forward_mode.is_draft_extend(include_v2=True) + ) ): try: import deep_gemm @@ -1081,10 +1091,14 @@ def init_forward_metadata_replay_cuda_graph( ) # Update DeepGEMM paged MQA schedule metadata outside the captured graph. - if is_cuda() and ( - forward_mode.is_decode_or_idle() - or forward_mode.is_target_verify() - or forward_mode.is_draft_extend(include_v2=True) + if ( + is_cuda() + and not _is_sm120 + and ( + forward_mode.is_decode_or_idle() + or forward_mode.is_target_verify() + or forward_mode.is_draft_extend(include_v2=True) + ) ): try: import deep_gemm diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py index de433fe2d505..9933b85e3dba 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/configurer.py @@ -18,12 +18,16 @@ def _compute_enable_deep_gemm(): sm_version = get_device_sm() if (_is_cuda and sm_version < 90) or (_is_musa and sm_version < 31): return False + # DeepGEMM requires TMEM/tcgen05 (SM100+datacenter), not available on SM120 + if sm_version // 10 == 12: + return False if not (_is_cuda or _is_musa): return False try: import deep_gemm # noqa: F401 - except ImportError: + except (ImportError, AssertionError): + # AssertionError: deep_gemm init may fail on unsupported architectures return False return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get() diff --git a/python/sglang/srt/layers/deepseek_v4_rope.py b/python/sglang/srt/layers/deepseek_v4_rope.py index c717850c63f7..3ff7615f3347 100644 --- a/python/sglang/srt/layers/deepseek_v4_rope.py +++ b/python/sglang/srt/layers/deepseek_v4_rope.py @@ -2,17 +2,20 @@ from functools import lru_cache from typing import Optional -import tilelang import torch import triton import triton.language as tl -tilelang.set_log_level("WARNING") +try: + import tilelang -pass_configs = { - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, -} + tilelang.set_log_level("WARNING") + pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + } +except ImportError: + pass FP8 = "float8_e4m3" BF16 = "bfloat16" diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py new file mode 100644 index 000000000000..31ccff929ea7 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_fallback.py @@ -0,0 +1,185 @@ +"""PyTorch fallback for MXFP4 MoE GEMM on SM120. + +The Marlin MXFP4 kernel produces NaN on SM120 (Blackwell Desktop). +This module provides a pure-PyTorch implementation that dequantizes +MXFP4 weights (packed int8 + float8_e8m0fnu scales) to BF16 and uses +torch.matmul for the GEMM, per active expert. + +Slow but functionally correct — matches the FlashMLA fallback pattern. +""" + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# ── FP4 E2M1 lookup table ────────────────────────────────────────── +# Nibble encoding: bit3=sign, bit2-1=exponent (bias=1), bit0=mantissa +# 16 possible values for 4-bit float +_FP4_E2M1_LUT = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, # positive (0x0-0x7) + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, # negative (0x8-0xF) + ], + dtype=torch.float32, +) + + +def _dequant_mxfp4_weight( + packed: torch.Tensor, + scales: torch.Tensor, + unpacked_k: int, +) -> torch.Tensor: + """Dequantize one expert's MXFP4 weight from packed int8 to bfloat16. + + Args: + packed: [N, K//2] int8 — 2 FP4 values per byte (low nibble=even, high=odd) + scales: [N, K//32] float32 — dequantization scale per group of 32 elements + unpacked_k: K, the full unpacked dimension + + Returns: + [N, K] bfloat16 weight matrix + """ + device = packed.device + lut = _FP4_E2M1_LUT.to(device=device) + + # View as unsigned bytes for bit manipulation + u8 = packed.view(torch.uint8).to(torch.int32) + low = u8 & 0x0F # even-index elements + high = (u8 >> 4) & 0x0F # odd-index elements + + # Lookup FP4 → float32 + vals_low = lut[low.long()] # [N, K//2] + vals_high = lut[high.long()] # [N, K//2] + + # Interleave: [low0, high0, low1, high1, ...] + unpacked = torch.stack([vals_low, vals_high], dim=-1) # [N, K//2, 2] + unpacked = unpacked.reshape(packed.shape[0], -1) # [N, K] + unpacked = unpacked[:, :unpacked_k] # trim if needed + + # Apply group scales (group_size=32) + # scales: [N, K//32] — each scale covers 32 consecutive elements along K + if scales.dtype == torch.float8_e8m0fnu: + scales_f32 = scales.to(torch.float32) + else: + scales_f32 = scales.float() + scales_expanded = scales_f32.repeat_interleave(32, dim=-1)[:, :unpacked_k] + + result = unpacked * scales_expanded + return result.to(torch.bfloat16) + + +def mxfp4_moe_forward_fallback( + hidden_states: torch.Tensor, + w13_packed: torch.Tensor, + w2_packed: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_size: int, + intermediate_size: int, + inplace: bool = False, + routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, +) -> torch.Tensor: + """Pure-PyTorch MXFP4 MoE forward pass. + + Args: + hidden_states: [M, K] bfloat16 input activations + w13_packed: [E, 2*I, K//2] int8 packed gate_up_proj weights + w2_packed: [E, K, I//2] int8 packed down_proj weights + w13_scale: [E, 2*I, K//32] scales for gate_up_proj + w2_scale: [E, K, I//32] scales for down_proj + topk_ids: [M, topk] int32 expert assignments + topk_weights: [M, topk] float32 routing weights + hidden_size: K + intermediate_size: I (per partition) + inplace: whether to write output in-place + routed_scaling_factor: optional global scaling factor + clamp_limit: optional SwiGLU clamp limit (2604B submode) + + Returns: + [M, K] bfloat16 output tensor + """ + M, K = hidden_states.shape + topk = topk_ids.shape[1] + device = hidden_states.device + dtype = hidden_states.dtype + I = intermediate_size + + output = hidden_states if inplace else torch.zeros(M, K, dtype=dtype, device=device) + if not inplace: + output.zero_() + + # Find all active experts + active_experts = topk_ids.unique() + + for eid in active_experts: + eid_val = eid.item() + if eid_val < 0: + continue + + # Find (token_idx, slot_idx) pairs assigned to this expert + mask = topk_ids == eid_val # [M, topk] + token_mask = mask.any(dim=1) # [M] + token_indices = token_mask.nonzero(as_tuple=True)[0] + + if len(token_indices) == 0: + continue + + # ── GEMM1: gate_up_proj ── + # w13: [2*I, K//2] int8 → dequant → [2*I, K] bf16 + w13_dq = _dequant_mxfp4_weight( + w13_packed[eid_val], w13_scale[eid_val], K + ) # [2*I, K] + + h = hidden_states[token_indices] # [n, K] + # y = h @ W13^T → [n, K] @ [K, 2*I] = [n, 2*I] + intermediate = torch.matmul(h.float(), w13_dq.float().T).to(dtype) + + # ── SiLU + Mul (with optional clamp) ── + gate = intermediate[:, :I] + up = intermediate[:, I:] + if clamp_limit is not None and clamp_limit > 0: + gate = torch.clamp(gate, max=clamp_limit) + up = torch.clamp(up, min=-clamp_limit, max=clamp_limit) + intermediate2 = F.silu(gate) * up # [n, I] + + # ── GEMM2: down_proj ── + # w2: [K, I//2] int8 → dequant → [K, I] bf16 + w2_dq = _dequant_mxfp4_weight( + w2_packed[eid_val], w2_scale[eid_val], I + ) # [K, I] + + # y = intermediate2 @ W2^T → [n, I] @ [I, K] = [n, K] + down = torch.matmul(intermediate2.float(), w2_dq.float().T).to(dtype) + + # ── Accumulate with topk weights (vectorized over topk slots) ── + expert_mask = (topk_ids[token_indices] == eid_val).to(dtype) # [n, topk] + combined_weights = (expert_mask * topk_weights[token_indices].to(dtype)).sum( + dim=1, keepdim=True + ) # [n, 1] + output[token_indices] += down * combined_weights + + if routed_scaling_factor is not None and routed_scaling_factor != 1.0: + output.mul_(routed_scaling_factor) + + return output diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py new file mode 100644 index 000000000000..c48faeeb2f22 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/mxfp4_moe_sm120_triton.py @@ -0,0 +1,444 @@ +"""SM120-optimized Triton MXFP4 MoE kernel — CUDA graph compatible. + +Replaces the PyTorch fallback (per-expert for-loop + full dequant + matmul) +with fused Triton kernels that: +1. Fuse FP4 dequant + GEMV (no intermediate BF16 weight materialization) +2. Process each (token, expert) slot independently — no data-dependent routing +3. Respect SM120 shared memory constraint (99 KB/block) + +CUDA graph compatibility: +- No .unique(), .item(), .nonzero() — all routing is tensor-level +- Fixed grid dimensions (M*topk, N_blocks) per captured batch size +- All control flow is static or within Triton kernels + +SM120 constraints: +- SMEM: 99 KB/block (vs SM100 228 KB) +- No TMEM/tcgen05 — uses mma.sync.aligned via Triton +- Max warps: 48/SM +- Registers: ~128/thread practical limit +""" + +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def _dequant_fp4_lut(nibble): + """Decode a 4-bit FP4 E2M1 nibble to float32 using arithmetic.""" + sign_bit = (nibble >> 3) & 1 + exp_bits = (nibble >> 1) & 3 + man_bit = nibble & 1 + + is_subnormal = exp_bits == 0 + mantissa = 1.0 + man_bit.to(tl.float32) * 0.5 + exponent = tl.math.exp2((exp_bits - 1).to(tl.float32)) + val = tl.where(is_subnormal, man_bit.to(tl.float32) * 0.5, mantissa * exponent) + val = tl.where(sign_bit != 0, -val, val) + return val + + +# ── Per-slot GEMV kernel: processes one (token, expert) pair ── + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 64, "BLOCK_K": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 128, "BLOCK_K": 64}, num_warps=8, num_stages=2), + ], + key=["N", "K"], +) +@triton.jit +def _mxfp4_slot_gemv_kernel( + # Pointers + A_ptr, # [M_total, K] bf16 — source rows + B_packed_ptr, # [E, N, K//2] uint8 — packed FP4 expert weights + B_scale_ptr, # [E, N, K//32] float32 — weight scales + C_ptr, # [num_slots, N] bf16 — output + token_ids_ptr, # [num_slots] int32 — which A row for each slot + expert_ids_ptr, # [num_slots] int32 — which expert's B for each slot + # Dimensions + N: tl.int32, + K: tl.int32, + # A strides + stride_am: tl.int32, + # B strides (within an expert) + stride_bn: tl.int32, + stride_bk2: tl.int32, + # B_scale strides (within an expert) + stride_bsn: tl.int32, + stride_bsk32: tl.int32, + # Expert strides (between experts) + expert_b_stride: tl.int64, + expert_s_stride: tl.int64, + # C strides + stride_cm: tl.int32, + # Block sizes + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Per-slot fused MXFP4 dequant + GEMV. + + Grid: (num_slots, cdiv(N, BLOCK_N)) + Each program computes one (token, expert) pair for a BLOCK_N slice of output. + """ + slot_id = tl.program_id(0) + n_block = tl.program_id(1) + + token_id = tl.load(token_ids_ptr + slot_id).to(tl.int64) + expert_id = tl.load(expert_ids_ptr + slot_id).to(tl.int64) + + offs_n = n_block * BLOCK_N + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + + # Expert weight base pointers + b_base = expert_id * expert_b_stride + s_base = expert_id * expert_s_stride + a_base = token_id * stride_am + + for k_start in range(0, K, BLOCK_K): + # ── Load packed B: [BLOCK_N, BLOCK_K//2] ── + offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2) + b_mask = n_mask[:, None] & (offs_k2[None, :] < K // 2) + b_packed = tl.load( + B_packed_ptr + + b_base + + offs_n[:, None] * stride_bn + + offs_k2[None, :] * stride_bk2, + mask=b_mask, + other=0, + ) + + # ── FP4 dequant ── + b_u8 = b_packed.to(tl.int32) + val_lo = _dequant_fp4_lut(b_u8 & 0x0F) # even K indices + val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) # odd K indices + + # ── Load and apply scales: [BLOCK_N, BLOCK_K//2] ── + group_ids = tl.arange(0, BLOCK_K // 2) // 16 # 32 values per group, 2 per byte + s_mask = n_mask[:, None] & ((k_start // 32 + group_ids[None, :]) < K // 32) + scales = tl.load( + B_scale_ptr + + s_base + + offs_n[:, None] * stride_bsn + + (k_start // 32 + group_ids[None, :]) * stride_bsk32, + mask=s_mask, + other=1.0, + ) + val_lo = val_lo * scales + val_hi = val_hi * scales + + # ── Load A even/odd: [BLOCK_K//2] each ── + offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2 + offs_k_odd = offs_k_even + 1 + + a_even = tl.load( + A_ptr + a_base + offs_k_even, + mask=offs_k_even < K, + other=0.0, + ).to(tl.float32) + a_odd = tl.load( + A_ptr + a_base + offs_k_odd, + mask=offs_k_odd < K, + other=0.0, + ).to(tl.float32) + + # ── Dot product: acc[n] += sum_k(a_even[k]*B_lo[n,k] + a_odd[k]*B_hi[n,k]) ── + acc += tl.sum(a_even[None, :] * val_lo, axis=1) + acc += tl.sum(a_odd[None, :] * val_hi, axis=1) + + # ── Store output ── + tl.store( + C_ptr + slot_id * stride_cm + offs_n, + acc.to(tl.bfloat16), + mask=n_mask, + ) + + +# ── Legacy per-expert GEMM kernel (kept for benchmarking) ── + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_warps=8, num_stages=2 + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 64, "BLOCK_K": 64}, num_warps=4, num_stages=2 + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def _mxfp4_gemm_kernel( + # Pointers + A_ptr, # [M, K] bf16 activation + B_packed_ptr, # [N, K//2] uint8 packed FP4 + B_scale_ptr, # [N, K//32] float32 scales + C_ptr, # [M, N] bf16 output + # Dimensions + M, + N, + K, + # Strides + stride_am, + stride_ak, + stride_bn, + stride_bk2, + stride_bsn, + stride_bsk32, + stride_cm, + stride_cn, + # Constexprs + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Fused MXFP4 dequant + GEMM: C = A @ dequant(B_packed, B_scale).T""" + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k_start in range(0, K, BLOCK_K): + offs_k2 = k_start // 2 + tl.arange(0, BLOCK_K // 2) + b_mask = (offs_n[:, None] < N) & (offs_k2[None, :] < K // 2) + b_packed = tl.load( + B_packed_ptr + offs_n[:, None] * stride_bn + offs_k2[None, :] * stride_bk2, + mask=b_mask, + other=0, + ) + + b_u8 = b_packed.to(tl.int32) + val_lo = _dequant_fp4_lut(b_u8 & 0x0F) + val_hi = _dequant_fp4_lut((b_u8 >> 4) & 0x0F) + + group_ids = tl.arange(0, BLOCK_K // 2) // 16 + scales_per_byte = tl.load( + B_scale_ptr + + offs_n[:, None] * stride_bsn + + (k_start // 32 + group_ids[None, :]) * stride_bsk32, + mask=(offs_n[:, None] < N) + & ((k_start // 32 + group_ids[None, :]) < K // 32), + other=1.0, + ) + val_lo = val_lo * scales_per_byte + val_hi = val_hi * scales_per_byte + + offs_k_even = k_start + tl.arange(0, BLOCK_K // 2) * 2 + offs_k_odd = offs_k_even + 1 + + a_even_mask = (offs_m[:, None] < M) & (offs_k_even[None, :] < K) + a_even = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_k_even[None, :] * stride_ak, + mask=a_even_mask, + other=0.0, + ).to(tl.float32) + + a_odd_mask = (offs_m[:, None] < M) & (offs_k_odd[None, :] < K) + a_odd = tl.load( + A_ptr + offs_m[:, None] * stride_am + offs_k_odd[None, :] * stride_ak, + mask=a_odd_mask, + other=0.0, + ).to(tl.float32) + + acc += tl.dot(a_even, tl.trans(val_lo)) + acc += tl.dot(a_odd, tl.trans(val_hi)) + + c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store( + C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, + acc.to(tl.bfloat16), + mask=c_mask, + ) + + +def mxfp4_gemm_triton( + A: torch.Tensor, + B_packed: torch.Tensor, + B_scale: torch.Tensor, + K_full: int, +) -> torch.Tensor: + """Triton fused MXFP4 dequant + GEMM: output = A @ dequant(B).T + + Kept for standalone benchmarking. The MoE forward uses the slot kernel. + """ + M = A.shape[0] + N = B_packed.shape[0] + K = K_full + + if B_scale.dtype == torch.float8_e8m0fnu: + B_scale = B_scale.to(torch.float32) + elif B_scale.dtype != torch.float32: + B_scale = B_scale.float() + + C = torch.empty(M, N, dtype=torch.bfloat16, device=A.device) + A = A.contiguous() + B_packed = B_packed.contiguous() + B_scale = B_scale.contiguous() + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + B_u8 = B_packed.view(torch.uint8) + + _mxfp4_gemm_kernel[grid]( + A, + B_u8, + B_scale, + C, + M, + N, + K, + A.stride(0), + A.stride(1), + B_u8.stride(0), + B_u8.stride(1), + B_scale.stride(0), + B_scale.stride(1), + C.stride(0), + C.stride(1), + ) + return C + + +def mxfp4_moe_forward_triton( + hidden_states: torch.Tensor, + w13_packed: torch.Tensor, + w2_packed: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + hidden_size: int, + intermediate_size: int, + inplace: bool = False, + routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, +) -> torch.Tensor: + """SM120-optimized MXFP4 MoE forward — CUDA graph compatible. + + Uses per-slot GEMV kernels instead of per-expert Python loops. + Each (token, expert) slot is processed independently with a fixed grid, + eliminating .unique()/.item()/.nonzero() that break CUDA graph capture. + """ + import torch.nn.functional as F + + M, K = hidden_states.shape + topk = topk_ids.shape[1] + I = intermediate_size + num_slots = M * topk + device = hidden_states.device + dtype = hidden_states.dtype + + # ── Graph-safe routing: flatten topk assignments ── + # token_ids[slot] = which row of A (original token index) + # expert_ids[slot] = which expert's weights to use + flat_expert_ids = topk_ids.reshape(-1).contiguous() # [M*topk] + token_ids = ( + torch.arange(M, device=device, dtype=torch.int32) + .unsqueeze(1) + .expand(M, topk) + .reshape(-1) + .contiguous() + ) # [M*topk] + + # ── Ensure scales are float32 ── + if w13_scale.dtype != torch.float32: + w13_scale = w13_scale.to(torch.float32) + if w2_scale.dtype != torch.float32: + w2_scale = w2_scale.to(torch.float32) + + # ── GEMM1: gate_up projection ── + # hidden_states[token] @ w13[expert].T → [num_slots, 2*I] + intermediate = torch.empty(num_slots, 2 * I, dtype=dtype, device=device) + + w13_u8 = w13_packed.view(torch.uint8) # [E, 2*I, K//2] + grid1 = lambda meta: (num_slots, triton.cdiv(2 * I, meta["BLOCK_N"])) + + _mxfp4_slot_gemv_kernel[grid1]( + hidden_states, + w13_u8, + w13_scale, + intermediate, + token_ids, + flat_expert_ids, + 2 * I, + K, + hidden_states.stride(0), + w13_u8.stride(1), + w13_u8.stride(2), + w13_scale.stride(1), + w13_scale.stride(2), + w13_u8.stride(0), + w13_scale.stride(0), + intermediate.stride(0), + ) + + # ── SiLU activation (graph-safe vectorized ops) ── + gate = intermediate[:, :I].float() + up = intermediate[:, I:].float() + if clamp_limit is not None and clamp_limit > 0: + gate = torch.clamp(gate, max=clamp_limit) + up = torch.clamp(up, min=-clamp_limit, max=clamp_limit) + activated = (F.silu(gate) * up).to(dtype) + + # ── GEMM2: down projection ── + # activated[slot] @ w2[expert].T → [num_slots, K] + down = torch.empty(num_slots, K, dtype=dtype, device=device) + + # For GEMM2, A is the activated buffer — each slot reads its own row + slot_ids = torch.arange(num_slots, device=device, dtype=torch.int32) + + w2_u8 = w2_packed.view(torch.uint8) # [E, K, I//2] + grid2 = lambda meta: (num_slots, triton.cdiv(K, meta["BLOCK_N"])) + + _mxfp4_slot_gemv_kernel[grid2]( + activated, + w2_u8, + w2_scale, + down, + slot_ids, + flat_expert_ids, + K, + I, + activated.stride(0), + w2_u8.stride(1), + w2_u8.stride(2), + w2_scale.stride(1), + w2_scale.stride(2), + w2_u8.stride(0), + w2_scale.stride(0), + down.stride(0), + ) + + # ── Weighted sum across topk slots (graph-safe) ── + flat_weights = topk_weights.reshape(-1).unsqueeze(1).to(dtype) # [M*topk, 1] + output = (down * flat_weights).view(M, topk, K).sum(dim=1) + + if routed_scaling_factor is not None and routed_scaling_factor != 1.0: + output.mul_(routed_scaling_factor) + + return output diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py index e953615bb132..638c632de4c0 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py @@ -5,10 +5,12 @@ from __future__ import annotations import functools +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch import torch.nn.functional as F +import triton import triton.language as tl from sglang.srt.environ import envs @@ -85,6 +87,217 @@ padding_size = get_moe_padding_size(_use_aiter) +logger = logging.getLogger(__name__) + + +def _is_mxfp4_xpu_packed( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> bool: + """Detect MXFP4-packed routed-expert weights on XPU. + + The DSv4 fp8 checkpoint loader passes ``use_fp8_w8a8=True`` for routed + experts that are actually MXFP4 (e.g. DeepSeek-V4-Flash), so we must + NOT exclude on ``use_fp8_w8a8``. We discriminate MXFP4 from real FP8 + weights via the packed-last-dim invariant: + - MXFP4 routed experts: w1.shape[-1] == hidden_size // 2 + - FP8 shared experts : w1.shape[-1] == hidden_size (skip) + """ + return ( + _is_xpu + and not (use_int8_w8a8 or use_int8_w8a16 or use_int4_w4a16) + and (w1.dtype == torch.uint8 or w1.dtype == torch.int8) + and (w2.dtype == torch.uint8 or w2.dtype == torch.int8) + and w1_scale is not None + and w2_scale is not None + and w1.shape[-1] * 2 == hidden_states.shape[-1] + ) + + +# E2M1 lookup table: nibble value 0x0–0xF → float +_E2M1_LUT = torch.tensor( + [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, # 0b0xxx (positive) + 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # 0b1xxx (negative) + ], + dtype=torch.float32, +) + +# Per-(device, dtype) cache of the LUT to avoid the per-call host->device +# copy + sync that ``_E2M1_LUT.to(device=..., dtype=...)`` triggers on XPU. +_E2M1_LUT_CACHE: dict = {} + + +def _get_e2m1_lut(device: torch.device, dtype: torch.dtype) -> torch.Tensor: + key = (str(device), dtype) + cached = _E2M1_LUT_CACHE.get(key) + if cached is None: + cached = _E2M1_LUT.to(device=device, dtype=dtype) + _E2M1_LUT_CACHE[key] = cached + return cached + + +# --------------------------------------------------------------------------- +# Triton MXFP4 dequant kernel: replaces the PyTorch loop-based upcast. +# One kernel launch converts [E, N, half_K] packed uint8 -> [E, N, K] bf16 +# with fused block-scale multiplication. No int64 intermediates, no +# host syncs, no thousands of small kernel launches. +# --------------------------------------------------------------------------- + + +@triton.jit +def _mxfp4_dequant_kernel( + W_ptr, # [E*N, half_K] uint8 - packed weights (flattened E*N) + S_ptr, # [E*N, half_K // 16] float32 - block scales + LUT_ptr, # [16] bfloat16 - E2M1 lookup table + Out_ptr, # [E*N, K] bfloat16 - output + half_K, # K // 2 (packed dimension) + stride_wn, + stride_sn, + stride_on, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """Dequantise MXFP4-packed weights to bf16 with block-scale fusion. + + Grid: (cdiv(E*N, BLOCK_N), cdiv(half_K, BLOCK_K)) + """ + pid_n = tl.program_id(0) + pid_k = tl.program_id(1) + + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + + w_off = rn[:, None] * stride_wn + rk[None, :] + mask = rk[None, :] < half_K + packed = tl.load(W_ptr + w_off, mask=mask, other=0) + + lo = (packed & 0xF).to(tl.int32) + hi = ((packed >> 4) & 0xF).to(tl.int32) + + lo_val = tl.load(LUT_ptr + lo).to(tl.float32) + hi_val = tl.load(LUT_ptr + hi).to(tl.float32) + + scale_col = rk[None, :] // 16 + s_off = rn[:, None] * stride_sn + scale_col + scale = tl.load(S_ptr + s_off, mask=mask, other=1.0).to(tl.float32) + + lo_out = (lo_val * scale).to(tl.bfloat16) + hi_out = (hi_val * scale).to(tl.bfloat16) + + out_off_lo = rn[:, None] * stride_on + rk[None, :] * 2 + out_off_hi = out_off_lo + 1 + tl.store(Out_ptr + out_off_lo, lo_out, mask=mask) + tl.store(Out_ptr + out_off_hi, hi_out, mask=mask) + + +def _upcast_mxfp4_triton( + w_packed: torch.Tensor, + w_scale: torch.Tensor, + target_dtype: torch.dtype, +) -> torch.Tensor: + """Triton-accelerated MXFP4 -> bf16 dequant with fused block-scale multiply. + + Replaces the PyTorch loop that generated thousands of small kernel + launches with a single Triton kernel launch per weight tensor. + + w_packed : [E, N, K//2] uint8 - two E2M1 values per byte + w_scale : [E, N, K//32] float32 - MX block scale (direct multiplier) + Returns : [E, N, K] target_dtype - contiguous + """ + w_u8 = w_packed.view(torch.uint8).contiguous() + E, N, half_K = w_u8.shape + K = half_K * 2 + + lut = _get_e2m1_lut(w_u8.device, torch.bfloat16).contiguous() + out = torch.empty(E, N, K, dtype=torch.bfloat16, device=w_u8.device) + + w_flat = w_u8.reshape(E * N, half_K) + s_flat = w_scale.to(torch.float32).reshape(E * N, half_K // 16).contiguous() + out_flat = out.reshape(E * N, K) + + stride_wn = half_K + stride_sn = half_K // 16 + stride_on = K + + BLOCK_N = 4 + BLOCK_K = min(128, half_K) + total_rows = E * N + grid = ( + triton.cdiv(total_rows, BLOCK_N), + triton.cdiv(half_K, BLOCK_K), + ) + + _mxfp4_dequant_kernel[grid]( + w_flat, s_flat, lut, out_flat, + half_K, + stride_wn, stride_sn, stride_on, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + ) + + if target_dtype != torch.bfloat16: + out = out.to(target_dtype) + return out + + +def _upcast_mxfp4_one_xpu( + w_packed: torch.Tensor, + w_scale: torch.Tensor, + target_dtype: torch.dtype, +) -> torch.Tensor: + """Upcast MXFP4-packed expert weights to *target_dtype*. + + Uses a fused Triton dequant kernel (single kernel launch). + + w_packed : [E, N, K//2] uint8/int8 — two E2M1 values per byte + w_scale : [E, N, K//32] float32 — MX block scale (direct multiplier) + Returns : [E, N, K] target_dtype — contiguous + """ + return _upcast_mxfp4_triton(w_packed, w_scale, target_dtype) + + +def _log_mxfp4_xpu_budget( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + target_dtype: torch.dtype, +) -> None: + """Diagnostic: free/total XPU memory and per-weight bf16 transient cost. + + Logged at INFO so it's visible without raising the global log level; + fires once per MXFP4 MoE call. + """ + elem_size = torch.tensor([], dtype=target_dtype).element_size() + # Packed weights have last dim = K // 2; bf16 output has last dim = K. + w1_bytes = w1.numel() * 2 * elem_size + w2_bytes = w2.numel() * 2 * elem_size + try: + free_bytes, total_dev_bytes = torch.xpu.mem_get_info(hidden_states.device) + mem_str = ( + f"xpu_free={free_bytes / 1024**3:.2f} GiB " + f"xpu_total={total_dev_bytes / 1024**3:.2f} GiB" + ) + except Exception as exc: # pragma: no cover - diagnostic only + mem_str = f"xpu_free=" + logger.info( + "MXFP4 upcast (sequenced) on %s: %s; per-GEMM transient w1=%.2f GiB, " + "w2=%.2f GiB (target_dtype=%s, w1.shape=%s, w2.shape=%s, hidden=%d)", + hidden_states.device, + mem_str, + w1_bytes / 1024**3, + w2_bytes / 1024**3, + target_dtype, + tuple(w1.shape), + tuple(w2.shape), + hidden_states.shape[-1], + ) + @register_custom_op(mutates_args=["hidden_states"]) def inplace_fused_experts( @@ -239,6 +452,10 @@ def fused_experts( moe_runner_config.num_experts is None or moe_runner_config.num_experts != moe_runner_config.num_local_experts ) + # MXFP4-packed routed experts on XPU are kept in their packed form + # here and dequantized to bf16 lazily inside ``fused_experts_impl``, + # interleaved with each GEMM call so peak transient memory is one + # weight (~8 GiB w1 / ~4 GiB w2 at TP=1) instead of both at once. if moe_runner_config.inplace: assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" inplace_fused_experts( @@ -346,15 +563,20 @@ def _prepare_fused_moe_run( use_int4_w4a16: bool, per_channel_quant: bool, block_shape: Optional[List[int]], + mxfp4_xpu: bool = False, ): """Resolve config, down_config, TMA flag, and aligned expert routing ids. Shared by ``fused_experts_impl`` and ``pre_permute_standard_to_triton`` so both paths compute alignment from the same source. """ - padded_size = padding_size - if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: - padded_size = 0 + gemm_block_shape: Optional[List[int]] = None + if mxfp4_xpu: + padding_size = 0 + else: + if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: + padded_size = 0 + gemm_block_shape = block_shape num_tokens = hidden_states.shape[0] E = w1.shape[0] @@ -366,13 +588,23 @@ def _prepare_fused_moe_run( dtype=hidden_states.dtype, ) + # MXFP4 weights are packed with last dim = K/2; the GEMM operates on + # the bf16 tensor with last dim = K, so feed the unpacked shape to + # the tile-config selector. + if mxfp4_xpu: + w1_shape_for_cfg = (w1.shape[0], w1.shape[1], hidden_states.shape[1]) + w2_shape_for_cfg = (w2.shape[0], w2.shape[1], w2.shape[2] * 2) + else: + w1_shape_for_cfg = w1.shape + w2_shape_for_cfg = (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size) + config, (down_config, _) = try_get_optimal_moe_config( - w1.shape, - (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), + w1_shape_for_cfg, + w2_shape_for_cfg, topk_ids.shape[1], config_dtype, num_tokens, - block_shape=block_shape, + block_shape=gemm_block_shape, per_channel_quant=per_channel_quant, return_down_config=True, ) @@ -434,6 +666,7 @@ def _fused_moe_kernel_sequence( filter_expert: bool, hooks: Optional[Any] = None, swiglu_limit: Optional[float] = None, + mxfp4_xpu: bool = False, ) -> torch.Tensor: """Run the MoE kernel/activation/kernel/combine sequence in a single shot. @@ -445,6 +678,20 @@ def _fused_moe_kernel_sequence( E, N, _ = w1.shape topk = topk_ids.shape[1] compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + mxfp4_target_dtype = torch.float32 + gemm_block_shape: Optional[List[int]] = None + gemm_w1_scale = None + gemm_w2_scale = None + if mxfp4_xpu: + mxfp4_target_dtype = ( + hidden_states.dtype + if hidden_states.dtype in (torch.float16, torch.bfloat16) + else torch.bfloat16 + ) + else: + gemm_block_shape = block_shape + gemm_w1_scale = w1_scale + gemm_w2_scale = w1_scale padded_tokens = ( min(num_tokens * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1) @@ -479,13 +726,20 @@ def _fused_moe_kernel_sequence( dtype=hidden_states.dtype, ) + # MXFP4 (XPU): upcast w1 just-in-time. Released right after GEMM1 + # so w2's upcast doesn't have to share the budget. + if mxfp4_xpu: + w1_eff = _upcast_mxfp4_one_xpu(w1, w1_scale, mxfp4_target_dtype) + else: + w1_eff = w1 + invoke_fused_moe_kernel( hidden_states, - w1, + w1_eff, b1, intermediate_cache1, a1_scale, - w1_scale, + gemm_w1_scale, w1_zp, topk_weights, topk_ids, @@ -501,11 +755,16 @@ def _fused_moe_kernel_sequence( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, - block_shape=block_shape, + block_shape=gemm_block_shape, c_sorted=down_moe_use_tma, filter_expert=filter_expert, ) + # Drop the bf16 w1 reference so the allocator can reuse its block + # for the w2 upcast that follows GEMM1. + if mxfp4_xpu: + del w1_eff + if hooks and hooks.after_gate_up: # Hooks expect intermediate_cache1 shaped (num_tokens, topk, N); the # underlying buffer is laid out as (total_tokens, N) where @@ -545,7 +804,7 @@ def _fused_moe_kernel_sequence( # fusion=False: explicit clamp_ on intermediate_cache1 (path checker) assert swiglu_limit == 10 assert intermediate_cache1.shape == (total_tokens, N) - assert _is_cuda or _is_hip, "DeepSeek V4 only supports CUDA/HIP downstream" + assert _is_cuda or _is_hip or _is_xpu, "DeepSeek V4 only supports CUDA/HIP/XPU downstream" swiglu_limit_for_triton: Optional[float] = None swiglu_limit_for_silu_and_mul_clamp: Optional[float] = None @@ -661,9 +920,15 @@ def _fused_moe_kernel_sequence( out_slice = out_hidden_states out_slice.zero_() + # MXFP4 (XPU): upcast w2 just-in-time for GEMM2. + if mxfp4_xpu: + w2_eff = _upcast_mxfp4_one_xpu(w2, w2_scale, mxfp4_target_dtype) + else: + w2_eff = w2 + invoke_fused_moe_kernel( intermediate_cache2, - w2, + w2_eff, b2, ( out_slice @@ -675,7 +940,7 @@ def _fused_moe_kernel_sequence( ) ), a2_scale, - w2_scale, + gemm_w2_scale, w2_zp, topk_weights, topk_ids, @@ -691,7 +956,7 @@ def _fused_moe_kernel_sequence( use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, - block_shape=block_shape, + block_shape=gemm_block_shape, a_use_tma=down_moe_use_tma, b_use_tma=down_moe_use_tma, filter_expert=filter_expert, @@ -699,6 +964,9 @@ def _fused_moe_kernel_sequence( router_topk=topk, ) + if mxfp4_xpu: + del w2_eff + if hooks and hooks.after_down: hooks.after_down( intermediate_cache2, intermediate_cache3, topk_weights, topk_ids @@ -814,12 +1082,35 @@ def fused_experts_impl( filter_expert: bool = True, swiglu_limit: Optional[float] = None, ): - padded_size = padding_size - if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: + # MXFP4-packed routed experts on XPU arrive here as uint8 with last + # dim = K/2 plus uint8 E8M0 scales. We dequantize to bf16 lazily, one + # weight at a time, freeing each bf16 buffer between the up- and + # down-projection GEMMs to keep peak transient at ~8 GiB (TP=1) + # instead of ~12 GiB. + mxfp4_xpu = _is_mxfp4_xpu_packed( + hidden_states, w1, w2, w1_scale, w2_scale, + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + ) + if mxfp4_xpu: + mxfp4_target_dtype = ( + hidden_states.dtype + if hidden_states.dtype in (torch.float16, torch.bfloat16) + else torch.bfloat16 + ) + #_log_mxfp4_xpu_budget(hidden_states, w1, w2, mxfp4_target_dtype) + # The bf16 GEMM that follows must NOT see fp8/block-quant flags or + # the packed scales: those are folded into the upcast result. + gemm_use_fp8_w8a8 = False padded_size = 0 + else: + gemm_use_fp8_w8a8 = use_fp8_w8a8 + padded_size = padding_size + if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: + padded_size = 0 # Check constraints. - if use_int4_w4a16: + if use_int4_w4a16 or mxfp4_xpu: + # Packed last dim is K/2; bf16 last dim after upcast is K. assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch" else: assert ( @@ -843,12 +1134,13 @@ def fused_experts_impl( w1, w2, topk_ids, - use_fp8_w8a8=use_fp8_w8a8, + use_fp8_w8a8=gemm_use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, + mxfp4_xpu=mxfp4_xpu, ) return _fused_moe_kernel_sequence( @@ -865,7 +1157,7 @@ def fused_experts_impl( down_moe_use_tma, b1=b1, b2=b2, - use_fp8_w8a8=use_fp8_w8a8, + use_fp8_w8a8=gemm_use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, @@ -888,6 +1180,7 @@ def fused_experts_impl( filter_expert=filter_expert, hooks=None, swiglu_limit=swiglu_limit, + mxfp4_xpu=mxfp4_xpu, ) diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_config.py b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_config.py index b1cc298ee0fb..4d7c77538a71 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_config.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_config.py @@ -10,10 +10,11 @@ import triton from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import get_device_name, is_hip +from sglang.srt.utils import get_device_name, is_hip, is_xpu logger = logging.getLogger(__name__) _is_hip = is_hip() +_is_xpu = is_xpu() def get_config_file_name( @@ -190,6 +191,31 @@ def get_default_config( "num_warps": 4, "num_stages": 2 if _is_hip else 3, } + if _is_xpu: + # Intel Arc / Xe-HPG GPUs have ~64 KiB of SLM per work-group. + # The CUDA defaults above (large BLOCK_SIZE_N and num_stages>=3) + # exceed that budget and trigger + # UR_RESULT_ERROR_OUT_OF_RESOURCES at kernel launch. Use a + # conservative tile that fits, while still respecting the + # block_shape divisibility constraints when present. + if block_shape is None: + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + } + else: + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + } else: config = { "BLOCK_SIZE_M": 64, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3b009ea4977b..5c2df000695f 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -436,6 +436,8 @@ def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor: return gating_output.softmax(dim=-1) elif scoring_func == "sigmoid": return gating_output.sigmoid() + elif scoring_func == "sqrtsoftplus": + return torch.nn.functional.softplus(gating_output).sqrt() else: raise ValueError(f"Invalid scoring function: {scoring_func}") @@ -451,13 +453,9 @@ def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor: assert ( hidden_states.shape[0] == gating_output.shape[0] ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" - M, _ = hidden_states.shape - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - topk_weights = scoring_func_impl(gating_output.float()) - topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + scores = scoring_func_impl(gating_output.float()) + topk_weights, topk_ids = torch.topk(scores, topk, dim=-1) + topk_ids = topk_ids.to(torch.int32) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py index 90a3de66f4aa..2f0d32d89ae9 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py +++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.utils import log_info_on_rank0 -from sglang.srt.utils.common import is_sm90_supported +from sglang.srt.utils.common import get_device_sm, is_sm90_supported if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput @@ -63,10 +63,43 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_mega_moe_weights_built", False): return - if not is_sm90_supported(): + _sm = get_device_sm() + if not is_sm90_supported() and _sm // 10 != 12: raise RuntimeError( "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." ) + + # SM120: Skip Marlin repacking, keep original weight format + # for PyTorch dequant fallback (Marlin kernel produces NaN on SM120) + if _sm // 10 == 12: + from torch.nn import Parameter + + log_info_on_rank0( + logger, + f"SM120 detected: using PyTorch MXFP4 MoE fallback " + f"(layer: {self.prefix})...", + ) + # Keep weights in original packed int8 format + # Normalize scales to float32 for direct use in dequant + w13_s = layer.w13_weight_scale_inv.data + w2_s = layer.w2_weight_scale_inv.data + if w13_s.dtype == torch.float8_e8m0fnu: + pass # already in e8m0 format, will convert at runtime + elif w13_s.dtype in (torch.uint8, torch.int8): + layer.w13_weight_scale_inv = Parameter( + w13_s.view(torch.uint8) + .view(torch.float8_e8m0fnu) + .to(torch.float32), + requires_grad=False, + ) + layer.w2_weight_scale_inv = Parameter( + w2_s.view(torch.uint8).view(torch.float8_e8m0fnu).to(torch.float32), + requires_grad=False, + ) + # else: float32 scales are already usable directly + layer._dsv4_mxfp4_backend = "sm120_fallback" + return + if not check_moe_marlin_supports_layer(layer, 32): raise RuntimeError( "Current DeepSeekV4 MoE layer does not satisfy Marlin constraints." @@ -99,6 +132,43 @@ def apply( if not TopKOutputChecker.format_is_standard(topk_output): raise ValueError(f"Unsupported topk output format: {topk_output.format}") + # SM120 fallback: use Triton fused dequant+GEMM (or PyTorch fallback) + if getattr(layer, "_dsv4_mxfp4_backend", None) == "sm120_fallback": + from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( + mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback, + ) + + hidden_states = dispatch_output.hidden_states + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] * 2 + + output = mxfp4_moe_forward_fallback( + hidden_states=hidden_states, + w13_packed=w13, + w2_packed=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + topk_ids=topk_output.topk_ids, + topk_weights=topk_output.topk_weights, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + routed_scaling_factor=( + self.runner.config.routed_scaling_factor + if hasattr(self.runner, "config") + else None + ), + clamp_limit=( + self.runner.config.swiglu_limit + if hasattr(self.runner, "config") + else None + ), + ) + return StandardCombineInput(hidden_states=output) + quant_info = MarlinMoeQuantInfo( w13_qweight=layer.w13_weight, w2_qweight=layer.w2_weight, diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 778556baaa6d..e450da6a4c7f 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -704,6 +704,127 @@ def forward( return o +# ----------------------------------------------------------------------------- +# Triton fallback for `hc_split_sinkhorn` (TileLang kernel in +# sglang.srt.layers.mhc). Used on platforms / builds where TileLang is +# unavailable. Functionally equivalent (numerically may differ within float +# eps). +# ----------------------------------------------------------------------------- + + +@triton.jit +def _hc_split_sinkhorn_triton_kernel( + mixes_ptr, # [N, (2 + HC) * HC] float32 + hc_scale_ptr, # [3] float32 + hc_base_ptr, # [(2 + HC) * HC] float32 + pre_ptr, # [N, HC] float32 + post_ptr, # [N, HC] float32 + comb_ptr, # [N, HC, HC] float32 + N, + HC: tl.constexpr, + SINKHORN_ITERS: tl.constexpr, + EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= N: + return + + HC2: tl.constexpr = HC * HC + MIX_HC: tl.constexpr = (2 + HC) * HC + + s0 = tl.load(hc_scale_ptr + 0) + s1 = tl.load(hc_scale_ptr + 1) + s2 = tl.load(hc_scale_ptr + 2) + + h = tl.arange(0, HC) + + pre_mix = tl.load(mixes_ptr + pid * MIX_HC + h) + pre_base = tl.load(hc_base_ptr + h) + pre = tl.sigmoid(pre_mix * s0 + pre_base) + EPS + tl.store(pre_ptr + pid * HC + h, pre) + + post_mix = tl.load(mixes_ptr + pid * MIX_HC + HC + h) + post_base = tl.load(hc_base_ptr + HC + h) + post = 2.0 * tl.sigmoid(post_mix * s1 + post_base) + tl.store(post_ptr + pid * HC + h, post) + + j = tl.arange(0, HC)[:, None] + k = tl.arange(0, HC)[None, :] + idx = j * HC + k + comb_mix = tl.load(mixes_ptr + pid * MIX_HC + 2 * HC + idx) + comb_base = tl.load(hc_base_ptr + 2 * HC + idx) + comb = comb_mix * s2 + comb_base + + row_max = tl.max(comb, axis=1)[:, None] + comb = tl.exp(comb - row_max) + row_sum = tl.sum(comb, axis=1)[:, None] + comb = comb / row_sum + EPS + + col_sum = tl.sum(comb, axis=0)[None, :] + comb = comb / (col_sum + EPS) + + for _ in tl.static_range(SINKHORN_ITERS - 1): + row_sum = tl.sum(comb, axis=1)[:, None] + comb = comb / (row_sum + EPS) + col_sum = tl.sum(comb, axis=0)[None, :] + comb = comb / (col_sum + EPS) + + tl.store(comb_ptr + pid * HC2 + idx, comb) + + +def _hc_split_sinkhorn_triton( + mixes: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + hc_mult: int, + sinkhorn_iters: int, + eps: float, +): + """Triton fallback for hc_split_sinkhorn. + + Same outputs (``pre``, ``post``, ``comb``) as the TileLang kernel in + ``sglang.srt.layers.mhc``. ``hc_mult`` must be a power of two for + ``tl.arange`` to be valid (typical values are 4 / 8). + """ + assert mixes.dtype == torch.float32, "mixes must be float32" + assert hc_scale.dtype == torch.float32 and hc_base.dtype == torch.float32 + assert hc_mult & (hc_mult - 1) == 0, "hc_mult must be a power of two" + assert sinkhorn_iters >= 1 + + b, s, last = mixes.size() + assert last == (2 + hc_mult) * hc_mult + + n = b * s + pre = mixes.new_empty(b, s, hc_mult) + post = mixes.new_empty(b, s, hc_mult) + comb = mixes.new_empty(b, s, hc_mult, hc_mult) + + if n == 0: + return pre, post, comb + + mixes_flat = mixes.reshape(n, (2 + hc_mult) * hc_mult).contiguous() + pre_flat = pre.view(n, hc_mult) + post_flat = post.view(n, hc_mult) + comb_flat = comb.view(n, hc_mult, hc_mult) + hc_base_c = hc_base.contiguous() + hc_scale_c = hc_scale.contiguous() + + _hc_split_sinkhorn_triton_kernel[(n,)]( + mixes_flat, + hc_scale_c, + hc_base_c, + pre_flat, + post_flat, + comb_flat, + n, + HC=hc_mult, + SINKHORN_ITERS=sinkhorn_iters, + EPS=eps, + num_warps=1, + ) + return pre, post, comb + + class DeepseekV4DecoderLayer(nn.Module): def __init__( self, @@ -839,16 +960,26 @@ def hc_pre_torch_impl(x, hc_fn): else: x_flat, mixes = hc_pre_torch_impl(x, hc_fn) - from sglang.srt.layers.mhc import hc_split_sinkhorn + if _is_cuda: + from sglang.srt.layers.mhc import hc_split_sinkhorn - pre, post, comb = hc_split_sinkhorn( - mixes, - hc_scale, - hc_base, - self.hc_mult, - self.hc_sinkhorn_iters, - self.hc_eps, - ) + pre, post, comb = hc_split_sinkhorn( + mixes, + hc_scale, + hc_base, + self.hc_mult, + self.hc_sinkhorn_iters, + self.hc_eps, + ) + else: + pre, post, comb = _hc_split_sinkhorn_triton( + mixes, + hc_scale, + hc_base, + self.hc_mult, + self.hc_sinkhorn_iters, + self.hc_eps, + ) y = (pre.squeeze(1).unsqueeze(-1) * x_flat.view(shape)).sum(dim=1) return y.to(dtype), post.squeeze(1), comb.squeeze(1), False diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 79d9a7584f2c..12d6d3890af1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1689,6 +1689,7 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: from sglang.srt.arg_groups.hisparse_hook import ( apply_hisparse_nsa_backend_defaults, ) + from sglang.srt.utils import is_sm120_supported user_set_prefill = self.nsa_prefill_backend is not None user_set_decode = self.nsa_decode_backend is not None @@ -1702,7 +1703,13 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: self.nsa_prefill_backend = "tilelang" self.nsa_decode_backend = "tilelang" elif kv_cache_dtype == "fp8_e4m3": - if major >= 10: + if is_sm120_supported(): + # SM120: trtllm does not support SM120; use tilelang for both paths. + if not user_set_prefill: + self.nsa_prefill_backend = "tilelang" + if not user_set_decode: + self.nsa_decode_backend = "tilelang" + elif major >= 10: if not user_set_prefill: self.nsa_prefill_backend = "trtllm" if not user_set_decode: @@ -1715,7 +1722,13 @@ def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: self.nsa_decode_backend = "flashmla_kv" else: # set prefill/decode backends based on hardware architecture. - if major >= 10: + if is_sm120_supported(): + # SM120: trtllm does not support SM120; use tilelang (portable) + if not user_set_prefill: + self.nsa_prefill_backend = "tilelang" + if not user_set_decode: + self.nsa_decode_backend = "tilelang" + elif major >= 10: if not user_set_prefill: self.nsa_prefill_backend = "flashmla_sparse" if not user_set_decode: @@ -1932,6 +1945,14 @@ def _handle_model_specific_adjustments(self): logger.info( "Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM" ) + elif is_sm120_supported(): + # SM120: DSv4-Flash uses MXFP4 experts; marlin backend dispatches + # to our SM120 Triton fallback in mxfp4_marlin_moe.py + if self.moe_runner_backend == "auto": + self.moe_runner_backend = "marlin" + logger.info( + "Use marlin as MoE runner backend on SM120 for DeepseekV3/V4" + ) elif is_hip(): if not self.enable_dp_attention and self.nnodes == 1: # TODO (Hubert): Put this back later diff --git a/python/sglang/test/test_sm120_mqa_fallback.py b/python/sglang/test/test_sm120_mqa_fallback.py new file mode 100644 index 000000000000..3e3a87c9fd10 --- /dev/null +++ b/python/sglang/test/test_sm120_mqa_fallback.py @@ -0,0 +1,279 @@ +""" +Unit tests for SM120 MQA fallback kernels. + +These tests verify correctness of the PyTorch-native fallback implementations +that replace DeepGEMM's fp8_paged_mqa_logits and fp8_mqa_logits on SM120. + +Run: python -m pytest python/sglang/test/test_sm120_mqa_fallback.py -v +""" + +import pytest +import torch + +from sglang.srt.layers.attention.nsa.sm120_mqa_fallback import ( + _dequant_fp8_with_scale_suffix, + compute_paged_mqa_schedule_metadata, + sm120_fp8_mqa_logits, + sm120_fp8_paged_mqa_logits, +) + + +def _make_fp8_with_scale(data_f32: torch.Tensor) -> torch.Tensor: + """Helper: pack float32 data into FP8 + appended scale suffix format. + + For testing, we use a scale of 1.0 so the FP8 values are the raw values. + The last 4 bytes of each row store the float32 scale. + """ + device = data_f32.device + shape = data_f32.shape + head_dim = shape[-1] + + # Clamp to FP8 E4M3 range + fp8_max = torch.finfo(torch.float8_e4m3fn).max + data_clamped = data_f32.clamp(-fp8_max, fp8_max) + data_fp8 = data_clamped.to(torch.float8_e4m3fn) + + # Scale = 1.0 as float32 -> 4 bytes + scale_val = torch.ones((*shape[:-1], 1), dtype=torch.float32, device=device) + scale_bytes = scale_val.view(torch.float8_e4m3fn) # reinterpret as 4 fp8 bytes + + # Concatenate: [data_fp8 | scale_bytes] + result = torch.cat([data_fp8, scale_bytes], dim=-1) + return result + + +class TestDequantFP8: + def test_roundtrip(self): + """Dequantized values should approximately match original float32.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + data = torch.randn(4, 128, device=device) + packed = _make_fp8_with_scale(data) + recovered = _dequant_fp8_with_scale_suffix(packed.unsqueeze(-2), 128) + recovered = recovered.squeeze(-2) + # FP8 E4M3 has limited precision, allow some tolerance + torch.testing.assert_close(recovered, data, atol=0.2, rtol=0.1) + + def test_scale_applied(self): + """Non-unity scale should be applied correctly.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + head_dim = 128 + data = torch.ones(2, head_dim, device=device) * 0.5 + data_fp8 = data.to(torch.float8_e4m3fn) + + # Scale = 2.0 + scale = torch.full((2, 1), 2.0, dtype=torch.float32, device=device) + scale_bytes = scale.view(torch.float8_e4m3fn) + packed = torch.cat([data_fp8, scale_bytes], dim=-1) + + result = _dequant_fp8_with_scale_suffix(packed.unsqueeze(-2), head_dim) + result = result.squeeze(-2) + expected = data.float() * 2.0 + torch.testing.assert_close(result, expected, atol=0.1, rtol=0.05) + + +class TestScheduleMetadata: + def test_returns_none(self): + """SM120 schedule metadata is always None (scheduling handled internally).""" + result = compute_paged_mqa_schedule_metadata( + torch.tensor([10, 20]), block_size=64, num_sms=84 + ) + assert result is None + + +class TestPagedMQALogits: + @pytest.fixture + def setup(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = "cuda" + batch = 2 + next_n = 1 + n_heads = 4 + head_dim = 128 + head_dim_with_sf = head_dim + 4 + block_kv = 64 + num_blocks = 8 + max_seq_len = 256 + + # Create random FP8 queries + q_raw = torch.randn(batch, next_n, n_heads, head_dim, device=device) * 0.1 + q_fp8 = _make_fp8_with_scale(q_raw) + + # Create random FP8 KV cache blocks + kv_raw = torch.randn(num_blocks, block_kv, 1, head_dim, device=device) * 0.1 + kv_fp8 = _make_fp8_with_scale(kv_raw) + + # Head weights + weights = torch.randn(batch, n_heads, device=device) + + # Sequence lengths + seqlens = torch.tensor([[100], [64]], dtype=torch.int32, device=device) + + # Block tables: batch 0 uses blocks [0,1], batch 1 uses blocks [2] + block_tables = torch.zeros(batch, 4, dtype=torch.int32, device=device) + block_tables[0, :2] = torch.tensor([0, 1]) + block_tables[1, :1] = torch.tensor([2]) + + return { + "q_fp8": q_fp8, + "kv_fp8": kv_fp8, + "weights": weights, + "seqlens": seqlens, + "block_tables": block_tables, + "max_seq_len": max_seq_len, + "batch": batch, + "next_n": next_n, + } + + def test_output_shape(self, setup): + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + expected_shape = ( + setup["batch"] * setup["next_n"], + setup["max_seq_len"], + ) + assert logits.shape == expected_shape + + def test_masked_positions_are_neginf(self, setup): + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + # Positions beyond seq_len should be -inf + seq_len_0 = setup["seqlens"][0, 0].item() + assert torch.all(logits[0, seq_len_0:] == float("-inf")) + + def test_valid_positions_are_finite(self, setup): + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + seq_len_0 = setup["seqlens"][0, 0].item() + assert torch.all(torch.isfinite(logits[0, :seq_len_0])) + + def test_zero_seqlen(self, setup): + """Batch element with zero seqlen should produce all -inf.""" + setup["seqlens"][1, 0] = 0 + logits = sm120_fp8_paged_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["seqlens"], + setup["block_tables"], + schedule_metadata=None, + max_seq_len=setup["max_seq_len"], + ) + assert torch.all(logits[1] == float("-inf")) + + +class TestContiguousMQALogits: + @pytest.fixture + def setup(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = "cuda" + num_q = 4 + n_heads = 4 + head_dim = 128 + head_dim_with_sf = head_dim + 4 + num_k = 200 + + # Queries with scale suffix + q_raw = torch.randn(num_q, n_heads, head_dim, device=device) * 0.1 + q_fp8 = _make_fp8_with_scale(q_raw) + + # KV with scale suffix + k_raw = torch.randn(num_k, head_dim, device=device) * 0.1 + k_fp8 = _make_fp8_with_scale(k_raw.unsqueeze(-2)).squeeze(-2) + k_scale = torch.ones(num_k, device=device) + + # Weights + weights = torch.randn(num_q, n_heads, device=device) + + # Ragged ranges + ks = torch.tensor([0, 50, 100, 150], dtype=torch.int32, device=device) + ke = torch.tensor([50, 100, 150, 200], dtype=torch.int32, device=device) + + return { + "q_fp8": q_fp8, + "kv_fp8": (k_fp8, k_scale), + "weights": weights, + "ks": ks, + "ke": ke, + "num_q": num_q, + "num_k": num_k, + } + + def test_output_shape(self, setup): + logits = sm120_fp8_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["ks"], + setup["ke"], + ) + assert logits.shape[0] == setup["num_q"] + assert logits.shape[1] >= setup["num_k"] + + def test_masked_outside_range(self, setup): + logits = sm120_fp8_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["ks"], + setup["ke"], + ) + # For q=0: valid range [0, 50), positions [50, num_k) should be -inf + assert torch.all(logits[0, 50 : setup["num_k"]] == float("-inf")) + + def test_valid_inside_range(self, setup): + logits = sm120_fp8_mqa_logits( + setup["q_fp8"], + setup["kv_fp8"], + setup["weights"], + setup["ks"], + setup["ke"], + ) + # For q=0: valid range [0, 50), should be finite + assert torch.all(torch.isfinite(logits[0, :50])) + + def test_empty_input(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + device = "cuda" + q_fp8 = torch.zeros(0, 4, 132, dtype=torch.float8_e4m3fn, device=device) + k_fp8 = torch.zeros(10, 132, dtype=torch.float8_e4m3fn, device=device) + k_scale = torch.ones(10, device=device) + weights = torch.zeros(0, 4, device=device) + ks = torch.zeros(0, dtype=torch.int32, device=device) + ke = torch.zeros(0, dtype=torch.int32, device=device) + + logits = sm120_fp8_mqa_logits(q_fp8, (k_fp8, k_scale), weights, ks, ke) + assert logits.shape[0] == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])