diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 45defc6926ba..8459b20fae99 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1407,6 +1407,25 @@ def get_aiter_allreduce_max_size(cls) -> int | None: # https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273 return int(cls._ALL_REDUCE_MAX_SIZE / 2) + @classmethod + @if_aiter_supported + def are_gdn_triton_kernels_available(cls) -> bool: + """Check if AITER Triton kernels for GDN attention are importable. + + These are optional Triton kernels (conv1d fast-path, gated delta net) + used by GatedDeltaNetAttention's decode fast-path. They may be absent + in older aiter builds. + """ + if not cls._AITER_ENABLED: + return False + try: + import aiter.ops.triton.causal_conv1d_update_single_token # noqa: F401 + import aiter.ops.triton.gated_delta_net # noqa: F401 + + return True + except (ImportError, ModuleNotFoundError): + return False + @staticmethod @if_aiter_supported def register_ops_once() -> None: diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 02e48921d419..caf8b0c97654 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -32,6 +32,7 @@ def chunk_gated_delta_rule_fwd( cu_seqlens: torch.Tensor | None = None, chunk_indices: torch.Tensor | None = None, chunk_offsets: torch.Tensor | None = None, + core_attn_out: torch.Tensor | None = None, ): g = chunk_local_cumsum( g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices @@ -77,6 +78,7 @@ def chunk_gated_delta_rule_fwd( scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, + core_attn_out=core_attn_out, ) if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None @@ -102,6 +104,7 @@ def forward( chunk_indices: torch.Tensor | None = None, chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, + core_attn_out: torch.Tensor | None = None, ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) @@ -119,9 +122,15 @@ def forward( cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, + core_attn_out=core_attn_out, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + if core_attn_out is not None: + assert not torch.is_grad_enabled(), ( + "core_attn_out buffer reuse is only supported for inference" + ) + assert q.dtype == o.dtype, "Incompatible dtype for inplace computation" return o.to(q.dtype), final_state @@ -139,6 +148,7 @@ def chunk_gated_delta_rule( chunk_indices: torch.Tensor | None = None, chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, + core_attn_out: torch.Tensor | None = None, ): r""" Args: @@ -230,5 +240,6 @@ def chunk_gated_delta_rule( chunk_indices, chunk_offsets, use_qk_l2norm_in_kernel, + core_attn_out, ) return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index d812ec433720..0c323b8ce215 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -148,6 +148,7 @@ def chunk_fwd_o( cu_seqlens: torch.Tensor | None = None, chunk_indices: torch.Tensor | None = None, chunk_size: int = FLA_CHUNK_SIZE, + core_attn_out: torch.Tensor | None = None, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] @@ -158,7 +159,13 @@ def chunk_fwd_o( if scale is None: scale = k.shape[-1] ** -0.5 - o = torch.empty_like(v) + if core_attn_out is not None: + assert core_attn_out.numel() >= v.numel(), ( + f"core_attn_out too small: {core_attn_out.numel()} < {v.numel()}" + ) + o = core_attn_out[: v.numel()].view(*v.shape) + else: + o = torch.empty_like(v) def grid(meta): return (triton.cdiv(V, meta["BV"]), NT, B * H) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index b4699d4f0060..518e9d4f0cff 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -8,6 +8,7 @@ from transformers.activations import ACT2FN from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import ( VllmConfig, get_current_vllm_config, @@ -64,6 +65,20 @@ ) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata +# Optional ROCm AITER Triton kernels for the GDN decode fast-path. +# Availability is checked centrally via rocm_aiter_ops; the actual function +# references are imported here so that they can be called without per-call +# import overhead. +GDN_AITER_TRITON_AVAILABLE = rocm_aiter_ops.are_gdn_triton_kernels_available() + +if GDN_AITER_TRITON_AVAILABLE: + from aiter.ops.triton.causal_conv1d_update_single_token import ( + fused_reshape_causal_conv1d_update_single_token as gdn_aiter_fused_reshape_causal_conv1d_update_single_token, # noqa: E501 + ) + from aiter.ops.triton.gated_delta_net.fused_rearrange_sigmoid_gdr import ( + fused_rearrange_sigmoid_gated_delta_rule as gdn_aiter_fused_rearrange_sigmoid_gated_delta_rule, # noqa: E501 + ) + logger = init_logger(__name__) @@ -169,8 +184,9 @@ def forward_cuda( chunk_indices: torch.Tensor | None = None, chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = True, + core_attn_out: torch.Tensor | None = None, ): - return fi_chunk_gated_delta_rule( + o, final_state = fi_chunk_gated_delta_rule( q=q, k=k, v=v, @@ -181,6 +197,11 @@ def forward_cuda( cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) + if core_attn_out is not None: + o_flat = o.squeeze(0).reshape(-1) + co_flat = core_attn_out.reshape(-1) + co_flat[: o_flat.numel()].copy_(o_flat) + return o, final_state def forward_native( self, @@ -195,6 +216,7 @@ def forward_native( chunk_indices: torch.Tensor | None = None, chunk_offsets: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = True, + core_attn_out: torch.Tensor | None = None, ): return fla_chunk_gated_delta_rule( q=q, @@ -208,6 +230,7 @@ def forward_native( chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + core_attn_out=core_attn_out, ) @@ -271,7 +294,6 @@ def __init__( else 0 ) self.gqa_interleaved_layout = gqa_interleaved_layout - self._forward_method = self.forward_cuda if current_platform.is_xpu(): self._forward_method = self.forward_xpu elif current_platform.is_cpu(): @@ -281,6 +303,10 @@ def __init__( register_cpu_gdn_attention_ops() self._forward_method = self.forward_cpu + elif current_platform.is_rocm(): + self._forward_method = self.forward_hip + else: + self._forward_method = self.forward_cuda # QKV self.conv_dim = self.key_dim * 2 + self.value_dim @@ -297,6 +323,7 @@ def __init__( # we need to create qkvz_proj adaptively here. # When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5), # in_proj_qkv and in_proj_z are created separately instead. + self.has_lora_projections = not create_in_proj_qkvz if create_in_proj_qkvz: self.in_proj_qkvz = self.create_qkvz_proj( hidden_size=self.hidden_size, @@ -497,24 +524,155 @@ def fix_query_key_value_ordering( return query, key, value, z, b, a - def rearrange_mixed_qkv(self, mixed_qkv): - if mixed_qkv is None: - return None, None, None - query, key, value = torch.split( - mixed_qkv, + @torch.compile(fullgraph=True) + def prepare_gdn_attention_core_inputs( + self, + mixed_qkvz: torch.Tensor, + mixed_ba: torch.Tensor, + num_tokens: int, + ): + """ + Derives mixed_qkv, z, b, a from projected qkvz/ba for the GDN custom op. + + For gqa_interleaved_layout (Qwen3-Next): unpack the interleaved + [ng, (hk + hk + np/ng*hv + np/ng*hv)] layout into contiguous qkv. + For non-interleaved layout (Qwen3.5): simple split along last dim. + """ + if not self.gqa_interleaved_layout: + # Qwen3.5: weights are in [q, k, v, z] order + assert num_tokens == mixed_qkvz.shape[0] + qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size + z_size = self.value_dim // self.tp_size + mixed_qkv, z_flat = mixed_qkvz.split([qkv_size, z_size], dim=-1) + n = mixed_qkvz.shape[0] + z_out = z_flat.reshape(n, -1, self.head_v_dim) + b, a = mixed_ba.chunk(2, dim=-1) + return mixed_qkv, z_out, b, a + + # Qwen3-Next: interleaved GQA layout + base_shape_qkvz = mixed_qkvz.size()[:-1] + base_shape_ba = mixed_ba.size()[:-1] + ng = self.num_k_heads // self.tp_size + + new_tensor_shape_qkvz = base_shape_qkvz + ( + ng, + ( + self.head_k_dim + + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) + * self.num_v_heads + // self.num_k_heads + ), + ) + new_tensor_shape_ba = base_shape_ba + ( + ng, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads, + ] + + (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=-1) + + mixed_qkv_logical = torch.cat( [ - self.key_dim // self.tp_size, - self.key_dim // self.tp_size, - self.value_dim // self.tp_size, + query.reshape(num_tokens, -1), + key.reshape(num_tokens, -1), + value.reshape(num_tokens, -1), ], dim=-1, ) - query, key = map( - lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), - (query, key), + + # The split above produces non-contiguous views into the interleaved + # buffer. Concatenating everything into a single flat tensor forces a + # contiguous copy, then slicing back out gives contiguous q/k/v/z/b/a + # tensors that downstream kernels require. Doing this in one cat+slice + # keeps torch.compile in a single Triton graph instead of emitting + # separate copy kernels per tensor. The original code used + # rearrange(...).contiguous() on each tensor individually. + fused = torch.cat( + [ + mixed_qkv_logical.reshape(-1), + z.reshape(-1), + b.reshape(-1), + a.reshape(-1), + ], + dim=0, + ) + + curr = 0 + qkv_numel = mixed_qkv_logical.numel() + z_numel = z.numel() + b_numel = b.numel() + a_numel = a.numel() + + mixed_qkv_out = fused[curr : curr + qkv_numel].view(num_tokens, -1) + curr += qkv_numel + + z_out = fused[curr : curr + z_numel].view( + num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim + ) + curr += z_numel + + b_out = fused[curr : curr + b_numel].view( + num_tokens, self.num_v_heads // self.tp_size + ) + curr += b_numel + + a_out = fused[curr : curr + a_numel].view( + num_tokens, self.num_v_heads // self.tp_size + ) + + return mixed_qkv_out, z_out, b_out, a_out + + @torch.compile(fullgraph=True) + def rearrange_mixed_qkv(self, mixed_qkv): + """Split packed qkv into contiguous (1, seq, heads, dim) tensors. + + The original code used ``rearrange(x, "l (h d) -> 1 l h d", d=...)`` + followed by ``.contiguous()`` on each tensor. This version flattens + all three splits into a single buffer via ``torch.cat`` so that + torch.compile emits one Triton copy kernel instead of three separate + contiguous() calls. + """ + if mixed_qkv is None: + return None, None, None + + seq_len = mixed_qkv.shape[0] + q_dim = self.key_dim // self.tp_size + k_dim = self.key_dim // self.tp_size + v_dim = self.value_dim // self.tp_size + + query, key, value = torch.split(mixed_qkv, [q_dim, k_dim, v_dim], dim=-1) + + fused = torch.cat( + [query.reshape(-1), key.reshape(-1), value.reshape(-1)], dim=0 ) - value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) - return query.contiguous(), key.contiguous(), value.contiguous() + + q_size = seq_len * q_dim + k_size = seq_len * k_dim + + q_contig = fused[0:q_size] + k_contig = fused[q_size : q_size + k_size] + v_contig = fused[q_size + k_size :] + + query = q_contig.view(1, seq_len, -1, self.head_k_dim) + key = k_contig.view(1, seq_len, -1, self.head_k_dim) + value = v_contig.view(1, seq_len, -1, self.head_v_dim) + + return query, key, value def forward( self, @@ -523,6 +681,63 @@ def forward( ): self._forward_method(hidden_states, output) + def _output_projection( + self, + core_attn_out: torch.Tensor, + z: torch.Tensor, + output: torch.Tensor, + num_tokens: int, + ): + """Part 3: RMSNormGated + output linear projection. + + The RMSNormGated + quant sequence is eligible for fusion + by the compilation pass when fuse_norm_quant is enabled. + """ + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + output[:num_tokens], _ = self.out_proj(core_attn_out) + + def forward_hip( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + ): + """ROCm forward using AITER Triton fused projection+attention when + available, otherwise falling back to the generic CUDA path.""" + if not self.has_lora_projections and GDN_AITER_TRITON_AVAILABLE: + num_tokens = hidden_states.size(0) + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + projected_states_qkvz = projected_states_qkvz.view(num_tokens, -1) + projected_states_ba = projected_states_ba.view(num_tokens, -1) + core_attn_out = torch.empty( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + z = torch.empty( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), + dtype=projected_states_qkvz.dtype, + device=projected_states_qkvz.device, + ) + + torch.ops.vllm.gdn_attention_core( + projected_states_qkvz, + projected_states_ba, + z, + core_attn_out, + fast_kernel=True, + layer_name=_encode_layer_name(self.prefix), + ) + + self._output_projection(core_attn_out, z, output, num_tokens) + else: + self.forward_cuda(hidden_states, output) + def forward_cuda( self, hidden_states: torch.Tensor, @@ -538,7 +753,7 @@ def forward_cuda( # ============================================================ # Part 1: Input Projection # ============================================================ - if hasattr(self, "in_proj_qkv"): + if self.has_lora_projections: # LoRA path (Qwen3.5 only): separate in_proj_qkv and in_proj_z mixed_qkv, _ = self.in_proj_qkv(hidden_states) ba, _ = self.in_proj_ba(hidden_states) @@ -586,20 +801,14 @@ def forward_cuda( b, a, core_attn_out, - _encode_layer_name(self.prefix), + fast_kernel=False, + layer_name=_encode_layer_name(self.prefix), ) # ============================================================ # Part 3: Output Projection # ============================================================ - z_shape_og = z.shape - # Reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) + self._output_projection(core_attn_out, z, output, num_tokens) def forward_xpu( self, @@ -614,7 +823,7 @@ def forward_xpu( """ num_tokens = hidden_states.size(0) - assert not hasattr(self, "in_proj_qkv"), "lora isn't supported on XPU." + assert not self.has_lora_projections, "lora isn't supported on XPU." # ============================================================ # Part 1: Input Projection @@ -702,7 +911,7 @@ def forward_cpu( core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) - def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: + def _warmup_prefill_kernels(self, qkv_or_qkvz: torch.Tensor, v_dim: int) -> None: """Warm up GDN prefill kernels during V1 profiling. During V1 profile runs, ``_forward_core`` returns early because @@ -723,7 +932,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: ``BT = chunk_size`` (64). A single warmup pass with T = 64 is sufficient to populate the autotuner cache. - The decode path uses ``fused_sigmoid_gating_delta_rule_update`` + The decode path uses ``gdn_aiter_fused_rearrange_sigmoid_gated_delta_rule`` which has fixed kernel parameters (no autotuning), so only the prefill (chunked) path needs warming up. """ @@ -731,8 +940,8 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: return self._prefill_kernels_warmed_up = True - device = mixed_qkv.device - dtype = mixed_qkv.dtype + device = qkv_or_qkvz.device + dtype = qkv_or_qkvz.dtype num_k_heads = self.num_k_heads // self.tp_size num_v_heads = self.num_v_heads // self.tp_size _, state_dtype = self.get_state_dtype() @@ -743,7 +952,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: # then run chunk_gated_delta_rule with in-kernel L2 norm disabled. T = FLA_CHUNK_SIZE dummy_mixed_qkv = torch.randn( - T, mixed_qkv.shape[-1], device=device, dtype=dtype + T, qkv_or_qkvz.shape[-1] - v_dim, device=device, dtype=dtype ) dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype) @@ -806,6 +1015,66 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: torch.accelerator.empty_cache() + def _forward_core_rocm( + self, + qkvz: torch.Tensor, + ba: torch.Tensor, + z_out: torch.Tensor, + core_attn_out: torch.Tensor, + ): + """ROCm AITER fast path: conv1d + recurrent attention from packed + qkvz/ba layout. + + For decode-only (no spec, no prefill), dispatches directly to + ``_forward_core_decode_fast``. Otherwise unpacks the packed + layout and falls through to ``_forward_core``. + + Args: + qkvz: packed [q, k, v, z] projection (num_tokens, qkvz_dim) + ba: packed [b, a] gating vectors (num_tokens, 2*num_heads) + z_out: **output** buffer for z (num_tokens, num_heads, + head_dim); mutated in-place. + core_attn_out: Pre-allocated output buffer for attention results. + """ + forward_context = get_forward_context() + attn_metadata_raw = forward_context.attn_metadata + + if attn_metadata_raw is None: + v_dim = core_attn_out.shape[-1] * core_attn_out.shape[-2] + self._warmup_prefill_kernels(qkvz, v_dim) + return + + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] # type: ignore[index] + assert isinstance(attn_metadata, GDNAttentionMetadata) + + if ( + attn_metadata.spec_sequence_masks is None + and attn_metadata.num_prefills == 0 + and attn_metadata.num_decodes > 0 + ): + return self._forward_core_decode_fast( + qkvz=qkvz, + ba=ba, + z_out=z_out, + core_attn_out=core_attn_out, + attn_metadata=attn_metadata, + ) + + core_attn_out.zero_() + z_out.zero_() + num_tokens_all = qkvz.shape[0] + mixed_qkv, z, b, a = self.prepare_gdn_attention_core_inputs( + qkvz, ba, num_tokens_all + ) + z_out[:] = z + self._forward_core( + mixed_qkv=mixed_qkv, + b=b, + a=a, + core_attn_out=core_attn_out, + ) + def _forward_core( self, mixed_qkv: torch.Tensor, @@ -813,13 +1082,19 @@ def _forward_core( a: torch.Tensor, core_attn_out: torch.Tensor, ): + """Core conv1d + recurrent attention (standard path). + + Args: + mixed_qkv: packed [q, k, v] projection (num_tokens, qkv_dim) + b: beta gating vector (num_tokens, num_heads) + a: alpha gating vector (num_tokens, num_heads) + core_attn_out: Pre-allocated output buffer for attention results. + """ forward_context = get_forward_context() attn_metadata_raw = forward_context.attn_metadata if attn_metadata_raw is None: - # V1 profile run — warm up prefill kernels so that - # autotuning completes before KV cache allocation. - self._warmup_prefill_kernels(mixed_qkv) + self._warmup_prefill_kernels(mixed_qkv, 0) return assert isinstance(attn_metadata_raw, dict) @@ -1065,6 +1340,72 @@ def _forward_core( else: core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) + def _forward_core_decode_fast( + self, + qkvz: torch.Tensor, + ba: torch.Tensor, + z_out: torch.Tensor, + core_attn_out: torch.Tensor, + attn_metadata: GDNAttentionMetadata, + ): + non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc + non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 + self_kv_cache = self.kv_cache + # conv_state must be (..., dim, width-1) for the conv kernels. + # DS layout stores it that way directly; SD layout needs a transpose. + conv_state = ( + self_kv_cache[0] + if is_conv_state_dim_first() + else self_kv_cache[0].transpose(-1, -2) + ) + ssm_state = self_kv_cache[1] + + # 1. Convolution sequence transformation + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + mixed_qkv_non_spec, b, a = ( + gdn_aiter_fused_reshape_causal_conv1d_update_single_token( + qkvz, + attn_metadata.num_actual_tokens, + self.num_k_heads // self.tp_size, + self.num_v_heads // self.tp_size, + self.head_k_dim, + self.head_v_dim, + ba, + z_out, + core_attn_out, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=non_spec_state_indices_tensor[ # type: ignore[index] + : attn_metadata.num_actual_tokens + ], + validate_data=True, + ) + ) + + # 2. Recurrent attention + gdn_aiter_fused_rearrange_sigmoid_gated_delta_rule( + A_log=self.A_log, + a=a, + b=b, + dt_bias=self.dt_bias, + qkv=mixed_qkv_non_spec, + key_dim=self.key_dim // self.tp_size, + value_dim=self.value_dim // self.tp_size, + head_k_dim=self.head_k_dim, + head_v_dim=self.head_v_dim, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], # type: ignore[index] + ssm_state_indices=non_spec_state_indices_tensor, + use_qk_l2norm_in_kernel=True, + core_attn_out=core_attn_out.reshape(-1), + ) + def _forward_core_decode_non_spec( self, mixed_qkv: torch.Tensor, @@ -1121,33 +1462,51 @@ def _forward_core_decode_non_spec( def gdn_attention_core( - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, + qkv_or_qkvz: torch.Tensor, + b_or_ba: torch.Tensor, + a_or_z_out: torch.Tensor, core_attn_out: torch.Tensor, + fast_kernel: bool, layer_name: LayerNameType, ) -> None: - """ - Custom op for the core attention computation. - Only handles the convolution + recurrent attention part. - Input/output projections are handled outside this op. + """Custom op dispatching to _forward_core or _forward_core_rocm. + + Handles conv1d + recurrent attention only; input/output projections + are performed by the caller. + + When ``fast_kernel=False`` (standard path): + qkv_or_qkvz is [q, k, v], b_or_ba is b, a_or_z_out is a (read-only). + When ``fast_kernel=True`` (AITER Triton fast path, ROCm only): + qkv_or_qkvz is [q, k, v, z], b_or_ba is [b, a], a_or_z_out is the + z output buffer (mutated in-place). + + ``core_attn_out`` is always mutated in-place. """ layer_name = _resolve_layer_name(layer_name) forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self._forward_core( - mixed_qkv=mixed_qkv, - b=b, - a=a, - core_attn_out=core_attn_out, - ) + if fast_kernel: + self._forward_core_rocm( + qkvz=qkv_or_qkvz, + ba=b_or_ba, + z_out=a_or_z_out, + core_attn_out=core_attn_out, + ) + else: + self._forward_core( + mixed_qkv=qkv_or_qkvz, + b=b_or_ba, + a=a_or_z_out, + core_attn_out=core_attn_out, + ) def gdn_attention_core_fake( - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, + qkv_or_qkvz: torch.Tensor, + b_or_ba: torch.Tensor, + a_or_z_out: torch.Tensor, core_attn_out: torch.Tensor, + fast_kernel: bool, layer_name: LayerNameType, ) -> None: """Fake implementation for torch.compile.""" @@ -1157,7 +1516,7 @@ def gdn_attention_core_fake( direct_register_custom_op( op_name="gdn_attention_core", op_func=gdn_attention_core, - mutates_args=["core_attn_out"], + mutates_args=["a_or_z_out", "core_attn_out"], fake_impl=gdn_attention_core_fake, )