diff --git a/aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py b/aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py new file mode 100644 index 0000000000..4f2f0356d9 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py @@ -0,0 +1,507 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# Kernels for causal_conv1d **update** single-token paths: ``conv_state`` is updated in place. + +import triton +import triton.language as tl + + +@triton.jit() +def _causal_conv1d_update_single_token_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + conv_state_indices_ptr, + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_APC_ENABLED: + # Get the state from the initial_state_idx + conv_state_init = tl.load(initial_state_idx + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) + else: + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init + ).to(tl.int64) + + if USE_PAD_SLOT: # noqa + if conv_states_input_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + # IS_VARLEN is False + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + # IS_SPEC_DECODING is False + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + # note: NP2_STATELEN = triton.next_power_of_2(KERNEL_WIDTH - 1) + idx_cols = tl.arange(0, NP2_STATELEN) + conv_state_ptrs_cols = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[:, None] + + (idx_cols * stride_conv_state_tok)[None, :] + ) # [BLOCK_N, NP2_STATELEN] + mask_cols = ( + (conv_states_input_coord < num_cache_lines) + & (idx_feats < dim)[:, None] + & (idx_cols < KERNEL_WIDTH - 1)[None, :] + ) + cols = tl.load(conv_state_ptrs_cols, mask_cols, other=0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + # Get the state from the initial_state_idx + # cache_idx + conv_states_offset = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index + ).to(tl.int64) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[ + None, : + ] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[ + :, None + ] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator, not necessary + # if HAS_BIAS: + # bias = bias_ptr + idx_feats + # mask_bias = idx_feats < dim + # acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + # tl.float32 + # ) # [BLOCK_N] + # else: + # acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # LOAD WEIGHTS and compute + w_cols_ptrs = ( + w_ptr + + (idx_feats * stride_w_dim)[:, None] + + (idx_cols * stride_w_width)[None, :] + ) + mask_w_cols = (idx_feats < dim)[:, None] & (idx_cols < KERNEL_WIDTH - 1)[None, :] + w_cols = tl.load(w_cols_ptrs, mask_w_cols, other=0.0) # [BLOCK_N, NP2_STATELEN] + + w_last_ptrs = ( + w_ptr + (idx_feats * stride_w_dim) + (KERNEL_WIDTH - 1) * stride_w_width + ) + w_last = tl.load(w_last_ptrs, idx_feats < dim, other=0.0) # [BLOCK_N] + + # For the convolution output: dot(weights, [state_cols | x]) + # cols is [BLOCK_N, NP2_STATELEN] = conv_state history + # We need x as 1D [BLOCK_N] for the last weight column + x_1d = tl.load(x_base, mask=(idx_feats < dim), other=0.0) # [BLOCK_N], reload as 1D + acc = tl.sum((w_cols * cols).to(tl.float32), axis=1) + (w_last * x_1d).to( + tl.float32 + ) + + if HAS_BIAS: + bias = bias_ptr + idx_feats + acc += tl.load(bias, idx_feats < dim, other=0.0).to(tl.float32) # [BLOCK_N] + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = idx_feats < dim + o_ptrs = o_ptr + o_offset + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +@triton.jit() +def _reshape_causal_conv1d_update_single_token_kernel( + # Pointers to matrices + x_ptr, # (num_tokens, dim+z_dim, seqlen) where seqlen=1 + ba_ptr, + z_ptr, # (num_tokens, num_v_heads, head_v_dim) + core_attn_out_ptr, # (num_tokens, num_v_heads, head_v_dim) + b_ptr, # (num_accepted_tokens, num_v_heads) + a_ptr, # (num_accepted_tokens, num_v_heads) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + conv_state_indices_ptr, + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + o_ptr, # (num_accepted_tokens, dim, seqlen) + # Matrix dimensions + batch: int, + num_tokens: int, + num_k_heads: tl.constexpr, + num_v_heads: tl.constexpr, + head_k_dim: tl.constexpr, + head_v_dim: tl.constexpr, + dim: tl.constexpr, + head_qkvz_dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + stride_z_seq: tl.constexpr, + stride_ba_seq: tl.constexpr, + stride_ba_token: tl.constexpr, + stride_b_seq: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + num_program_write_z: tl.constexpr, + BLOCK_Z: tl.constexpr, + HV: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + ## write b, a + if tl.program_id(1) == 0: + ## HV = triton.next_power_of_2(num_v_heads) + idx_hv = tl.arange(0, HV) + ## map idx_hv to source idx + idx_h = idx_hv // (num_v_heads // num_k_heads) + idx_v = idx_hv % (num_v_heads // num_k_heads) + b_source_offset = idx_h * (2 * num_v_heads // num_k_heads) + idx_v + a_source_offset = ( + idx_h * (2 * num_v_heads // num_k_heads) + + num_v_heads // num_k_heads + + idx_v + ) + + b_source_ptrs = ( + ba_ptr + idx_seq * stride_ba_seq + b_source_offset * stride_ba_token + ) + a_source_ptrs = ( + ba_ptr + idx_seq * stride_ba_seq + a_source_offset * stride_ba_token + ) + mask_ba = idx_hv < num_v_heads + b = tl.load(b_source_ptrs, mask=mask_ba, other=0.0) + a = tl.load(a_source_ptrs, mask=mask_ba, other=0.0) + ## b, a should be contiguous so the last stride is 1 + b_ptrs = b_ptr + idx_seq * stride_b_seq + idx_hv + a_ptrs = a_ptr + idx_seq * stride_b_seq + idx_hv + tl.store(b_ptrs, b, mask_ba) + tl.store(a_ptrs, a, mask_ba) + ## write z + elif tl.program_id(1) < 1 + num_program_write_z: + idx_z = (tl.program_id(1) - 1) * BLOCK_Z + tl.arange(0, BLOCK_Z) + ## map idx_z to source idx + idx_z_x = ( + idx_z // (num_v_heads // num_k_heads * head_v_dim) * head_qkvz_dim + + 2 * head_k_dim + + num_v_heads // num_k_heads * head_v_dim + + idx_z % (num_v_heads // num_k_heads * head_v_dim) + ) + z_source_ptrs = x_ptr + idx_seq * stride_x_seq + idx_z_x * stride_x_dim + mask_z = idx_z < num_v_heads * head_v_dim + z = tl.load(z_source_ptrs, mask=mask_z, other=0.0) + z_ptrs = z_ptr + idx_seq * stride_z_seq + idx_z + tl.store(z_ptrs, z, mask=mask_z) + + ## zero-fill core_attn_out + # first, zero_fill [0, batch) for core_attn_out + core_attn_out_ptrs = core_attn_out_ptr + idx_seq * stride_z_seq + idx_z + tl.store(core_attn_out_ptrs, 0.0, mask=mask_z) + # second, zero_fill [batch, num_tokens) for both z and core_attn_out + n_repeat = (num_tokens - 1) // batch + for idx_repeat in tl.range(n_repeat): + idx_seq_remain = batch * (1 + idx_repeat) + idx_seq + z_ptrs = z_ptr + idx_seq_remain * stride_z_seq + idx_z + core_attn_out_ptrs = ( + core_attn_out_ptr + idx_seq_remain * stride_z_seq + idx_z + ) + mask_remain = (idx_seq_remain < num_tokens) & mask_z + tl.store(z_ptrs, 0.0, mask=mask_remain) + tl.store(core_attn_out_ptrs, 0.0, mask=mask_remain) + ## do regular causal conv1d udpate + else: + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = (tl.program_id(1) - 1 - num_program_write_z) * BLOCK_N + tl.arange( + 0, BLOCK_N + ) + ## map idx_feats to idx_feats_x + idx_feats_x = ( + (idx_feats < num_k_heads * head_k_dim).to(tl.int64) + * (idx_feats // head_k_dim * head_qkvz_dim + idx_feats % head_k_dim) + + ( + (idx_feats >= num_k_heads * head_k_dim) + & (idx_feats < num_k_heads * head_k_dim * 2) + ).to(tl.int64) + * ( + (idx_feats - num_k_heads * head_k_dim) // head_k_dim * head_qkvz_dim + + head_k_dim + + (idx_feats - num_k_heads * head_k_dim) % head_k_dim + ) + + (idx_feats >= num_k_heads * head_k_dim * 2).to(tl.int64) + * ( + (idx_feats - num_k_heads * head_k_dim * 2) + // (num_v_heads // num_k_heads * head_v_dim) + * head_qkvz_dim + + 2 * head_k_dim + + (idx_feats - num_k_heads * head_k_dim * 2) + % (num_v_heads // num_k_heads * head_v_dim) + ) + ) + + if IS_APC_ENABLED: + # Get the state from the initial_state_idx + conv_state_init = tl.load(initial_state_idx + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) + else: + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init + ).to(tl.int64) + + if USE_PAD_SLOT: # noqa + if conv_states_input_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + # IS_VARLEN is False + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + # STEP 1: READ init_state data + # note: NP2_STATELEN = triton.next_power_of_2(KERNEL_WIDTH - 1) + idx_cols = tl.arange(0, NP2_STATELEN) + conv_state_ptrs_cols = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[:, None] + + (idx_cols * stride_conv_state_tok)[None, :] + ) # [BLOCK_N, NP2_STATELEN] + mask_cols = ( + (conv_states_input_coord < num_cache_lines) + & (idx_feats < dim)[:, None] + & (idx_cols < KERNEL_WIDTH - 1)[None, :] + ) + cols = tl.load(conv_state_ptrs_cols, mask_cols, other=0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats_x * stride_x_dim) # [BLOCK_N] + + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + # Get the state from the initial_state_idx + # cache_idx + conv_states_offset = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index + ).to(tl.int64) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[ + None, : + ] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[ + :, None + ] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator, not necessary + # if HAS_BIAS: + # bias = bias_ptr + idx_feats + # mask_bias = idx_feats < dim + # acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + # tl.float32 + # ) # [BLOCK_N] + # else: + # acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # LOAD WEIGHTS and compute + w_cols_ptrs = ( + w_ptr + + (idx_feats * stride_w_dim)[:, None] + + (idx_cols * stride_w_width)[None, :] + ) + mask_w_cols = (idx_feats < dim)[:, None] & (idx_cols < KERNEL_WIDTH - 1)[ + None, : + ] + w_cols = tl.load(w_cols_ptrs, mask_w_cols, other=0.0) # [BLOCK_N, NP2_STATELEN] + + w_last_ptrs = ( + w_ptr + (idx_feats * stride_w_dim) + (KERNEL_WIDTH - 1) * stride_w_width + ) + w_last = tl.load(w_last_ptrs, idx_feats < dim, other=0.0) # [BLOCK_N] + + # For the convolution output: dot(weights, [state_cols | x]) + # cols is [BLOCK_N, NP2_STATELEN] = conv_state history + # We need x as 1D [BLOCK_N] for the last weight column + x_1d = tl.load( + x_base, mask=(idx_feats < dim), other=0.0 + ) # [BLOCK_N], reload as 1D + acc = tl.sum((w_cols * cols).to(tl.float32), axis=1) + (w_last * x_1d).to( + tl.float32 + ) + + if HAS_BIAS: + bias = bias_ptr + idx_feats + acc += tl.load(bias, idx_feats < dim, other=0.0).to(tl.float32) # [BLOCK_N] + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = idx_feats < dim + o_ptrs = o_ptr + o_offset + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) diff --git a/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/__init__.py b/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/__init__.py index 89cf712988..1882ff663e 100644 --- a/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/__init__.py +++ b/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/__init__.py @@ -8,10 +8,14 @@ This module provides optimized Triton kernels for decode/inference operations. """ +from .fused_rearrange_sigmoid_gdr import ( + fused_rearrange_sigmoid_gated_delta_rule_update_kernel, +) from .fused_recurrent import _fused_recurrent_gated_delta_rule_fwd_kernel from .fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_update __all__ = [ "_fused_recurrent_gated_delta_rule_fwd_kernel", + "fused_rearrange_sigmoid_gated_delta_rule_update_kernel", "fused_sigmoid_gating_delta_rule_update", ] diff --git a/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py b/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py new file mode 100644 index 0000000000..495fb6ad00 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import triton +import triton.language as tl + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_rearrange_sigmoid_gated_delta_rule_update_kernel( + A_log, + a, + b, + dt_bias, + beta, + threshold, + qkv, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_qkv_l: tl.constexpr, + stride_qkv_hd: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = qkv + bos * stride_qkv_l + ((i_h * K) + o_k) * stride_qkv_hd + p_k = qkv + bos * stride_qkv_l + (H * K + (i_h * K) + o_k) * stride_qkv_hd + p_v = qkv + bos * stride_qkv_l + (2 * H * K + (i_hv * V) + o_v) * stride_qkv_hd + + p_A_log = A_log + i_hv + if not IS_KDA: + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + else: + p_a = a + (bos * HV + i_hv) * K + o_k + p_dt_bias = dt_bias + i_hv * K + o_k + + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok + ).to(tl.int64) + if state_idx < 0: + return + p_h0 = h0 + state_idx * stride_init_state_token + else: + p_h0 = h0 + bos * HV * V * K + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b).to(tl.float32) + + x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x + + b_beta = tl.sigmoid(b_b.to(tl.float32)) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q * tl.rsqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k * tl.rsqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + if not IS_KDA: + b_h *= tl.exp(b_g) + else: + b_h *= tl.exp(b_g[None, :]) + b_v -= tl.sum(b_h * b_k[None, :], 1) + b_v *= b_beta + b_h += b_v[:, None] * b_k[None, :] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + if INPLACE_FINAL_STATE: + final_state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok + ).to(tl.int64) + if final_state_idx >= 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += stride_qkv_l + p_k += stride_qkv_l + p_v += stride_qkv_l + p_o += HV * V + p_b += HV + p_a += HV diff --git a/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py index 6f4cb40294..5acc36b81f 100644 --- a/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py +++ b/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py @@ -766,3 +766,127 @@ def _fused_silu_mul_fp8_per_tensor_static_quant_kernel( quant_fp8_out.to(out_fp8_ptr.dtype.element_ty), mask=mask, ) + + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + } +) +@triton.jit +def _fused_rms_gated_fp8_group_quant_kernel( + X, + W, + B, + Z, + Y_quant, + Scales, + stride_x_row, + stride_z_row, + stride_y_row, + stride_s_row, + stride_s_g, + M, + N: tl.constexpr, + eps, + RMS_TILE: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, + GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, + BLOCK_G: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + USE_UE8M0: tl.constexpr, + FP8_MIN_SCALING_FACTOR: tl.constexpr, + ACTIVATION: tl.constexpr, +): + row_start = tl.program_id(0) * ROWS_PER_BLOCK + rows = row_start + tl.arange(0, ROWS_PER_BLOCK) + row_mask_1d = rows < M + + # --- Full-row RMS: accumulate sum of squares in float32 --- + sumsq = tl.zeros([ROWS_PER_BLOCK], dtype=tl.float32) + off = 0 + while off < N: + cols = tl.arange(0, RMS_TILE) + off + col_mask = cols < N + mask = row_mask_1d[:, None] & col_mask[None, :] + row_offsets = rows[:, None] * stride_x_row + col_offsets = cols[None, :] + X_base = X + row_offsets + col_offsets + x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) + if ACTIVATION == "swish" or ACTIVATION == "silu": + x *= z * tl.sigmoid(z) + elif ACTIVATION == "sigmoid": + x *= tl.sigmoid(z) + xbar = tl.where(mask, x, 0.0) + sumsq += tl.sum(xbar * xbar, axis=1) + off += RMS_TILE + + var = sumsq / N + rstd = tl.rsqrt(var + eps) + + # --- Per-group: normalize (when NORM_BEFORE_GATE), linear, optional gate, FP8 --- + for g in range(NUM_GROUPS): + col0 = g * GROUP_SIZE + cols = tl.arange(0, BLOCK_G) + col0 + col_mask = cols < N + mask = row_mask_1d[:, None] & col_mask[None, :] + row_offsets = rows[:, None] * stride_x_row + col_offsets = cols[None, :] + X_base = X + row_offsets + col_offsets + x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32) + + if HAS_Z and not NORM_BEFORE_GATE: + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) + if ACTIVATION == "swish" or ACTIVATION == "silu": + x *= z * tl.sigmoid(z) + elif ACTIVATION == "sigmoid": + x *= tl.sigmoid(z) + + x_hat = x * rstd[:, None] + + w_mask = cols < N + w = tl.load(W + cols, mask=w_mask, other=0.0).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=w_mask, other=0.0).to(tl.float32) + y = x_hat * w[None, :] + b[None, :] + else: + y = x_hat * w[None, :] + + if HAS_Z and NORM_BEFORE_GATE: + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) + if ACTIVATION == "swish" or ACTIVATION == "silu": + y *= z * tl.sigmoid(z) + elif ACTIVATION == "sigmoid": + y *= tl.sigmoid(z) + + abs_y = tl.where(mask, tl.abs(y), 0.0) + absmax = tl.max(abs_y, axis=1) + scales_raw = absmax / FP8_MAX + if USE_UE8M0: + scales_raw = tl.exp2(tl.ceil(tl.log2(scales_raw))) + scales = tl.maximum(scales_raw, FP8_MIN_SCALING_FACTOR) + + y_scaled = y / scales[:, None] + y_quant = tl.maximum(tl.minimum(y_scaled, FP8_MAX), FP8_MIN) + + Y_base = Y_quant + rows[:, None] * stride_y_row + col_offsets + tl.store(Y_base, y_quant.to(Y_quant.dtype.element_ty), mask=mask) + + S_ptr = Scales + rows * stride_s_row + g * stride_s_g + tl.store(S_ptr, scales, mask=row_mask_1d) diff --git a/aiter/ops/triton/causal_conv1d_update_single_token.py b/aiter/ops/triton/causal_conv1d_update_single_token.py new file mode 100644 index 0000000000..07ac393568 --- /dev/null +++ b/aiter/ops/triton/causal_conv1d_update_single_token.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Launchers for **causal conv1d update** single-token paths. + +``causal_conv1d_update_single_token`` updates ``conv_state`` **in place** inside the Triton kernel +(the "update" in the name), then writes the convolution output into ``x``/``out`` as in vLLM. +""" + +from __future__ import annotations + +import torch +import triton + +from aiter.ops.triton._triton_kernels.causal_conv1d import PAD_SLOT_ID +from aiter.ops.triton._triton_kernels.causal_conv1d_update_single_token import ( + _causal_conv1d_update_single_token_kernel, + _reshape_causal_conv1d_update_single_token_kernel, +) + + +def _default_conv_state_indices(batch: int, device: torch.device) -> torch.Tensor: + return torch.arange(batch, device=device, dtype=torch.int32) + + +def causal_conv1d_update_single_token( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + activation: bool | str | None = None, + conv_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + query_start_loc: torch.Tensor | None = None, + max_query_len: int = -1, + pad_slot_id: int = PAD_SLOT_ID, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, + validate_data: bool = False, +) -> torch.Tensor: + assert ( + num_accepted_tokens is None + ), f"num_accepted_tokens must be None, got {num_accepted_tokens}" + assert ( + query_start_loc is None + ), f"query_start_loc must be None, got {query_start_loc}" + if validate_data: + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + seqlen = max_query_len + assert ( + seqlen == 1 + ), f"the single_token version only support seqlen to be 1, got {seqlen}" + _, width = weight.shape + num_cache_lines, _, state_len = conv_state.size() + + if conv_state_indices is None: + conv_state_indices = _default_conv_state_indices(batch, x.device) + + if validate_data: + assert dim == weight.size(0) + assert conv_state.stride(-2) == 1, ( + f"ERROR: expect contiguous along feat-dim of conv_state " + f"(currently stride={conv_state.stride()})" + ) + assert state_len >= width - 1 + assert dim == conv_state.size(1) + assert (batch,) == conv_state_indices.shape + assert num_cache_lines >= batch + assert weight.stride(1) == 1 + + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return (batch, triton.cdiv(dim, META["BLOCK_N"])) + + _causal_conv1d_update_single_token_kernel[grid]( + x, + weight, + bias, + conv_state, + conv_state_indices, + block_idx_last_scheduled_token, + initial_state_idx, + out, + batch, + dim, + seqlen, + state_len, + num_cache_lines, + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + pad_slot_id, + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) + + +def fused_reshape_causal_conv1d_update_single_token( + x: torch.Tensor, + num_actual_tokens: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + ba: torch.Tensor, + z_out: torch.Tensor, + core_attn_out: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + activation: bool | str | None = None, + conv_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + query_start_loc: torch.Tensor | None = None, + max_query_len: int = -1, + pad_slot_id: int = PAD_SLOT_ID, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, + validate_data: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert ( + num_accepted_tokens is None + ), f"num_accepted_tokens must be None, got {num_accepted_tokens}" + assert ( + query_start_loc is None + ), f"query_start_loc must be None, got {query_start_loc}" + assert z_out.is_contiguous(), "z_out should be contiguous" + assert core_attn_out.is_contiguous(), "core_attn_out should be contiguous" + x = x.view(x.shape[0], -1) + ba = ba.view(ba.shape[0], -1) + assert z_out.size() == core_attn_out.size() + original_z_shape = z_out.shape + num_tokens = z_out.shape[0] + z_out = z_out.view(original_z_shape[0], -1) + core_attn_out = core_attn_out.view(original_z_shape[0], -1) + if validate_data: + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + _, qkvz_dim, seqlen = x.shape + assert ( + seqlen == 1 + ), f"the single_token version only support seqlen to be 1, got {seqlen}" + batch = num_actual_tokens + _, width = weight.shape + head_dim = head_k_dim + head_k_dim + head_v_dim * num_v_heads // num_k_heads + head_qkvz_dim = head_dim + head_v_dim * num_v_heads // num_k_heads + dim = num_k_heads * head_dim + expected_qkvz_dim = num_k_heads * head_qkvz_dim + assert ( + qkvz_dim == expected_qkvz_dim + ), f"ERROR: expect qkvz_dim to be {expected_qkvz_dim}, got {qkvz_dim}" + num_cache_lines, _, state_len = conv_state.size() + + if conv_state_indices is None: + conv_state_indices = _default_conv_state_indices(batch, x.device) + + if validate_data: + assert dim == weight.size(0) + assert conv_state.stride(-2) == 1, ( + f"ERROR: expect contiguous along feat-dim of conv_state " + f"(currently stride={conv_state.stride()})" + ) + assert state_len >= width - 1 + assert dim == conv_state.size(1) + assert (batch,) == conv_state_indices.shape + assert num_cache_lines >= batch + assert weight.stride(1) == 1 + + out = torch.empty((num_actual_tokens, dim, seqlen), dtype=x.dtype, device=x.device) + b_out = torch.empty( + (num_actual_tokens, num_v_heads), dtype=ba.dtype, device=ba.device + ) + a_out = torch.empty( + (num_actual_tokens, num_v_heads), dtype=ba.dtype, device=ba.device + ) + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_z_seq = z_out.stride(0) + stride_ba_seq, stride_ba_token = ba.stride() + stride_b_seq = b_out.stride(0) + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + HV = triton.next_power_of_2(num_v_heads) + BLOCK_Z = 512 + num_program_write_z = triton.cdiv(num_v_heads * head_v_dim, BLOCK_Z) + + def grid(META): + return ( + batch, + 1 + num_program_write_z + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _reshape_causal_conv1d_update_single_token_kernel[grid]( + x, + ba, + z_out, + core_attn_out, + b_out, + a_out, + weight, + bias, + conv_state, + conv_state_indices, + block_idx_last_scheduled_token, + initial_state_idx, + out, + batch, + num_tokens, + num_k_heads, + num_v_heads, + head_k_dim, + head_v_dim, + dim, + head_qkvz_dim, + seqlen, + state_len, + num_cache_lines, + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + stride_z_seq, + stride_ba_seq, + stride_ba_token, + stride_b_seq, + pad_slot_id, + num_program_write_z, + BLOCK_Z, + HV=HV, + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + z_out = z_out.view(original_z_shape) + core_attn_out = core_attn_out.view(original_z_shape) + return out.to(original_x_dtype), b_out, a_out diff --git a/aiter/ops/triton/gated_delta_net/__init__.py b/aiter/ops/triton/gated_delta_net/__init__.py index 066c0d063e..bd201eab45 100644 --- a/aiter/ops/triton/gated_delta_net/__init__.py +++ b/aiter/ops/triton/gated_delta_net/__init__.py @@ -8,6 +8,7 @@ This module provides high-level Triton implementations for gated delta rule. """ +from .fused_rearrange_sigmoid_gdr import fused_rearrange_sigmoid_gated_delta_rule from .gated_delta_rule import ( chunk_gated_delta_rule, chunk_gated_delta_rule_opt, @@ -16,8 +17,9 @@ ) __all__ = [ - "fused_recurrent_gated_delta_rule", "chunk_gated_delta_rule", + "fused_rearrange_sigmoid_gated_delta_rule", + "fused_recurrent_gated_delta_rule", "chunk_gated_delta_rule_opt", "chunk_gated_delta_rule_opt_vk", ] diff --git a/aiter/ops/triton/gated_delta_net/fused_rearrange_sigmoid_gdr.py b/aiter/ops/triton/gated_delta_net/fused_rearrange_sigmoid_gdr.py new file mode 100644 index 0000000000..8dc22a6669 --- /dev/null +++ b/aiter/ops/triton/gated_delta_net/fused_rearrange_sigmoid_gdr.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# +# Adapted from flash-linear-attention / vLLM (see _triton_kernels copy). + +from __future__ import annotations + +import torch +import triton + +from aiter.ops.triton._triton_kernels.gated_delta_rule.decode.fused_rearrange_sigmoid_gdr import ( + fused_rearrange_sigmoid_gated_delta_rule_update_kernel, +) + + +def fused_rearrange_sigmoid_gated_delta_rule( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + qkv: torch.Tensor, + key_dim: int, + value_dim: int, + head_k_dim: int, + head_v_dim: int, + beta: float = 1.0, + threshold: float = 20.0, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + is_kda: bool = False, + core_attn_out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused Triton sigmoid-gated delta rule over packed QKV (decode-oriented). + """ + expected_shape = (qkv.shape[0], key_dim * 2 + value_dim) + assert ( + qkv.shape == expected_shape + ), f"expect qkv to be in shape {expected_shape}, got {qkv.shape}" + if scale is None: + scale = head_k_dim**-0.5 + else: + assert scale > 0, "scale must be positive" + + B = 1 + T = qkv.shape[0] + H = key_dim // head_k_dim + HV = value_dim // head_v_dim + K = head_k_dim + V = head_v_dim + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 4 + + if inplace_final_state and ssm_state_indices is None: + raise ValueError( + "ssm_state_indices is required when inplace_final_state=True " + "(kernel indexes final state slots per token)." + ) + + o = ( + core_attn_out[: NK * B * T * HV * V].view(NK, B, T, HV, V) + if core_attn_out is not None + else qkv.new_empty(NK, B, T, HV, V) + ) + if inplace_final_state: + if initial_state is None: + raise ValueError("initial_state is required when inplace_final_state=True") + final_state = initial_state + else: + st_dtype = initial_state.dtype if initial_state is not None else qkv.dtype + final_state = qkv.new_empty(T, HV, V, K, dtype=st_dtype) + + stride_init_state_token = ( + int(initial_state.stride(0)) if initial_state is not None else 0 + ) + stride_final_state_token = int(final_state.stride(0)) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + stride_qkv_l, stride_qkv_hd = qkv.stride() + + grid = (NK, NV, N * HV) + fused_rearrange_sigmoid_gated_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a.contiguous(), + b=b.contiguous(), + dt_bias=dt_bias, + beta=beta, + threshold=threshold, + qkv=qkv, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_qkv_l=stride_qkv_l, + stride_qkv_hd=stride_qkv_hd, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + INPLACE_FINAL_STATE=inplace_final_state, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_KDA=is_kda, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state diff --git a/aiter/ops/triton/quant/__init__.py b/aiter/ops/triton/quant/__init__.py index 144a0247dd..5a1eec8cac 100644 --- a/aiter/ops/triton/quant/__init__.py +++ b/aiter/ops/triton/quant/__init__.py @@ -7,11 +7,14 @@ ) from .fused_fp8_quant import ( + calc_rows_per_block, fused_rms_fp8_per_tensor_static_quant, fused_rms_fp8_group_quant, + fused_rms_gated_fp8_group_quant, fused_flatten_fp8_group_quant, fused_reduce_act_mul_fp8_group_quant, fused_reduce_rms_fp8_group_quant, + get_fp8_min_max_bounds, ) from .fused_mxfp4_quant import ( @@ -30,8 +33,11 @@ "dynamic_mxfp4_quant", "_mxfp4_quant_op", # fused_fp8_quant.py exports + "calc_rows_per_block", + "get_fp8_min_max_bounds", "fused_rms_fp8_per_tensor_static_quant", "fused_rms_fp8_group_quant", + "fused_rms_gated_fp8_group_quant", "fused_flatten_fp8_group_quant", "fused_reduce_act_mul_fp8_group_quant", "fused_reduce_rms_fp8_group_quant", diff --git a/aiter/ops/triton/quant/fused_fp8_quant.py b/aiter/ops/triton/quant/fused_fp8_quant.py index 7575310770..f0583a86b3 100644 --- a/aiter/ops/triton/quant/fused_fp8_quant.py +++ b/aiter/ops/triton/quant/fused_fp8_quant.py @@ -1,3 +1,4 @@ +from functools import cache from typing import Optional import torch import triton @@ -5,11 +6,13 @@ from aiter.ops.triton._triton_kernels.quant.fused_fp8_quant import ( _fused_rms_fp8_per_tensor_static_quant_kernel, _fused_rms_fp8_group_quant_kernel, + _fused_rms_gated_fp8_group_quant_kernel, _fused_flatten_fp8_group_quant_kernel, _fused_reduce_act_mul_fp8_group_quant, _fused_reduce_rms_fp8_group_quant_kernel, _fused_silu_mul_fp8_per_tensor_static_quant_kernel, ) +from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) @@ -329,6 +332,167 @@ def fused_rms_fp8_group_quant( return (out1_fp8, out1_bs), out1, out2, out_res1 +def get_fp8_min_max_bounds(fp8_dtype: torch.dtype) -> tuple[float, float]: + """Match vLLM ``quant_utils.get_fp8_min_max`` for ``fp8_dtype`` (incl. ROCm fnuz ±224).""" + if fp8_dtype == torch.float8_e4m3fnuz: + return -224.0, 224.0 + finfo = torch.finfo(fp8_dtype) + return float(finfo.min), float(finfo.max) + + +@cache +def _num_compute_units(device_id: int = 0) -> int: + """Match vLLM ``vllm.utils.platform_utils.num_compute_units`` (``current_platform.num_compute_units``).""" + return torch.cuda.get_device_properties(device_id).multi_processor_count + + +def calc_rows_per_block(M: int, device: torch.device) -> int: + """Same heuristic as vLLM ``input_quant_fp8.calc_rows_per_block``.""" + if device.type != "cuda": + raise ValueError( + "fused_rms_gated_fp8_group_quant targets AMD ROCm (HIP); expected a CUDA/HIP device." + ) + device_id = ( + device.index if device.index is not None else torch.cuda.current_device() + ) + sm_count = max(int(_num_compute_units(device_id)), 1) + rows_per_block = triton.next_power_of_2(triton.cdiv(M, 2 * sm_count)) + return min(int(rows_per_block), 4) + + +def fused_rms_gated_fp8_group_quant( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + z: torch.Tensor, + eps: float, + *, + norm_before_gate: bool = True, + use_ue8m0: bool = False, + activation: str = "silu", + out_dtype: torch.dtype | None = None, + fp8_min: float | None = None, + fp8_max: float | None = None, + fp8_min_scaling_factor: float | None = None, + group_size: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused RMSNorm (with optional bias), optional multiplicative gate from ``z``, + and FP8 quantization (same contract as vLLM ``_rmsnorm_quantize_group_native`` for + ``group_size == N``). + + Comparison with ``fused_rms_fp8_group_quant``: + Use ``fused_rms_fp8_group_quant`` when you need optional **two-stream** RMSNorm + (``inp1`` / optional ``inp2`` with separate weights and epsilons), optional + **residual** fused into ``inp1`` (``res1``), FP8 group quantization on the **first** + normalized stream only, the richer return tuple (quantized FP8, block scales, + optional unquantized ``inp1``, second RMS output, residual output), and optional + ``transpose_scale`` layout for scales. + + Use **this** function for **single** hidden ``x``, one RMS **weight** (and optional + **bias**), plus ``z`` for **elementwise multiplicative gating** (SiLU / sigmoid-style + activations on ``z``) matching ``x``'s shape; optional ``norm_before_gate`` ordering; + vLLM-aligned FP8 bounds / optional UE8M0 / ``group_size`` (``None`` = one scale per + row, else per-column-group scales). Returns only ``(x_quant_fp8, scales)``. Suited to + gated RMSNorm input quantization (e.g. SwiGLU-style / vLLM + ``_rmsnorm_quantize_group_native`` contracts), not the two-stream + residual pattern + above. + + ``x`` and ``z`` must be 2D contiguous with identical shape ``(M, N)``. + Returns ``(x_quant_fp8, scales)`` where ``scales`` is ``(M,)`` float32 if + ``group_size`` is ``None`` (one scale per row), or ``(M, N // group_size)`` float32 + when ``group_size`` divides ``N`` (one scale per row per column group). + + ``fp8_min`` / ``fp8_max`` / ``fp8_min_scaling_factor`` default from ``out_dtype`` (or + ``get_fp8_e4m3_dtype()``) using the same rules as vLLM ``get_fp8_min_max`` and + ``1.0 / (_FP8_MAX * 512)``. Pass them explicitly when you want to pin values (e.g. from + vLLM's ``get_fp8_min_max()`` at model init). + + Raises: + ValueError: if ``group_size`` is not ``None`` and ``group_size > N``, + ``group_size <= 0``, or ``N`` is not divisible by ``group_size``. + """ + assert x.is_contiguous() and z.is_contiguous() + assert x.shape == z.shape, "x and z must have the same shape" + fp8_dtype = out_dtype if out_dtype is not None else get_fp8_e4m3_dtype() + if (fp8_min is None) ^ (fp8_max is None): + raise ValueError("fp8_min and fp8_max must be passed together or both omitted.") + if fp8_min is None: + fp8_min, fp8_max = get_fp8_min_max_bounds(fp8_dtype) + if fp8_min_scaling_factor is None: + fp8_min_scaling_factor = 1.0 / (fp8_max * 512.0) + + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + + M, N = x.shape + if group_size is not None: + if group_size <= 0: + raise ValueError(f"group_size must be positive, got {group_size}") + if group_size > N: + raise ValueError( + f"group_size ({group_size}) must be less than or equal to hidden size " + f"N ({N}); per-column FP8 groups cannot exceed the row width." + ) + if N % group_size != 0: + raise ValueError( + f"hidden size N ({N}) must be divisible by group_size ({group_size})." + ) + + effective_gs = N if group_size is None else int(group_size) + num_groups = N // effective_gs + + MAX_FUSED_SIZE = 65536 // x.element_size() + if N > MAX_FUSED_SIZE: + raise RuntimeError("This RMSNorm quant kernel does not support N >= 64KB.") + + rms_tile = min(512, triton.next_power_of_2(N)) + block_g = triton.next_power_of_2(effective_gs) + rows_per_block = calc_rows_per_block(M, x.device) + num_warps = min(max(block_g // 256, 1), 8) + + x_quant = torch.empty(M, N, dtype=fp8_dtype, device=x.device) + if group_size is None: + scales = torch.empty(M, dtype=torch.float32, device=x.device) + stride_s_row = int(scales.stride(0)) + stride_s_g = 0 + else: + scales = torch.empty(M, num_groups, dtype=torch.float32, device=x.device) + stride_s_row, stride_s_g = (int(scales.stride(0)), int(scales.stride(1))) + + grid = (triton.cdiv(M, rows_per_block),) + _fused_rms_gated_fp8_group_quant_kernel[grid]( + x, + weight, + bias, + z, + x_quant, + scales, + x.stride(0), + z.stride(0), + x_quant.stride(0), + stride_s_row, + stride_s_g, + M, + N, + eps, + RMS_TILE=rms_tile, + ROWS_PER_BLOCK=rows_per_block, + GROUP_SIZE=effective_gs, + NUM_GROUPS=num_groups, + BLOCK_G=block_g, + NORM_BEFORE_GATE=norm_before_gate, + FP8_MIN=fp8_min, + FP8_MAX=fp8_max, + USE_UE8M0=use_ue8m0, + FP8_MIN_SCALING_FACTOR=fp8_min_scaling_factor, + num_warps=num_warps, + ACTIVATION=activation, + ) + return x_quant, scales + + def fused_flatten_fp8_group_quant( x: torch.Tensor, group_size, diff --git a/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py b/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py new file mode 100644 index 0000000000..0653eccc01 --- /dev/null +++ b/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Tests for ``fused_rms_gated_fp8_group_quant`` (kernel in ``_triton_kernels/quant/fused_fp8_quant``).""" + +import pytest +import torch + +from aiter.ops.triton.quant.fused_fp8_quant import ( + fused_rms_gated_fp8_group_quant, + get_fp8_min_max_bounds, +) +from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype + +cuda_ok = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA/HIP device required" +) + + +def ref_rmsnorm_quant( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + z: torch.Tensor, + eps: float, + norm_before_gate: bool, + activation: str, + fmin: float, + fmax: float, + group_size: int | None, +) -> tuple[torch.Tensor, torch.Tensor]: + x32 = x.float() + z32 = z.float() + var = x32.pow(2).mean(-1, keepdim=True) + x_hat = x32 * torch.rsqrt(var + eps) + y = x_hat * weight.float() + if bias is not None: + y = y + bias.float() + if norm_before_gate: + if activation in ("silu", "swish"): + y = y * (z32 * torch.sigmoid(z32)) + elif activation == "sigmoid": + y = y * torch.sigmoid(z32) + fp8_dtype = get_fp8_e4m3_dtype() + gs = x.shape[1] if group_size is None else group_size + ng = x.shape[1] // gs + yg = y.view(y.shape[0], ng, gs) + scales = yg.abs().amax(dim=-1).clamp_min(1e-12) / fmax + y_scaled = yg / scales.unsqueeze(-1) + q = y_scaled.clamp(fmin, fmax).to(fp8_dtype).view_as(y) + if group_size is None: + scales = scales.squeeze(-1) + return q, scales + + +def _scale_broadcast( + scales: torch.Tensor, N: int, group_size: int | None +) -> torch.Tensor: + if group_size is None: + return scales.unsqueeze(-1).expand(-1, N) + return scales.repeat_interleave(group_size, dim=1) + + +@cuda_ok +def test_fused_rms_gated_fp8_group_quant_matches_ref(): + device = "cuda" + torch.manual_seed(0) + M, N = 32, 64 + x = torch.randn(M, N, device=device, dtype=torch.bfloat16) + z = torch.randn(M, N, device=device, dtype=torch.bfloat16) + w = torch.randn(N, device=device, dtype=torch.bfloat16) + bias = torch.randn(N, device=device, dtype=torch.bfloat16) + + fp8_dtype = get_fp8_e4m3_dtype() + fmin, fmax = get_fp8_min_max_bounds(fp8_dtype) + scale_floor = 1.0 / (fmax * 512.0) + + y_q, scales_t = fused_rms_gated_fp8_group_quant( + x, + w, + bias, + z, + 1e-5, + norm_before_gate=True, + use_ue8m0=False, + activation="silu", + fp8_min=fmin, + fp8_max=fmax, + fp8_min_scaling_factor=scale_floor, + ) + y_ref, scales_ref = ref_rmsnorm_quant( + x, w, bias, z, 1e-5, True, "silu", fmin, fmax, None + ) + + torch.testing.assert_close(scales_t, scales_ref, rtol=1e-3, atol=1e-3) + sb = _scale_broadcast(scales_ref, N, None) + dq = y_q.float() * sb + dq_ref = y_ref.float() * sb + torch.testing.assert_close(dq, dq_ref, rtol=0.15, atol=0.15) + + y_default, scales_default = fused_rms_gated_fp8_group_quant( + x, + w, + bias, + z, + 1e-5, + norm_before_gate=True, + use_ue8m0=False, + activation="silu", + ) + torch.testing.assert_close(scales_t, scales_default, rtol=0.0, atol=0.0) + torch.testing.assert_close(y_q.float(), y_default.float(), rtol=0.0, atol=0.0) + + +_MS = [1, 3, 4, 512, 1024, 2048, 4096] +_NS = [128, 256] +_GROUP_SIZES = { + 128: [1, 2, 4, 8, 16, 32, 64, 128], + 256: [1, 2, 4, 8, 16, 32, 64, 128, 256], +} + + +def _sweep_cases(): + out = [] + for N in _NS: + for M in _MS: + for g in _GROUP_SIZES[N]: + out.append(pytest.param(M, N, g, id=f"M{M}-N{N}-g{g}")) + return out + + +@cuda_ok +@pytest.mark.parametrize(("M", "N", "group_size"), _sweep_cases()) +def test_fused_rms_gated_fp8_group_quant_sweep(M: int, N: int, group_size: int): + device = "cuda" + torch.manual_seed(1) + x = torch.randn(M, N, device=device, dtype=torch.bfloat16) + z = torch.randn(M, N, device=device, dtype=torch.bfloat16) + w = torch.randn(N, device=device, dtype=torch.bfloat16) + bias = torch.randn(N, device=device, dtype=torch.bfloat16) + fmin, fmax = get_fp8_min_max_bounds(get_fp8_e4m3_dtype()) + scale_floor = 1.0 / (fmax * 512.0) + + y_q, scales_t = fused_rms_gated_fp8_group_quant( + x, + w, + bias, + z, + 1e-5, + norm_before_gate=True, + use_ue8m0=False, + activation="silu", + fp8_min=fmin, + fp8_max=fmax, + fp8_min_scaling_factor=scale_floor, + group_size=group_size, + ) + y_ref, scales_ref = ref_rmsnorm_quant( + x, w, bias, z, 1e-5, True, "silu", fmin, fmax, group_size + ) + + assert scales_t.shape == scales_ref.shape + torch.testing.assert_close(scales_t, scales_ref, rtol=1e-3, atol=1e-3) + sb = _scale_broadcast(scales_ref, N, group_size) + dq = y_q.float() * sb + dq_ref = y_ref.float() * sb + torch.testing.assert_close(dq, dq_ref, rtol=0.15, atol=0.15) + + +@cuda_ok +def test_fused_rms_gated_fp8_group_quant_group_size_errors(): + device = "cuda" + x = torch.randn(2, 128, device=device, dtype=torch.bfloat16) + z = torch.randn_like(x) + w = torch.randn(128, device=device, dtype=torch.bfloat16) + b = torch.randn(128, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="less than or equal to hidden size"): + fused_rms_gated_fp8_group_quant(x, w, b, z, 1e-5, group_size=256) + with pytest.raises(ValueError, match="divisible by group_size"): + fused_rms_gated_fp8_group_quant(x, w, b, z, 1e-5, group_size=48) + with pytest.raises(ValueError, match="positive"): + fused_rms_gated_fp8_group_quant(x, w, b, z, 1e-5, group_size=0) diff --git a/op_tests/triton_tests/test_causal_conv1d_update_single_token.py b/op_tests/triton_tests/test_causal_conv1d_update_single_token.py new file mode 100644 index 0000000000..ad59e58427 --- /dev/null +++ b/op_tests/triton_tests/test_causal_conv1d_update_single_token.py @@ -0,0 +1,414 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Tests for ``causal_conv1d_update_single_token`` / ``fused_reshape_causal_conv1d_update_single_token``. + +``causal_conv1d_update_single_token`` updates ``conv_state`` in place; the reference mirrors +``_causal_conv1d_update_single_token_kernel`` (non-APC), not ``causal_conv1d_update_ref``. +Shape extras that used to live in smoke tests are folded into +``test_causal_conv1d_update_single_token_matches_ref`` (see ``_causal_conv1d_update_single_token_ref_cases``). +""" + +from __future__ import annotations + +import random + +import numpy as np +import pytest +import torch +import triton + +from aiter.ops.triton._triton_kernels.causal_conv1d import PAD_SLOT_ID +from aiter.ops.triton.causal_conv1d_update_single_token import ( + causal_conv1d_update_single_token, + fused_reshape_causal_conv1d_update_single_token, +) + +cuda_ok = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA/HIP device required" +) + + +def seed_everything(seed: int = 0) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def ref_causal_conv1d_update_single_token( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + activation: str | None, + conv_state_indices: torch.Tensor, + pad_slot_id: int | None, +) -> torch.Tensor: + """Python port of ``_causal_conv1d_update_single_token_kernel`` (non-APC, 1D indices). + + Mutates ``conv_state`` in place (like the Triton kernel). Clones ``x`` only for + ``out`` leaves non-updated timesteps equal to the input. + """ + out = x.clone() + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = width - 1 + np2 = triton.next_power_of_2(state_len) + num_cache_lines = conv_state.shape[0] + silu = activation in ("silu", "swish") + + if conv_state_indices.ndim != 1: + raise NotImplementedError("reference supports 1D conv_state_indices only") + + for b in range(batch): + coord_read = int(conv_state_indices[b].item()) + if pad_slot_id is not None and coord_read == pad_slot_id: + continue + coord_write = int(conv_state_indices[b].item()) + val = state_len - seqlen + + for f in range(dim): + cols_hist = [] + for j in range(np2): + if j < width - 1: + cols_hist.append(float(conv_state[coord_read, f, j].item())) + else: + cols_hist.append(0.0) + + for j in range(np2): + mask_cs = (coord_read < num_cache_lines) and (j + seqlen < state_len) + v_cs = ( + float(conv_state[coord_read, f, j + seqlen].item()) + if mask_cs + else 0.0 + ) + t = j - val + mask_x = (0 <= t) and (t < seqlen) + v_x = float(x[b, f, t].item()) if mask_x else 0.0 + new_v = v_cs if mask_cs else v_x + if j < state_len: + conv_state[coord_write, f, j] = torch.tensor( + new_v, dtype=conv_state.dtype, device=conv_state.device + ) + + acc = 0.0 + for j in range(np2): + wj = float(weight[f, j].item()) if j < width - 1 else 0.0 + acc += wj * cols_hist[j] + w_last = float(weight[f, width - 1].item()) + x0 = float(x[b, f, 0].item()) + acc += w_last * x0 + if bias is not None: + acc += float(bias[f].item()) + if silu: + acc = acc / (1.0 + np.exp(-acc)) + out[b, f, 0] = torch.tensor(acc, dtype=out.dtype, device=out.device) + + return out + + +def _logical_feat_to_qkvz_col_v2( + idx_feats: int, + num_k_heads: int, + head_k_dim: int, + head_v_dim: int, + head_qkvz_dim: int, + hv_ratio: int, +) -> int: + nk, hk, hv = num_k_heads, head_k_dim, head_v_dim + if idx_feats < nk * hk: + h = idx_feats // hk + r = idx_feats % hk + return h * head_qkvz_dim + r + if idx_feats < nk * hk * 2: + rel = idx_feats - nk * hk + h = rel // hk + r = rel % hk + return h * head_qkvz_dim + hk + r + rel = idx_feats - nk * hk * 2 + gs = hv_ratio * hv + h = rel // gs + r = rel % gs + return h * head_qkvz_dim + 2 * hk + r + + +def ref_fused_reshape_causal_conv1d_update_single_token( + x: torch.Tensor, + num_actual_tokens: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + ba: torch.Tensor, + z_out: torch.Tensor, + core_attn_out: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + activation: str | None, + conv_state_indices: torch.Tensor | None, + pad_slot_id: int | None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Reference: extract b/a/z like the kernel, build logical QKV, run ``ref_causal_conv1d_update_single_token``.""" + num_tokens = x.shape[0] + hv_ratio = num_v_heads // num_k_heads + head_dim = head_k_dim + head_k_dim + head_v_dim * hv_ratio + head_qkvz_dim = head_dim + head_v_dim * hv_ratio + dim = num_k_heads * head_dim + seqlen = x.shape[2] + device = x.device + dtype = x.dtype + + b_out = torch.empty(num_actual_tokens, num_v_heads, device=device, dtype=ba.dtype) + a_out = torch.empty_like(b_out) + for idx_seq in range(num_actual_tokens): + for idx_hv in range(num_v_heads): + idx_h = idx_hv // hv_ratio + idx_v = idx_hv % hv_ratio + b_off = idx_h * (2 * hv_ratio) + idx_v + a_off = idx_h * (2 * hv_ratio) + hv_ratio + idx_v + b_out[idx_seq, idx_hv] = ba[idx_seq, b_off] + a_out[idx_seq, idx_hv] = ba[idx_seq, a_off] + + z_flat = z_out.reshape(num_tokens, -1).clone() + core_flat = core_attn_out.reshape(num_tokens, -1).clone() + gs = hv_ratio * head_v_dim + for idx_seq in range(num_tokens): + for idx_z in range(num_v_heads * head_v_dim): + idx_z_x = ( + (idx_z // gs) * head_qkvz_dim + + 2 * head_k_dim + + hv_ratio * head_v_dim + + (idx_z % gs) + ) + z_flat[idx_seq, idx_z] = x[idx_seq, idx_z_x, 0] + + n_repeat = (num_tokens - 1) // num_actual_tokens if num_actual_tokens else 0 + for idx_repeat in range(n_repeat): + for idx_seq in range(num_actual_tokens): + idx_seq_remain = num_actual_tokens * (1 + idx_repeat) + idx_seq + if idx_seq_remain < num_tokens: + z_flat[idx_seq_remain].zero_() + core_flat[idx_seq_remain].zero_() + + x_lin = torch.zeros(num_actual_tokens, dim, seqlen, device=device, dtype=dtype) + for b in range(num_actual_tokens): + for f in range(dim): + col = _logical_feat_to_qkvz_col_v2( + f, num_k_heads, head_k_dim, head_v_dim, head_qkvz_dim, hv_ratio + ) + for t in range(seqlen): + x_lin[b, f, t] = x[b, col, t] + + cs = conv_state.clone() + if conv_state_indices is None: + cidx = torch.arange(num_actual_tokens, device=device, dtype=torch.int32) + else: + cidx = conv_state_indices + out_lin = ref_causal_conv1d_update_single_token( + x_lin, + cs, + weight, + bias, + activation, + cidx, + pad_slot_id, + ) + if seqlen == 1: + out_lin = out_lin.squeeze(-1) + return ( + out_lin, + b_out, + a_out, + z_flat.view_as(z_out), + core_flat.view_as(core_attn_out), + cs, + ) + + +def _causal_conv1d_update_single_token_ref_cases(): + """Cartesian core grid plus former smoke shapes (width=3, small dim); seqlen fixed to 1 for single-token API.""" + out = [] + seqlen = 1 + for itype in (torch.float32, torch.bfloat16): + for silu_activation in (True, False): + for has_bias in (True, False): + for width in (2, 4): + out.append( + pytest.param( + 1, + 1024, + width, + seqlen, + itype, + silu_activation, + has_bias, + id=f"b1-d1024-w{width}-s{seqlen}-" + f"silu{silu_activation}-bias{has_bias}-" + f"{'fp32' if itype == torch.float32 else 'bf16'}", + ) + ) + out.extend( + [ + pytest.param( + 2, + 64, + 3, + 1, + torch.bfloat16, + True, + True, + id="smoke-b2-d64-w3-s1-bf16", + ), + pytest.param( + 1, + 128, + 4, + 1, + torch.bfloat16, + True, + True, + id="smoke-b1-d128-w4-s1-bf16", + ), + ] + ) + return out + + +@cuda_ok +@pytest.mark.parametrize( + ( + "batch", + "dim", + "width", + "seqlen", + "itype", + "silu_activation", + "has_bias", + ), + _causal_conv1d_update_single_token_ref_cases(), +) +def test_causal_conv1d_update_single_token_matches_ref( + batch, dim, width, seqlen, itype, silu_activation, has_bias +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 6e-2 + seed_everything(0) + x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + x_tr = x.clone() + conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) + conv_tr = conv_state.clone() + conv_ref = conv_state.clone() + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + activation = None if not silu_activation else "silu" + cidx = torch.arange(batch, dtype=torch.int32, device=device) + + out_ref = ref_causal_conv1d_update_single_token( + x, conv_ref, weight, bias, activation, cidx, PAD_SLOT_ID + ) + out_tr = causal_conv1d_update_single_token( + x_tr, + conv_tr, + weight, + bias, + activation=activation, + conv_state_indices=cidx, + pad_slot_id=PAD_SLOT_ID, + ) + torch.testing.assert_close(conv_tr, conv_ref, rtol=0.0, atol=0.0) + torch.testing.assert_close(out_tr, out_ref, rtol=rtol, atol=atol) + + +@cuda_ok +@pytest.mark.parametrize( + "num_k_heads,num_v_heads,head_k_dim,head_v_dim,num_tokens,num_actual_tokens,width", + [ + (2, 2, 8, 8, 4, 2, 3), + (2, 4, 16, 8, 6, 3, 4), + ], +) +def test_fused_reshape_causal_conv1d_update_single_token_matches_ref( + num_k_heads, + num_v_heads, + head_k_dim, + head_v_dim, + num_tokens, + num_actual_tokens, + width, +): + device = "cuda" + torch.manual_seed(1) + hv_ratio = num_v_heads // num_k_heads + assert hv_ratio * num_k_heads == num_v_heads + head_dim = head_k_dim + head_k_dim + head_v_dim * hv_ratio + head_qkvz_dim = head_dim + head_v_dim * hv_ratio + qkvz_dim = num_k_heads * head_qkvz_dim + dim = num_k_heads * head_dim + seqlen = 1 + dtype = torch.bfloat16 + rtol, atol = 1e-2, 6e-2 + + x = torch.randn(num_tokens, qkvz_dim, seqlen, device=device, dtype=dtype) + ba = torch.randn(num_tokens, 2 * num_v_heads, device=device, dtype=dtype) + z_out = torch.zeros(num_tokens, num_v_heads, head_v_dim, device=device, dtype=dtype) + core = torch.zeros_like(z_out) + conv_state = torch.randn( + num_actual_tokens, dim, width - 1, device=device, dtype=dtype + ) + weight = torch.randn(dim, width, device=device, dtype=dtype) + bias = torch.randn(dim, device=device, dtype=dtype) + + z_ref = z_out.clone() + core_ref = core.clone() + cs_ref_init = conv_state.clone() + out_ref, b_ref, a_ref, z_r, c_r, cs_ref = ( + ref_fused_reshape_causal_conv1d_update_single_token( + x, + num_actual_tokens, + num_k_heads, + num_v_heads, + head_k_dim, + head_v_dim, + ba, + z_ref, + core_ref, + cs_ref_init, + weight, + bias, + "silu", + None, + PAD_SLOT_ID, + ) + ) + + z_tr = z_out.clone() + core_tr = core.clone() + cs_tr = conv_state.clone() + out_tr, b_tr, a_tr = fused_reshape_causal_conv1d_update_single_token( + x, + num_actual_tokens, + num_k_heads, + num_v_heads, + head_k_dim, + head_v_dim, + ba, + z_tr, + core_tr, + cs_tr, + weight, + bias, + activation="silu", + conv_state_indices=None, + pad_slot_id=PAD_SLOT_ID, + ) + + torch.testing.assert_close(out_tr.float(), out_ref.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(b_tr, b_ref, rtol=0.0, atol=0.0) + torch.testing.assert_close(a_tr, a_ref, rtol=0.0, atol=0.0) + torch.testing.assert_close(z_tr, z_r, rtol=rtol, atol=atol) + torch.testing.assert_close(core_tr, c_r, rtol=0.0, atol=0.0) + torch.testing.assert_close(cs_tr, cs_ref, rtol=0.0, atol=0.0) diff --git a/op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py b/op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py new file mode 100644 index 0000000000..8990816ea8 --- /dev/null +++ b/op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops.triton.gated_delta_net import fused_rearrange_sigmoid_gated_delta_rule + +cuda_ok = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA/HIP device required" +) + + +def _softplus(x: torch.Tensor, beta: float, threshold: float) -> torch.Tensor: + return torch.where( + beta * x <= threshold, + (1.0 / beta) * torch.log1p(torch.exp(beta * x)), + x, + ) + + +def ref_fused_rearrange_sigmoid_gdr( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + qkv: torch.Tensor, + key_dim: int, + value_dim: int, + head_k_dim: int, + head_v_dim: int, + beta: float, + threshold: float, + scale: float, + initial_state: torch.Tensor | None, + use_qk_l2norm_in_kernel: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """Float reference for decode path (B=1, one sequence), including GQA (HV >= H).""" + T = qkv.shape[0] + H = key_dim // head_k_dim + HV = value_dim // head_v_dim + K = head_k_dim + V = head_v_dim + if HV % H != 0: + raise ValueError(f"reference expects HV divisible by H, got H={H}, HV={HV}") + group = HV // H + B = 1 + o = torch.empty(B, T, HV, V, dtype=torch.float32, device=qkv.device) + h_state = torch.zeros(HV, V, K, dtype=torch.float32, device=qkv.device) + if initial_state is not None: + h_state = initial_state[0].to(torch.float32).clone() + + for t in range(T): + row = qkv[t] + for hv in range(HV): + i_h = hv // group + q_vec = row[i_h * K : (i_h + 1) * K].float() + k_vec = row[H * K + i_h * K : H * K + (i_h + 1) * K].float() + v_vec = row[2 * H * K + hv * V : 2 * H * K + (hv + 1) * V].float() + b_gate = b[t, hv].float() + x = a[t, hv].float() + dt_bias[hv].float() + sp = _softplus(x, beta, threshold) + g = -torch.exp(A_log[hv].float()) * sp + beta_out = torch.sigmoid(b_gate) + if use_qk_l2norm_in_kernel: + q_vec = q_vec * torch.rsqrt((q_vec * q_vec).sum() + 1e-6) + k_vec = k_vec * torch.rsqrt((k_vec * k_vec).sum() + 1e-6) + q_vec = q_vec * scale + h_sub = h_state[hv] + h_sub = h_sub * torch.exp(g) + v_adj = v_vec - (h_sub * k_vec.unsqueeze(0)).sum(dim=-1) + v_adj = v_adj * beta_out + h_sub = h_sub + v_adj.unsqueeze(-1) * k_vec.unsqueeze(0) + out_vec = (h_sub * q_vec.unsqueeze(0)).sum(dim=-1) + o[0, t, hv] = out_vec + h_state[hv] = h_sub + return o, h_state.unsqueeze(0) + + +# Shapes aligned with ``test_gated_delta_rule.test_fused_recurrent``; dtypes are +# half-precision only — long packed ``T`` with float32 activations tends to blow +# up the recurrent reference / kernel without tighter dynamic-range clamps. +# Each row ends with ``use_qk_l2norm_in_kernel`` (True for stable long-T sweep). +# One small bf16 row uses False to cover the no–L2-norm path (replaces former ``basic``). +_FUSED_GDR_SWEEP = [ + (63, 1, 1, 64, 1, 1, torch.float16, True), + (500, 4, 4, 60, 1, 1, torch.float16, True), + (1000, 2, 8, 128, 1, 0.1, torch.float16, True), + (1024, 2, 2, 128, 0.1, 1, torch.float16, True), + (1024, 3, 3, 128, 1, 10, torch.float16, True), + (2048, 4, 4, 64, 0.1, 1, torch.float16, True), + (1024, 4, 4, 128, 1, 0.1, torch.float16, True), + (1024, 4, 8, 128, 1, 10, torch.float16, True), + (1024, 4, 4, 128, 1, 0.1, torch.bfloat16, True), + (1024, 4, 8, 128, 1, 1, torch.bfloat16, True), + (2048, 4, 8, 64, 0.1, 1, torch.bfloat16, True), + (8, 4, 4, 16, 16**-0.5, 1, torch.bfloat16, False), +] + + +@cuda_ok +@pytest.mark.parametrize( + ( + "T", + "H", + "HV", + "D", + "scale", + "gate_logit_normalizer", + "dtype", + "use_qk_l2norm_in_kernel", + ), + [ + pytest.param( + *row, + id="T{}-H{}-HV{}-D{}-scale{}-gate_logit_normalizer{}-{}-l2{}".format(*row), + ) + for row in _FUSED_GDR_SWEEP + ], +) +def test_fused_rearrange_sigmoid_gdr_sweep( + T: int, + H: int, + HV: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype: torch.dtype, + use_qk_l2norm_in_kernel: bool, +): + """Shape/dtype sweep aligned with ``test_gated_delta_rule.test_fused_recurrent``.""" + if HV % H != 0: + pytest.skip("reference/kernel GQA mapping needs HV divisible by H") + device = "cuda" + K = V = D + key_dim = H * K + value_dim = HV * V + + if use_qk_l2norm_in_kernel: + torch.manual_seed(42) + qkv = torch.randn(T, key_dim * 2 + value_dim, device=device, dtype=dtype) * 0.05 + A_log = ( + torch.randn(HV, device=device, dtype=torch.float32).clamp(-2.0, 0.5) * 0.02 + ) + a = (torch.randn(T, HV, device=device, dtype=dtype) * 0.05).clamp(-1.0, 1.0) + a = a / gate_logit_normalizer + b_gate = (torch.randn(T, HV, device=device, dtype=dtype) * 0.05).clamp( + -1.0, 1.0 + ) + dt_bias = (torch.randn(HV, device=device, dtype=dtype) * 0.005).clamp(-0.5, 0.5) + initial = torch.randn(1, HV, V, K, device=device, dtype=dtype) * 0.05 + else: + torch.manual_seed(0) + qkv = torch.randn(T, key_dim * 2 + value_dim, device=device, dtype=dtype) + A_log = torch.randn(HV, device=device, dtype=torch.float32) * 0.02 + a = torch.randn(T, HV, device=device, dtype=dtype) * 0.1 + a = a / gate_logit_normalizer + b_gate = torch.randn(T, HV, device=device, dtype=dtype) * 0.1 + dt_bias = torch.randn(HV, device=device, dtype=dtype) * 0.01 + initial = torch.randn(1, HV, V, K, device=device, dtype=dtype) + + o_ref, h_ref = ref_fused_rearrange_sigmoid_gdr( + A_log, + a, + b_gate, + dt_bias, + qkv, + key_dim, + value_dim, + K, + V, + 1.0, + 20.0, + scale, + initial, + use_qk_l2norm_in_kernel, + ) + + core = torch.empty(1 * 1 * T * HV * V, device=device, dtype=dtype) + o_tr, h_tr = fused_rearrange_sigmoid_gated_delta_rule( + A_log, + a, + b_gate, + dt_bias, + qkv, + key_dim, + value_dim, + K, + V, + beta=1.0, + threshold=20.0, + scale=scale, + initial_state=initial, + inplace_final_state=False, + cu_seqlens=None, + ssm_state_indices=None, + num_accepted_tokens=None, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + is_kda=False, + core_attn_out=core, + ) + + if dtype == torch.bfloat16: + rtol, atol = 0.05, 0.1 + elif dtype == torch.float16: + rtol, atol = 0.03, 0.08 + else: + rtol, atol = 0.02, 0.05 + + if use_qk_l2norm_in_kernel: + assert torch.isfinite(o_tr.float()).all(), "non-finite Triton output" + assert torch.isfinite(h_tr.float()).all(), "non-finite Triton final_state" + torch.testing.assert_close(o_tr.float(), o_ref, rtol=rtol, atol=atol) + torch.testing.assert_close(h_tr[-1].float(), h_ref[0], rtol=rtol, atol=atol)