diff --git a/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py b/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py index f0bf4a04a4dd..c7ff4d54a6ad 100644 --- a/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py +++ b/python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py @@ -3,7 +3,7 @@ Both SM90 and SM100+ use the same pool layout: [pool, HV, V, K] (K-last). SM90 (Hopper): full support — decode, prefill, MTP. State dtype: fp32. -SM100+ (Blackwell+): decode-only with bf16 state. More support on the way. +SM100+ (Blackwell+): decode and prefill with bf16 state. MTP verify on the way. Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+). """ @@ -74,8 +74,7 @@ class FlashInferGDNKernel(LinearAttnKernelBase): """FlashInfer kernel for GDN with K-last SSM state layout. SM90 (Hopper): decode uses gather/scatter; prefill and MTP verify supported. - SM100+ (Blackwell+): decode uses pool API (initial_state_indices); prefill - and MTP verify are not supported (use Triton backend for those). + SM100+ (Blackwell+): decode and prefill supported; MTP verify not yet supported. Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+). """ @@ -97,7 +96,7 @@ def __init__(self): raise RuntimeError("FlashInfer GDN decode kernel is unavailable.") sm_major = torch.cuda.get_device_capability()[0] - self.use_state_pool = sm_major != 9 + self.is_sm100plus = sm_major >= 10 if sm_major == 9: if self._prefill_fn is None: @@ -136,7 +135,7 @@ def decode( a_fi = a.view(batch_size, 1, num_v_heads) b_fi = b.view(batch_size, 1, num_v_heads) - if self.use_state_pool: + if self.is_sm100plus: output_fi, _ = self._decode_fn( q=query_fi, k=key_fi, @@ -186,13 +185,6 @@ def extend( query_start_loc: torch.Tensor, **kwargs, ) -> tuple: - if self.use_state_pool: - raise NotImplementedError( - "FlashInfer GDN prefill is not supported on SM100+. " - "Use --linear-attn-prefill-backend triton." - ) - - # SM90: chunked prefill using FlashInfer GDN prefill kernel. from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd total_seq_len = q.shape[1] @@ -207,30 +199,50 @@ def extend( alpha_fi = torch.exp(g[0].to(torch.float32)) beta_fi = beta[0].to(torch.float32) - cu_seqlens_fi = query_start_loc.to(torch.int64) - - # Remap negative padding indices to sentinel slot - ssm_cache_indices = torch.where( - cache_indices >= 0, - cache_indices, - ssm_states.shape[0] - 1, - ).to(torch.int64) - - # FlashInfer requires float32 initial state, K-last layout [B, HV, V, K] - initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32) - - output_fi, output_state_fi = self._prefill_fn( - q=q_fi, - k=k_fi, - v=v_fi, - g=alpha_fi, - beta=beta_fi, - scale=None, - initial_state=initial_state_fi, - output_final_state=True, - cu_seqlens=cu_seqlens_fi, - use_qk_l2norm_in_kernel=False, - ) + if self.is_sm100plus: + # Negative indices (e.g. -1) are padding markers for slots not yet + # assigned to a real sequence; clamp them to 0 (the reserved dummy + # slot) so the FlashInfer kernel never reads out-of-bounds state. + ssm_cache_indices = cache_indices.clamp(min=0).to(torch.int64) + initial_state_fi = ssm_states[ssm_cache_indices].contiguous() + # Pre-allocate bf16 output_state so the kernel compiles and writes the + # bf16 state path directly, avoiding a fp32 allocation and a subsequent + # fp32->bf16 conversion in the scatter step. + output_state_fi = torch.empty_like(initial_state_fi) + output_fi, output_state_fi = self._prefill_fn( + q=q_fi, + k=k_fi, + v=v_fi, + g=alpha_fi, + beta=beta_fi, + scale=None, + initial_state=initial_state_fi, + output_final_state=True, + cu_seqlens=query_start_loc, # already int32 + use_qk_l2norm_in_kernel=False, + output_state=output_state_fi, + ) + else: + # SM90: preserve original negative-index handling (remap to last slot). + ssm_cache_indices = torch.where( + cache_indices >= 0, + cache_indices, + ssm_states.shape[0] - 1, + ).to(torch.int64) + # State must be float32; kernel requires int64 cu_seqlens. + initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32) + output_fi, output_state_fi = self._prefill_fn( + q=q_fi, + k=k_fi, + v=v_fi, + g=alpha_fi, + beta=beta_fi, + scale=None, + initial_state=initial_state_fi, + output_final_state=True, + cu_seqlens=query_start_loc.to(torch.int64), + use_qk_l2norm_in_kernel=False, + ) # Write back state to pool ssm_states.index_copy_( @@ -267,7 +279,7 @@ def target_verify( retrieve_parent_token: torch.Tensor, **kwargs, ) -> torch.Tensor: - if self.use_state_pool: + if self.is_sm100plus: raise NotImplementedError( "FlashInfer GDN MTP verify is not yet supported on SM100+." ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a49db505030..3d7ad910e487 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2970,6 +2970,22 @@ def _handle_linear_attn_backend(self): f"got {self.mamba_ssm_dtype!r}" ) + # SM100+ FlashInfer GDN prefill requires CUDA 13+ (CuTe DSL kernel) + # for correctness and best performance. + prefill = self.linear_attn_prefill_backend or self.linear_attn_backend + cuda_version = torch.version.cuda + cuda_major = int(cuda_version.split(".")[0]) if cuda_version is not None else 0 + if ( + prefill == "flashinfer" + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 10 + and cuda_major < 13 + ): + raise ValueError( + "--linear-attn-prefill-backend flashinfer on SM100+ requires CUDA 13+, " + f"got CUDA {cuda_version or 'unknown'}" + ) + def _handle_context_parallelism(self): if self.attn_cp_size > 1: # The tp_size is the world size, not the real tensor parallel size diff --git a/test/manual/4-gpu-models/test_qwen35_fp4_triton.py b/test/manual/4-gpu-models/test_qwen35_fp4_triton.py index 8ed90d175f8a..88384c2c9bb7 100644 --- a/test/manual/4-gpu-models/test_qwen35_fp4_triton.py +++ b/test/manual/4-gpu-models/test_qwen35_fp4_triton.py @@ -48,12 +48,6 @@ def test_gsm8k(self): extra_args=base_args, variant="Triton", ), - # TODO: Fix this and re-enable it - # ModelLaunchSettings( - # QWEN35_FP4_MODEL, - # extra_args=base_args + ["--linear-attn-decode-backend", "flashinfer"], - # variant="FlashInfer", - # ), ] run_combined_tests( diff --git a/test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py b/test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py new file mode 100644 index 000000000000..d8727a76e2ac --- /dev/null +++ b/test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py @@ -0,0 +1,83 @@ +import unittest + +import torch + +from sglang.test.accuracy_test_runner import AccuracyTestParams +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.run_combined_tests import run_combined_tests +from sglang.test.test_utils import ( + CustomTestCase, + ModelLaunchSettings, +) + +register_cuda_ci(est_time=720, suite="stage-c-test-4-gpu-b200") + +QWEN35_FP4_MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4" +ACC_THRESHOLDS = {QWEN35_FP4_MODEL: {"gsm8k": 0.95}} + +_cuda_major = int(torch.version.cuda.split(".")[0]) if torch.version.cuda else 0 + +_is_sm100_cuda13 = ( + torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 10 + and _cuda_major >= 13 +) + + +@unittest.skipUnless(_is_sm100_cuda13, "requires SM100+ GPU and CUDA 13+") +class TestQwen35FP4FlashInfer(CustomTestCase): + def test_gsm8k(self): + base_args = [ + "--tp-size", + "4", + "--chunked-prefill-size", + "2048", + "--mamba-scheduler-strategy", + "extra_buffer", + "--mamba-track-interval", + "128", + "--mamba-ssm-dtype", + "bfloat16", + "--max-running-requests", + "128", + "--reasoning-parser", + "qwen3", + "--attention-backend", + "trtllm_mha", + "--quantization", + "modelopt_fp4", + "--model-loader-extra-config", + '{"enable_multithread_load": true,"num_threads": 64}', + "--linear-attn-decode-backend", + "flashinfer", + "--linear-attn-prefill-backend", + "flashinfer", + ] + + variants = [ + ModelLaunchSettings( + QWEN35_FP4_MODEL, + extra_args=base_args, + variant="FlashInfer", + ), + ] + + run_combined_tests( + models=variants, + test_name="Qwen3.5-397B-A17B-NVFP4", + accuracy_params=AccuracyTestParams( + dataset="gsm8k", + baseline_accuracy=ACC_THRESHOLDS[QWEN35_FP4_MODEL]["gsm8k"], + num_examples=200, + num_threads=128, + max_tokens=16000, + thinking_mode="qwen3", + temperature=0.6, + top_p=0.95, + top_k=20, + ), + ) + + +if __name__ == "__main__": + unittest.main()