From f66524674e37944c8a6f3a365c422a59781b40d4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 20 Apr 2026 19:49:08 -0700 Subject: [PATCH 01/11] Initial impl --- docs/advanced_features/server_arguments.md | 1 + python/sglang/srt/environ.py | 6 + .../srt/layers/attention/dsa_backend.py | 286 ++++++++++++++-- python/sglang/srt/server_args.py | 23 ++ test/registered/kernels/test_dsa_indexer.py | 311 +++++++++++++++++- 5 files changed, 591 insertions(+), 36 deletions(-) diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index a880f518c13b..c042f0f73703 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -273,6 +273,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--mm-attention-backend` | Set multimodal attention backend. | `None` | `sdpa`, `fa3`, `fa4`, `triton_attn`, `ascend_attn`, `aiter_attn` | | `--dsa-prefill-backend` | Choose the DSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek DSA-style attention). `--nsa-prefill-backend` is a deprecated alias. | `flashmla_sparse` | `flashmla_sparse`, `flashmla_kv`, `flashmla_auto`, `fa3`, `tilelang`, `aiter`, `trtllm` | | `--dsa-decode-backend` | Choose the DSA backend for the decode stage when running DeepSeek DSA-style attention. Overrides `--attention-backend` for decoding. `--nsa-decode-backend` is a deprecated alias. | `fa3` | `flashmla_sparse`, `flashmla_kv`, `fa3`, `tilelang`, `aiter`, `trtllm` | +| `--dsa-topk-backend` | Choose the DSA indexer top-k backend. `--nsa-topk-backend` is a deprecated alias. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. | `sgl-kernel` | `sgl-kernel`, `torch`, `flashinfer` | | `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only).| `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_deepgemm`, `cutlass`, `triton`, `aiter` | | `--fp4-gemm-backend` | Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback.| `flashinfer_cutlass` | `auto`, `flashinfer_cudnn`, `flashinfer_cutlass`, `flashinfer_trtllm` | | `--disable-flashinfer-autotune` | Flashinfer autotune is enabled by default. Set this flag to disable the autotune. | `False` | bool flag (set to enable) | diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 238d8d176e30..54f728aa86d7 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -465,6 +465,12 @@ class Envs: # DSA Backend (canonical names; fall back to SGLANG_NSA_* with deprecation warning) SGLANG_DSA_FUSE_TOPK = EnvBoolWithAlias(True, deprecated_name="SGLANG_NSA_FUSE_TOPK") + SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC = EnvBoolWithAlias( + False, deprecated_name="SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC" + ) + SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvIntWithAlias( + 0, deprecated_name="SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK" + ) SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA = EnvBoolWithAlias( True, deprecated_name="SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA" ) diff --git a/python/sglang/srt/layers/attention/dsa_backend.py b/python/sglang/srt/layers/attention/dsa_backend.py index a2c062b91f08..a017fd98ec16 100644 --- a/python/sglang/srt/layers/attention/dsa_backend.py +++ b/python/sglang/srt/layers/attention/dsa_backend.py @@ -1,8 +1,17 @@ from __future__ import annotations from dataclasses import dataclass -from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, TypeAlias +from enum import Enum, IntEnum, auto +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + TypeAlias, +) import torch @@ -86,6 +95,46 @@ def _to_2d_context_lens(seqlens_32: torch.Tensor, batch_size: int) -> torch.Tens _USE_FUSED_METADATA_COPY = envs.SGLANG_USE_FUSED_METADATA_COPY.get() and not _is_hip +def _dsa_topk_unfused( + score: torch.Tensor, + lengths: torch.Tensor, + topk: int, + row_starts: Optional[torch.Tensor] = None, + topk_op: Callable[..., Tuple[torch.Tensor, torch.Tensor]] = torch.topk, + topk_op_kwargs: Optional[Dict[str, object]] = None, +) -> torch.Tensor: + batch_size, max_score_len = score.shape + topk_indices = score.new_full((batch_size, topk), -1, dtype=torch.int32) + if batch_size == 0 or topk == 0 or max_score_len == 0: + return topk_indices + + if row_starts is None: + row_starts = torch.zeros_like(lengths, dtype=torch.int32, device=score.device) + else: + row_starts = row_starts.to(dtype=torch.int32, device=score.device) + lengths = lengths.to(dtype=torch.int32, device=score.device) + + col_indices = torch.arange(max_score_len, dtype=torch.int32, device=score.device) + col_indices = col_indices.unsqueeze(0) + row_starts_unsqueezed = row_starts.unsqueeze(1) + row_ends_unsqueezed = (row_starts + lengths).unsqueeze(1) + valid_mask = (col_indices >= row_starts_unsqueezed) & ( + col_indices < row_ends_unsqueezed + ) + + masked_logits = score.masked_fill(~valid_mask, float("-inf")) + valid_topk = min(topk, max_score_len) + topk_kwargs = topk_op_kwargs or {} + topk_scores, topk_col_indices = topk_op(masked_logits, valid_topk, **topk_kwargs) + topk_local_indices = topk_col_indices.to(torch.int32) - row_starts_unsqueezed + topk_local_indices = topk_local_indices.masked_fill( + topk_scores == float("-inf"), -1 + ) + topk_indices[:, :valid_topk] = topk_local_indices + + return topk_indices + + @dataclass(frozen=True) class DSAFlashMLAMetadata: """Metadata only needed by FlashMLA""" @@ -165,6 +214,21 @@ class TopkTransformMethod(IntEnum): RAGGED = auto() +class DSATopKBackend(Enum): + SGL_KERNEL = "sgl-kernel" + TORCH = "torch" + FLASHINFER = "flashinfer" + + def is_sgl_kernel(self) -> bool: + return self == DSATopKBackend.SGL_KERNEL + + def is_torch(self) -> bool: + return self == DSATopKBackend.TORCH + + def is_flashinfer(self) -> bool: + return self == DSATopKBackend.FLASHINFER + + @torch.compile def _compiled_cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor: return torch.cat(tensors, dim=dim) @@ -190,6 +254,7 @@ def _cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor: class DSAIndexerMetadata(BaseIndexerMetadata): attn_metadata: DSAMetadata topk_transform_method: TopkTransformMethod + topk_backend: DSATopKBackend = DSATopKBackend.SGL_KERNEL paged_mqa_schedule_metadata: Optional[torch.Tensor] = None force_unfused_topk: bool = False @@ -223,6 +288,47 @@ def get_dsa_extend_len_cpu(self) -> List[int]: def get_token_to_batch_idx(self) -> torch.Tensor: return self.attn_metadata.token_to_batch_idx + def _build_flashinfer_paged_args( + self, + ks: Optional[torch.Tensor], + cu_seqlens_q_topk: Optional[torch.Tensor], + batch_idx_list: Optional[List[int]], + device: torch.device, + num_rows: Optional[int] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + row_to_batch = ( + torch.as_tensor(batch_idx_list, dtype=torch.int32, device=device) + if batch_idx_list is not None + else None + ) + + if row_to_batch is None and cu_seqlens_q_topk is not None: + # Decode-like case (one query row per batch) does not need an explicit mapping. + # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. + num_batches = cu_seqlens_q_topk.shape[0] - 1 + if not (ks is None and num_rows is not None and num_rows == num_batches): + q_lens = torch.diff(cu_seqlens_q_topk).to( + dtype=torch.int32, device=device + ) + row_to_batch = torch.repeat_interleave( + torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), + q_lens, + ) + + if ks is not None and row_to_batch is None: + raise RuntimeError( + "PAGED topk_transform with row_starts requires cu_seqlens_q metadata." + ) + + row_starts = ks + if row_starts is not None and row_to_batch is not None: + batch_base = self.attn_metadata.cu_seqlens_k.to( + dtype=torch.int32, device=device + )[:-1] + row_starts = row_starts - batch_base[row_to_batch] + + return row_to_batch, row_starts + def topk_transform( self, logits: torch.Tensor, @@ -233,12 +339,6 @@ def topk_transform( batch_idx_list: List[int] = None, topk_indices_offset_override: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from sgl_kernel import ( - fast_topk_transform_fused, - fast_topk_transform_ragged_fused, - fast_topk_v2, - ) - if topk_indices_offset_override is not None: cu_topk_indices_offset = topk_indices_offset_override cu_seqlens_q_topk = None @@ -262,32 +362,114 @@ def topk_transform( page_table_size_1 = self.attn_metadata.page_table_1 if not envs.SGLANG_DSA_FUSE_TOPK.get() or self.force_unfused_topk: - return fast_topk_v2(logits, seq_lens_topk, topk, row_starts=ks) - elif self.topk_transform_method == TopkTransformMethod.PAGED: - # NOTE(dark): if fused, we return a transformed page table directly - return fast_topk_transform_fused( - score=logits, - lengths=seq_lens_topk, - page_table_size_1=page_table_size_1, - cu_seqlens_q=cu_seqlens_q_topk, - topk=topk, - row_starts=ks, - ) - elif self.topk_transform_method == TopkTransformMethod.RAGGED: - if cu_topk_indices_offset is None: - raise RuntimeError( - "RAGGED topk_transform requires topk_indices_offset; " - "expected extend-without-speculative metadata." + # Unfused topk + if self.topk_backend.is_sgl_kernel(): + from sgl_kernel import fast_topk_v2 + + return fast_topk_v2(logits, seq_lens_topk, topk, row_starts=ks) + elif self.topk_backend.is_torch(): + return _dsa_topk_unfused( + logits, + seq_lens_topk, + topk, + row_starts=ks, + topk_op=torch.topk, + topk_op_kwargs={"dim": -1}, + ) + elif self.topk_backend.is_flashinfer(): + import flashinfer + + return _dsa_topk_unfused( + logits, + seq_lens_topk, + topk, + row_starts=ks, + topk_op=flashinfer.top_k, + topk_op_kwargs={ + "sorted": False, + "deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + "tie_break": envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get(), + "dsa_graph_safe": True, + }, ) - return fast_topk_transform_ragged_fused( - score=logits, - lengths=seq_lens_topk, - topk_indices_offset=cu_topk_indices_offset, - topk=topk, - row_starts=ks, - ) else: - assert False, f"Unsupported {self.topk_transform_method = }" + # Fused topk + if self.topk_backend.is_sgl_kernel(): + from sgl_kernel import ( + fast_topk_transform_fused, + fast_topk_transform_ragged_fused, + ) + + if self.topk_transform_method == TopkTransformMethod.PAGED: + # NOTE(dark): if fused, we return a transformed page table directly + return fast_topk_transform_fused( + score=logits, + lengths=seq_lens_topk, + page_table_size_1=page_table_size_1, + cu_seqlens_q=cu_seqlens_q_topk, + topk=topk, + row_starts=ks, + ) + elif self.topk_transform_method == TopkTransformMethod.RAGGED: + if cu_topk_indices_offset is None: + raise RuntimeError( + "RAGGED topk_transform requires topk_indices_offset; " + "expected extend-without-speculative metadata." + ) + return fast_topk_transform_ragged_fused( + score=logits, + lengths=seq_lens_topk, + topk_indices_offset=cu_topk_indices_offset, + topk=topk, + row_starts=ks, + ) + else: + assert False, f"Unsupported {self.topk_transform_method = }" + elif self.topk_backend.is_flashinfer(): + import flashinfer + + if self.topk_transform_method == TopkTransformMethod.PAGED: + row_to_batch, row_starts = self._build_flashinfer_paged_args( + ks=ks, + cu_seqlens_q_topk=cu_seqlens_q_topk, + batch_idx_list=batch_idx_list, + device=logits.device, + num_rows=logits.shape[0], + ) + + return flashinfer.top_k_page_table_transform( + logits.contiguous(), + self.attn_metadata.page_table_1.contiguous(), + seq_lens_topk.contiguous(), + topk, + row_to_batch=row_to_batch, + deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + tie_break=envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get(), + dsa_graph_safe=True, + row_starts=row_starts, + ) + elif self.topk_transform_method == TopkTransformMethod.RAGGED: + if cu_topk_indices_offset is None: + raise RuntimeError( + "RAGGED topk_transform requires topk_indices_offset; " + "expected extend-without-speculative metadata." + ) + return flashinfer.top_k_ragged_transform( + logits.contiguous(), + cu_topk_indices_offset.contiguous(), + seq_lens_topk.contiguous(), + topk, + deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + tie_break=envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get(), + dsa_graph_safe=True, + row_starts=ks, + ) + else: + assert False, f"Unsupported {self.topk_transform_method = }" + else: + assert ( + False + ), f"Unsupported {self.topk_backend = } for SGLANG_DSA_FUSE_TOPK." _DSA_IMPL_T: TypeAlias = Literal[ @@ -337,6 +519,9 @@ def __init__( model_runner.server_args.dsa_prefill_backend ) self.dsa_decode_impl: _DSA_IMPL_T = model_runner.server_args.dsa_decode_backend + self.dsa_topk_backend: DSATopKBackend = DSATopKBackend( + model_runner.server_args.dsa_topk_backend + ) if self.num_q_heads <= 64: self.flashmla_kv_num_q_heads = 64 elif self.num_q_heads <= 128: @@ -1427,7 +1612,15 @@ def forward_extend( forward_batch.forward_mode ) if envs.SGLANG_DSA_FUSE_TOPK.get(): - page_table_1 = topk_indices + if ( + self.dsa_topk_backend.is_sgl_kernel() + or self.dsa_topk_backend.is_flashinfer() + ): + page_table_1 = topk_indices + else: + assert ( + False + ), f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." else: if topk_transform_method == TopkTransformMethod.RAGGED: topk_indices_offset = metadata.topk_indices_offset @@ -1617,7 +1810,15 @@ def forward_decode( layer.layer_id, ) elif envs.SGLANG_DSA_FUSE_TOPK.get(): - page_table_1 = topk_indices + if ( + self.dsa_topk_backend.is_sgl_kernel() + or self.dsa_topk_backend.is_flashinfer() + ): + page_table_1 = topk_indices + else: + assert ( + False + ), f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." else: page_table_1 = transform_index_page_table_decode( page_table=metadata.page_table_1, @@ -2126,7 +2327,15 @@ def _forward_trtllm( topk_indices = self._pad_topk_indices(topk_indices, q.shape[0]) if envs.SGLANG_DSA_FUSE_TOPK.get(): - page_table_1 = topk_indices + if ( + self.dsa_topk_backend.is_sgl_kernel() + or self.dsa_topk_backend.is_flashinfer() + ): + page_table_1 = topk_indices + else: + assert ( + False + ), f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." elif is_prefill: page_table_1 = transform_index_page_table_prefill( page_table=metadata.page_table_1, @@ -2280,6 +2489,7 @@ def get_indexer_metadata( topk_transform_method=self.get_topk_transform_method( forward_batch.forward_mode ), + topk_backend=self.dsa_topk_backend, paged_mqa_schedule_metadata=self.forward_metadata.paged_mqa_schedule_metadata, force_unfused_topk=force_unfused, ) @@ -2526,3 +2736,9 @@ def init_forward_metadata_replay_cuda_graph( DSAMetadata = DSAMetadata DSAFlashMLAMetadata = DSAFlashMLAMetadata DSAIndexerMetadata = DSAIndexerMetadata +NativeSparseAttnBackend = DeepseekSparseAttnBackend +NativeSparseAttnMultiStepBackend = DeepseekSparseAttnMultiStepBackend +NSAMetadata = DSAMetadata +NSAFlashMLAMetadata = DSAFlashMLAMetadata +NSAIndexerMetadata = DSAIndexerMetadata +NSATopKBackend = DSATopKBackend diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ae4cb7c1211d..8e839bdb9813 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -266,6 +266,9 @@ ] NSA_CHOICES = DSA_CHOICES # deprecated alias +DSA_TOPK_BACKEND_CHOICES = ["sgl-kernel", "torch", "flashinfer"] +NSA_TOPK_BACKEND_CHOICES = DSA_TOPK_BACKEND_CHOICES # deprecated alias + MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"] MAMBA_BACKEND_CHOICES = ["triton", "flashinfer"] @@ -545,6 +548,7 @@ class ServerArgs: dsa_decode_backend: Optional[str] = ( None # auto-detect based on hardware/kv_cache_dtype ) + dsa_topk_backend: str = "sgl-kernel" disable_flashinfer_autotune: bool = False mamba_backend: str = "triton" @@ -5424,6 +5428,25 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=DSA_CHOICES, help="[Deprecated] Use --dsa-decode-backend instead.", ) + parser.add_argument( + "--dsa-topk-backend", + dest="dsa_topk_backend", + default=ServerArgs.dsa_topk_backend, + type=str, + choices=DSA_TOPK_BACKEND_CHOICES, + help="DSA indexer top-k backend. Options: 'sgl-kernel', 'torch', 'flashinfer'. " + "The 'torch' backend currently requires SGLANG_DSA_FUSE_TOPK=false.", + ) + parser.add_argument( + "--nsa-topk-backend", + dest="dsa_topk_backend", + action=DeprecatedAliasStoreAction, + new_flag="--dsa-topk-backend", + default=argparse.SUPPRESS, + type=str, + choices=DSA_TOPK_BACKEND_CHOICES, + help="[Deprecated] Use --dsa-topk-backend instead.", + ) parser.add_argument( "--fp8-gemm-backend", type=str, diff --git a/test/registered/kernels/test_dsa_indexer.py b/test/registered/kernels/test_dsa_indexer.py index 09021180fa17..bdeba8aa1439 100644 --- a/test/registered/kernels/test_dsa_indexer.py +++ b/test/registered/kernels/test_dsa_indexer.py @@ -4,6 +4,7 @@ import torch +from sglang.srt.environ import envs from sglang.srt.layers import dp_attention as _dp_attn from sglang.test.ci.ci_register import register_cuda_ci @@ -16,7 +17,13 @@ Indexer, rotate_activation, ) -from sglang.srt.layers.attention.dsa_backend import DeepseekSparseAttnBackend +from sglang.srt.layers.attention.dsa_backend import ( + DeepseekSparseAttnBackend, + DSAIndexerMetadata, + DSAMetadata, + DSATopKBackend, + TopkTransformMethod, +) from sglang.srt.layers.layernorm import LayerNorm from sglang.srt.layers.linear import LinearBase from sglang.srt.mem_cache.memory_pool import DSATokenToKVPool @@ -250,6 +257,7 @@ def __init__(self, config=None): "enable_deterministic_inference": False, "dsa_prefill_backend": "flashmla_sparse", "dsa_decode_backend": "fa3", + "dsa_topk_backend": "sgl-kernel", }, )() @@ -417,6 +425,246 @@ def _verify_topk_output(self, topk_indices, batch_size, q_len, topk): "Output should have padding or exact topk size", ) + def _make_tie_free_logits(self, batch_size: int, max_score_len: int) -> torch.Tensor: + perm = torch.argsort( + torch.randn( + batch_size, max_score_len, dtype=torch.float32, device=self.device + ), + dim=-1, + ) + return torch.gather( + torch.arange(max_score_len, device=self.device, dtype=torch.float32) + .unsqueeze(0) + .expand(batch_size, -1), + dim=1, + index=perm, + ) + + def _run_unfused_topk_backend_validity_test( + self, + batch_size: int, + max_score_len: int, + topk: int, + topk_backend: DSATopKBackend, + with_row_starts: bool, + ): + logits = self._make_tie_free_logits(batch_size, max_score_len) + + if with_row_starts: + row_starts = torch.randint( + 0, + max_score_len - 1, + (batch_size,), + dtype=torch.int32, + device=self.device, + ) + max_lengths = max_score_len - row_starts + random_lengths = torch.randint( + 0, + max_score_len - 1, + (batch_size,), + dtype=torch.int32, + device=self.device, + ) + seq_lens_expanded = torch.minimum(max_lengths, random_lengths) + else: + row_starts = None + seq_lens_expanded = torch.randint( + 0, + max_score_len - 1, + (batch_size,), + dtype=torch.int32, + device=self.device, + ) + + seq_lens_expanded = seq_lens_expanded.to(dtype=torch.int32, device=self.device) + max_seq_len_k = int(seq_lens_expanded.max().item()) + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=self.device + ) + dsa_cu_seqlens_k = torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ) + dsa_cu_seqlens_k[1:] = torch.cumsum(seq_lens_expanded, dim=0) + page_table_1 = ( + torch.arange(max_seq_len_k, dtype=torch.int32, device=self.device) + .unsqueeze(0) + .expand(batch_size, -1) + .contiguous() + ) + metadata = DSAIndexerMetadata( + attn_metadata=DSAMetadata( + page_size=1, + cache_seqlens_int32=seq_lens_expanded.clone(), + max_seq_len_q=1, + max_seq_len_k=max_seq_len_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q.clone(), + page_table_1=page_table_1, + real_page_table=page_table_1, + dsa_cache_seqlens_int32=seq_lens_expanded.clone(), + dsa_cu_seqlens_q=cu_seqlens_q.clone(), + dsa_cu_seqlens_k=dsa_cu_seqlens_k, + dsa_extend_seq_lens_list=seq_lens_expanded.cpu().tolist(), + dsa_seqlens_expanded=seq_lens_expanded, + ), + topk_transform_method=TopkTransformMethod.PAGED, + topk_backend=topk_backend, + ) + + with envs.SGLANG_DSA_FUSE_TOPK.override(False): + topk_test = metadata.topk_transform(logits, topk, ks=row_starts) + self.assertEqual(topk_test.shape, (batch_size, topk)) + self.assertEqual(topk_test.dtype, torch.int32) + expected_valid = torch.minimum( + seq_lens_expanded, + torch.full_like(seq_lens_expanded, topk), + ) + actual_valid = (topk_test >= 0).sum(dim=-1).to(torch.int32) + self.assertTrue(torch.equal(actual_valid, expected_valid)) + + starts = ( + row_starts.to(torch.int32) + if row_starts is not None + else torch.zeros( + (topk_test.shape[0],), dtype=torch.int32, device=topk_test.device + ) + ) + for row in range(topk_test.shape[0]): + test_row = topk_test[row] + valid_test = test_row[test_row >= 0] + expected_k = int(expected_valid[row].item()) + self.assertEqual(valid_test.numel(), expected_k) + if expected_k == 0: + continue + start = int(starts[row].item()) + row_len = int(seq_lens_expanded[row].item()) + self.assertTrue(torch.all((valid_test >= 0) & (valid_test < row_len))) + self.assertEqual(torch.unique(valid_test).numel(), valid_test.numel()) + + row_scores = logits[row, start : start + row_len] + ref_topk = torch.topk(row_scores, expected_k, dim=-1, sorted=False).indices + self.assertTrue( + torch.equal( + torch.sort(valid_test.to(torch.int32)).values, + torch.sort(ref_topk.to(torch.int32)).values, + ) + ) + + def _run_fused_topk_backend_equivalence_test( + self, + batch_size: int, + max_score_len: int, + topk: int, + topk_transform_method: TopkTransformMethod, + with_row_starts: bool, + ): + logits = self._make_tie_free_logits(batch_size, max_score_len) + + if with_row_starts: + row_starts = torch.randint( + 0, + max_score_len - 1, + (batch_size,), + dtype=torch.int32, + device=self.device, + ) + max_lengths = max_score_len - row_starts + random_lengths = torch.randint( + 1, + max_score_len, + (batch_size,), + dtype=torch.int32, + device=self.device, + ) + seq_lens_expanded = torch.minimum(max_lengths, random_lengths) + else: + row_starts = None + seq_lens_expanded = torch.randint( + 1, + max_score_len, + (batch_size,), + dtype=torch.int32, + device=self.device, + ) + + topk_indices_offset = ( + torch.arange(batch_size, dtype=torch.int32, device=self.device) * 1024 + ) + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=self.device + ) + cu_seqlens_k = torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ) + dsa_cu_seqlens_k = torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ) + dsa_cu_seqlens_k[1:] = torch.cumsum(seq_lens_expanded, dim=0) + + page_table_1 = ( + ( + torch.arange(max_score_len, dtype=torch.int32, device=self.device) + .unsqueeze(0) + .expand(batch_size, -1) + ) + + ( + torch.arange( + batch_size, dtype=torch.int32, device=self.device + ).unsqueeze(1) + * max_score_len + ) + ).contiguous() + + attn_metadata = DSAMetadata( + page_size=1, + cache_seqlens_int32=seq_lens_expanded.clone(), + max_seq_len_q=1, + max_seq_len_k=max_score_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + page_table_1=page_table_1, + real_page_table=page_table_1, + dsa_cache_seqlens_int32=seq_lens_expanded.clone(), + dsa_cu_seqlens_q=cu_seqlens_q.clone(), + dsa_cu_seqlens_k=dsa_cu_seqlens_k, + dsa_extend_seq_lens_list=seq_lens_expanded.cpu().tolist(), + dsa_seqlens_expanded=seq_lens_expanded, + topk_indices_offset=( + topk_indices_offset + if topk_transform_method == TopkTransformMethod.RAGGED + else None + ), + ) + + metadata_sgl = DSAIndexerMetadata( + attn_metadata=attn_metadata, + topk_transform_method=topk_transform_method, + topk_backend=DSATopKBackend.SGL_KERNEL, + ) + metadata_flashinfer = DSAIndexerMetadata( + attn_metadata=attn_metadata, + topk_transform_method=topk_transform_method, + topk_backend=DSATopKBackend.FLASHINFER, + ) + + with envs.SGLANG_DSA_FUSE_TOPK.override(True): + out_sgl = metadata_sgl.topk_transform(logits, topk, ks=row_starts) + out_flashinfer = metadata_flashinfer.topk_transform( + logits, topk, ks=row_starts + ) + + self.assertEqual(out_sgl.shape, out_flashinfer.shape) + self.assertEqual(out_sgl.dtype, out_flashinfer.dtype) + self.assertEqual(out_sgl.dtype, torch.int32) + + self.assertTrue( + torch.equal( + torch.sort(out_sgl, dim=-1).values, + torch.sort(out_flashinfer, dim=-1).values, + ) + ) + @patch("sglang.srt.layers.attention.dsa.dsa_indexer.deep_gemm") def test_indexer_basic_creation(self, mock_deep_gemm): """Test basic indexer creation and initialization.""" @@ -626,6 +874,67 @@ def test_indexer_metadata_interface(self): topk_indices = metadata.topk_transform(logits, topk) self.assertEqual(topk_indices.shape, (batch_size, topk)) + def test_topk_backends_unfused(self): + batch_size = 8 + max_score_len = 16 * 1024 + topk = 2048 + for topk_backend in [ + DSATopKBackend.SGL_KERNEL, + DSATopKBackend.TORCH, + DSATopKBackend.FLASHINFER, + ]: + tie_break_values = ( + [0, 1, 2] if topk_backend == DSATopKBackend.FLASHINFER else [0] + ) + for tie_break in tie_break_values: + for with_row_starts in [False, True]: + with self.subTest( + topk_backend=topk_backend.value, + tie_break=tie_break, + with_row_starts=with_row_starts, + ): + with envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.override( + tie_break + ): + self._run_unfused_topk_backend_validity_test( + batch_size, + max_score_len, + topk, + topk_backend=topk_backend, + with_row_starts=with_row_starts, + ) + + def test_topk_backends_fused(self): + batch_size = 8 + max_score_len = 16 * 1024 + topk = 2048 + for tie_break in [0, 1, 2]: + for topk_transform_method in [ + TopkTransformMethod.PAGED, + TopkTransformMethod.RAGGED, + ]: + for with_row_starts in [False, True]: + if ( + topk_transform_method == TopkTransformMethod.PAGED + and with_row_starts + ): + continue + with self.subTest( + tie_break=tie_break, + topk_transform_method=topk_transform_method.name, + with_row_starts=with_row_starts, + ): + with envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.override( + tie_break + ): + self._run_fused_topk_backend_equivalence_test( + batch_size=batch_size, + max_score_len=max_score_len, + topk=topk, + topk_transform_method=topk_transform_method, + with_row_starts=with_row_starts, + ) + # TODO: enable this test after indexer accuracy aligned # @patch("sglang.srt.layers.attention.dsa.dsa_indexer.deep_gemm") # def test_indexer_with_different_topk(self, mock_deep_gemm): From 4d052c0286844072b68188cdc2b2cfe9315e73d1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 14 May 2026 11:01:42 -0700 Subject: [PATCH 02/11] Add docs --- docs/references/environment_variables.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 87e085880231..d9a5038d2590 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its | Environment Variable | Description | Default Value | | --- | --- | --- | | `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` | +| `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` (`SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC` is a deprecated alias) | `false` | +| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`. Valid values are `0`, `1`, and `2`; non-zero values imply deterministic topk in FlashInfer. (`SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK` is a deprecated alias) | `0` | | `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` | | `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` | | `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` | From c70bc30da49f0777f03fc0b0b81e6ee0ca796e8d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 16 May 2026 11:36:54 -0700 Subject: [PATCH 03/11] Add docs_new --- docs_new/docs/advanced_features/server_arguments.mdx | 6 ++++++ docs_new/docs/references/environment_variables.mdx | 10 ++++++++++ 2 files changed, 16 insertions(+) diff --git a/docs_new/docs/advanced_features/server_arguments.mdx b/docs_new/docs/advanced_features/server_arguments.mdx index d40f0dce0d68..b6b87c734a33 100644 --- a/docs_new/docs/advanced_features/server_arguments.mdx +++ b/docs_new/docs/advanced_features/server_arguments.mdx @@ -1206,6 +1206,12 @@ Please consult the documentation below and [server_args.py](https://github.com/s `fa3` flashmla_sparse, flashmla_kv, fa3, tilelang, aiter, trtllm + + `--dsa-topk-backend` + Choose the DSA indexer top-k backend. `--nsa-topk-backend` is a deprecated alias. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. + `sgl-kernel` + sgl-kernel, torch, flashinfer + `--fp8-gemm-backend` Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only). diff --git a/docs_new/docs/references/environment_variables.mdx b/docs_new/docs/references/environment_variables.mdx index 61dae0e07095..e92fd35bdfa6 100644 --- a/docs_new/docs/references/environment_variables.mdx +++ b/docs_new/docs/references/environment_variables.mdx @@ -416,6 +416,16 @@ SGLang supports various environment variables that can be used to configure its Fuse the operation of picking topk logits and picking topk indices from page table. SGLANG_NSA_FUSE_TOPK is a deprecated alias. true + + SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC + Use deterministic FlashInfer topk kernels when --dsa-topk-backend=flashinfer. SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC is a deprecated alias. + false + + + SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK + Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer. Valid values are 0, 1, and 2; non-zero values imply deterministic topk in FlashInfer. SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK is a deprecated alias. + 0 + SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA Precompute metadata that can be shared among different draft steps when MTP is enabled. SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA is a deprecated alias. From 92a2e688a72593543b446d274f5c2174ca76cd7c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 01:37:33 -0700 Subject: [PATCH 04/11] Minor reformat --- test/registered/kernels/test_dsa_indexer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/registered/kernels/test_dsa_indexer.py b/test/registered/kernels/test_dsa_indexer.py index bdeba8aa1439..04e0c9ea57d1 100644 --- a/test/registered/kernels/test_dsa_indexer.py +++ b/test/registered/kernels/test_dsa_indexer.py @@ -425,7 +425,9 @@ def _verify_topk_output(self, topk_indices, batch_size, q_len, topk): "Output should have padding or exact topk size", ) - def _make_tie_free_logits(self, batch_size: int, max_score_len: int) -> torch.Tensor: + def _make_tie_free_logits( + self, batch_size: int, max_score_len: int + ) -> torch.Tensor: perm = torch.argsort( torch.randn( batch_size, max_score_len, dtype=torch.float32, device=self.device From 660d7ef6d25d427dc3d7ac6276f87912197dbd4c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 01:58:56 -0700 Subject: [PATCH 05/11] Clean up after rebase --- docs/advanced_features/server_arguments.md | 2 +- docs/references/environment_variables.md | 4 +- .../advanced_features/server_arguments.mdx | 2 +- .../docs/references/environment_variables.mdx | 4 +- python/sglang/srt/environ.py | 8 +-- .../srt/layers/attention/dsa_backend.py | 70 ++++++++----------- .../srt/layers/attention/nsa_backend.py | 5 -- python/sglang/srt/server_args.py | 11 --- test/registered/kernels/test_dsa_indexer.py | 65 +++++++++++++---- 9 files changed, 88 insertions(+), 83 deletions(-) diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index c042f0f73703..067865d3aa08 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -273,7 +273,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--mm-attention-backend` | Set multimodal attention backend. | `None` | `sdpa`, `fa3`, `fa4`, `triton_attn`, `ascend_attn`, `aiter_attn` | | `--dsa-prefill-backend` | Choose the DSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek DSA-style attention). `--nsa-prefill-backend` is a deprecated alias. | `flashmla_sparse` | `flashmla_sparse`, `flashmla_kv`, `flashmla_auto`, `fa3`, `tilelang`, `aiter`, `trtllm` | | `--dsa-decode-backend` | Choose the DSA backend for the decode stage when running DeepSeek DSA-style attention. Overrides `--attention-backend` for decoding. `--nsa-decode-backend` is a deprecated alias. | `fa3` | `flashmla_sparse`, `flashmla_kv`, `fa3`, `tilelang`, `aiter`, `trtllm` | -| `--dsa-topk-backend` | Choose the DSA indexer top-k backend. `--nsa-topk-backend` is a deprecated alias. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. | `sgl-kernel` | `sgl-kernel`, `torch`, `flashinfer` | +| `--dsa-topk-backend` | Choose the DSA indexer top-k backend. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. | `sgl-kernel` | `sgl-kernel`, `torch`, `flashinfer` | | `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only).| `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_deepgemm`, `cutlass`, `triton`, `aiter` | | `--fp4-gemm-backend` | Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback.| `flashinfer_cutlass` | `auto`, `flashinfer_cudnn`, `flashinfer_cutlass`, `flashinfer_trtllm` | | `--disable-flashinfer-autotune` | Flashinfer autotune is enabled by default. Set this flag to disable the autotune. | `False` | bool flag (set to enable) | diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index d9a5038d2590..3c8625e14aab 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -95,8 +95,8 @@ SGLang supports various environment variables that can be used to configure its | Environment Variable | Description | Default Value | | --- | --- | --- | | `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` | -| `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` (`SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC` is a deprecated alias) | `false` | -| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`. Valid values are `0`, `1`, and `2`; non-zero values imply deterministic topk in FlashInfer. (`SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK` is a deprecated alias) | `0` | +| `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` | `false` | +| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`. Valid values are `0`, `1`, and `2`; non-zero values imply deterministic topk in FlashInfer. | `0` | | `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` | | `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` | | `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` | diff --git a/docs_new/docs/advanced_features/server_arguments.mdx b/docs_new/docs/advanced_features/server_arguments.mdx index b6b87c734a33..68f19455fdca 100644 --- a/docs_new/docs/advanced_features/server_arguments.mdx +++ b/docs_new/docs/advanced_features/server_arguments.mdx @@ -1208,7 +1208,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s `--dsa-topk-backend` - Choose the DSA indexer top-k backend. `--nsa-topk-backend` is a deprecated alias. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. + Choose the DSA indexer top-k backend. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. `sgl-kernel` sgl-kernel, torch, flashinfer diff --git a/docs_new/docs/references/environment_variables.mdx b/docs_new/docs/references/environment_variables.mdx index e92fd35bdfa6..f24cd67f1f2e 100644 --- a/docs_new/docs/references/environment_variables.mdx +++ b/docs_new/docs/references/environment_variables.mdx @@ -418,12 +418,12 @@ SGLang supports various environment variables that can be used to configure its SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC - Use deterministic FlashInfer topk kernels when --dsa-topk-backend=flashinfer. SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC is a deprecated alias. + Use deterministic FlashInfer topk kernels when --dsa-topk-backend=flashinfer. false SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK - Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer. Valid values are 0, 1, and 2; non-zero values imply deterministic topk in FlashInfer. SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK is a deprecated alias. + Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer. Valid values are 0, 1, and 2; non-zero values imply deterministic topk in FlashInfer. 0 diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 54f728aa86d7..c0cf6f0ae052 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -465,12 +465,8 @@ class Envs: # DSA Backend (canonical names; fall back to SGLANG_NSA_* with deprecation warning) SGLANG_DSA_FUSE_TOPK = EnvBoolWithAlias(True, deprecated_name="SGLANG_NSA_FUSE_TOPK") - SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC = EnvBoolWithAlias( - False, deprecated_name="SGLANG_NSA_TOPK_FLASHINFER_DETERMINISTIC" - ) - SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvIntWithAlias( - 0, deprecated_name="SGLANG_NSA_TOPK_FLASHINFER_TIE_BREAK" - ) + SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC = EnvBool(False) + SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvInt(0) SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA = EnvBoolWithAlias( True, deprecated_name="SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA" ) diff --git a/python/sglang/srt/layers/attention/dsa_backend.py b/python/sglang/srt/layers/attention/dsa_backend.py index a017fd98ec16..66bfb114fff6 100644 --- a/python/sglang/srt/layers/attention/dsa_backend.py +++ b/python/sglang/srt/layers/attention/dsa_backend.py @@ -302,6 +302,15 @@ def _build_flashinfer_paged_args( else None ) + if ( + row_to_batch is not None + and cu_seqlens_q_topk is not None + and num_rows is not None + and row_to_batch.shape[0] != num_rows + ): + q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) + row_to_batch = torch.repeat_interleave(row_to_batch, q_lens) + if row_to_batch is None and cu_seqlens_q_topk is not None: # Decode-like case (one query row per batch) does not need an explicit mapping. # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. @@ -322,10 +331,7 @@ def _build_flashinfer_paged_args( row_starts = ks if row_starts is not None and row_to_batch is not None: - batch_base = self.attn_metadata.cu_seqlens_k.to( - dtype=torch.int32, device=device - )[:-1] - row_starts = row_starts - batch_base[row_to_batch] + row_starts = row_starts - self.attn_metadata.cu_seqlens_k[:-1][row_to_batch] return row_to_batch, row_starts @@ -424,7 +430,7 @@ def topk_transform( row_starts=ks, ) else: - assert False, f"Unsupported {self.topk_transform_method = }" + raise RuntimeError(f"Unsupported {self.topk_transform_method = }") elif self.topk_backend.is_flashinfer(): import flashinfer @@ -465,11 +471,11 @@ def topk_transform( row_starts=ks, ) else: - assert False, f"Unsupported {self.topk_transform_method = }" + raise RuntimeError(f"Unsupported {self.topk_transform_method = }") else: - assert ( - False - ), f"Unsupported {self.topk_backend = } for SGLANG_DSA_FUSE_TOPK." + raise RuntimeError( + f"Unsupported {self.topk_backend = } for SGLANG_DSA_FUSE_TOPK." + ) _DSA_IMPL_T: TypeAlias = Literal[ @@ -577,6 +583,16 @@ def __init__( else: self.workspace_buffer = None + def _get_fused_topk_page_table(self, topk_indices: torch.Tensor) -> torch.Tensor: + if ( + self.dsa_topk_backend.is_sgl_kernel() + or self.dsa_topk_backend.is_flashinfer() + ): + return topk_indices + raise RuntimeError( + f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." + ) + def get_device_int32_arange(self, l: int) -> torch.Tensor: if l > len(self._arange_buf): next_pow_of_2 = 1 << (l - 1).bit_length() @@ -1612,15 +1628,7 @@ def forward_extend( forward_batch.forward_mode ) if envs.SGLANG_DSA_FUSE_TOPK.get(): - if ( - self.dsa_topk_backend.is_sgl_kernel() - or self.dsa_topk_backend.is_flashinfer() - ): - page_table_1 = topk_indices - else: - assert ( - False - ), f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." + page_table_1 = self._get_fused_topk_page_table(topk_indices) else: if topk_transform_method == TopkTransformMethod.RAGGED: topk_indices_offset = metadata.topk_indices_offset @@ -1810,15 +1818,7 @@ def forward_decode( layer.layer_id, ) elif envs.SGLANG_DSA_FUSE_TOPK.get(): - if ( - self.dsa_topk_backend.is_sgl_kernel() - or self.dsa_topk_backend.is_flashinfer() - ): - page_table_1 = topk_indices - else: - assert ( - False - ), f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." + page_table_1 = self._get_fused_topk_page_table(topk_indices) else: page_table_1 = transform_index_page_table_decode( page_table=metadata.page_table_1, @@ -2327,15 +2327,7 @@ def _forward_trtllm( topk_indices = self._pad_topk_indices(topk_indices, q.shape[0]) if envs.SGLANG_DSA_FUSE_TOPK.get(): - if ( - self.dsa_topk_backend.is_sgl_kernel() - or self.dsa_topk_backend.is_flashinfer() - ): - page_table_1 = topk_indices - else: - assert ( - False - ), f"Unsupported {self.dsa_topk_backend = } for SGLANG_DSA_FUSE_TOPK." + page_table_1 = self._get_fused_topk_page_table(topk_indices) elif is_prefill: page_table_1 = transform_index_page_table_prefill( page_table=metadata.page_table_1, @@ -2736,9 +2728,3 @@ def init_forward_metadata_replay_cuda_graph( DSAMetadata = DSAMetadata DSAFlashMLAMetadata = DSAFlashMLAMetadata DSAIndexerMetadata = DSAIndexerMetadata -NativeSparseAttnBackend = DeepseekSparseAttnBackend -NativeSparseAttnMultiStepBackend = DeepseekSparseAttnMultiStepBackend -NSAMetadata = DSAMetadata -NSAFlashMLAMetadata = DSAFlashMLAMetadata -NSAIndexerMetadata = DSAIndexerMetadata -NSATopKBackend = DSATopKBackend diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index b2554684a59d..74f51bdb9651 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -15,9 +15,4 @@ DSAFlashMLAMetadata, DSAIndexerMetadata, DSAMetadata, - NativeSparseAttnBackend, - NativeSparseAttnMultiStepBackend, - NSAFlashMLAMetadata, - NSAIndexerMetadata, - NSAMetadata, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8e839bdb9813..a8037056909c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -267,7 +267,6 @@ NSA_CHOICES = DSA_CHOICES # deprecated alias DSA_TOPK_BACKEND_CHOICES = ["sgl-kernel", "torch", "flashinfer"] -NSA_TOPK_BACKEND_CHOICES = DSA_TOPK_BACKEND_CHOICES # deprecated alias MAMBA_SCHEDULER_STRATEGY_CHOICES = ["auto", "no_buffer", "extra_buffer"] @@ -5437,16 +5436,6 @@ def add_cli_args(parser: argparse.ArgumentParser): help="DSA indexer top-k backend. Options: 'sgl-kernel', 'torch', 'flashinfer'. " "The 'torch' backend currently requires SGLANG_DSA_FUSE_TOPK=false.", ) - parser.add_argument( - "--nsa-topk-backend", - dest="dsa_topk_backend", - action=DeprecatedAliasStoreAction, - new_flag="--dsa-topk-backend", - default=argparse.SUPPRESS, - type=str, - choices=DSA_TOPK_BACKEND_CHOICES, - help="[Deprecated] Use --dsa-topk-backend instead.", - ) parser.add_argument( "--fp8-gemm-backend", type=str, diff --git a/test/registered/kernels/test_dsa_indexer.py b/test/registered/kernels/test_dsa_indexer.py index 04e0c9ea57d1..ad7d19eeb0b0 100644 --- a/test/registered/kernels/test_dsa_indexer.py +++ b/test/registered/kernels/test_dsa_indexer.py @@ -560,14 +560,16 @@ def _run_fused_topk_backend_equivalence_test( topk: int, topk_transform_method: TopkTransformMethod, with_row_starts: bool, + query_lens: Optional[List[int]] = None, ): - logits = self._make_tie_free_logits(batch_size, max_score_len) + num_rows = sum(query_lens) if query_lens is not None else batch_size + logits = self._make_tie_free_logits(num_rows, max_score_len) if with_row_starts: row_starts = torch.randint( 0, max_score_len - 1, - (batch_size,), + (num_rows,), dtype=torch.int32, device=self.device, ) @@ -575,7 +577,7 @@ def _run_fused_topk_backend_equivalence_test( random_lengths = torch.randint( 1, max_score_len, - (batch_size,), + (num_rows,), dtype=torch.int32, device=self.device, ) @@ -585,22 +587,32 @@ def _run_fused_topk_backend_equivalence_test( seq_lens_expanded = torch.randint( 1, max_score_len, - (batch_size,), + (num_rows,), dtype=torch.int32, device=self.device, ) topk_indices_offset = ( - torch.arange(batch_size, dtype=torch.int32, device=self.device) * 1024 - ) - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=self.device + torch.arange(num_rows, dtype=torch.int32, device=self.device) * 1024 ) + if query_lens is None: + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=self.device + ) + q_lens = None + batch_idx_list = None + else: + q_lens = torch.tensor(query_lens, dtype=torch.int32, device=self.device) + cu_seqlens_q = torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ) + cu_seqlens_q[1:] = torch.cumsum(q_lens, dim=0) + batch_idx_list = list(range(batch_size)) cu_seqlens_k = torch.zeros( batch_size + 1, dtype=torch.int32, device=self.device ) dsa_cu_seqlens_k = torch.zeros( - batch_size + 1, dtype=torch.int32, device=self.device + num_rows + 1, dtype=torch.int32, device=self.device ) dsa_cu_seqlens_k[1:] = torch.cumsum(seq_lens_expanded, dim=0) @@ -651,9 +663,19 @@ def _run_fused_topk_backend_equivalence_test( ) with envs.SGLANG_DSA_FUSE_TOPK.override(True): - out_sgl = metadata_sgl.topk_transform(logits, topk, ks=row_starts) + out_sgl = metadata_sgl.topk_transform( + logits, + topk, + ks=row_starts, + cu_seqlens_q=q_lens, + batch_idx_list=batch_idx_list, + ) out_flashinfer = metadata_flashinfer.topk_transform( - logits, topk, ks=row_starts + logits, + topk, + ks=row_starts, + cu_seqlens_q=q_lens, + batch_idx_list=batch_idx_list, ) self.assertEqual(out_sgl.shape, out_flashinfer.shape) @@ -876,7 +898,7 @@ def test_indexer_metadata_interface(self): topk_indices = metadata.topk_transform(logits, topk) self.assertEqual(topk_indices.shape, (batch_size, topk)) - def test_topk_backends_unfused(self): + def test_topk_unfused_backends_valid_selection(self): batch_size = 8 max_score_len = 16 * 1024 topk = 2048 @@ -906,7 +928,7 @@ def test_topk_backends_unfused(self): with_row_starts=with_row_starts, ) - def test_topk_backends_fused(self): + def test_topk_fused_backends_equivalence(self): batch_size = 8 max_score_len = 16 * 1024 topk = 2048 @@ -920,6 +942,8 @@ def test_topk_backends_fused(self): topk_transform_method == TopkTransformMethod.PAGED and with_row_starts ): + # The synthetic paged fixture uses the decode-like row mapping. + # Ragged fused and unfused cases cover shifted row windows. continue with self.subTest( tie_break=tie_break, @@ -936,6 +960,21 @@ def test_topk_backends_fused(self): topk_transform_method=topk_transform_method, with_row_starts=with_row_starts, ) + with self.subTest( + tie_break=tie_break, + topk_transform_method=TopkTransformMethod.PAGED.name, + with_row_starts=False, + query_lens="multi", + ): + with envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.override(tie_break): + self._run_fused_topk_backend_equivalence_test( + batch_size=batch_size, + max_score_len=max_score_len, + topk=topk, + topk_transform_method=TopkTransformMethod.PAGED, + with_row_starts=False, + query_lens=[1, 2, 3, 1, 2, 1, 3, 2], + ) # TODO: enable this test after indexer accuracy aligned # @patch("sglang.srt.layers.attention.dsa.dsa_indexer.deep_gemm") From f4e3705df6321314790ceb7632768b8f2d858734 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 02:52:27 -0700 Subject: [PATCH 06/11] Update tie break docs --- docs/references/environment_variables.md | 2 +- docs_new/docs/references/environment_variables.mdx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 3c8625e14aab..4cb0571a568a 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -96,7 +96,7 @@ SGLang supports various environment variables that can be used to configure its | --- | --- | --- | | `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` | | `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` | `false` | -| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`. Valid values are `0`, `1`, and `2`; non-zero values imply deterministic topk in FlashInfer. | `0` | +| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: `0` disables explicit tie-breaking, `1` prefers the smaller candidate index for equal scores, and `2` prefers the larger candidate index for equal scores. Non-zero values make FlashInfer use deterministic topk. | `0` | | `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` | | `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` | | `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` | diff --git a/docs_new/docs/references/environment_variables.mdx b/docs_new/docs/references/environment_variables.mdx index f24cd67f1f2e..c71c9b88ad56 100644 --- a/docs_new/docs/references/environment_variables.mdx +++ b/docs_new/docs/references/environment_variables.mdx @@ -423,7 +423,7 @@ SGLang supports various environment variables that can be used to configure its SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK - Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer. Valid values are 0, 1, and 2; non-zero values imply deterministic topk in FlashInfer. + Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer: 0 disables explicit tie-breaking, 1 prefers the smaller candidate index for equal scores, and 2 prefers the larger candidate index for equal scores. Non-zero values make FlashInfer use deterministic topk. 0 From dcce26dbd0815752f1c93271fe76f69680f1d8ce Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 18:17:51 -0700 Subject: [PATCH 07/11] Refactor --- docs/references/environment_variables.md | 2 +- .../docs/references/environment_variables.mdx | 4 +- python/sglang/srt/environ.py | 2 +- .../layers/attention/dsa/dsa_topk_backend.py | 267 ++++++++++++++++++ .../srt/layers/attention/dsa_backend.py | 241 ++-------------- test/registered/kernels/test_dsa_indexer.py | 6 +- 6 files changed, 291 insertions(+), 231 deletions(-) create mode 100644 python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 4cb0571a568a..af6cb6070cf7 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -96,7 +96,7 @@ SGLang supports various environment variables that can be used to configure its | --- | --- | --- | | `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` | | `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` | `false` | -| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: `0` disables explicit tie-breaking, `1` prefers the smaller candidate index for equal scores, and `2` prefers the larger candidate index for equal scores. Non-zero values make FlashInfer use deterministic topk. | `0` | +| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: `none` disables explicit tie-breaking, `small` prefers the smaller candidate index for equal scores, and `large` prefers the larger candidate index for equal scores. Non-`none` values make FlashInfer use deterministic topk. | `none` | | `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` | | `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` | | `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` | diff --git a/docs_new/docs/references/environment_variables.mdx b/docs_new/docs/references/environment_variables.mdx index c71c9b88ad56..58eff5b3f3d6 100644 --- a/docs_new/docs/references/environment_variables.mdx +++ b/docs_new/docs/references/environment_variables.mdx @@ -423,8 +423,8 @@ SGLang supports various environment variables that can be used to configure its SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK - Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer: 0 disables explicit tie-breaking, 1 prefers the smaller candidate index for equal scores, and 2 prefers the larger candidate index for equal scores. Non-zero values make FlashInfer use deterministic topk. - 0 + Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer: none disables explicit tie-breaking, small prefers the smaller candidate index for equal scores, and large prefers the larger candidate index for equal scores. Non-none values make FlashInfer use deterministic topk. + none SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index c0cf6f0ae052..49270352c55e 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -466,7 +466,7 @@ class Envs: # DSA Backend (canonical names; fall back to SGLANG_NSA_* with deprecation warning) SGLANG_DSA_FUSE_TOPK = EnvBoolWithAlias(True, deprecated_name="SGLANG_NSA_FUSE_TOPK") SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC = EnvBool(False) - SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvInt(0) + SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvStr("none") SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA = EnvBoolWithAlias( True, deprecated_name="SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA" ) diff --git a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py new file mode 100644 index 000000000000..e3ed3ecf5dbb --- /dev/null +++ b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +from enum import Enum, IntEnum, auto +from typing import Callable, Dict, List, Optional, Tuple + +import torch + +from sglang.srt.environ import envs + + +class TopkTransformMethod(IntEnum): + # Transform topk indices to indices to the page table (page_size = 1) + PAGED = auto() + # Transform topk indices to indices to ragged kv (non-paged) + RAGGED = auto() + + +class DSATopKBackend(Enum): + SGL_KERNEL = "sgl-kernel" + TORCH = "torch" + FLASHINFER = "flashinfer" + + def is_sgl_kernel(self) -> bool: + return self == DSATopKBackend.SGL_KERNEL + + def is_torch(self) -> bool: + return self == DSATopKBackend.TORCH + + def is_flashinfer(self) -> bool: + return self == DSATopKBackend.FLASHINFER + + def topk_func( + self, + score: torch.Tensor, + lengths: torch.Tensor, + topk: int, + row_starts: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.is_sgl_kernel(): + from sgl_kernel import fast_topk_v2 + + return fast_topk_v2(score, lengths, topk, row_starts=row_starts) + if self.is_torch(): + return _topk_unfused( + score, + lengths, + topk, + row_starts=row_starts, + topk_op=torch.topk, + topk_op_kwargs={"dim": -1}, + ) + if self.is_flashinfer(): + import flashinfer + + return _topk_unfused( + score, + lengths, + topk, + row_starts=row_starts, + topk_op=flashinfer.top_k, + topk_op_kwargs={ + "sorted": False, + **_flashinfer_topk_kwargs(), + }, + ) + raise RuntimeError(f"Unsupported {self = }.") + + def topk_transform( + self, + logits: torch.Tensor, + lengths: torch.Tensor, + topk: int, + topk_transform_method: TopkTransformMethod, + attn_metadata, + cu_seqlens_q_topk: Optional[torch.Tensor] = None, + topk_indices_offset: Optional[torch.Tensor] = None, + row_starts: Optional[torch.Tensor] = None, + batch_idx_list: Optional[List[int]] = None, + force_unfused_topk: bool = False, + ) -> torch.Tensor: + if not envs.SGLANG_DSA_FUSE_TOPK.get() or force_unfused_topk: + return self.topk_func(logits, lengths, topk, row_starts=row_starts) + + if self.is_sgl_kernel(): + from sgl_kernel import ( + fast_topk_transform_fused, + fast_topk_transform_ragged_fused, + ) + + if topk_transform_method == TopkTransformMethod.PAGED: + page_table_size_1 = ( + attn_metadata.page_table_1[batch_idx_list] + if batch_idx_list is not None + else attn_metadata.page_table_1 + ) + return fast_topk_transform_fused( + score=logits, + lengths=lengths, + page_table_size_1=page_table_size_1, + cu_seqlens_q=cu_seqlens_q_topk, + topk=topk, + row_starts=row_starts, + ) + if topk_transform_method == TopkTransformMethod.RAGGED: + if topk_indices_offset is None: + raise RuntimeError( + "RAGGED topk_transform requires topk_indices_offset; " + "expected extend-without-speculative metadata." + ) + return fast_topk_transform_ragged_fused( + score=logits, + lengths=lengths, + topk_indices_offset=topk_indices_offset, + topk=topk, + row_starts=row_starts, + ) + raise RuntimeError(f"Unsupported {topk_transform_method = }.") + + if self.is_flashinfer(): + import flashinfer + + if topk_transform_method == TopkTransformMethod.PAGED: + row_to_batch, local_row_starts = _build_flashinfer_paged_args( + attn_metadata=attn_metadata, + row_starts=row_starts, + cu_seqlens_q_topk=cu_seqlens_q_topk, + batch_idx_list=batch_idx_list, + device=logits.device, + num_rows=logits.shape[0], + ) + return flashinfer.top_k_page_table_transform( + logits.contiguous(), + attn_metadata.page_table_1.contiguous(), + lengths.contiguous(), + topk, + row_to_batch=row_to_batch, + **_flashinfer_topk_kwargs(), + row_starts=local_row_starts, + ) + if topk_transform_method == TopkTransformMethod.RAGGED: + if topk_indices_offset is None: + raise RuntimeError( + "RAGGED topk_transform requires topk_indices_offset; " + "expected extend-without-speculative metadata." + ) + return flashinfer.top_k_ragged_transform( + logits.contiguous(), + topk_indices_offset.contiguous(), + lengths.contiguous(), + topk, + **_flashinfer_topk_kwargs(), + row_starts=row_starts, + ) + raise RuntimeError(f"Unsupported {topk_transform_method = }.") + + raise RuntimeError(f"Unsupported {self = } for SGLANG_DSA_FUSE_TOPK.") + + +def _topk_unfused( + score: torch.Tensor, + lengths: torch.Tensor, + topk: int, + row_starts: Optional[torch.Tensor] = None, + topk_op: Callable[..., Tuple[torch.Tensor, torch.Tensor]] = torch.topk, + topk_op_kwargs: Optional[Dict[str, object]] = None, +) -> torch.Tensor: + batch_size, max_score_len = score.shape + topk_indices = score.new_full((batch_size, topk), -1, dtype=torch.int32) + if batch_size == 0 or topk == 0 or max_score_len == 0: + return topk_indices + + if row_starts is None: + row_starts = torch.zeros_like(lengths, dtype=torch.int32, device=score.device) + else: + row_starts = row_starts.to(dtype=torch.int32, device=score.device) + lengths = lengths.to(dtype=torch.int32, device=score.device) + + col_indices = torch.arange(max_score_len, dtype=torch.int32, device=score.device) + col_indices = col_indices.unsqueeze(0) + row_starts_unsqueezed = row_starts.unsqueeze(1) + row_ends_unsqueezed = (row_starts + lengths).unsqueeze(1) + valid_mask = (col_indices >= row_starts_unsqueezed) & ( + col_indices < row_ends_unsqueezed + ) + + masked_logits = score.masked_fill(~valid_mask, float("-inf")) + valid_topk = min(topk, max_score_len) + topk_kwargs = topk_op_kwargs or {} + topk_scores, topk_col_indices = topk_op(masked_logits, valid_topk, **topk_kwargs) + topk_local_indices = topk_col_indices.to(torch.int32) - row_starts_unsqueezed + topk_local_indices = topk_local_indices.masked_fill( + topk_scores == float("-inf"), -1 + ) + topk_indices[:, :valid_topk] = topk_local_indices + + return topk_indices + + +def _build_flashinfer_paged_args( + attn_metadata, + row_starts: Optional[torch.Tensor], + cu_seqlens_q_topk: Optional[torch.Tensor], + batch_idx_list: Optional[List[int]], + device: torch.device, + num_rows: Optional[int] = None, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + row_to_batch = ( + torch.as_tensor(batch_idx_list, dtype=torch.int32, device=device) + if batch_idx_list is not None + else None + ) + + if ( + row_to_batch is not None + and cu_seqlens_q_topk is not None + and num_rows is not None + and row_to_batch.shape[0] != num_rows + ): + q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) + row_to_batch = torch.repeat_interleave(row_to_batch, q_lens) + + if row_to_batch is None and cu_seqlens_q_topk is not None: + # Decode-like case (one query row per batch) does not need an explicit mapping. + # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. + num_batches = cu_seqlens_q_topk.shape[0] - 1 + if not (row_starts is None and num_rows is not None and num_rows == num_batches): + q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) + row_to_batch = torch.repeat_interleave( + torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), + q_lens, + ) + + if row_starts is not None and row_to_batch is None: + raise RuntimeError( + "PAGED topk_transform with row_starts requires cu_seqlens_q metadata." + ) + + local_row_starts = row_starts + if local_row_starts is not None and row_to_batch is not None: + local_row_starts = local_row_starts - attn_metadata.cu_seqlens_k[:-1][ + row_to_batch + ] + + return row_to_batch, local_row_starts + + +def _flashinfer_topk_kwargs() -> Dict[str, object]: + return { + "deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + "tie_break": _flashinfer_tie_break_value(), + "dsa_graph_safe": True, + } + + +def _flashinfer_tie_break_value() -> int: + mode = envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get().lower() + tie_break_values = { + "none": 0, + "small": 1, + "large": 2, + } + if mode not in tie_break_values: + raise RuntimeError( + "SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK must be one of " + f"{tuple(tie_break_values)}, got {mode!r}." + ) + return tie_break_values[mode] diff --git a/python/sglang/srt/layers/attention/dsa_backend.py b/python/sglang/srt/layers/attention/dsa_backend.py index 66bfb114fff6..27e6f46ed6f3 100644 --- a/python/sglang/srt/layers/attention/dsa_backend.py +++ b/python/sglang/srt/layers/attention/dsa_backend.py @@ -1,10 +1,8 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Enum, IntEnum, auto from typing import ( TYPE_CHECKING, - Callable, Dict, List, Literal, @@ -25,6 +23,10 @@ compute_cu_seqlens, ) from sglang.srt.layers.attention.dsa.dsa_indexer import BaseIndexerMetadata +from sglang.srt.layers.attention.dsa.dsa_topk_backend import ( + DSATopKBackend, + TopkTransformMethod, +) from sglang.srt.layers.attention.dsa.quant_k_cache import quantize_k_cache from sglang.srt.layers.attention.dsa.transform_index import ( transform_index_page_table_decode, @@ -95,46 +97,6 @@ def _to_2d_context_lens(seqlens_32: torch.Tensor, batch_size: int) -> torch.Tens _USE_FUSED_METADATA_COPY = envs.SGLANG_USE_FUSED_METADATA_COPY.get() and not _is_hip -def _dsa_topk_unfused( - score: torch.Tensor, - lengths: torch.Tensor, - topk: int, - row_starts: Optional[torch.Tensor] = None, - topk_op: Callable[..., Tuple[torch.Tensor, torch.Tensor]] = torch.topk, - topk_op_kwargs: Optional[Dict[str, object]] = None, -) -> torch.Tensor: - batch_size, max_score_len = score.shape - topk_indices = score.new_full((batch_size, topk), -1, dtype=torch.int32) - if batch_size == 0 or topk == 0 or max_score_len == 0: - return topk_indices - - if row_starts is None: - row_starts = torch.zeros_like(lengths, dtype=torch.int32, device=score.device) - else: - row_starts = row_starts.to(dtype=torch.int32, device=score.device) - lengths = lengths.to(dtype=torch.int32, device=score.device) - - col_indices = torch.arange(max_score_len, dtype=torch.int32, device=score.device) - col_indices = col_indices.unsqueeze(0) - row_starts_unsqueezed = row_starts.unsqueeze(1) - row_ends_unsqueezed = (row_starts + lengths).unsqueeze(1) - valid_mask = (col_indices >= row_starts_unsqueezed) & ( - col_indices < row_ends_unsqueezed - ) - - masked_logits = score.masked_fill(~valid_mask, float("-inf")) - valid_topk = min(topk, max_score_len) - topk_kwargs = topk_op_kwargs or {} - topk_scores, topk_col_indices = topk_op(masked_logits, valid_topk, **topk_kwargs) - topk_local_indices = topk_col_indices.to(torch.int32) - row_starts_unsqueezed - topk_local_indices = topk_local_indices.masked_fill( - topk_scores == float("-inf"), -1 - ) - topk_indices[:, :valid_topk] = topk_local_indices - - return topk_indices - - @dataclass(frozen=True) class DSAFlashMLAMetadata: """Metadata only needed by FlashMLA""" @@ -207,28 +169,6 @@ class DSAMetadata: token_to_batch_idx: Optional[torch.Tensor] = None -class TopkTransformMethod(IntEnum): - # Transform topk indices to indices to the page table (page_size = 1) - PAGED = auto() - # Transform topk indices to indices to ragged kv (non-paged) - RAGGED = auto() - - -class DSATopKBackend(Enum): - SGL_KERNEL = "sgl-kernel" - TORCH = "torch" - FLASHINFER = "flashinfer" - - def is_sgl_kernel(self) -> bool: - return self == DSATopKBackend.SGL_KERNEL - - def is_torch(self) -> bool: - return self == DSATopKBackend.TORCH - - def is_flashinfer(self) -> bool: - return self == DSATopKBackend.FLASHINFER - - @torch.compile def _compiled_cat(tensors: list[torch.Tensor], dim: int = -1) -> torch.Tensor: return torch.cat(tensors, dim=dim) @@ -288,53 +228,6 @@ def get_dsa_extend_len_cpu(self) -> List[int]: def get_token_to_batch_idx(self) -> torch.Tensor: return self.attn_metadata.token_to_batch_idx - def _build_flashinfer_paged_args( - self, - ks: Optional[torch.Tensor], - cu_seqlens_q_topk: Optional[torch.Tensor], - batch_idx_list: Optional[List[int]], - device: torch.device, - num_rows: Optional[int] = None, - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - row_to_batch = ( - torch.as_tensor(batch_idx_list, dtype=torch.int32, device=device) - if batch_idx_list is not None - else None - ) - - if ( - row_to_batch is not None - and cu_seqlens_q_topk is not None - and num_rows is not None - and row_to_batch.shape[0] != num_rows - ): - q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) - row_to_batch = torch.repeat_interleave(row_to_batch, q_lens) - - if row_to_batch is None and cu_seqlens_q_topk is not None: - # Decode-like case (one query row per batch) does not need an explicit mapping. - # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. - num_batches = cu_seqlens_q_topk.shape[0] - 1 - if not (ks is None and num_rows is not None and num_rows == num_batches): - q_lens = torch.diff(cu_seqlens_q_topk).to( - dtype=torch.int32, device=device - ) - row_to_batch = torch.repeat_interleave( - torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), - q_lens, - ) - - if ks is not None and row_to_batch is None: - raise RuntimeError( - "PAGED topk_transform with row_starts requires cu_seqlens_q metadata." - ) - - row_starts = ks - if row_starts is not None and row_to_batch is not None: - row_starts = row_starts - self.attn_metadata.cu_seqlens_k[:-1][row_to_batch] - - return row_to_batch, row_starts - def topk_transform( self, logits: torch.Tensor, @@ -362,120 +255,18 @@ def topk_transform( seq_lens_topk = ke_offset else: seq_lens_topk = self.get_seqlens_expanded() - if batch_idx_list is not None: - page_table_size_1 = self.attn_metadata.page_table_1[batch_idx_list] - else: - page_table_size_1 = self.attn_metadata.page_table_1 - - if not envs.SGLANG_DSA_FUSE_TOPK.get() or self.force_unfused_topk: - # Unfused topk - if self.topk_backend.is_sgl_kernel(): - from sgl_kernel import fast_topk_v2 - - return fast_topk_v2(logits, seq_lens_topk, topk, row_starts=ks) - elif self.topk_backend.is_torch(): - return _dsa_topk_unfused( - logits, - seq_lens_topk, - topk, - row_starts=ks, - topk_op=torch.topk, - topk_op_kwargs={"dim": -1}, - ) - elif self.topk_backend.is_flashinfer(): - import flashinfer - - return _dsa_topk_unfused( - logits, - seq_lens_topk, - topk, - row_starts=ks, - topk_op=flashinfer.top_k, - topk_op_kwargs={ - "sorted": False, - "deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), - "tie_break": envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get(), - "dsa_graph_safe": True, - }, - ) - else: - # Fused topk - if self.topk_backend.is_sgl_kernel(): - from sgl_kernel import ( - fast_topk_transform_fused, - fast_topk_transform_ragged_fused, - ) - - if self.topk_transform_method == TopkTransformMethod.PAGED: - # NOTE(dark): if fused, we return a transformed page table directly - return fast_topk_transform_fused( - score=logits, - lengths=seq_lens_topk, - page_table_size_1=page_table_size_1, - cu_seqlens_q=cu_seqlens_q_topk, - topk=topk, - row_starts=ks, - ) - elif self.topk_transform_method == TopkTransformMethod.RAGGED: - if cu_topk_indices_offset is None: - raise RuntimeError( - "RAGGED topk_transform requires topk_indices_offset; " - "expected extend-without-speculative metadata." - ) - return fast_topk_transform_ragged_fused( - score=logits, - lengths=seq_lens_topk, - topk_indices_offset=cu_topk_indices_offset, - topk=topk, - row_starts=ks, - ) - else: - raise RuntimeError(f"Unsupported {self.topk_transform_method = }") - elif self.topk_backend.is_flashinfer(): - import flashinfer - - if self.topk_transform_method == TopkTransformMethod.PAGED: - row_to_batch, row_starts = self._build_flashinfer_paged_args( - ks=ks, - cu_seqlens_q_topk=cu_seqlens_q_topk, - batch_idx_list=batch_idx_list, - device=logits.device, - num_rows=logits.shape[0], - ) - - return flashinfer.top_k_page_table_transform( - logits.contiguous(), - self.attn_metadata.page_table_1.contiguous(), - seq_lens_topk.contiguous(), - topk, - row_to_batch=row_to_batch, - deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), - tie_break=envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get(), - dsa_graph_safe=True, - row_starts=row_starts, - ) - elif self.topk_transform_method == TopkTransformMethod.RAGGED: - if cu_topk_indices_offset is None: - raise RuntimeError( - "RAGGED topk_transform requires topk_indices_offset; " - "expected extend-without-speculative metadata." - ) - return flashinfer.top_k_ragged_transform( - logits.contiguous(), - cu_topk_indices_offset.contiguous(), - seq_lens_topk.contiguous(), - topk, - deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), - tie_break=envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get(), - dsa_graph_safe=True, - row_starts=ks, - ) - else: - raise RuntimeError(f"Unsupported {self.topk_transform_method = }") - else: - raise RuntimeError( - f"Unsupported {self.topk_backend = } for SGLANG_DSA_FUSE_TOPK." - ) + return self.topk_backend.topk_transform( + logits=logits, + lengths=seq_lens_topk, + topk=topk, + topk_transform_method=self.topk_transform_method, + attn_metadata=self.attn_metadata, + cu_seqlens_q_topk=cu_seqlens_q_topk, + topk_indices_offset=cu_topk_indices_offset, + row_starts=ks, + batch_idx_list=batch_idx_list, + force_unfused_topk=self.force_unfused_topk, + ) _DSA_IMPL_T: TypeAlias = Literal[ diff --git a/test/registered/kernels/test_dsa_indexer.py b/test/registered/kernels/test_dsa_indexer.py index ad7d19eeb0b0..a422a1e65e92 100644 --- a/test/registered/kernels/test_dsa_indexer.py +++ b/test/registered/kernels/test_dsa_indexer.py @@ -908,7 +908,9 @@ def test_topk_unfused_backends_valid_selection(self): DSATopKBackend.FLASHINFER, ]: tie_break_values = ( - [0, 1, 2] if topk_backend == DSATopKBackend.FLASHINFER else [0] + ["none", "small", "large"] + if topk_backend == DSATopKBackend.FLASHINFER + else ["none"] ) for tie_break in tie_break_values: for with_row_starts in [False, True]: @@ -932,7 +934,7 @@ def test_topk_fused_backends_equivalence(self): batch_size = 8 max_score_len = 16 * 1024 topk = 2048 - for tie_break in [0, 1, 2]: + for tie_break in ["none", "small", "large"]: for topk_transform_method in [ TopkTransformMethod.PAGED, TopkTransformMethod.RAGGED, From f20dcdd67b9d7f8462d99cef0ac13c0861886438 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 18:24:51 -0700 Subject: [PATCH 08/11] Clean up --- .../layers/attention/dsa/dsa_topk_backend.py | 47 ++++++++++--------- .../srt/layers/attention/nsa_backend.py | 5 ++ test/registered/kernels/test_dsa_indexer.py | 9 ++-- 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py index e3ed3ecf5dbb..052240aef22f 100644 --- a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py +++ b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py @@ -7,6 +7,12 @@ from sglang.srt.environ import envs +_FLASHINFER_TIE_BREAK_VALUES = { + "none": 0, + "small": 1, + "large": 2, +} + class TopkTransformMethod(IntEnum): # Transform topk indices to indices to the page table (page_size = 1) @@ -60,7 +66,9 @@ def topk_func( topk_op=flashinfer.top_k, topk_op_kwargs={ "sorted": False, - **_flashinfer_topk_kwargs(), + "deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + "tie_break": _flashinfer_tie_break_value(), + "dsa_graph_safe": True, }, ) raise RuntimeError(f"Unsupported {self = }.") @@ -134,7 +142,9 @@ def topk_transform( lengths.contiguous(), topk, row_to_batch=row_to_batch, - **_flashinfer_topk_kwargs(), + deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + tie_break=_flashinfer_tie_break_value(), + dsa_graph_safe=True, row_starts=local_row_starts, ) if topk_transform_method == TopkTransformMethod.RAGGED: @@ -148,7 +158,9 @@ def topk_transform( topk_indices_offset.contiguous(), lengths.contiguous(), topk, - **_flashinfer_topk_kwargs(), + deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), + tie_break=_flashinfer_tie_break_value(), + dsa_graph_safe=True, row_starts=row_starts, ) raise RuntimeError(f"Unsupported {topk_transform_method = }.") @@ -223,7 +235,9 @@ def _build_flashinfer_paged_args( # Decode-like case (one query row per batch) does not need an explicit mapping. # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. num_batches = cu_seqlens_q_topk.shape[0] - 1 - if not (row_starts is None and num_rows is not None and num_rows == num_batches): + if not ( + row_starts is None and num_rows is not None and num_rows == num_batches + ): q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) row_to_batch = torch.repeat_interleave( torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), @@ -237,31 +251,18 @@ def _build_flashinfer_paged_args( local_row_starts = row_starts if local_row_starts is not None and row_to_batch is not None: - local_row_starts = local_row_starts - attn_metadata.cu_seqlens_k[:-1][ - row_to_batch - ] + local_row_starts = ( + local_row_starts - attn_metadata.cu_seqlens_k[:-1][row_to_batch] + ) return row_to_batch, local_row_starts -def _flashinfer_topk_kwargs() -> Dict[str, object]: - return { - "deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(), - "tie_break": _flashinfer_tie_break_value(), - "dsa_graph_safe": True, - } - - def _flashinfer_tie_break_value() -> int: mode = envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get().lower() - tie_break_values = { - "none": 0, - "small": 1, - "large": 2, - } - if mode not in tie_break_values: + if mode not in _FLASHINFER_TIE_BREAK_VALUES: raise RuntimeError( "SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK must be one of " - f"{tuple(tie_break_values)}, got {mode!r}." + f"{tuple(_FLASHINFER_TIE_BREAK_VALUES)}, got {mode!r}." ) - return tie_break_values[mode] + return _FLASHINFER_TIE_BREAK_VALUES[mode] diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 74f51bdb9651..b2554684a59d 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -15,4 +15,9 @@ DSAFlashMLAMetadata, DSAIndexerMetadata, DSAMetadata, + NativeSparseAttnBackend, + NativeSparseAttnMultiStepBackend, + NSAFlashMLAMetadata, + NSAIndexerMetadata, + NSAMetadata, ) diff --git a/test/registered/kernels/test_dsa_indexer.py b/test/registered/kernels/test_dsa_indexer.py index a422a1e65e92..772d1df10369 100644 --- a/test/registered/kernels/test_dsa_indexer.py +++ b/test/registered/kernels/test_dsa_indexer.py @@ -17,12 +17,14 @@ Indexer, rotate_activation, ) +from sglang.srt.layers.attention.dsa.dsa_topk_backend import ( + DSATopKBackend, + TopkTransformMethod, +) from sglang.srt.layers.attention.dsa_backend import ( DeepseekSparseAttnBackend, DSAIndexerMetadata, DSAMetadata, - DSATopKBackend, - TopkTransformMethod, ) from sglang.srt.layers.layernorm import LayerNorm from sglang.srt.layers.linear import LinearBase @@ -593,7 +595,8 @@ def _run_fused_topk_backend_equivalence_test( ) topk_indices_offset = ( - torch.arange(num_rows, dtype=torch.int32, device=self.device) * 1024 + torch.arange(num_rows, dtype=torch.int32, device=self.device) + * max_score_len ) if query_lens is None: cu_seqlens_q = torch.arange( From fb7fa680acab672894fd76a35b47962d99aeb89d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 18:30:16 -0700 Subject: [PATCH 09/11] Refactor to `[None, "small", "large"]` --- docs/references/environment_variables.md | 2 +- docs_new/docs/references/environment_variables.mdx | 4 ++-- python/sglang/srt/environ.py | 2 +- .../sglang/srt/layers/attention/dsa/dsa_topk_backend.py | 8 +++++--- test/registered/kernels/test_dsa_indexer.py | 6 +++--- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index af6cb6070cf7..b69c0e4e3113 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -96,7 +96,7 @@ SGLang supports various environment variables that can be used to configure its | --- | --- | --- | | `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` | | `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` | `false` | -| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: `none` disables explicit tie-breaking, `small` prefers the smaller candidate index for equal scores, and `large` prefers the larger candidate index for equal scores. Non-`none` values make FlashInfer use deterministic topk. | `none` | +| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: unset disables explicit tie-breaking, `small` prefers the smaller candidate index for equal scores, and `large` prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk. | unset | | `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` | | `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` | | `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` | diff --git a/docs_new/docs/references/environment_variables.mdx b/docs_new/docs/references/environment_variables.mdx index 58eff5b3f3d6..7e7f0608fb13 100644 --- a/docs_new/docs/references/environment_variables.mdx +++ b/docs_new/docs/references/environment_variables.mdx @@ -423,8 +423,8 @@ SGLang supports various environment variables that can be used to configure its SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK - Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer: none disables explicit tie-breaking, small prefers the smaller candidate index for equal scores, and large prefers the larger candidate index for equal scores. Non-none values make FlashInfer use deterministic topk. - none + Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer: unset disables explicit tie-breaking, small prefers the smaller candidate index for equal scores, and large prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk. + unset SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 49270352c55e..2658a581d022 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -466,7 +466,7 @@ class Envs: # DSA Backend (canonical names; fall back to SGLANG_NSA_* with deprecation warning) SGLANG_DSA_FUSE_TOPK = EnvBoolWithAlias(True, deprecated_name="SGLANG_NSA_FUSE_TOPK") SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC = EnvBool(False) - SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvStr("none") + SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvStr(None) SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA = EnvBoolWithAlias( True, deprecated_name="SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA" ) diff --git a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py index 052240aef22f..c5a5d3fd7e98 100644 --- a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py +++ b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py @@ -8,7 +8,6 @@ from sglang.srt.environ import envs _FLASHINFER_TIE_BREAK_VALUES = { - "none": 0, "small": 1, "large": 2, } @@ -259,10 +258,13 @@ def _build_flashinfer_paged_args( def _flashinfer_tie_break_value() -> int: - mode = envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get().lower() + mode = envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get() + if mode is None: + return 0 + mode = mode.lower() if mode not in _FLASHINFER_TIE_BREAK_VALUES: raise RuntimeError( "SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK must be one of " - f"{tuple(_FLASHINFER_TIE_BREAK_VALUES)}, got {mode!r}." + f"{tuple(_FLASHINFER_TIE_BREAK_VALUES)} or unset, got {mode!r}." ) return _FLASHINFER_TIE_BREAK_VALUES[mode] diff --git a/test/registered/kernels/test_dsa_indexer.py b/test/registered/kernels/test_dsa_indexer.py index 772d1df10369..354fceafbdf6 100644 --- a/test/registered/kernels/test_dsa_indexer.py +++ b/test/registered/kernels/test_dsa_indexer.py @@ -911,9 +911,9 @@ def test_topk_unfused_backends_valid_selection(self): DSATopKBackend.FLASHINFER, ]: tie_break_values = ( - ["none", "small", "large"] + [None, "small", "large"] if topk_backend == DSATopKBackend.FLASHINFER - else ["none"] + else [None] ) for tie_break in tie_break_values: for with_row_starts in [False, True]: @@ -937,7 +937,7 @@ def test_topk_fused_backends_equivalence(self): batch_size = 8 max_score_len = 16 * 1024 topk = 2048 - for tie_break in ["none", "small", "large"]: + for tie_break in [None, "small", "large"]: for topk_transform_method in [ TopkTransformMethod.PAGED, TopkTransformMethod.RAGGED, From aeeedc84a037625b88783a72d61c7f6043d23789 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 20 May 2026 18:42:40 -0700 Subject: [PATCH 10/11] Clean up --- docs/references/environment_variables.md | 2 +- docs_new/docs/references/environment_variables.mdx | 2 +- python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py | 7 ++----- python/sglang/srt/layers/attention/dsa_backend.py | 6 +++--- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index b69c0e4e3113..63cb9c837d5a 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -96,7 +96,7 @@ SGLang supports various environment variables that can be used to configure its | --- | --- | --- | | `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` | | `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` | `false` | -| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: unset disables explicit tie-breaking, `small` prefers the smaller candidate index for equal scores, and `large` prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk. | unset | +| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: unset disables explicit tie-breaking, `small` prefers the smaller candidate index for equal scores, and `large` prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk. | `unset` | | `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` | | `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` | | `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` | diff --git a/docs_new/docs/references/environment_variables.mdx b/docs_new/docs/references/environment_variables.mdx index 7e7f0608fb13..3efc356a1c59 100644 --- a/docs_new/docs/references/environment_variables.mdx +++ b/docs_new/docs/references/environment_variables.mdx @@ -424,7 +424,7 @@ SGLang supports various environment variables that can be used to configure its SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK Tie-break mode for FlashInfer DSA topk when --dsa-topk-backend=flashinfer: unset disables explicit tie-breaking, small prefers the smaller candidate index for equal scores, and large prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk. - unset + unset SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA diff --git a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py index c5a5d3fd7e98..48064c70c716 100644 --- a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py +++ b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py @@ -213,7 +213,7 @@ def _build_flashinfer_paged_args( cu_seqlens_q_topk: Optional[torch.Tensor], batch_idx_list: Optional[List[int]], device: torch.device, - num_rows: Optional[int] = None, + num_rows: int, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: row_to_batch = ( torch.as_tensor(batch_idx_list, dtype=torch.int32, device=device) @@ -224,7 +224,6 @@ def _build_flashinfer_paged_args( if ( row_to_batch is not None and cu_seqlens_q_topk is not None - and num_rows is not None and row_to_batch.shape[0] != num_rows ): q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) @@ -234,9 +233,7 @@ def _build_flashinfer_paged_args( # Decode-like case (one query row per batch) does not need an explicit mapping. # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. num_batches = cu_seqlens_q_topk.shape[0] - 1 - if not ( - row_starts is None and num_rows is not None and num_rows == num_batches - ): + if not (row_starts is None and num_rows == num_batches): q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) row_to_batch = torch.repeat_interleave( torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), diff --git a/python/sglang/srt/layers/attention/dsa_backend.py b/python/sglang/srt/layers/attention/dsa_backend.py index 27e6f46ed6f3..b0e996c86772 100644 --- a/python/sglang/srt/layers/attention/dsa_backend.py +++ b/python/sglang/srt/layers/attention/dsa_backend.py @@ -233,9 +233,9 @@ def topk_transform( logits: torch.Tensor, topk: int, ks: Optional[torch.Tensor] = None, - cu_seqlens_q: torch.Tensor = None, - ke_offset: torch.Tensor = None, - batch_idx_list: List[int] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + ke_offset: Optional[torch.Tensor] = None, + batch_idx_list: Optional[List[int]] = None, topk_indices_offset_override: Optional[torch.Tensor] = None, ) -> torch.Tensor: if topk_indices_offset_override is not None: From c430da9c4a160419971ae921159bb19dc6e65969 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 22 May 2026 11:32:40 -0700 Subject: [PATCH 11/11] Address DSA topk review comments --- .../sglang/srt/layers/attention/dsa/dsa_topk_backend.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py index 48064c70c716..8b76557e26a5 100644 --- a/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py +++ b/python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py @@ -226,7 +226,9 @@ def _build_flashinfer_paged_args( and cu_seqlens_q_topk is not None and row_to_batch.shape[0] != num_rows ): - q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) + q_lens = (cu_seqlens_q_topk[1:] - cu_seqlens_q_topk[:-1]).to( + dtype=torch.int32, device=device + ) row_to_batch = torch.repeat_interleave(row_to_batch, q_lens) if row_to_batch is None and cu_seqlens_q_topk is not None: @@ -234,7 +236,9 @@ def _build_flashinfer_paged_args( # Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe. num_batches = cu_seqlens_q_topk.shape[0] - 1 if not (row_starts is None and num_rows == num_batches): - q_lens = torch.diff(cu_seqlens_q_topk).to(dtype=torch.int32, device=device) + q_lens = (cu_seqlens_q_topk[1:] - cu_seqlens_q_topk[:-1]).to( + dtype=torch.int32, device=device + ) row_to_batch = torch.repeat_interleave( torch.arange(q_lens.shape[0], dtype=torch.int32, device=device), q_lens,