Skip to content
1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--mm-attention-backend` | Set multimodal attention backend. | `None` | `sdpa`, `fa3`, `fa4`, `triton_attn`, `ascend_attn`, `aiter_attn` |
| `--dsa-prefill-backend` | Choose the DSA backend for the prefill stage (overrides `--attention-backend` when running DeepSeek DSA-style attention). `--nsa-prefill-backend` is a deprecated alias. | `flashmla_sparse` | `flashmla_sparse`, `flashmla_kv`, `flashmla_auto`, `fa3`, `tilelang`, `aiter`, `trtllm` |
| `--dsa-decode-backend` | Choose the DSA backend for the decode stage when running DeepSeek DSA-style attention. Overrides `--attention-backend` for decoding. `--nsa-decode-backend` is a deprecated alias. | `fa3` | `flashmla_sparse`, `flashmla_kv`, `fa3`, `tilelang`, `aiter`, `trtllm` |
| `--dsa-topk-backend` | Choose the DSA indexer top-k backend. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`. | `sgl-kernel` | `sgl-kernel`, `torch`, `flashinfer` |
| `--fp8-gemm-backend` | Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only).| `auto` | `auto`, `deep_gemm`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_deepgemm`, `cutlass`, `triton`, `aiter` |
| `--fp4-gemm-backend` | Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback.| `flashinfer_cutlass` | `auto`, `flashinfer_cudnn`, `flashinfer_cutlass`, `flashinfer_trtllm` |
| `--disable-flashinfer-autotune` | Flashinfer autotune is enabled by default. Set this flag to disable the autotune. | `False` | bool flag (set to enable) |
Expand Down
2 changes: 2 additions & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_DSA_FUSE_TOPK` | Fuse the operation of picking topk logits and picking topk indices from page table (`SGLANG_NSA_FUSE_TOPK` is a deprecated alias) | `true` |
| `SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC` | Use deterministic FlashInfer topk kernels when `--dsa-topk-backend=flashinfer` | `false` |
| `SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK` | Tie-break mode for FlashInfer DSA topk when `--dsa-topk-backend=flashinfer`: unset disables explicit tie-breaking, `small` prefers the smaller candidate index for equal scores, and `large` prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk. | `unset` |
| `SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA` | Precompute metadata that can be shared among different draft steps when MTP is enabled (`SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA` is a deprecated alias) | `true` |
| `SGLANG_USE_FUSED_METADATA_COPY` | Control whether to use fused metadata copy kernel for cuda graph replay | `true` |
| `SGLANG_DSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` | When the maximum kv len in current prefill batch exceeds this value, the sparse mla kernel will be applied, else it falls back to dense MHA implementation. Default to the index topk of model (2048 for DeepSeek V3.2) (`SGLANG_NSA_PREFILL_DENSE_ATTN_KV_LEN_THRESHOLD` is a deprecated alias) | `2048` |
Expand Down
6 changes: 6 additions & 0 deletions docs_new/docs/advanced_features/server_arguments.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,12 @@ Please consult the documentation below and [server_args.py](https://github.com/s
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>`fa3`</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}><code>flashmla_sparse</code>, <code>flashmla_kv</code>, <code>fa3</code>, <code>tilelang</code>, <code>aiter</code>, <code>trtllm</code></td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>`--dsa-topk-backend`</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Choose the DSA indexer top-k backend. The `torch` backend currently requires `SGLANG_DSA_FUSE_TOPK=false`.</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}>`sgl-kernel`</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}><code>sgl-kernel</code>, <code>torch</code>, <code>flashinfer</code></td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>`--fp8-gemm-backend`</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Choose the runner backend for Blockwise FP8 GEMM operations. Options: 'auto' (default, auto-selects based on hardware), 'deep_gemm' (JIT-compiled; enabled by default on NVIDIA Hopper (SM90) and Blackwell (SM100) when DeepGEMM is installed), 'flashinfer_trtllm' (FlashInfer TRTLLM backend; SM100/SM103 only), 'flashinfer_cutlass' (FlashInfer CUTLASS backend, SM120 only), 'flashinfer_deepgemm' (Hopper SM90 only, uses swapAB optimization for small M dimensions in decoding), 'cutlass' (optimal for Hopper/Blackwell GPUs and high-throughput), 'triton' (fallback, widely compatible), 'aiter' (ROCm only).</td>
Expand Down
10 changes: 10 additions & 0 deletions docs_new/docs/references/environment_variables.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,16 @@ SGLang supports various environment variables that can be used to configure its
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Fuse the operation of picking topk logits and picking topk indices from page table. <code>SGLANG_NSA_FUSE_TOPK</code> is a deprecated alias.</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}><code>true</code></td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}><code>SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC</code></td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Use deterministic FlashInfer topk kernels when <code>--dsa-topk-backend=flashinfer</code>.</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}><code>false</code></td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}><code>SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK</code></td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Tie-break mode for FlashInfer DSA topk when <code>--dsa-topk-backend=flashinfer</code>: unset disables explicit tie-breaking, <code>small</code> prefers the smaller candidate index for equal scores, and <code>large</code> prefers the larger candidate index for equal scores. Setting this variable makes FlashInfer use deterministic topk.</td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}><code>unset</code></td>
</tr>
<tr>
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}><code>SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA</code></td>
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Precompute metadata that can be shared among different draft steps when MTP is enabled. <code>SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA</code> is a deprecated alias.</td>
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ class Envs:

