Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,19 +2145,6 @@ def _handle_mamba_radix_cache(
not self.enable_mamba_extra_buffer()
), f"mamba extra_buffer is not supported for {model_arch} model"

# FlashInfer GDN decode is incompatible with no_buffer scheduling.
# See https://github.com/sgl-project/sglang/issues/20791
if (
self.linear_attn_decode_backend == "flashinfer"
and self.mamba_scheduler_strategy == "no_buffer"
):
raise ValueError(
"FlashInfer GDN decode (--linear-attn-decode-backend flashinfer) is not "
"compatible with --mamba-scheduler-strategy no_buffer. "
"Please use --mamba-scheduler-strategy extra_buffer instead. "
"See https://github.com/sgl-project/sglang/issues/20791"
)

if self.enable_mamba_extra_buffer(): # extra_buffer
if self.disable_radix_cache:
raise ValueError(
Expand Down Expand Up @@ -2577,9 +2564,27 @@ def _handle_mamba_backend(self):
)

def _handle_linear_attn_backend(self):
# SM100+ FlashInfer GDN decode requires bf16 state; SM90 uses float32.
import torch

# SM100+: default to FlashInfer GDN decode when the user hasn't
# explicitly chosen a decode backend and mamba-ssm-dtype is bf16
# (required by FlashInfer GDN on SM100+).
# Fixed in FlashInfer v0.6.7: flashinfer-ai/flashinfer#2810
# Excluded when MTP speculative decoding is enabled because
# FlashInfer GDN MTP verify is not yet supported on SM100+.
if (
self.linear_attn_decode_backend is None
and is_sm100_supported()
and self.mamba_ssm_dtype == "bfloat16"
and self.speculative_algorithm is None
):
self.linear_attn_decode_backend = "flashinfer"
logger.info(
"SM100+ detected with mamba-ssm-dtype=bfloat16, "
"defaulting --linear-attn-decode-backend to flashinfer."
)

# SM100+ FlashInfer GDN decode requires bf16 state; SM90 uses float32.
decode = self.linear_attn_decode_backend or self.linear_attn_backend
if (
decode == "flashinfer"
Expand Down
Loading