diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 90c59bd6ddf51..9aa44a1600ae6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -15,6 +15,7 @@ Do not modify directly.* * com.microsoft.BitmaskBiasDropout * com.microsoft.BitmaskDropout * com.microsoft.CDist + * com.microsoft.CausalConvWithState * com.microsoft.ComplexMul * com.microsoft.ComplexMulConj * com.microsoft.ConvTransposeWithDynamicPads @@ -49,6 +50,7 @@ Do not modify directly.* * com.microsoft.GroupQueryAttention * com.microsoft.Inverse * com.microsoft.Irfft + * com.microsoft.LinearAttention * com.microsoft.LongformerAttention * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 @@ -900,6 +902,68 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.CausalConvWithState** + + Stateful causal depthwise convolution, generalized to N spatial dimensions. + + Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step. + Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation. + + The convolution is causal (looks only at current and past positions along the last + spatial dimension) and depthwise (each channel is convolved independently with its own kernel). + + Input layout is channels-first: (batch_size, channels, ...). + Weight layout: (channels, 1, k_1, ...) for depthwise convolution. + The carry state stores the last (k-1) positions along the causal axis for incremental decode. + + The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is + enforced on the last spatial dimension only. + + The optional activation attribute supports fused SiLU/Swish activation. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : string
+
Fused activation function. One of: 'silu', 'swish', 'none'. Default is 'none'.
+
ndim : int
+
Spatial dimensionality: 1, 2, or 3. Default is 1.
+
+ +#### Inputs (2 - 4) + +
+
input : T
+
Input tensor with shape (batch_size, channels, ...). Channels-first layout. Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).
+
weight : T
+
Depthwise convolution kernel with shape (channels, 1, k_1, ...). Spatial kernel sizes: (k_1, ..., k_ndim).
+
bias (optional) : T
+
Optional per-channel bias with shape (channels).
+
past_state (optional) : T
+
Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). If not provided, padding is zero.
+
+ +#### Outputs + +
+
output : T
+
Convolution output with same shape as input.
+
present_state : T
+
Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). Contains the last (k-1) values from the virtual input along the causal axis.
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
+ + ### **com.microsoft.ComplexMul** #### Version @@ -2703,6 +2767,79 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.LinearAttention** + + Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1). + + All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always + required. The op internally unpacks to 4D for computation. + + The update_rule attribute selects the recurrence type: + - "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t + - "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t + - "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t + - "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t + + where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product. + + Semantics: Equivalent to running the recurrent update sequentially for each token, + but may be implemented using chunk-parallel algorithms for GPU efficiency. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
chunk_size : int
+
Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness.
+
kv_num_heads : int (required)
+
Number of key/value heads. Always required.
+
q_num_heads : int (required)
+
Number of query heads. Always required.
+
scale : float
+
Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override.
+
update_rule : string
+
The update rule for the linear attention recurrence. One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.
+
+ +#### Inputs (3 - 6) + +
+
query : T
+
Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension.
+
key : T
+
Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes.
+
value : T
+
Value vectors with 3D packed shape (B, T, H_kv * d_v).
+
past_state (optional) : S
+
Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros.
+
decay (optional) : T
+
Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for 'gated' and 'gated_delta' modes.
+
beta (optional) : T
+
Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for 'delta' and 'gated_delta' modes.
+
+ +#### Outputs + +
+
output : T
+
Attention output with 3D packed shape (B, T, H_q * d_v).
+
present_state : S
+
Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
S : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain state types to float tensors.
+
+ + ### **com.microsoft.LongformerAttention** Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 092c05f9e081a..1209446c6a367 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2217,5 +2217,247 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } })); +constexpr const char* CausalConvWithState_ver1_doc = R"DOC( +Stateful causal depthwise convolution, generalized to N spatial dimensions. + +Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step. +Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation. + +The convolution is causal (looks only at current and past positions along the last +spatial dimension) and depthwise (each channel is convolved independently with its own kernel). + +Input layout is channels-first: (batch_size, channels, ...). +Weight layout: (channels, 1, k_1, ...) for depthwise convolution. +The carry state stores the last (k-1) positions along the causal axis for incremental decode. + +The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is +enforced on the last spatial dimension only. + +The optional activation attribute supports fused SiLU/Swish activation. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + CausalConvWithState, 1, + OpSchema() + .SetDoc(CausalConvWithState_ver1_doc) + .Attr("activation", + "Fused activation function. One of: 'silu', 'swish', 'none'. " + "Default is 'none'.", + AttributeProto::STRING, + std::string("none")) + .Attr("ndim", + "Spatial dimensionality: 1, 2, or 3. Default is 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "input", + "Input tensor with shape (batch_size, channels, ...). Channels-first layout. " + "Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).", + "T") + .Input(1, + "weight", + "Depthwise convolution kernel with shape (channels, 1, k_1, ...). " + "Spatial kernel sizes: (k_1, ..., k_ndim).", + "T") + .Input(2, + "bias", + "Optional per-channel bias with shape (channels).", + "T", + OpSchema::Optional) + .Input(3, + "past_state", + "Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). " + "If not provided, padding is zero.", + "T", + OpSchema::Optional) + .Output(0, + "output", + "Convolution output with same shape as input.", + "T") + .Output(1, + "present_state", + "Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). " + "Contains the last (k-1) values from the virtual input along the causal axis.", + "T") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + + // Output 0: same shape as input (batch_size, channels, ...) + propagateShapeFromInputToOutput(ctx, 0, 0); + + // Output 1: state shape is (batch_size, channels, [non-causal spatial dims...], k_last - 1) + // For ndim=1: (B, C, k_1-1) + // For ndim=2: (B, C, input_H, k_2-1) + // For ndim=3: (B, C, input_D, input_H, k_3-1) + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + auto& input_shape = getInputShape(ctx, 0); + auto& weight_shape = getInputShape(ctx, 1); + int64_t ndim = getAttribute(ctx, "ndim", 1); + TensorShapeProto state_shape; + *state_shape.add_dim() = input_shape.dim(0); // batch_size + *state_shape.add_dim() = input_shape.dim(1); // channels + // Copy non-causal spatial dims from input (dims 2 .. 2+ndim-2) + for (int64_t i = 0; i < ndim - 1; ++i) { + *state_shape.add_dim() = input_shape.dim(static_cast(2 + i)); + } + // Causal (last) spatial dim: kernel_size - 1 + int last_kernel_dim = weight_shape.dim_size() - 1; + if (weight_shape.dim(last_kernel_dim).has_dim_value()) { + state_shape.add_dim()->set_dim_value(weight_shape.dim(last_kernel_dim).dim_value() - 1); + } else { + state_shape.add_dim(); // unknown + } + updateOutputShape(ctx, 1, state_shape); + } + })); + +constexpr const char* LinearAttention_ver1_doc = R"DOC( +Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1). + +All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always +required. The op internally unpacks to 4D for computation. + +The update_rule attribute selects the recurrence type: +- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t +- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t +- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t +- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t + +where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product. + +Semantics: Equivalent to running the recurrent update sequentially for each token, +but may be implemented using chunk-parallel algorithms for GPU efficiency. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + LinearAttention, 1, + OpSchema() + .SetDoc(LinearAttention_ver1_doc) + .Attr("update_rule", + "The update rule for the linear attention recurrence. " + "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", + AttributeProto::STRING, + std::string("gated_delta")) + .Attr("scale", + "Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads " + "and uses 1/sqrt(d_k). Set explicitly to override.", + AttributeProto::FLOAT, + 0.0f) + .Attr("q_num_heads", + "Number of query heads. Always required.", + AttributeProto::INT) + .Attr("kv_num_heads", + "Number of key/value heads. Always required.", + AttributeProto::INT) + .Attr("chunk_size", + "Chunk size for the chunk-parallel WY decomposition during prefill (T>1). " + "Tuning hint; does not affect output correctness.", + AttributeProto::INT, + static_cast(64)) + .Input(0, + "query", + "Query vectors with 3D packed shape (B, T, H_q * d_k). " + "Heads are packed into the last dimension.", + "T") + .Input(1, + "key", + "Key vectors with 3D packed shape (B, T, H_kv * d_k). " + "Should be L2-normalized for delta/gated_delta modes.", + "T") + .Input(2, + "value", + "Value vectors with 3D packed shape (B, T, H_kv * d_v).", + "T") + .Input(3, + "past_state", + "Recurrent state from previous step with shape (B, H_kv, d_k, d_v). " + "Always 4D. If not provided, defaults to zeros.", + "S", + OpSchema::Optional) + .Input(4, + "decay", + "Exponential decay gate in log-space. 3D packed shape: " + "(B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or " + "(B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). " + "Required for 'gated' and 'gated_delta' modes.", + "T", + OpSchema::Optional) + .Input(5, + "beta", + "Update rate (sigmoid output). 3D packed shape: " + "(B, T, H_kv) or (B, T, 1). " + "Required for 'delta' and 'gated_delta' modes.", + "T", + OpSchema::Optional) + .Output(0, + "output", + "Attention output with 3D packed shape (B, T, H_q * d_v).", + "T") + .Output(1, + "present_state", + "Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.", + "S") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("S", + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain state types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + + // Read required attributes + auto* q_num_heads_attr = ctx.getAttribute("q_num_heads"); + auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads"); + int64_t q_num_heads = (q_num_heads_attr && q_num_heads_attr->has_i()) ? q_num_heads_attr->i() : 0; + int64_t kv_num_heads = (kv_num_heads_attr && kv_num_heads_attr->has_i()) ? kv_num_heads_attr->i() : 0; + + // Output 0: (B, T, H_q * d_v) — 3D packed + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) { + auto& query_shape = getInputShape(ctx, 0); + auto& value_shape = getInputShape(ctx, 2); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_shape.dim(0); // B + *output_shape.add_dim() = query_shape.dim(1); // T + // H_q * d_v: d_v = value.dim(2) / kv_num_heads, then H_q * d_v + if (value_shape.dim(2).has_dim_value()) { + int64_t d_v = value_shape.dim(2).dim_value() / kv_num_heads; + output_shape.add_dim()->set_dim_value(kv_num_heads * d_v); + } else { + output_shape.add_dim(); // unknown + } + updateOutputShape(ctx, 0, output_shape); + } + + // Output 1: present_state shape (B, H_kv, d_k, d_v) — 4D + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) { + auto& query_shape = getInputShape(ctx, 0); + auto& value_shape = getInputShape(ctx, 2); + TensorShapeProto state_shape; + *state_shape.add_dim() = query_shape.dim(0); // B + state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv + // d_k = query.dim(2) / q_num_heads + if (query_shape.dim(2).has_dim_value()) { + state_shape.add_dim()->set_dim_value(query_shape.dim(2).dim_value() / q_num_heads); + } else { + state_shape.add_dim(); + } + // d_v = value.dim(2) / kv_num_heads + if (value_shape.dim(2).has_dim_value()) { + state_shape.add_dim()->set_dim_value(value_shape.dim(2).dim_value() / kv_num_heads); + } else { + state_shape.add_dim(); + } + updateOutputShape(ctx, 1, state_shape); + } else if (hasInputShape(ctx, 3)) { + propagateShapeFromInputToOutput(ctx, 3, 1); + } + })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 6c20aae94d132..59f97c222ceb2 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -88,6 +88,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConvWithState); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -199,6 +201,8 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc new file mode 100644 index 0000000000000..2a7837dd1ce73 --- /dev/null +++ b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc @@ -0,0 +1,658 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "gtest/gtest.h" +#include "core/common/logging/logging.h" +#include "core/framework/kernel_registry.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TensorType { + kFloat, + kFloat16 +}; + +// Reference implementation for CausalConvWithState +// Performs depthwise causal 1D convolution with optional state, bias, and activation. +// +// Input: (B, D, L) channels-first +// Weight: (D, 1, K) depthwise +// Bias: (D,) optional +// past_state: (B, D, K-1) optional carry state +// +// Output: (B, D, L) convolution output (with optional activation) +// present_state: (B, D, K-1) updated carry state +void CausalConvWithStateReference( + const std::vector& input, + const std::vector& weight, + const std::vector* bias, + const std::vector* conv_state, + std::vector& output, + std::vector& present_state, + int batch_size, + int channels, + int input_length, + int kernel_size, + const std::string& activation) { + int state_length = kernel_size - 1; + int total_virtual_length = state_length + input_length; + + output.resize(batch_size * channels * input_length); + present_state.resize(batch_size * channels * state_length); + + for (int b = 0; b < batch_size; ++b) { + for (int d = 0; d < channels; ++d) { + int bd = b * channels + d; + + // Build virtual input: [conv_state, input] + std::vector virtual_input(total_virtual_length, 0.0f); + if (conv_state != nullptr) { + for (int s = 0; s < state_length; ++s) { + virtual_input[s] = (*conv_state)[bd * state_length + s]; + } + } + for (int l = 0; l < input_length; ++l) { + virtual_input[state_length + l] = input[bd * input_length + l]; + } + + // Compute depthwise convolution + for (int pos = 0; pos < input_length; ++pos) { + float acc = 0.0f; + for (int j = 0; j < kernel_size; ++j) { + float val = virtual_input[pos + j]; + float w = weight[d * kernel_size + j]; + acc += val * w; + } + // Add bias + if (bias != nullptr) { + acc += (*bias)[d]; + } + // Apply activation + if (activation == "silu" || activation == "swish") { + acc = acc / (1.0f + std::exp(-acc)); + } + output[bd * input_length + pos] = acc; + } + + // Compute present_state: last state_length values from virtual input + for (int s = 0; s < state_length; ++s) { + present_state[bd * state_length + s] = + virtual_input[input_length + s]; + } + } + } +} + +// Returns a WebGPU EP if it is available and has the CausalConvWithState kernel registered, +// or nullptr otherwise. +std::unique_ptr TryGetEpWithCausalConvWithState() { + auto ep = DefaultWebGpuExecutionProvider(); + if (!ep) { + ep = DefaultCpuExecutionProvider(); + } + + auto kernel_registry = ep->GetKernelRegistry(); + if (kernel_registry) { + const KernelCreateInfo* info = nullptr; + KernelRegistry::TypeConstraintMap type_constraints; + auto status = kernel_registry->TryFindKernel( + ep->Type(), "CausalConvWithState", kMSDomain, 1, + type_constraints, DefaultLoggingManager().DefaultLogger(), &info); + if (!status.IsOK()) return nullptr; + } + return ep; +} + +} // anonymous namespace + +static void RunCausalConvWithStateTest( + const std::vector& input_data, + const std::vector& weight_data, + const std::vector* bias_data, + const std::vector* conv_state_data, + const std::vector& expected_output, + const std::vector& expected_state, + int batch_size, + int channels, + int input_length, + int kernel_size, + const std::string& activation, + TensorType tensor_type) { + auto ep = TryGetEpWithCausalConvWithState(); + if (!ep) { + GTEST_SKIP() << "CausalConvWithState kernel not registered"; + return; + } + + int state_length = kernel_size - 1; + + std::vector input_shape = {batch_size, channels, input_length}; + std::vector weight_shape = {channels, 1, kernel_size}; + std::vector bias_shape = {channels}; + std::vector state_shape = {batch_size, channels, state_length}; + std::vector output_shape = {batch_size, channels, input_length}; + + { + OpTester test("CausalConvWithState", 1, onnxruntime::kMSDomain); + test.AddAttribute("activation", activation); + + if (tensor_type == TensorType::kFloat) { + test.AddInput("input", input_shape, input_data); + test.AddInput("weight", weight_shape, weight_data); + + if (bias_data != nullptr) { + test.AddInput("bias", bias_shape, *bias_data); + } else { + test.AddOptionalInputEdge(); + } + + if (conv_state_data != nullptr) { + test.AddInput("past_state", state_shape, *conv_state_data); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, expected_output); + test.AddOutput("present_state", state_shape, expected_state); + } else { + test.AddInput("input", input_shape, ToFloat16(input_data)); + test.AddInput("weight", weight_shape, ToFloat16(weight_data)); + + if (bias_data != nullptr) { + test.AddInput("bias", bias_shape, ToFloat16(*bias_data)); + } else { + test.AddOptionalInputEdge(); + } + + if (conv_state_data != nullptr) { + test.AddInput("past_state", state_shape, ToFloat16(*conv_state_data)); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, ToFloat16(expected_output)); + test.AddOutput("present_state", state_shape, ToFloat16(expected_state)); + } + + test.SetOutputAbsErr("output", 0.01f); + test.SetOutputAbsErr("present_state", 0.01f); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +static void RunCausalConvWithStateTests( + const std::vector& input_data, + const std::vector& weight_data, + const std::vector* bias_data, + const std::vector* conv_state_data, + int batch_size, + int channels, + int input_length, + int kernel_size, + const std::string& activation = "silu") { + // Compute expected output using reference implementation + std::vector expected_output; + std::vector expected_state; + CausalConvWithStateReference( + input_data, weight_data, bias_data, conv_state_data, + expected_output, expected_state, + batch_size, channels, input_length, kernel_size, activation); + + // FP32 test + RunCausalConvWithStateTest( + input_data, weight_data, bias_data, conv_state_data, + expected_output, expected_state, + batch_size, channels, input_length, kernel_size, activation, + TensorType::kFloat); + + // FP16 test + RunCausalConvWithStateTest( + input_data, weight_data, bias_data, conv_state_data, + expected_output, expected_state, + batch_size, channels, input_length, kernel_size, activation, + TensorType::kFloat16); +} + +// ============================================================================= +// Basic tests - simple cases +// ============================================================================= + +TEST(CausalConvWithStateTest, BasicNoStateNoBias) { + // B=1, D=2, L=4, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + // Input: (1, 2, 4) + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, // channel 0 + 0.5f, 1.5f, 2.5f, 3.5f}; // channel 1 + + // Weight: (2, 1, 3) + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, // channel 0 kernel + 0.4f, 0.5f, 0.6f}; // channel 1 kernel + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, nullptr, + batch_size, channels, input_length, kernel_size, "none"); +} + +TEST(CausalConvWithStateTest, BasicWithBias) { + // B=1, D=2, L=4, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.1f, -0.2f}; + + RunCausalConvWithStateTests( + input_data, weight_data, &bias_data, nullptr, + batch_size, channels, input_length, kernel_size, "none"); +} + +TEST(CausalConvWithStateTest, BasicWithState) { + // B=1, D=2, L=3, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, + 0.5f, 1.5f, 2.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + // State: (1, 2, 2) - kernel_size - 1 = 2 + std::vector conv_state_data = { + -1.0f, 0.5f, // channel 0 state + 0.3f, -0.7f}; // channel 1 state + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "none"); +} + +TEST(CausalConvWithStateTest, WithStateAndBias) { + // B=1, D=2, L=3, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, + 0.5f, 1.5f, 2.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.1f, -0.2f}; + std::vector conv_state_data = { + -1.0f, 0.5f, + 0.3f, -0.7f}; + + RunCausalConvWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "none"); +} + +// ============================================================================= +// SiLU activation tests +// ============================================================================= + +TEST(CausalConvWithStateTest, SiluActivationNoState) { + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, nullptr, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConvWithStateTest, SiluActivationWithState) { + int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, + 0.5f, 1.5f, 2.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector conv_state_data = { + -1.0f, 0.5f, + 0.3f, -0.7f}; + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConvWithStateTest, SiluActivationWithBiasAndState) { + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.1f, -0.2f}; + std::vector conv_state_data = { + -1.0f, 0.5f, + 0.3f, -0.7f}; + + RunCausalConvWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +// ============================================================================= +// Kernel size variations +// ============================================================================= + +TEST(CausalConvWithStateTest, KernelSize2) { + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 2; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.3f, 0.7f, + 0.4f, 0.6f}; + // State: (1, 2, 1) - kernel_size - 1 = 1 + std::vector conv_state_data = {0.5f, -0.3f}; + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConvWithStateTest, KernelSize4) { + int batch_size = 1, channels = 1, input_length = 5, kernel_size = 4; + + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + std::vector weight_data = {0.1f, 0.2f, 0.3f, 0.4f}; + // State: (1, 1, 3) + std::vector conv_state_data = {-1.0f, 0.0f, 0.5f}; + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "none"); +} + +// ============================================================================= +// Batch size > 1 +// ============================================================================= + +TEST(CausalConvWithStateTest, MultiBatch) { + int batch_size = 2, channels = 2, input_length = 3, kernel_size = 3; + + // Input: (2, 2, 3) + std::vector input_data = { + // Batch 0 + 1.0f, 2.0f, 3.0f, // ch 0 + 0.5f, 1.5f, 2.5f, // ch 1 + // Batch 1 + -1.0f, 0.0f, 1.0f, // ch 0 + 0.2f, 0.4f, 0.6f}; // ch 1 + + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + + std::vector bias_data = {0.1f, -0.1f}; + + // State: (2, 2, 2) + std::vector conv_state_data = { + // Batch 0 + -0.5f, 0.5f, // ch 0 + 0.3f, -0.3f, // ch 1 + // Batch 1 + 0.1f, -0.1f, // ch 0 + 0.7f, 0.8f}; // ch 1 + + RunCausalConvWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +// ============================================================================= +// Single token decode (L=1) - the primary use case for incremental decoding +// ============================================================================= + +TEST(CausalConvWithStateTest, SingleTokenDecode) { + int batch_size = 1, channels = 4, input_length = 1, kernel_size = 4; + + // Input: (1, 4, 1) + std::vector input_data = {0.5f, -0.3f, 1.2f, 0.8f}; + + // Weight: (4, 1, 4) + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, 0.4f, + 0.5f, 0.6f, 0.7f, 0.8f, + -0.1f, -0.2f, 0.1f, 0.2f, + 0.3f, 0.3f, 0.3f, 0.3f}; + + std::vector bias_data = {0.0f, 0.1f, -0.1f, 0.0f}; + + // State: (1, 4, 3) - carrying the last 3 values per channel + std::vector conv_state_data = { + 1.0f, 2.0f, 3.0f, // ch 0 + -1.0f, 0.0f, 1.0f, // ch 1 + 0.5f, 0.5f, 0.5f, // ch 2 + -0.2f, 0.4f, -0.6f}; // ch 3 + + RunCausalConvWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConvWithStateTest, SingleTokenDecodeMultiBatch) { + int batch_size = 2, channels = 2, input_length = 1, kernel_size = 3; + + // Input: (2, 2, 1) + std::vector input_data = { + 0.5f, // B0, ch 0 + -0.3f, // B0, ch 1 + 1.2f, // B1, ch 0 + 0.8f}; // B1, ch 1 + + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + + // State: (2, 2, 2) + std::vector conv_state_data = { + 1.0f, 2.0f, // B0, ch 0 + -1.0f, 0.0f, // B0, ch 1 + 0.5f, 0.5f, // B1, ch 0 + -0.2f, 0.4f}; // B1, ch 1 + + RunCausalConvWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +// ============================================================================= +// State continuity test: verify that present_state from one call can be used +// as conv_state for the next call (simulating autoregressive decode) +// ============================================================================= + +TEST(CausalConvWithStateTest, StateContinuity) { + // Process a sequence of single tokens and verify state propagation + int batch_size = 1, channels = 1, kernel_size = 3; + int input_length = 1; + + std::vector weight_data = {0.2f, 0.3f, 0.5f}; + std::vector bias_data = {0.1f}; + + // Initial state: zeros + std::vector conv_state = {0.0f, 0.0f}; + + // First token + std::vector input1 = {1.0f}; + std::vector expected_output1; + std::vector expected_state1; + CausalConvWithStateReference(input1, weight_data, &bias_data, &conv_state, + expected_output1, expected_state1, + batch_size, channels, input_length, kernel_size, "none"); + + RunCausalConvWithStateTest(input1, weight_data, &bias_data, &conv_state, + expected_output1, expected_state1, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); + + // Second token, using present_state from first as conv_state + std::vector input2 = {2.0f}; + std::vector expected_output2; + std::vector expected_state2; + CausalConvWithStateReference(input2, weight_data, &bias_data, &expected_state1, + expected_output2, expected_state2, + batch_size, channels, input_length, kernel_size, "none"); + + RunCausalConvWithStateTest(input2, weight_data, &bias_data, &expected_state1, + expected_output2, expected_state2, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); + + // Third token + std::vector input3 = {3.0f}; + std::vector expected_output3; + std::vector expected_state3; + CausalConvWithStateReference(input3, weight_data, &bias_data, &expected_state2, + expected_output3, expected_state3, + batch_size, channels, input_length, kernel_size, "none"); + + RunCausalConvWithStateTest(input3, weight_data, &bias_data, &expected_state2, + expected_output3, expected_state3, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); + + // The present_state after processing [1, 2, 3] should be [2, 3] + EXPECT_NEAR(expected_state3[0], 2.0f, 1e-5f); + EXPECT_NEAR(expected_state3[1], 3.0f, 1e-5f); +} + +// ============================================================================= +// Equivalence test: sequence processing should match token-by-token with state +// ============================================================================= + +TEST(CausalConvWithStateTest, SequenceVsTokenByToken) { + int batch_size = 1, channels = 2, kernel_size = 3; + + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.05f, -0.05f}; + + // Initial state: zeros + std::vector conv_state = {0.0f, 0.0f, 0.0f, 0.0f}; // (1, 2, 2) + + // Full sequence: length 4 + std::vector full_input = { + 1.0f, 2.0f, 3.0f, 4.0f, // ch 0 + 0.5f, 1.5f, 2.5f, 3.5f}; // ch 1 + + // Process full sequence at once + std::vector full_output; + std::vector full_final_state; + CausalConvWithStateReference(full_input, weight_data, &bias_data, &conv_state, + full_output, full_final_state, + batch_size, channels, 4, kernel_size, "none"); + + // Process token by token + std::vector current_state = conv_state; + std::vector token_outputs; + + for (int t = 0; t < 4; ++t) { + // Extract single token: (1, 2, 1) + std::vector token_input = { + full_input[0 * 4 + t], // ch 0 + full_input[1 * 4 + t]}; // ch 1 + + std::vector token_output; + std::vector next_state; + CausalConvWithStateReference(token_input, weight_data, &bias_data, ¤t_state, + token_output, next_state, + batch_size, channels, 1, kernel_size, "none"); + + // Collect outputs + for (int d = 0; d < channels; ++d) { + token_outputs.push_back(token_output[d]); + } + current_state = next_state; + } + + // Rearrange token_outputs from (T, D) to (D, T) layout for comparison + std::vector token_outputs_dlt(channels * 4); + for (int t = 0; t < 4; ++t) { + for (int d = 0; d < channels; ++d) { + token_outputs_dlt[d * 4 + t] = token_outputs[t * channels + d]; + } + } + + // Compare outputs + for (int i = 0; i < channels * 4; ++i) { + EXPECT_NEAR(full_output[i], token_outputs_dlt[i], 1e-5f) + << "Mismatch at index " << i; + } + + // Compare final states + for (int i = 0; i < channels * 2; ++i) { + EXPECT_NEAR(full_final_state[i], current_state[i], 1e-5f) + << "State mismatch at index " << i; + } +} + +// ============================================================================= +// Larger dimension test with realistic sizes +// ============================================================================= + +TEST(CausalConvWithStateTest, LargerDimensions) { + int batch_size = 2, channels = 8, input_length = 16, kernel_size = 4; + + // Generate test data with a simple pattern + std::vector input_data(batch_size * channels * input_length); + for (int i = 0; i < static_cast(input_data.size()); ++i) { + input_data[i] = std::sin(static_cast(i) * 0.1f); + } + + std::vector weight_data(channels * kernel_size); + for (int i = 0; i < static_cast(weight_data.size()); ++i) { + weight_data[i] = std::cos(static_cast(i) * 0.2f) * 0.5f; + } + + std::vector bias_data(channels); + for (int i = 0; i < channels; ++i) { + bias_data[i] = 0.01f * static_cast(i); + } + + int state_length = kernel_size - 1; + std::vector conv_state_data(batch_size * channels * state_length); + for (int i = 0; i < static_cast(conv_state_data.size()); ++i) { + conv_state_data[i] = std::sin(static_cast(i) * 0.3f) * 0.5f; + } + + RunCausalConvWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc new file mode 100644 index 0000000000000..fd2b648c8badf --- /dev/null +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -0,0 +1,1201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "gtest/gtest.h" +#include "core/common/logging/logging.h" +#include "core/framework/kernel_registry.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +using namespace onnxruntime::test; + +namespace onnxruntime { +namespace test { + +namespace { + +// Reference implementation of the linear attention recurrence. +// Processes all tokens sequentially and returns output + final_state. +void LinearAttentionReference( + const std::string& update_rule, + int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v, + float scale, + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector* initial_state, + const std::vector* decay, + const std::vector* beta, + std::vector& output, + std::vector& final_state) { + int bht = batch_size * num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); + + // State: (B, H, dk, dv) + final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); + output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); + + // Initialize state from initial_state if provided + if (initial_state != nullptr) { + final_state = *initial_state; + } + + for (int b = 0; b < batch_size; b++) { + for (int h = 0; h < num_heads; h++) { + // State for this (b, h): dk x dv + auto state_offset = [&](int k, int v) { + return ((b * num_heads + h) * head_dim_k + k) * head_dim_v + v; + }; + + for (int t = 0; t < seq_length; t++) { + auto qkv_offset = [&](int dim) { + return ((b * num_heads + h) * seq_length + t) * dim; + }; + + // Load q, k for this token + std::vector q_vec(head_dim_k), k_vec(head_dim_k), v_vec(head_dim_v); + for (int i = 0; i < head_dim_k; i++) { + q_vec[i] = query[qkv_offset(head_dim_k) + i]; + k_vec[i] = key[qkv_offset(head_dim_k) + i]; + } + for (int i = 0; i < head_dim_v; i++) { + v_vec[i] = value[qkv_offset(head_dim_v) + i]; + } + + // Step 1: Apply decay (gated, gated_delta) + if (update_rule == "gated" || update_rule == "gated_delta") { + for (int k = 0; k < head_dim_k; k++) { + float exp_g; + if (decay_broadcast_dk) { + int decay_idx = (b * num_heads + h) * seq_length + t; + exp_g = std::exp((*decay)[decay_idx]); + } else { + int decay_idx = ((b * num_heads + h) * seq_length + t) * head_dim_k + k; + exp_g = std::exp((*decay)[decay_idx]); + } + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] *= exp_g; + } + } + } + + // Step 2: Compute state update + if (update_rule == "delta" || update_rule == "gated_delta") { + // retrieved = S^T @ k (for each v dimension) + std::vector retrieved(head_dim_v, 0.0f); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + for (int k = 0; k < head_dim_k; k++) { + retrieved[v_idx] += final_state[state_offset(k, v_idx)] * k_vec[k]; + } + } + + // delta = beta * (v - retrieved) + int beta_idx = (b * num_heads + h) * seq_length + t; + float beta_val = (*beta)[beta_idx]; + std::vector delta(head_dim_v); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + delta[v_idx] = beta_val * (v_vec[v_idx] - retrieved[v_idx]); + } + + // S += k ⊗ delta + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * delta[v_idx]; + } + } + } else { + // linear, gated: S += k ⊗ v + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * v_vec[v_idx]; + } + } + } + + // Step 3: Compute output = scale * S^T @ q + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + float sum = 0.0f; + for (int k = 0; k < head_dim_k; k++) { + sum += final_state[state_offset(k, v_idx)] * q_vec[k]; + } + int out_idx = ((b * num_heads + h) * seq_length + t) * head_dim_v + v_idx; + output[out_idx] = scale * sum; + } + } + } + } +} + +// GQA-aware reference implementation. +// Q has q_num_heads heads, K has n_k_heads heads, V/state have kv_num_heads heads. +// Standard GQA: q_num_heads >= kv_num_heads, heads_per_group = q_num_heads / kv_num_heads. +// K-to-KV sharing: kv_per_k_head = kv_num_heads / n_k_heads. +void LinearAttentionGQAReference( + const std::string& update_rule, + int batch_size, int q_num_heads, int kv_num_heads, int n_k_heads, + int seq_length, int head_dim_k, int head_dim_v, + float scale, + const std::vector& query, // (B, q_num_heads, T, dk) + const std::vector& key, // (B, n_k_heads, T, dk) + const std::vector& value, // (B, kv_num_heads, T, dv) + const std::vector* initial_state, // (B, kv_num_heads, dk, dv) + const std::vector* decay, // (B, kv_num_heads, T[, dk]) + const std::vector* beta, // (B, kv_num_heads, T) + std::vector& output, // (B, kv_num_heads, T, dv) + std::vector& final_state) { // (B, kv_num_heads, dk, dv) + int bht_kv = batch_size * kv_num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht_kv); + int kv_per_k_head = kv_num_heads / n_k_heads; + bool inverse_gqa = q_num_heads < kv_num_heads; + int heads_per_group = inverse_gqa ? 0 : q_num_heads / kv_num_heads; + + final_state.resize(batch_size * kv_num_heads * head_dim_k * head_dim_v, 0.0f); + // Output always indexed by kv_num_heads (matches schema: output_dim == V_dim) + output.resize(batch_size * kv_num_heads * seq_length * head_dim_v, 0.0f); + + if (initial_state != nullptr) { + final_state = *initial_state; + } + + for (int b = 0; b < batch_size; b++) { + for (int kv_h = 0; kv_h < kv_num_heads; kv_h++) { + int k_head = kv_h / kv_per_k_head; + + auto state_offset = [&](int k, int v) { + return ((b * kv_num_heads + kv_h) * head_dim_k + k) * head_dim_v + v; + }; + + for (int t = 0; t < seq_length; t++) { + // Load k from the K-head that this KV-head maps to + std::vector k_vec(head_dim_k), v_vec(head_dim_v); + int k_base = ((b * n_k_heads + k_head) * seq_length + t) * head_dim_k; + for (int i = 0; i < head_dim_k; i++) k_vec[i] = key[k_base + i]; + int v_base = ((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_v; + for (int i = 0; i < head_dim_v; i++) v_vec[i] = value[v_base + i]; + + // Step 1: Apply decay + if (update_rule == "gated" || update_rule == "gated_delta") { + for (int k = 0; k < head_dim_k; k++) { + float exp_g; + if (decay_broadcast_dk) { + exp_g = std::exp((*decay)[(b * kv_num_heads + kv_h) * seq_length + t]); + } else { + exp_g = std::exp((*decay)[((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_k + k]); + } + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] *= exp_g; + } + } + } + + // Step 2: Update state + if (update_rule == "delta" || update_rule == "gated_delta") { + std::vector retrieved(head_dim_v, 0.0f); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + for (int k = 0; k < head_dim_k; k++) { + retrieved[v_idx] += final_state[state_offset(k, v_idx)] * k_vec[k]; + } + } + int beta_idx = (b * kv_num_heads + kv_h) * seq_length + t; + float beta_val = (*beta)[beta_idx]; + std::vector delta(head_dim_v); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + delta[v_idx] = beta_val * (v_vec[v_idx] - retrieved[v_idx]); + } + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * delta[v_idx]; + } + } + } else { + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * v_vec[v_idx]; + } + } + } + + // Step 3: Compute output + if (!inverse_gqa) { + // Standard GQA/MHA: one output per Q head + for (int g = 0; g < heads_per_group; g++) { + int q_h = kv_h * heads_per_group + g; + int q_base = ((b * q_num_heads + q_h) * seq_length + t) * head_dim_k; + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + float sum = 0.0f; + for (int k = 0; k < head_dim_k; k++) { + sum += final_state[state_offset(k, v_idx)] * query[q_base + k]; + } + // For standard, output head == q head; since q==kv per schema, also == kv_h index + int out_idx = ((b * kv_num_heads + (kv_h * heads_per_group + g)) * seq_length + t) * head_dim_v + v_idx; + output[out_idx] = scale * sum; + } + } + } else { + // Inverse GQA: output indexed by kv_head, Q broadcast + int q_h = kv_h * q_num_heads / kv_num_heads; + int q_base = ((b * q_num_heads + q_h) * seq_length + t) * head_dim_k; + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + float sum = 0.0f; + for (int k = 0; k < head_dim_k; k++) { + sum += final_state[state_offset(k, v_idx)] * query[q_base + k]; + } + int out_idx = ((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_v + v_idx; + output[out_idx] = scale * sum; + } + } + } + } + } +} + +// Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout +std::vector PackBHTD_to_BTHD(const std::vector& data_4d, + int B, int H, int T, int D) { + std::vector packed(B * T * H * D); + for (int b = 0; b < B; b++) { + for (int h = 0; h < H; h++) { + for (int t = 0; t < T; t++) { + for (int d = 0; d < D; d++) { + int src_idx = ((b * H + h) * T + t) * D + d; + int dst_idx = (b * T + t) * (H * D) + h * D + d; + packed[dst_idx] = data_4d[src_idx]; + } + } + } + } + return packed; +} + +// Convert decay/beta from (B,H,T) layout to (B,T,H) layout +std::vector TransposeBHT_to_BTH(const std::vector& data, + int B, int H, int T) { + std::vector transposed(B * T * H); + for (int b = 0; b < B; b++) { + for (int h = 0; h < H; h++) { + for (int t = 0; t < T; t++) { + int src_idx = (b * H + h) * T + t; + int dst_idx = (b * T + t) * H + h; + transposed[dst_idx] = data[src_idx]; + } + } + } + return transposed; +} + +// Returns a WebGPU EP if it is available and has the LinearAttention kernel registered, +// or nullptr otherwise. +std::unique_ptr TryGetEpWithLinearAttention() { + auto ep = DefaultWebGpuExecutionProvider(); + if (!ep) { + ep = DefaultCpuExecutionProvider(); + } + + auto kernel_registry = ep->GetKernelRegistry(); + if (kernel_registry) { + const KernelCreateInfo* info = nullptr; + KernelRegistry::TypeConstraintMap type_constraints; + auto status = kernel_registry->TryFindKernel( + ep->Type(), "LinearAttention", kMSDomain, 1, + type_constraints, DefaultLoggingManager().DefaultLogger(), &info); + if (!status.IsOK()) return nullptr; + } + return ep; +} + +void RunLinearAttentionTest( + const std::string& update_rule, + int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v, + float scale, + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector* initial_state, + const std::vector* decay, + const std::vector* beta_data) { + auto ep = TryGetEpWithLinearAttention(); + if (!ep) { + GTEST_SKIP() << "LinearAttention kernel not registered"; + return; + } + + // Compute reference output (reference works in 4D layout) + std::vector expected_output_4d, expected_state; + LinearAttentionReference(update_rule, batch_size, num_heads, seq_length, + head_dim_k, head_dim_v, scale, + query, key, value, initial_state, decay, beta_data, + expected_output_4d, expected_state); + + int bht = batch_size * num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); + + // Convert from 4D (B,H,T,D) to 3D packed (B,T,H*D) for OpTester + auto query_3d = PackBHTD_to_BTHD(query, batch_size, num_heads, seq_length, head_dim_k); + auto key_3d = PackBHTD_to_BTHD(key, batch_size, num_heads, seq_length, head_dim_k); + auto value_3d = PackBHTD_to_BTHD(value, batch_size, num_heads, seq_length, head_dim_v); + auto output_3d = PackBHTD_to_BTHD(expected_output_4d, batch_size, num_heads, seq_length, head_dim_v); + + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("update_rule", update_rule); + tester.AddAttribute("scale", scale); + tester.AddAttribute("q_num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(num_heads)); + + // Add required inputs — 3D packed (B, T, H*D) + std::vector qk_dims = {batch_size, seq_length, num_heads * head_dim_k}; + std::vector v_dims = {batch_size, seq_length, num_heads * head_dim_v}; + tester.AddInput("query", qk_dims, query_3d); + tester.AddInput("key", qk_dims, key_3d); + tester.AddInput("value", v_dims, value_3d); + + // Optional: past_state (4D, same format as before) + if (initial_state != nullptr) { + std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v}; + tester.AddInput("past_state", state_dims, *initial_state); + } else { + tester.AddOptionalInputEdge(); + } + + // Optional: decay — convert from (B,H,T[,dk]) to (B,T,H[*dk]) + if (decay != nullptr) { + if (decay_broadcast_dk) { + // (B,H,T) → (B,T,H) + auto decay_3d = TransposeBHT_to_BTH(*decay, batch_size, num_heads, seq_length); + std::vector decay_dims = {batch_size, seq_length, num_heads}; + tester.AddInput("decay", decay_dims, decay_3d); + } else { + // (B,H,T,dk) → (B,T,H*dk) + auto decay_3d = PackBHTD_to_BTHD(*decay, batch_size, num_heads, seq_length, head_dim_k); + std::vector decay_dims = {batch_size, seq_length, num_heads * head_dim_k}; + tester.AddInput("decay", decay_dims, decay_3d); + } + } else { + tester.AddOptionalInputEdge(); + } + + // Optional: beta — convert from (B*H*T) flat to (B,T,H) + if (beta_data != nullptr) { + auto beta_3d = TransposeBHT_to_BTH(*beta_data, batch_size, num_heads, seq_length); + std::vector beta_dims = {batch_size, seq_length, num_heads}; + tester.AddInput("beta", beta_dims, beta_3d); + } else { + tester.AddOptionalInputEdge(); + } + + // Add outputs — output is 3D packed, state is 4D + std::vector out_dims = {batch_size, seq_length, num_heads * head_dim_v}; + std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v}; + tester.AddOutput("output", out_dims, output_3d, false, 0.005f, 0.005f); + tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GQA-aware test harness. +// Q: (B, q_num_heads, T, dk), K: (B, n_k_heads, T, dk), V: (B, kv_num_heads, T, dv) +void RunLinearAttentionGQATest( + const std::string& update_rule, + int batch_size, int q_num_heads, int kv_num_heads, int n_k_heads, + int seq_length, int head_dim_k, int head_dim_v, + float scale, + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector* initial_state, + const std::vector* decay, + const std::vector* beta_data) { + auto ep = TryGetEpWithLinearAttention(); + if (!ep) { + GTEST_SKIP() << "LinearAttention kernel not registered"; + return; + } + + std::vector expected_output_4d, expected_state; + LinearAttentionGQAReference(update_rule, batch_size, q_num_heads, kv_num_heads, n_k_heads, + seq_length, head_dim_k, head_dim_v, scale, + query, key, value, initial_state, decay, beta_data, + expected_output_4d, expected_state); + + int bht_kv = batch_size * kv_num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht_kv); + + // Pack to 3D — each tensor uses its own head count + auto query_3d = PackBHTD_to_BTHD(query, batch_size, q_num_heads, seq_length, head_dim_k); + auto key_3d = PackBHTD_to_BTHD(key, batch_size, n_k_heads, seq_length, head_dim_k); + auto value_3d = PackBHTD_to_BTHD(value, batch_size, kv_num_heads, seq_length, head_dim_v); + // Output always indexed by kv_num_heads (matches schema) + auto output_3d = PackBHTD_to_BTHD(expected_output_4d, batch_size, kv_num_heads, seq_length, head_dim_v); + + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("update_rule", update_rule); + tester.AddAttribute("scale", scale); + tester.AddAttribute("q_num_heads", static_cast(q_num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + + tester.AddInput("query", {batch_size, seq_length, q_num_heads * head_dim_k}, query_3d); + tester.AddInput("key", {batch_size, seq_length, n_k_heads * head_dim_k}, key_3d); + tester.AddInput("value", {batch_size, seq_length, kv_num_heads * head_dim_v}, value_3d); + + if (initial_state != nullptr) { + tester.AddInput("past_state", {batch_size, kv_num_heads, head_dim_k, head_dim_v}, *initial_state); + } else { + tester.AddOptionalInputEdge(); + } + + if (decay != nullptr) { + if (decay_broadcast_dk) { + auto decay_3d = TransposeBHT_to_BTH(*decay, batch_size, kv_num_heads, seq_length); + tester.AddInput("decay", {batch_size, seq_length, kv_num_heads}, decay_3d); + } else { + auto decay_3d = PackBHTD_to_BTHD(*decay, batch_size, kv_num_heads, seq_length, head_dim_k); + tester.AddInput("decay", {batch_size, seq_length, kv_num_heads * head_dim_k}, decay_3d); + } + } else { + tester.AddOptionalInputEdge(); + } + + if (beta_data != nullptr) { + auto beta_3d = TransposeBHT_to_BTH(*beta_data, batch_size, kv_num_heads, seq_length); + tester.AddInput("beta", {batch_size, seq_length, kv_num_heads}, beta_3d); + } else { + tester.AddOptionalInputEdge(); + } + + tester.AddOutput("output", {batch_size, seq_length, kv_num_heads * head_dim_v}, + output_3d, false, 0.005f, 0.005f); + tester.AddOutput("present_state", {batch_size, kv_num_heads, head_dim_k, head_dim_v}, + expected_state, false, 0.005f, 0.005f); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +TEST(ContribOpLinearAttentionTest, LinearRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +TEST(ContribOpLinearAttentionTest, LinearRule_WithInitialState) { + const int B = 1, H = 1, T = 2, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f}; + + // Non-zero initial state + std::vector initial_state(dk * dv, 0.1f); + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, nullptr, nullptr); +} + +// =========================================================================== +// Test: Gated update rule (decay, no beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Decay in log-space (small negative values for slight decay) + std::vector decay = {-0.1f, -0.2f, -0.05f, -0.15f}; + + // Initial state (needed to see decay effect) + std::vector initial_state(dk * dv, 1.0f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + std::vector decay = { + -0.1f, -0.2f, -0.05f, -0.15f, + -0.2f, -0.1f, -0.3f, -0.05f, + -0.05f, -0.15f, -0.1f, -0.2f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); +} + +// =========================================================================== +// Test: Delta update rule (no decay, uses beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, DeltaRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector beta = {0.8f}; // shape: (1,1,1,1) + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, nullptr, &beta); +} + +TEST(ContribOpLinearAttentionTest, DeltaRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + std::vector beta = {0.8f, 0.6f, 0.9f}; // shape: (1,1,3,1) + + RunLinearAttentionTest("delta", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, &beta); +} + +// =========================================================================== +// Test: GatedDelta update rule (full - decay + beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector decay = {-0.1f, -0.2f, -0.05f, -0.15f}; + std::vector beta = {0.8f}; + + std::vector initial_state(dk * dv, 1.0f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + std::vector decay = { + -0.1f, -0.2f, -0.05f, -0.15f, + -0.2f, -0.1f, -0.3f, -0.05f, + -0.05f, -0.15f, -0.1f, -0.2f}; + std::vector beta = {0.8f, 0.6f, 0.9f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Test: Gated rule with B,H,T decay (broadcast across dk) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedRule_BroadcastDecay) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + // Decay shape: (B, H, T) = (1, 1, 3) — one scalar per token + std::vector decay = {-0.1f, -0.2f, -0.05f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_BroadcastDecay) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + // Decay shape: (B, H, T) = (1, 1, 3) — one scalar per token + std::vector decay = {-0.1f, -0.2f, -0.05f}; + std::vector beta = {0.8f, 0.6f, 0.9f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Test: Multi-batch, multi-head +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_MultiBatchMultiHead) { + const int B = 2, H = 2, T = 2, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + // Total: B*H*T*dk = 2*2*2*4 = 32 values for q/k, B*H*T*dv = 32 for v + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + + // Fill with deterministic pattern + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = std::sin(static_cast(i) * 0.3f); + key[i] = std::cos(static_cast(i) * 0.5f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = std::sin(static_cast(i) * 0.7f + 1.0f); + } + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_MultiBatchMultiHead) { + const int B = 2, H = 2, T = 2, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = std::sin(static_cast(i) * 0.3f); + key[i] = std::cos(static_cast(i) * 0.5f); + decay[i] = -0.1f - 0.1f * std::sin(static_cast(i) * 0.2f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = std::sin(static_cast(i) * 0.7f + 1.0f); + } + for (int i = 0; i < B * H * T; i++) { + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i)); + } + + std::vector initial_state(B * H * dk * dv, 0.1f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Test: Default scale (should use 1/sqrt(dk)) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { + auto ep = TryGetEpWithLinearAttention(); + if (!ep) { + GTEST_SKIP() << "LinearAttention kernel not registered on WebGPU EP (or EP not available)"; + return; + } + + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Compute with explicit scale for reference + float actual_scale = 1.0f / std::sqrt(static_cast(dk)); + std::vector expected_output, expected_state; + LinearAttentionReference("linear", B, H, T, dk, dv, actual_scale, + query, key, value, nullptr, nullptr, nullptr, + expected_output, expected_state); + + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("update_rule", std::string("linear")); + tester.AddAttribute("q_num_heads", static_cast(H)); + tester.AddAttribute("kv_num_heads", static_cast(H)); + // Don't set scale — use default (0.0 triggers 1/sqrt(dk)) + + // Convert to 3D packed for B=1, H=1 (flat data is identical) + std::vector qk_dims = {B, T, H * dk}; + std::vector v_dims = {B, T, H * dv}; + tester.AddInput("query", qk_dims, query); + tester.AddInput("key", qk_dims, key); + tester.AddInput("value", v_dims, value); + tester.AddOptionalInputEdge(); // past_state + tester.AddOptionalInputEdge(); // decay + tester.AddOptionalInputEdge(); // beta + + std::vector out_dims = {B, T, H * dv}; + std::vector state_dims = {B, H, dk, dv}; + tester.AddOutput("output", out_dims, expected_output, false, 0.005f, 0.005f); + tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// =========================================================================== +// Test: Longer sequence +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_LongerSequence) { + const int B = 1, H = 2, T = 16, dk = 8, dv = 8; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + decay[i] = -0.05f - 0.05f * std::abs(std::sin(static_cast(i) * 0.07f)); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * H * T; i++) { + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Test with Qwen3.5-like dimensions: dk=128, dv=128, broadcast decay +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_Qwen35Like) { + const int B = 1, H = 2, T = 8, dk = 128, dv = 128; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + // Broadcast decay: (B, H, T) — one scalar per head per token, like Qwen3.5 + std::vector decay(B * H * T); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.05f * std::sin(static_cast(i) * 0.013f); + key[i] = 0.05f * std::cos(static_cast(i) * 0.017f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.05f * std::sin(static_cast(i) * 0.023f + 0.5f); + } + for (int i = 0; i < B * H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Test with non-power-of-2 dk to trigger workgroup padding bug +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_NonPowerOf2DK) { + const int B = 1, H = 1, T = 3, dk = 3, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.3f); + key[i] = 0.5f * std::cos(static_cast(i) * 0.5f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.7f + 1.0f); + } + for (int i = 0; i < B * H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.5f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Tests: Larger dimensions exercising multi-tile vec4 path (tile_v > 1) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_LargerDims) { + const int B = 1, H = 2, T = 4, dk = 16, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedRule_LargerDims) { + const int B = 1, H = 2, T = 4, dk = 32, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + decay[i] = -0.05f - 0.05f * std::abs(std::sin(static_cast(i) * 0.07f)); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_LargerDims) { + const int B = 2, H = 2, T = 4, dk = 32, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + decay[i] = -0.05f - 0.05f * std::abs(std::sin(static_cast(i) * 0.07f)); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * H * T; i++) { + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Tests: GQA (Grouped Query Attention) — q_num_heads != kv_num_heads +// =========================================================================== +// Tests: GQA — K has fewer heads than KV (n_k < kv_num_heads) +// Schema requires q_num_heads == kv_num_heads; K head count is derived from +// the key tensor shape. Multiple KV heads share one K head via kv_per_k_head. +// =========================================================================== + +// Small K-GQA: q=kv=4, n_k=2 → each K head serves 2 KV heads +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_KGQA_Small) { + const int B = 1, q_H = 4, kv_H = 4, n_k = 2, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.1f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Linear rule with K-GQA: q=kv=4, n_k=2 +TEST(ContribOpLinearAttentionTest, LinearRule_KGQA) { + const int B = 1, q_H = 4, kv_H = 4, n_k = 2, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + RunLinearAttentionGQATest("linear", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +// Qwen3.5 9B-like: q=kv=32, n_k=16 (K has half the heads), +// dk=128, dv=128, broadcast decay +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_Qwen35_KGQA) { + const int B = 1, q_H = 32, kv_H = 32, n_k = 16, T = 4, dk = 128, dv = 128; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.05f * std::sin(static_cast(i) * 0.013f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.05f * std::cos(static_cast(i) * 0.017f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.05f * std::sin(static_cast(i) * 0.023f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.01f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Tests: Inverse GQA — q_num_heads < kv_num_heads +// Each KV head has its own output slot; Q is broadcast across KV groups. +// =========================================================================== + +// Small inverse GQA: q=2, kv=4 → each Q head shared by 2 KV heads +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_InverseGQA_Small) { + const int B = 1, q_H = 2, kv_H = 4, n_k = 4, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.1f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Linear rule with inverse GQA: q=2, kv=4 +TEST(ContribOpLinearAttentionTest, LinearRule_InverseGQA) { + const int B = 1, q_H = 2, kv_H = 4, n_k = 4, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + RunLinearAttentionGQATest("linear", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +// Larger inverse GQA with K-head sharing: q=2, kv=8, n_k=4, dk=16, dv=64 +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_InverseGQA_LargerDims) { + const int B = 1, q_H = 2, kv_H = 8, n_k = 4, T = 4, dk = 16, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.013f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.1f * std::cos(static_cast(i) * 0.017f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.023f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.01f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +} // namespace test +} // namespace onnxruntime