Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Both SM90 and SM100+ use the same pool layout: [pool, HV, V, K] (K-last).

SM90 (Hopper): full support — decode, prefill, MTP. State dtype: fp32.
SM100+ (Blackwell+): decode-only with bf16 state. More support on the way.
SM100+ (Blackwell+): decode and prefill with bf16 state. MTP verify on the way.

Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+).
"""
Expand Down Expand Up @@ -74,8 +74,7 @@ class FlashInferGDNKernel(LinearAttnKernelBase):
"""FlashInfer kernel for GDN with K-last SSM state layout.

SM90 (Hopper): decode uses gather/scatter; prefill and MTP verify supported.
SM100+ (Blackwell+): decode uses pool API (initial_state_indices); prefill
and MTP verify are not supported (use Triton backend for those).
SM100+ (Blackwell+): decode and prefill supported; MTP verify not yet supported.

Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+).
"""
Expand All @@ -97,7 +96,7 @@ def __init__(self):
raise RuntimeError("FlashInfer GDN decode kernel is unavailable.")

sm_major = torch.cuda.get_device_capability()[0]
self.use_state_pool = sm_major != 9
self.is_sm100plus = sm_major >= 10

if sm_major == 9:
if self._prefill_fn is None:
Expand Down Expand Up @@ -136,7 +135,7 @@ def decode(
a_fi = a.view(batch_size, 1, num_v_heads)
b_fi = b.view(batch_size, 1, num_v_heads)

if self.use_state_pool:
if self.is_sm100plus:
output_fi, _ = self._decode_fn(
q=query_fi,
k=key_fi,
Expand Down Expand Up @@ -186,13 +185,6 @@ def extend(
query_start_loc: torch.Tensor,
**kwargs,
) -> tuple:
if self.use_state_pool:
raise NotImplementedError(
"FlashInfer GDN prefill is not supported on SM100+. "
"Use --linear-attn-prefill-backend triton."
)

# SM90: chunked prefill using FlashInfer GDN prefill kernel.
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd

total_seq_len = q.shape[1]
Expand All @@ -207,30 +199,50 @@ def extend(
alpha_fi = torch.exp(g[0].to(torch.float32))
beta_fi = beta[0].to(torch.float32)

cu_seqlens_fi = query_start_loc.to(torch.int64)

# Remap negative padding indices to sentinel slot
ssm_cache_indices = torch.where(
cache_indices >= 0,
cache_indices,
ssm_states.shape[0] - 1,
).to(torch.int64)

# FlashInfer requires float32 initial state, K-last layout [B, HV, V, K]
initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32)

output_fi, output_state_fi = self._prefill_fn(
q=q_fi,
k=k_fi,
v=v_fi,
g=alpha_fi,
beta=beta_fi,
scale=None,
initial_state=initial_state_fi,
output_final_state=True,
cu_seqlens=cu_seqlens_fi,
use_qk_l2norm_in_kernel=False,
)
if self.is_sm100plus:
# Negative indices (e.g. -1) are padding markers for slots not yet
# assigned to a real sequence; clamp them to 0 (the reserved dummy
# slot) so the FlashInfer kernel never reads out-of-bounds state.
ssm_cache_indices = cache_indices.clamp(min=0).to(torch.int64)
initial_state_fi = ssm_states[ssm_cache_indices].contiguous()
# Pre-allocate bf16 output_state so the kernel compiles and writes the
# bf16 state path directly, avoiding a fp32 allocation and a subsequent
# fp32->bf16 conversion in the scatter step.
output_state_fi = torch.empty_like(initial_state_fi)
output_fi, output_state_fi = self._prefill_fn(
q=q_fi,
k=k_fi,
v=v_fi,
g=alpha_fi,
beta=beta_fi,
scale=None,
initial_state=initial_state_fi,
output_final_state=True,
cu_seqlens=query_start_loc, # already int32
use_qk_l2norm_in_kernel=False,
output_state=output_state_fi,
)
else:
# SM90: preserve original negative-index handling (remap to last slot).
ssm_cache_indices = torch.where(
cache_indices >= 0,
cache_indices,
ssm_states.shape[0] - 1,
).to(torch.int64)
# State must be float32; kernel requires int64 cu_seqlens.
initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32)
output_fi, output_state_fi = self._prefill_fn(
q=q_fi,
k=k_fi,
v=v_fi,
g=alpha_fi,
beta=beta_fi,
scale=None,
initial_state=initial_state_fi,
output_final_state=True,
cu_seqlens=query_start_loc.to(torch.int64),
use_qk_l2norm_in_kernel=False,
)

# Write back state to pool
ssm_states.index_copy_(
Expand Down Expand Up @@ -267,7 +279,7 @@ def target_verify(
retrieve_parent_token: torch.Tensor,
**kwargs,
) -> torch.Tensor:
if self.use_state_pool:
if self.is_sm100plus:
raise NotImplementedError(
"FlashInfer GDN MTP verify is not yet supported on SM100+."
)
Expand Down
16 changes: 16 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2970,6 +2970,22 @@ def _handle_linear_attn_backend(self):
f"got {self.mamba_ssm_dtype!r}"
)

# SM100+ FlashInfer GDN prefill requires CUDA 13+ (CuTe DSL kernel)
# for correctness and best performance.
prefill = self.linear_attn_prefill_backend or self.linear_attn_backend
cuda_version = torch.version.cuda
cuda_major = int(cuda_version.split(".")[0]) if cuda_version is not None else 0
if (
Comment thread
yuan-luo marked this conversation as resolved.
prefill == "flashinfer"
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 10
and cuda_major < 13
):
raise ValueError(
"--linear-attn-prefill-backend flashinfer on SM100+ requires CUDA 13+, "
f"got CUDA {cuda_version or 'unknown'}"
)

def _handle_context_parallelism(self):
if self.attn_cp_size > 1:
# The tp_size is the world size, not the real tensor parallel size
Expand Down
6 changes: 0 additions & 6 deletions test/manual/4-gpu-models/test_qwen35_fp4_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ def test_gsm8k(self):
extra_args=base_args,
variant="Triton",
),
# TODO: Fix this and re-enable it
# ModelLaunchSettings(
# QWEN35_FP4_MODEL,
# extra_args=base_args + ["--linear-attn-decode-backend", "flashinfer"],
# variant="FlashInfer",
# ),
]

run_combined_tests(
Expand Down
83 changes: 83 additions & 0 deletions test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest

import torch

from sglang.test.accuracy_test_runner import AccuracyTestParams
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.run_combined_tests import run_combined_tests
from sglang.test.test_utils import (
CustomTestCase,
ModelLaunchSettings,
)

register_cuda_ci(est_time=720, suite="stage-c-test-4-gpu-b200")

QWEN35_FP4_MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4"
ACC_THRESHOLDS = {QWEN35_FP4_MODEL: {"gsm8k": 0.95}}

_cuda_major = int(torch.version.cuda.split(".")[0]) if torch.version.cuda else 0

_is_sm100_cuda13 = (
torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 10
and _cuda_major >= 13
)


@unittest.skipUnless(_is_sm100_cuda13, "requires SM100+ GPU and CUDA 13+")
class TestQwen35FP4FlashInfer(CustomTestCase):
def test_gsm8k(self):
base_args = [
"--tp-size",
"4",
"--chunked-prefill-size",
"2048",
"--mamba-scheduler-strategy",
"extra_buffer",
"--mamba-track-interval",
"128",
"--mamba-ssm-dtype",
"bfloat16",
"--max-running-requests",
"128",
"--reasoning-parser",
"qwen3",
"--attention-backend",
"trtllm_mha",
"--quantization",
"modelopt_fp4",
"--model-loader-extra-config",
'{"enable_multithread_load": true,"num_threads": 64}',
"--linear-attn-decode-backend",
"flashinfer",
"--linear-attn-prefill-backend",
"flashinfer",
]

variants = [
ModelLaunchSettings(
QWEN35_FP4_MODEL,
extra_args=base_args,
variant="FlashInfer",
),
]

run_combined_tests(
models=variants,
test_name="Qwen3.5-397B-A17B-NVFP4",
accuracy_params=AccuracyTestParams(
dataset="gsm8k",
baseline_accuracy=ACC_THRESHOLDS[QWEN35_FP4_MODEL]["gsm8k"],
num_examples=200,
num_threads=128,
max_tokens=16000,
thinking_mode="qwen3",
temperature=0.6,
top_p=0.95,
top_k=20,
),
)


if __name__ == "__main__":
unittest.main()
Loading