diff --git a/tests/kernels/test_fused_recurrent_packed_decode.py b/tests/kernels/test_fused_recurrent_packed_decode.py new file mode 100644 index 000000000000..f81f3c776e98 --- /dev/null +++ b/tests/kernels/test_fused_recurrent_packed_decode.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.fla.ops import ( + fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode, +) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("strided_mixed_qkv", [False, True]) +def test_fused_recurrent_packed_decode_matches_reference( + dtype: torch.dtype, strided_mixed_qkv: bool +): + torch.manual_seed(0) + + # Small but representative GDN config (Qwen3Next defaults are K=128, V=128). + B = 32 + H = 4 + HV = 8 # grouped value attention: HV must be divisible by H + K = 128 + V = 128 + qkv_dim = 2 * (H * K) + (HV * V) + + device = torch.device("cuda") + + if strided_mixed_qkv: + # Simulate a packed view into a larger projection buffer: + # mixed_qkv.stride(0) > mixed_qkv.shape[1] + proj = torch.randn((B, qkv_dim + 64), device=device, dtype=dtype) + mixed_qkv = proj[:, :qkv_dim] + else: + mixed_qkv = torch.randn((B, qkv_dim), device=device, dtype=dtype) + + a = torch.randn((B, HV), device=device, dtype=dtype) + b = torch.randn((B, HV), device=device, dtype=dtype) + A_log = torch.randn((HV,), device=device, dtype=dtype) + dt_bias = torch.randn((HV,), device=device, dtype=dtype) + + # Continuous batching indices (include PAD_SLOT_ID=-1 cases). + ssm_state_indices = torch.arange(B, device=device, dtype=torch.int32) + ssm_state_indices[-3:] = -1 + + state0 = torch.randn((B, HV, V, K), device=device, dtype=dtype) + state_ref = state0.clone() + state_packed = state0.clone() + + out_packed = torch.empty((B, 1, HV, V), device=device, dtype=dtype) + + # Reference path: materialize contiguous Q/K/V + explicit gating. + q, k, v = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1) + q = q.view(B, H, K).unsqueeze(1).contiguous() + k = k.view(B, H, K).unsqueeze(1).contiguous() + v = v.view(B, HV, V).unsqueeze(1).contiguous() + + x = a.float() + dt_bias.float() + softplus_x = torch.where( + x <= 20.0, torch.log1p(torch.exp(torch.clamp(x, max=20.0))), x + ) + g = (-torch.exp(A_log.float()) * softplus_x).unsqueeze(1) + beta = torch.sigmoid(b.float()).to(dtype).unsqueeze(1) + + out_ref, state_ref = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=K**-0.5, + initial_state=state_ref, + inplace_final_state=True, + cu_seqlens=None, + ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=True, + ) + + # Packed path: fused gating + recurrent directly from packed mixed_qkv. + fused_recurrent_gated_delta_rule_packed_decode( + mixed_qkv=mixed_qkv, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + scale=K**-0.5, + initial_state=state_packed, + out=out_packed, + ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=True, + ) + + atol = 2e-2 if dtype != torch.float32 else 1e-4 + rtol = 1e-2 if dtype != torch.float32 else 1e-4 + torch.testing.assert_close(out_packed, out_ref, rtol=rtol, atol=atol) + torch.testing.assert_close(state_packed, state_ref, rtol=rtol, atol=atol) diff --git a/vllm/envs.py b/vllm/envs.py index 716810da1c27..2fe95d5ac17b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -96,6 +96,7 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_OINK_OPS: bool = False VLLM_ROCM_USE_AITER: bool = False @@ -899,6 +900,9 @@ def _get_or_set_default() -> str: "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ["VLLM_DISABLED_KERNELS"].split(","), + "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool( + int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1")) + ), # Disable pynccl (using torch.distributed instead) "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py index 06bd38d4c80e..e52387a20b41 100644 --- a/vllm/model_executor/layers/fla/ops/__init__.py +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -7,7 +7,10 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from .chunk import chunk_gated_delta_rule -from .fused_recurrent import fused_recurrent_gated_delta_rule +from .fused_recurrent import ( + fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode, +) from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update from .layernorm_guard import RMSNormGated @@ -15,5 +18,6 @@ "RMSNormGated", "chunk_gated_delta_rule", "fused_recurrent_gated_delta_rule", + "fused_recurrent_gated_delta_rule_packed_decode", "fused_sigmoid_gating_delta_rule_update", ] diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 67d77e88294c..f7b562f64771 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -252,6 +252,231 @@ def fused_recurrent_gated_delta_rule_fwd( return o, final_state +@triton.jit +def fused_recurrent_gated_delta_rule_packed_decode_kernel( + mixed_qkv, + a, + b, + A_log, + dt_bias, + o, + h0, + ht, + ssm_state_indices, + scale, + stride_mixed_qkv_tok: tl.constexpr, + stride_a_tok: tl.constexpr, + stride_b_tok: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + SOFTPLUS_THRESHOLD: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64) + p_o = o + (i_n * HV + i_hv) * V + o_v + + if state_idx < 0: + zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty) + tl.store(p_o, zero, mask=mask_v) + return + + p_h0 = h0 + state_idx * stride_init_state_token + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok + q_off = i_h * K + o_k + k_off = (H * K) + i_h * K + o_k + v_off = (2 * H * K) + i_hv * V + o_v + b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + + a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32) + b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32) + A_log_val = tl.load(A_log + i_hv).to(tl.float32) + dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32) + x = a_val + dt_bias_val + softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x) + g_val = -tl.exp(A_log_val) * softplus_x + beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32) + + b_h *= exp(g_val) + b_v -= tl.sum(b_h * b_k[None, :], 1) + b_v *= beta_val + b_h += b_v[:, None] * b_k[None, :] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_ht = ht + state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_packed_decode( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + out: torch.Tensor, + ssm_state_indices: torch.Tensor, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + if mixed_qkv.ndim != 2: + raise ValueError( + f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})." + ) + if mixed_qkv.stride(-1) != 1: + raise ValueError("`mixed_qkv` must be contiguous in the last dim.") + if a.ndim != 2 or b.ndim != 2: + raise ValueError( + f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})." + ) + if a.stride(-1) != 1 or b.stride(-1) != 1: + raise ValueError("`a`/`b` must be contiguous in the last dim.") + if A_log.ndim != 1 or dt_bias.ndim != 1: + raise ValueError("`A_log`/`dt_bias` must be 1D tensors.") + if A_log.stride(0) != 1 or dt_bias.stride(0) != 1: + raise ValueError("`A_log`/`dt_bias` must be contiguous.") + if ssm_state_indices.ndim != 1: + raise ValueError( + f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})." + ) + if not out.is_contiguous(): + raise ValueError("`out` must be contiguous.") + + dev = mixed_qkv.device + if ( + a.device != dev + or b.device != dev + or A_log.device != dev + or dt_bias.device != dev + or initial_state.device != dev + or out.device != dev + or ssm_state_indices.device != dev + ): + raise ValueError("All inputs must be on the same device.") + + B = mixed_qkv.shape[0] + if a.shape[0] != B or b.shape[0] != B: + raise ValueError( + "Mismatched batch sizes: " + f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}." + ) + if ssm_state_indices.shape[0] != B: + raise ValueError( + f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))." + ) + + if initial_state.ndim != 4: + raise ValueError( + f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})." + ) + if initial_state.stride(-1) != 1: + raise ValueError("`initial_state` must be contiguous in the last dim.") + HV, V, K = initial_state.shape[-3:] + if a.shape[1] != HV or b.shape[1] != HV: + raise ValueError( + f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})." + ) + if A_log.numel() != HV or dt_bias.numel() != HV: + raise ValueError( + f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})." + ) + if out.shape != (B, 1, HV, V): + raise ValueError( + f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})." + ) + + qkv_dim = mixed_qkv.shape[1] + qk_dim = qkv_dim - HV * V + if qk_dim <= 0 or qk_dim % 2 != 0: + raise ValueError( + f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}." + ) + q_dim = qk_dim // 2 + if q_dim % K != 0: + raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.") + H = q_dim // K + if H <= 0 or HV % H != 0: + raise ValueError( + f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}." + ) + + BK = triton.next_power_of_2(K) + if triton.cdiv(K, BK) != 1: + raise ValueError( + f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})." + ) + BV = min(triton.next_power_of_2(V), 32) + num_stages = 3 + num_warps = 1 + + stride_mixed_qkv_tok = mixed_qkv.stride(0) + stride_a_tok = a.stride(0) + stride_b_tok = b.stride(0) + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = initial_state.stride(0) + stride_indices_seq = ssm_state_indices.stride(0) + + NV = triton.cdiv(V, BV) + grid = (NV, B * HV) + fused_recurrent_gated_delta_rule_packed_decode_kernel[grid]( + mixed_qkv=mixed_qkv, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + o=out, + h0=initial_state, + ht=initial_state, + ssm_state_indices=ssm_state_indices, + scale=scale, + stride_mixed_qkv_tok=stride_mixed_qkv_tok, + stride_a_tok=stride_a_tok, + stride_b_tok=stride_b_tok, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + SOFTPLUS_THRESHOLD=20.0, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + return out, initial_state + + class FusedRecurrentFunction(torch.autograd.Function): @staticmethod def forward( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index c5c02d4bcc98..ea3914b06f5a 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -10,6 +10,7 @@ from torch import nn from transformers.activations import ACT2FN +from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CacheConfig, @@ -34,6 +35,7 @@ chunk_gated_delta_rule as fla_chunk_gated_delta_rule, ) from vllm.model_executor.layers.fla.ops import ( + fused_recurrent_gated_delta_rule_packed_decode, fused_sigmoid_gating_delta_rule_update, ) from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd @@ -474,6 +476,9 @@ def __init__( ) self.chunk_gated_delta_rule = ChunkGatedDeltaRule() + self.enable_packed_recurrent_decode = ( + envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -652,9 +657,6 @@ def _forward_core( a: torch.Tensor, core_attn_out: torch.Tensor, ): - """ - Core attention computation (called by custom op). - """ forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata @@ -665,6 +667,22 @@ def _forward_core( assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, GDNAttentionMetadata) + + if ( + self.enable_packed_recurrent_decode + and attn_metadata.spec_sequence_masks is None + and attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes > 0 + ): + return self._forward_core_decode_non_spec( + mixed_qkv=mixed_qkv, + b=b, + a=a, + core_attn_out=core_attn_out, + attn_metadata=attn_metadata, + virtual_engine=forward_context.virtual_engine, + ) + has_initial_state = attn_metadata.has_initial_state spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc @@ -849,6 +867,55 @@ def _forward_core( else: core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + def _forward_core_decode_non_spec( + self, + mixed_qkv: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, + core_attn_out: torch.Tensor, + attn_metadata: GDNAttentionMetadata, + virtual_engine: int, + ): + """ + Core attention computation with a packed non-spec decode fast path. + """ + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache[virtual_engine] + conv_state = self_kv_cache[0].transpose(-1, -2) + ssm_state = self_kv_cache[1] + num_actual_tokens = attn_metadata.num_actual_tokens + + mixed_qkv = mixed_qkv[:num_actual_tokens] + b = b[:num_actual_tokens] + a = a[:num_actual_tokens] + + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + mixed_qkv_non_spec = causal_conv1d_update( + mixed_qkv, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + validate_data=False, + ) + out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1) + fused_recurrent_gated_delta_rule_packed_decode( + mixed_qkv=mixed_qkv_non_spec, + a=a, + b=b, + A_log=self.A_log, + dt_bias=self.dt_bias, + scale=self.head_k_dim**-0.5, + initial_state=ssm_state, + out=out_buf, + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + use_qk_l2norm_in_kernel=True, + ) + return + class Qwen3NextAttention(nn.Module): def __init__(