diff --git a/tests/kernels/test_fused_recurrent_packed_decode.py b/tests/kernels/test_fused_recurrent_packed_decode.py new file mode 100644 index 000000000000..4a332ef9bf91 --- /dev/null +++ b/tests/kernels/test_fused_recurrent_packed_decode.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.fla.ops import ( + fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode_fwd, +) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("strided_mixed_qkv", [False, True]) +def test_fused_recurrent_packed_decode_matches_reference( + dtype: torch.dtype, strided_mixed_qkv: bool +): + torch.manual_seed(0) + + # Small but representative GDN config (Qwen3Next defaults are K=128, V=128). + B = 32 + H = 4 + HV = 8 # grouped value attention: HV must be divisible by H + K = 128 + V = 128 + qkv_dim = 2 * (H * K) + (HV * V) + + device = torch.device("cuda") + + if strided_mixed_qkv: + # Simulate a packed view into a larger projection buffer: + # mixed_qkv.stride(0) > mixed_qkv.shape[1] + proj = torch.randn((B, qkv_dim + 64), device=device, dtype=dtype) + mixed_qkv = proj[:, :qkv_dim] + else: + mixed_qkv = torch.randn((B, qkv_dim), device=device, dtype=dtype) + + a = torch.randn((B, HV), device=device, dtype=dtype) + b = torch.randn((B, HV), device=device, dtype=dtype) + A_log = torch.randn((HV,), device=device, dtype=dtype) + dt_bias = torch.randn((HV,), device=device, dtype=dtype) + + # Continuous batching indices (include PAD_SLOT_ID=-1 cases). + ssm_state_indices = torch.arange(B, device=device, dtype=torch.int32) + ssm_state_indices[-3:] = -1 + + state0 = torch.randn((B, HV, V, K), device=device, dtype=dtype) + state_ref = state0.clone() + state_packed = state0.clone() + + out_ref = torch.empty((B, 1, HV, V), device=device, dtype=dtype) + out_packed = torch.empty((B, 1, HV, V), device=device, dtype=dtype) + + # Reference path: materialize contiguous Q/K/V + explicit gating. + q, k, v = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1) + q = q.view(B, H, K).unsqueeze(1).contiguous() + k = k.view(B, H, K).unsqueeze(1).contiguous() + v = v.view(B, HV, V).unsqueeze(1).contiguous() + + x = a.float() + dt_bias.float() + softplus_x = torch.where( + x <= 20.0, torch.log1p(torch.exp(torch.clamp(x, max=20.0))), x + ) + g = (-torch.exp(A_log.float()) * softplus_x).unsqueeze(1) + beta = torch.sigmoid(b.float()).to(dtype).unsqueeze(1) + + fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=state_ref, + out=out_ref, + inplace_final_state=True, + cu_seqlens=None, + ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=True, + ) + + # Packed path: fused gating + recurrent directly from packed mixed_qkv. + fused_recurrent_gated_delta_rule_packed_decode_fwd( + mixed_qkv=mixed_qkv, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + scale=K**-0.5, + initial_state=state_packed, + out=out_packed, + ssm_state_indices=ssm_state_indices, + use_qk_l2norm_in_kernel=True, + ) + + torch.testing.assert_close(out_packed, out_ref, rtol=1e-2, atol=2e-2) + torch.testing.assert_close(state_packed, state_ref, rtol=1e-2, atol=2e-2) diff --git a/vllm/envs.py b/vllm/envs.py index 864ea6649a49..ee3520e68f33 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -218,6 +218,7 @@ VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" + VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = False VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE: bool = False VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL: bool = False @@ -899,6 +900,12 @@ def _get_or_set_default() -> str: "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ["VLLM_DISABLED_KERNELS"].split(","), + # Enable an experimental packed recurrent decode fast path for FLA models. + # Disabled by default for safety; when disabled, code falls back to the + # default unfused path. + "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool( + int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "0")) + ), # Disable pynccl (using torch.distributed instead) "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py index c19cc14ba692..291eedfbd5e9 100644 --- a/vllm/model_executor/layers/fla/ops/__init__.py +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -7,11 +7,15 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from .chunk import chunk_gated_delta_rule -from .fused_recurrent import fused_recurrent_gated_delta_rule +from .fused_recurrent import ( + fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode_fwd, +) from .layernorm_guard import RMSNormGated __all__ = [ "RMSNormGated", "chunk_gated_delta_rule", "fused_recurrent_gated_delta_rule", + "fused_recurrent_gated_delta_rule_packed_decode_fwd", ] diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 67d77e88294c..28f2536db4d8 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -106,16 +106,23 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: i_t = 0 - # Load state index and check for PAD_SLOT_ID (-1) - state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( - tl.int64 - ) - # Skip if state index is invalid (PAD_SLOT_ID = -1) + # Load state index and check for PAD_SLOT_ID (-1). + state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok + ).to(tl.int64) + # If state index is invalid (PAD_SLOT_ID = -1), write zeros to the + # output and return early. This is important for padded requests in + # CUDAGraph mode where `o` may be reused across replays. if state_idx < 0: + p_o_pad = p_o + zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty) + for _ in range(0, T): + tl.store(p_o_pad, zero, mask=mask_v) + p_o_pad += HV * V return p_h0 = h0 + state_idx * stride_init_state_token else: - p_h0 = h0 + bos * HV * V * K + p_h0 = h0 + i_n * stride_init_state_token p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) @@ -150,13 +157,20 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( # keep the states for multi-query tokens if INPLACE_FINAL_STATE: - # Load state index and check for PAD_SLOT_ID (-1) - final_state_idx = tl.load( - ssm_state_indices + i_n * stride_indices_seq + i_t - ).to(tl.int64) - # Only store if state index is valid (not PAD_SLOT_ID) - if final_state_idx >= 0: - p_ht = ht + final_state_idx * stride_final_state_token + if IS_CONTINUOUS_BATCHING: + # Load state index and check for PAD_SLOT_ID (-1). + final_state_idx = tl.load( + ssm_state_indices + + i_n * stride_indices_seq + + i_t * stride_indices_tok + ).to(tl.int64) + # Only store if state index is valid (not PAD_SLOT_ID). + if final_state_idx >= 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + p_ht = ht + i_n * stride_final_state_token p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) else: @@ -175,6 +189,266 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta += HV * (V if IS_BETA_HEADWISE else 1) +@triton.jit +def fused_recurrent_gated_delta_rule_packed_decode_fwd_kernel( + mixed_qkv, + a, + b, + A_log, + dt_bias, + o, + h0, + ht, + ssm_state_indices, + scale, + stride_mixed_qkv_tok: tl.constexpr, + stride_a_tok: tl.constexpr, + stride_b_tok: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + SOFTPLUS_THRESHOLD: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, +): + """Fused recurrent kernel for decode-uniform (B=num_tokens, T=1). + + - Reads Q/K/V directly from a packed `mixed_qkv` row: [Q | K | V]. + - Fuses gated-delta gating (a/b -> g/beta) inside the recurrent kernel. + - Supports continuous batching via `ssm_state_indices` (PAD_SLOT_ID = -1). + """ + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64) + + # Output tensor layout is [B, 1, HV, V] (contiguous). The size-1 time dim + # is elided in the address computation. + p_o = o + (i_n * HV + i_hv) * V + o_v + + # If state index is invalid (PAD_SLOT_ID = -1), write zeros to the output + # and return early. This is important for padded requests in CUDAGraph mode + # where `o` may be reused across replays. + if state_idx < 0: + zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty) + tl.store(p_o, zero, mask=mask_v) + return + + # Load initial state for this sequence/head. + p_h0 = h0 + state_idx * stride_init_state_token + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + # Load q/k/v from packed `mixed_qkv` for this token. + p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok + q_off = i_h * K + o_k + k_off = (H * K) + i_h * K + o_k + v_off = (2 * H * K) + i_hv * V + o_v + b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + + # Fused gating: + # g = -exp(A_log) * softplus(a + dt_bias), beta = sigmoid(b) + a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32) + b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32) + A_log_val = tl.load(A_log + i_hv).to(tl.float32) + dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32) + x = a_val + dt_bias_val + softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x) + g_val = -tl.exp(A_log_val) * softplus_x + beta_val = tl.sigmoid(b_val) + # Match the existing behavior where `beta` is written out in `b.dtype` and + # reloaded as float32 inside the recurrent kernel. + beta_val = beta_val.to(b.dtype.element_ty).to(tl.float32) + + # Recurrent update for a single token (T=1). + b_h *= exp(g_val) + b_v -= tl.sum(b_h * b_k[None, :], 1) + b_v *= beta_val + b_h += b_v[:, None] * b_k[None, :] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # Store final state (in-place). + p_ht = ht + state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_packed_decode_fwd( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + out: torch.Tensor, + ssm_state_indices: torch.Tensor, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Decode-only fast path (uniform batching): fused packed-QKV recurrent. + + Expects: + - mixed_qkv: [B, 2*H*K + HV*V] (packed [Q|K|V] per token) + - a/b: [B, HV] + - out: [B, 1, HV, V] (contiguous) + - initial_state: [S, HV, V, K], updated in-place using `ssm_state_indices`. + """ + if mixed_qkv.ndim != 2: + raise ValueError( + f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})." + ) + if mixed_qkv.stride(-1) != 1: + raise ValueError("`mixed_qkv` must be contiguous in the last dim.") + if a.ndim != 2 or b.ndim != 2: + raise ValueError( + f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})." + ) + if a.stride(-1) != 1 or b.stride(-1) != 1: + raise ValueError("`a`/`b` must be contiguous in the last dim.") + if ssm_state_indices.ndim != 1: + raise ValueError( + f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})." + ) + if not out.is_contiguous(): + raise ValueError("`out` must be contiguous.") + + if not torch.is_floating_point(mixed_qkv): + raise ValueError("`mixed_qkv` must be a floating tensor.") + if not torch.is_floating_point(a) or not torch.is_floating_point(b): + raise ValueError("`a`/`b` must be floating tensors.") + if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias): + raise ValueError("`A_log`/`dt_bias` must be floating tensors.") + if not torch.is_floating_point(initial_state) or not torch.is_floating_point(out): + raise ValueError("`initial_state`/`out` must be floating tensors.") + + dev = mixed_qkv.device + if ( + a.device != dev + or b.device != dev + or A_log.device != dev + or dt_bias.device != dev + or initial_state.device != dev + or out.device != dev + or ssm_state_indices.device != dev + ): + raise ValueError("All inputs must be on the same device.") + + B = mixed_qkv.shape[0] + if a.shape[0] != B or b.shape[0] != B: + raise ValueError( + f"Mismatched batch sizes: mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}." + ) + if ssm_state_indices.shape[0] != B: + raise ValueError( + f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))." + ) + + if initial_state.ndim != 4: + raise ValueError( + f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})." + ) + if initial_state.stride(-1) != 1: + raise ValueError("`initial_state` must be contiguous in the last dim.") + HV, V, K = initial_state.shape[-3:] + if a.shape[1] != HV or b.shape[1] != HV: + raise ValueError( + f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})." + ) + if A_log.numel() != HV or dt_bias.numel() != HV: + raise ValueError( + f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})." + ) + + if out.shape != (B, 1, HV, V): + raise ValueError( + f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})." + ) + + qkv_dim = mixed_qkv.shape[1] + qk_dim = qkv_dim - HV * V + if qk_dim <= 0 or qk_dim % 2 != 0: + raise ValueError( + f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}." + ) + q_dim = qk_dim // 2 + if q_dim % K != 0: + raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.") + H = q_dim // K + if H <= 0 or HV % H != 0: + raise ValueError( + f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}." + ) + + BK = triton.next_power_of_2(K) + if triton.cdiv(K, BK) != 1: + raise ValueError( + f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})." + ) + + BV = min(triton.next_power_of_2(V), 32) + num_stages = 3 + num_warps = 1 + + stride_mixed_qkv_tok = mixed_qkv.stride(0) + stride_a_tok = a.stride(0) + stride_b_tok = b.stride(0) + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = initial_state.stride(0) + stride_indices_seq = ssm_state_indices.stride(0) + + NV = triton.cdiv(V, BV) + grid = (NV, B * HV) + fused_recurrent_gated_delta_rule_packed_decode_fwd_kernel[grid]( + mixed_qkv=mixed_qkv, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + o=out, + h0=initial_state, + ht=initial_state, + ssm_state_indices=ssm_state_indices, + scale=scale, + stride_mixed_qkv_tok=stride_mixed_qkv_tok, + stride_a_tok=stride_a_tok, + stride_b_tok=stride_b_tok, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + SOFTPLUS_THRESHOLD=20.0, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + return out, initial_state + + def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, @@ -183,6 +457,7 @@ def fused_recurrent_gated_delta_rule_fwd( beta: torch.Tensor, scale: float, initial_state: torch.Tensor, + out: torch.Tensor | None = None, inplace_final_state: bool = True, cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, @@ -198,7 +473,22 @@ def fused_recurrent_gated_delta_rule_fwd( num_stages = 3 num_warps = 1 - o = q.new_empty(NK, *v.shape) + if out is None: + o = q.new_empty(NK, *v.shape) + else: + if out.shape != v.shape: + raise ValueError( + f"`out` must have the same shape as `v` (got out.shape={tuple(out.shape)}, v.shape={tuple(v.shape)})" + ) + if not torch.is_floating_point(out): + raise ValueError("`out` must be a floating tensor.") + if out.device != q.device: + raise ValueError( + f"`out` must be on the same device as `q` (got out.device={out.device}, q.device={q.device})" + ) + if not out.is_contiguous(): + raise ValueError("`out` must be contiguous.") + o = out.unsqueeze(0) if inplace_final_state: final_state = initial_state else: @@ -263,6 +553,7 @@ def forward( beta: torch.Tensor, scale: float, initial_state: torch.Tensor, + out: torch.Tensor | None = None, inplace_final_state: bool = True, cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, @@ -277,6 +568,7 @@ def forward( beta=beta.contiguous(), scale=scale, initial_state=initial_state, + out=out, inplace_final_state=inplace_final_state, cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, @@ -295,6 +587,7 @@ def fused_recurrent_gated_delta_rule( beta: torch.Tensor = None, scale: float = None, initial_state: torch.Tensor = None, + out: torch.Tensor | None = None, inplace_final_state: bool = True, cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, @@ -384,6 +677,7 @@ def fused_recurrent_gated_delta_rule( beta, scale, initial_state, + out, inplace_final_state, cu_seqlens, ssm_state_indices, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 7f1386d7be57..8c8f818af32c 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -10,6 +10,7 @@ from torch import nn from transformers.activations import ACT2FN +from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CacheConfig, @@ -35,6 +36,7 @@ ) from vllm.model_executor.layers.fla.ops import ( fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode_fwd, ) from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd from vllm.model_executor.layers.fused_moe import SharedFusedMoE @@ -726,6 +728,41 @@ def _forward_core( else: mixed_qkv_non_spec = None + # Decode-uniform fast path (B=num_tokens, T=1): avoid materializing + # contiguous Q/K/V and the standalone gating kernel by directly feeding + # packed `mixed_qkv` into a fused recurrent kernel. + is_uniform_decode = ( + spec_sequence_masks is None + and attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes > 0 + ) + if ( + envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE + and is_uniform_decode + and mixed_qkv_non_spec is not None + ): + out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1) + try: + fused_recurrent_gated_delta_rule_packed_decode_fwd( + mixed_qkv=mixed_qkv_non_spec, + a=a, + b=b, + A_log=self.A_log, + dt_bias=self.dt_bias, + scale=self.head_k_dim**-0.5, + initial_state=ssm_state, + out=out_buf, + ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens], + use_qk_l2norm_in_kernel=True, + ) + return + except ValueError as exc: + logger.warning_once( + "Packed recurrent decode fast path unavailable; falling back " + "to default path: %s", + exc, + ) + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( mixed_qkv_non_spec @@ -771,6 +808,7 @@ def _forward_core( core_attn_out_spec, last_recurrent_state = None, None # 2.2: Process the remaining part + wrote_core_attn_out_non_spec = False if attn_metadata.num_prefills > 0: initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 @@ -793,22 +831,45 @@ def _forward_core( ssm_state.dtype ) elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[ - : attn_metadata.num_decodes + 1 - ], - ssm_state_indices=non_spec_state_indices_tensor, - use_qk_l2norm_in_kernel=True, + if spec_sequence_masks is None: + # Decode hot-path: write directly into the output buffer to avoid an + # extra allocation + copy. + out_buf = core_attn_out[:num_actual_tokens].unsqueeze(0) + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + out=out_buf, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) + ) + wrote_core_attn_out_non_spec = True + else: + core_attn_out_non_spec, last_recurrent_state = ( + fused_recurrent_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + ) ) - ) else: core_attn_out_non_spec, last_recurrent_state = None, None @@ -825,7 +886,8 @@ def _forward_core( elif spec_sequence_masks is not None: core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + if core_attn_out_non_spec is not None and not wrote_core_attn_out_non_spec: + core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) class Qwen3NextAttention(nn.Module):