diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4229f40932f7..2c7fe51ff808 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2145,19 +2145,6 @@ def _handle_mamba_radix_cache( not self.enable_mamba_extra_buffer() ), f"mamba extra_buffer is not supported for {model_arch} model" - # FlashInfer GDN decode is incompatible with no_buffer scheduling. - # See https://github.com/sgl-project/sglang/issues/20791 - if ( - self.linear_attn_decode_backend == "flashinfer" - and self.mamba_scheduler_strategy == "no_buffer" - ): - raise ValueError( - "FlashInfer GDN decode (--linear-attn-decode-backend flashinfer) is not " - "compatible with --mamba-scheduler-strategy no_buffer. " - "Please use --mamba-scheduler-strategy extra_buffer instead. " - "See https://github.com/sgl-project/sglang/issues/20791" - ) - if self.enable_mamba_extra_buffer(): # extra_buffer if self.disable_radix_cache: raise ValueError( @@ -2577,9 +2564,27 @@ def _handle_mamba_backend(self): ) def _handle_linear_attn_backend(self): - # SM100+ FlashInfer GDN decode requires bf16 state; SM90 uses float32. import torch + # SM100+: default to FlashInfer GDN decode when the user hasn't + # explicitly chosen a decode backend and mamba-ssm-dtype is bf16 + # (required by FlashInfer GDN on SM100+). + # Fixed in FlashInfer v0.6.7: flashinfer-ai/flashinfer#2810 + # Excluded when MTP speculative decoding is enabled because + # FlashInfer GDN MTP verify is not yet supported on SM100+. + if ( + self.linear_attn_decode_backend is None + and is_sm100_supported() + and self.mamba_ssm_dtype == "bfloat16" + and self.speculative_algorithm is None + ): + self.linear_attn_decode_backend = "flashinfer" + logger.info( + "SM100+ detected with mamba-ssm-dtype=bfloat16, " + "defaulting --linear-attn-decode-backend to flashinfer." + ) + + # SM100+ FlashInfer GDN decode requires bf16 state; SM90 uses float32. decode = self.linear_attn_decode_backend or self.linear_attn_backend if ( decode == "flashinfer"