# DSA Backend (canonical names; fall back to SGLANG_NSA_* with deprecation warning)
SGLANG_DSA_FUSE_TOPK = EnvBoolWithAlias(True, deprecated_name="SGLANG_NSA_FUSE_TOPK")
SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC = EnvBool(False)
Comment thread
zianglih marked this conversation as resolved.
SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK = EnvStr(None)
SGLANG_DSA_ENABLE_MTP_PRECOMPUTE_METADATA = EnvBoolWithAlias(
True, deprecated_name="SGLANG_NSA_ENABLE_MTP_PRECOMPUTE_METADATA"
)
Expand Down
271 changes: 271 additions & 0 deletions python/sglang/srt/layers/attention/dsa/dsa_topk_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import annotations

from enum import Enum, IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple

import torch

from sglang.srt.environ import envs

_FLASHINFER_TIE_BREAK_VALUES = {
"small": 1,
"large": 2,
}


class TopkTransformMethod(IntEnum):
# Transform topk indices to indices to the page table (page_size = 1)
PAGED = auto()
# Transform topk indices to indices to ragged kv (non-paged)
RAGGED = auto()


class DSATopKBackend(Enum):
SGL_KERNEL = "sgl-kernel"
TORCH = "torch"
FLASHINFER = "flashinfer"

def is_sgl_kernel(self) -> bool:
return self == DSATopKBackend.SGL_KERNEL

def is_torch(self) -> bool:
return self == DSATopKBackend.TORCH

def is_flashinfer(self) -> bool:
return self == DSATopKBackend.FLASHINFER

def topk_func(
self,
score: torch.Tensor,
lengths: torch.Tensor,
topk: int,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.is_sgl_kernel():
from sgl_kernel import fast_topk_v2

return fast_topk_v2(score, lengths, topk, row_starts=row_starts)
if self.is_torch():
return _topk_unfused(
score,
lengths,
topk,
row_starts=row_starts,
topk_op=torch.topk,
topk_op_kwargs={"dim": -1},
)
if self.is_flashinfer():
import flashinfer

return _topk_unfused(
score,
lengths,
topk,
row_starts=row_starts,
topk_op=flashinfer.top_k,
topk_op_kwargs={
"sorted": False,
"deterministic": envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(),
"tie_break": _flashinfer_tie_break_value(),
"dsa_graph_safe": True,
},
)
raise RuntimeError(f"Unsupported {self = }.")

def topk_transform(
self,
logits: torch.Tensor,
lengths: torch.Tensor,
topk: int,
topk_transform_method: TopkTransformMethod,
attn_metadata,
cu_seqlens_q_topk: Optional[torch.Tensor] = None,
topk_indices_offset: Optional[torch.Tensor] = None,
row_starts: Optional[torch.Tensor] = None,
batch_idx_list: Optional[List[int]] = None,
force_unfused_topk: bool = False,
) -> torch.Tensor:
if not envs.SGLANG_DSA_FUSE_TOPK.get() or force_unfused_topk:
return self.topk_func(logits, lengths, topk, row_starts=row_starts)

if self.is_sgl_kernel():
from sgl_kernel import (
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
)

if topk_transform_method == TopkTransformMethod.PAGED:
page_table_size_1 = (
attn_metadata.page_table_1[batch_idx_list]
if batch_idx_list is not None
else attn_metadata.page_table_1
)
return fast_topk_transform_fused(
score=logits,
lengths=lengths,
page_table_size_1=page_table_size_1,
cu_seqlens_q=cu_seqlens_q_topk,
topk=topk,
row_starts=row_starts,
)
if topk_transform_method == TopkTransformMethod.RAGGED:
if topk_indices_offset is None:
raise RuntimeError(
"RAGGED topk_transform requires topk_indices_offset; "
"expected extend-without-speculative metadata."
)
return fast_topk_transform_ragged_fused(
score=logits,
lengths=lengths,
topk_indices_offset=topk_indices_offset,
topk=topk,
row_starts=row_starts,
)
raise RuntimeError(f"Unsupported {topk_transform_method = }.")

if self.is_flashinfer():
import flashinfer

if topk_transform_method == TopkTransformMethod.PAGED:
row_to_batch, local_row_starts = _build_flashinfer_paged_args(
attn_metadata=attn_metadata,
row_starts=row_starts,
cu_seqlens_q_topk=cu_seqlens_q_topk,
batch_idx_list=batch_idx_list,
device=logits.device,
num_rows=logits.shape[0],
)
return flashinfer.top_k_page_table_transform(
logits.contiguous(),
attn_metadata.page_table_1.contiguous(),
lengths.contiguous(),
topk,
row_to_batch=row_to_batch,
deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(),
tie_break=_flashinfer_tie_break_value(),
dsa_graph_safe=True,
row_starts=local_row_starts,
)
if topk_transform_method == TopkTransformMethod.RAGGED:
if topk_indices_offset is None:
raise RuntimeError(
"RAGGED topk_transform requires topk_indices_offset; "
"expected extend-without-speculative metadata."
)
return flashinfer.top_k_ragged_transform(
logits.contiguous(),
topk_indices_offset.contiguous(),
lengths.contiguous(),
topk,
deterministic=envs.SGLANG_DSA_TOPK_FLASHINFER_DETERMINISTIC.get(),
tie_break=_flashinfer_tie_break_value(),
dsa_graph_safe=True,
row_starts=row_starts,
)
raise RuntimeError(f"Unsupported {topk_transform_method = }.")

raise RuntimeError(f"Unsupported {self = } for SGLANG_DSA_FUSE_TOPK.")


def _topk_unfused(
score: torch.Tensor,
lengths: torch.Tensor,
topk: int,
row_starts: Optional[torch.Tensor] = None,
topk_op: Callable[..., Tuple[torch.Tensor, torch.Tensor]] = torch.topk,
topk_op_kwargs: Optional[Dict[str, object]] = None,
) -> torch.Tensor:
batch_size, max_score_len = score.shape
topk_indices = score.new_full((batch_size, topk), -1, dtype=torch.int32)
if batch_size == 0 or topk == 0 or max_score_len == 0:
return topk_indices

if row_starts is None:
row_starts = torch.zeros_like(lengths, dtype=torch.int32, device=score.device)
else:
row_starts = row_starts.to(dtype=torch.int32, device=score.device)
lengths = lengths.to(dtype=torch.int32, device=score.device)

col_indices = torch.arange(max_score_len, dtype=torch.int32, device=score.device)
col_indices = col_indices.unsqueeze(0)
row_starts_unsqueezed = row_starts.unsqueeze(1)
row_ends_unsqueezed = (row_starts + lengths).unsqueeze(1)
valid_mask = (col_indices >= row_starts_unsqueezed) & (
col_indices < row_ends_unsqueezed
)

masked_logits = score.masked_fill(~valid_mask, float("-inf"))
valid_topk = min(topk, max_score_len)
topk_kwargs = topk_op_kwargs or {}
topk_scores, topk_col_indices = topk_op(masked_logits, valid_topk, **topk_kwargs)
topk_local_indices = topk_col_indices.to(torch.int32) - row_starts_unsqueezed
topk_local_indices = topk_local_indices.masked_fill(
topk_scores == float("-inf"), -1
)
topk_indices[:, :valid_topk] = topk_local_indices

return topk_indices


def _build_flashinfer_paged_args(
attn_metadata,
row_starts: Optional[torch.Tensor],
cu_seqlens_q_topk: Optional[torch.Tensor],
batch_idx_list: Optional[List[int]],
device: torch.device,
num_rows: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
row_to_batch = (
torch.as_tensor(batch_idx_list, dtype=torch.int32, device=device)
if batch_idx_list is not None
else None
)

if (
row_to_batch is not None
and cu_seqlens_q_topk is not None
and row_to_batch.shape[0] != num_rows
):
q_lens = (cu_seqlens_q_topk[1:] - cu_seqlens_q_topk[:-1]).to(
dtype=torch.int32, device=device
)
row_to_batch = torch.repeat_interleave(row_to_batch, q_lens)

if row_to_batch is None and cu_seqlens_q_topk is not None:
# Decode-like case (one query row per batch) does not need an explicit mapping.
# Avoid dynamic tensor construction in this branch to keep CUDA graph capture safe.
num_batches = cu_seqlens_q_topk.shape[0] - 1
if not (row_starts is None and num_rows == num_batches):
q_lens = (cu_seqlens_q_topk[1:] - cu_seqlens_q_topk[:-1]).to(
dtype=torch.int32, device=device
)
row_to_batch = torch.repeat_interleave(
torch.arange(q_lens.shape[0], dtype=torch.int32, device=device),
q_lens,
)

if row_starts is not None and row_to_batch is None:
raise RuntimeError(
"PAGED topk_transform with row_starts requires cu_seqlens_q metadata."
)

local_row_starts = row_starts
if local_row_starts is not None and row_to_batch is not None:
local_row_starts = (
local_row_starts - attn_metadata.cu_seqlens_k[:-1][row_to_batch]
)

return row_to_batch, local_row_starts


def _flashinfer_tie_break_value() -> int:
mode = envs.SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK.get()
if mode is None:
return 0
mode = mode.lower()
if mode not in _FLASHINFER_TIE_BREAK_VALUES:
raise RuntimeError(
"SGLANG_DSA_TOPK_FLASHINFER_TIE_BREAK must be one of "
f"{tuple(_FLASHINFER_TIE_BREAK_VALUES)} or unset, got {mode!r}."
)
return _FLASHINFER_TIE_BREAK_VALUES[mode]
Loading
Loading