diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 90c59bd6ddf51..9aa44a1600ae6 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -15,6 +15,7 @@ Do not modify directly.*
* com.microsoft.BitmaskBiasDropout
* com.microsoft.BitmaskDropout
* com.microsoft.CDist
+ * com.microsoft.CausalConvWithState
* com.microsoft.ComplexMul
* com.microsoft.ComplexMulConj
* com.microsoft.ConvTransposeWithDynamicPads
@@ -49,6 +50,7 @@ Do not modify directly.*
* com.microsoft.GroupQueryAttention
* com.microsoft.Inverse
* com.microsoft.Irfft
+ * com.microsoft.LinearAttention
* com.microsoft.LongformerAttention
* com.microsoft.MatMulBnb4
* com.microsoft.MatMulFpQ4
@@ -900,6 +902,68 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.CausalConvWithState**
+
+ Stateful causal depthwise convolution, generalized to N spatial dimensions.
+
+ Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step.
+ Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation.
+
+ The convolution is causal (looks only at current and past positions along the last
+ spatial dimension) and depthwise (each channel is convolved independently with its own kernel).
+
+ Input layout is channels-first: (batch_size, channels, ...).
+ Weight layout: (channels, 1, k_1, ...) for depthwise convolution.
+ The carry state stores the last (k-1) positions along the causal axis for incremental decode.
+
+ The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is
+ enforced on the last spatial dimension only.
+
+ The optional activation attribute supports fused SiLU/Swish activation.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- activation : string
+- Fused activation function. One of: 'silu', 'swish', 'none'. Default is 'none'.
+- ndim : int
+- Spatial dimensionality: 1, 2, or 3. Default is 1.
+
+
+#### Inputs (2 - 4)
+
+
+- input : T
+- Input tensor with shape (batch_size, channels, ...). Channels-first layout. Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).
+- weight : T
+- Depthwise convolution kernel with shape (channels, 1, k_1, ...). Spatial kernel sizes: (k_1, ..., k_ndim).
+- bias (optional) : T
+- Optional per-channel bias with shape (channels).
+- past_state (optional) : T
+- Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). If not provided, padding is zero.
+
+
+#### Outputs
+
+
+- output : T
+- Convolution output with same shape as input.
+- present_state : T
+- Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). Contains the last (k-1) values from the virtual input along the causal axis.
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
+
+
+
### **com.microsoft.ComplexMul**
#### Version
@@ -2703,6 +2767,79 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.LinearAttention**
+
+ Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1).
+
+ All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always
+ required. The op internally unpacks to 4D for computation.
+
+ The update_rule attribute selects the recurrence type:
+ - "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
+ - "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
+ - "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t
+ - "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t
+
+ where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.
+
+ Semantics: Equivalent to running the recurrent update sequentially for each token,
+ but may be implemented using chunk-parallel algorithms for GPU efficiency.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- chunk_size : int
+- Chunk size for the chunk-parallel WY decomposition during prefill (T>1). Tuning hint; does not affect output correctness.
+- kv_num_heads : int (required)
+- Number of key/value heads. Always required.
+- q_num_heads : int (required)
+- Number of query heads. Always required.
+- scale : float
+- Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads and uses 1/sqrt(d_k). Set explicitly to override.
+- update_rule : string
+- The update rule for the linear attention recurrence. One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.
+
+
+#### Inputs (3 - 6)
+
+
+- query : T
+- Query vectors with 3D packed shape (B, T, H_q * d_k). Heads are packed into the last dimension.
+- key : T
+- Key vectors with 3D packed shape (B, T, H_kv * d_k). Should be L2-normalized for delta/gated_delta modes.
+- value : T
+- Value vectors with 3D packed shape (B, T, H_kv * d_v).
+- past_state (optional) : S
+- Recurrent state from previous step with shape (B, H_kv, d_k, d_v). Always 4D. If not provided, defaults to zeros.
+- decay (optional) : T
+- Exponential decay gate in log-space. 3D packed shape: (B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or (B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). Required for 'gated' and 'gated_delta' modes.
+- beta (optional) : T
+- Update rate (sigmoid output). 3D packed shape: (B, T, H_kv) or (B, T, 1). Required for 'delta' and 'gated_delta' modes.
+
+
+#### Outputs
+
+
+- output : T
+- Attention output with 3D packed shape (B, T, H_q * d_v).
+- present_state : S
+- Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
+- S : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain state types to float tensors.
+
+
+
### **com.microsoft.LongformerAttention**
Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 092c05f9e081a..1209446c6a367 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -2217,5 +2217,247 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
}
}));
+constexpr const char* CausalConvWithState_ver1_doc = R"DOC(
+Stateful causal depthwise convolution, generalized to N spatial dimensions.
+
+Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step.
+Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation.
+
+The convolution is causal (looks only at current and past positions along the last
+spatial dimension) and depthwise (each channel is convolved independently with its own kernel).
+
+Input layout is channels-first: (batch_size, channels, ...).
+Weight layout: (channels, 1, k_1, ...) for depthwise convolution.
+The carry state stores the last (k-1) positions along the causal axis for incremental decode.
+
+The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is
+enforced on the last spatial dimension only.
+
+The optional activation attribute supports fused SiLU/Swish activation.
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ CausalConvWithState, 1,
+ OpSchema()
+ .SetDoc(CausalConvWithState_ver1_doc)
+ .Attr("activation",
+ "Fused activation function. One of: 'silu', 'swish', 'none'. "
+ "Default is 'none'.",
+ AttributeProto::STRING,
+ std::string("none"))
+ .Attr("ndim",
+ "Spatial dimensionality: 1, 2, or 3. Default is 1.",
+ AttributeProto::INT,
+ static_cast(1))
+ .Input(0,
+ "input",
+ "Input tensor with shape (batch_size, channels, ...). Channels-first layout. "
+ "Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).",
+ "T")
+ .Input(1,
+ "weight",
+ "Depthwise convolution kernel with shape (channels, 1, k_1, ...). "
+ "Spatial kernel sizes: (k_1, ..., k_ndim).",
+ "T")
+ .Input(2,
+ "bias",
+ "Optional per-channel bias with shape (channels).",
+ "T",
+ OpSchema::Optional)
+ .Input(3,
+ "past_state",
+ "Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). "
+ "If not provided, padding is zero.",
+ "T",
+ OpSchema::Optional)
+ .Output(0,
+ "output",
+ "Convolution output with same shape as input.",
+ "T")
+ .Output(1,
+ "present_state",
+ "Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). "
+ "Contains the last (k-1) values from the virtual input along the causal axis.",
+ "T")
+ .TypeConstraint("T",
+ {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ propagateElemTypeFromInputToOutput(ctx, 0, 1);
+
+ // Output 0: same shape as input (batch_size, channels, ...)
+ propagateShapeFromInputToOutput(ctx, 0, 0);
+
+ // Output 1: state shape is (batch_size, channels, [non-causal spatial dims...], k_last - 1)
+ // For ndim=1: (B, C, k_1-1)
+ // For ndim=2: (B, C, input_H, k_2-1)
+ // For ndim=3: (B, C, input_D, input_H, k_3-1)
+ if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) {
+ auto& input_shape = getInputShape(ctx, 0);
+ auto& weight_shape = getInputShape(ctx, 1);
+ int64_t ndim = getAttribute(ctx, "ndim", 1);
+ TensorShapeProto state_shape;
+ *state_shape.add_dim() = input_shape.dim(0); // batch_size
+ *state_shape.add_dim() = input_shape.dim(1); // channels
+ // Copy non-causal spatial dims from input (dims 2 .. 2+ndim-2)
+ for (int64_t i = 0; i < ndim - 1; ++i) {
+ *state_shape.add_dim() = input_shape.dim(static_cast(2 + i));
+ }
+ // Causal (last) spatial dim: kernel_size - 1
+ int last_kernel_dim = weight_shape.dim_size() - 1;
+ if (weight_shape.dim(last_kernel_dim).has_dim_value()) {
+ state_shape.add_dim()->set_dim_value(weight_shape.dim(last_kernel_dim).dim_value() - 1);
+ } else {
+ state_shape.add_dim(); // unknown
+ }
+ updateOutputShape(ctx, 1, state_shape);
+ }
+ }));
+
+constexpr const char* LinearAttention_ver1_doc = R"DOC(
+Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1).
+
+All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always
+required. The op internally unpacks to 4D for computation.
+
+The update_rule attribute selects the recurrence type:
+- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
+- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
+- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t
+- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t
+
+where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.
+
+Semantics: Equivalent to running the recurrent update sequentially for each token,
+but may be implemented using chunk-parallel algorithms for GPU efficiency.
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ LinearAttention, 1,
+ OpSchema()
+ .SetDoc(LinearAttention_ver1_doc)
+ .Attr("update_rule",
+ "The update rule for the linear attention recurrence. "
+ "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.",
+ AttributeProto::STRING,
+ std::string("gated_delta"))
+ .Attr("scale",
+ "Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads "
+ "and uses 1/sqrt(d_k). Set explicitly to override.",
+ AttributeProto::FLOAT,
+ 0.0f)
+ .Attr("q_num_heads",
+ "Number of query heads. Always required.",
+ AttributeProto::INT)
+ .Attr("kv_num_heads",
+ "Number of key/value heads. Always required.",
+ AttributeProto::INT)
+ .Attr("chunk_size",
+ "Chunk size for the chunk-parallel WY decomposition during prefill (T>1). "
+ "Tuning hint; does not affect output correctness.",
+ AttributeProto::INT,
+ static_cast(64))
+ .Input(0,
+ "query",
+ "Query vectors with 3D packed shape (B, T, H_q * d_k). "
+ "Heads are packed into the last dimension.",
+ "T")
+ .Input(1,
+ "key",
+ "Key vectors with 3D packed shape (B, T, H_kv * d_k). "
+ "Should be L2-normalized for delta/gated_delta modes.",
+ "T")
+ .Input(2,
+ "value",
+ "Value vectors with 3D packed shape (B, T, H_kv * d_v).",
+ "T")
+ .Input(3,
+ "past_state",
+ "Recurrent state from previous step with shape (B, H_kv, d_k, d_v). "
+ "Always 4D. If not provided, defaults to zeros.",
+ "S",
+ OpSchema::Optional)
+ .Input(4,
+ "decay",
+ "Exponential decay gate in log-space. 3D packed shape: "
+ "(B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or "
+ "(B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). "
+ "Required for 'gated' and 'gated_delta' modes.",
+ "T",
+ OpSchema::Optional)
+ .Input(5,
+ "beta",
+ "Update rate (sigmoid output). 3D packed shape: "
+ "(B, T, H_kv) or (B, T, 1). "
+ "Required for 'delta' and 'gated_delta' modes.",
+ "T",
+ OpSchema::Optional)
+ .Output(0,
+ "output",
+ "Attention output with 3D packed shape (B, T, H_q * d_v).",
+ "T")
+ .Output(1,
+ "present_state",
+ "Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.",
+ "S")
+ .TypeConstraint("T",
+ {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeConstraint("S",
+ {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain state types to float tensors.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ propagateElemTypeFromInputToOutput(ctx, 0, 1);
+
+ // Read required attributes
+ auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
+ auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
+ int64_t q_num_heads = (q_num_heads_attr && q_num_heads_attr->has_i()) ? q_num_heads_attr->i() : 0;
+ int64_t kv_num_heads = (kv_num_heads_attr && kv_num_heads_attr->has_i()) ? kv_num_heads_attr->i() : 0;
+
+ // Output 0: (B, T, H_q * d_v) — 3D packed
+ if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) {
+ auto& query_shape = getInputShape(ctx, 0);
+ auto& value_shape = getInputShape(ctx, 2);
+ TensorShapeProto output_shape;
+ *output_shape.add_dim() = query_shape.dim(0); // B
+ *output_shape.add_dim() = query_shape.dim(1); // T
+ // H_q * d_v: d_v = value.dim(2) / kv_num_heads, then H_q * d_v
+ if (value_shape.dim(2).has_dim_value()) {
+ int64_t d_v = value_shape.dim(2).dim_value() / kv_num_heads;
+ output_shape.add_dim()->set_dim_value(kv_num_heads * d_v);
+ } else {
+ output_shape.add_dim(); // unknown
+ }
+ updateOutputShape(ctx, 0, output_shape);
+ }
+
+ // Output 1: present_state shape (B, H_kv, d_k, d_v) — 4D
+ if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) {
+ auto& query_shape = getInputShape(ctx, 0);
+ auto& value_shape = getInputShape(ctx, 2);
+ TensorShapeProto state_shape;
+ *state_shape.add_dim() = query_shape.dim(0); // B
+ state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv
+ // d_k = query.dim(2) / q_num_heads
+ if (query_shape.dim(2).has_dim_value()) {
+ state_shape.add_dim()->set_dim_value(query_shape.dim(2).dim_value() / q_num_heads);
+ } else {
+ state_shape.add_dim();
+ }
+ // d_v = value.dim(2) / kv_num_heads
+ if (value_shape.dim(2).has_dim_value()) {
+ state_shape.add_dim()->set_dim_value(value_shape.dim(2).dim_value() / kv_num_heads);
+ } else {
+ state_shape.add_dim();
+ }
+ updateOutputShape(ctx, 1, state_shape);
+ } else if (hasInputShape(ctx, 3)) {
+ propagateShapeFromInputToOutput(ctx, 3, 1);
+ }
+ }));
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 6c20aae94d132..59f97c222ceb2 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -88,6 +88,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttention);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConvWithState);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
@@ -199,6 +201,8 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc
new file mode 100644
index 0000000000000..2a7837dd1ce73
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc
@@ -0,0 +1,658 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+#include "gtest/gtest.h"
+#include "core/common/logging/logging.h"
+#include "core/framework/kernel_registry.h"
+#include "core/session/onnxruntime_cxx_api.h"
+#include "test/common/tensor_op_test_utils.h"
+#include "test/common/cuda_op_test_utils.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
+
+namespace onnxruntime {
+namespace test {
+
+namespace {
+enum class TensorType {
+ kFloat,
+ kFloat16
+};
+
+// Reference implementation for CausalConvWithState
+// Performs depthwise causal 1D convolution with optional state, bias, and activation.
+//
+// Input: (B, D, L) channels-first
+// Weight: (D, 1, K) depthwise
+// Bias: (D,) optional
+// past_state: (B, D, K-1) optional carry state
+//
+// Output: (B, D, L) convolution output (with optional activation)
+// present_state: (B, D, K-1) updated carry state
+void CausalConvWithStateReference(
+ const std::vector& input,
+ const std::vector& weight,
+ const std::vector* bias,
+ const std::vector* conv_state,
+ std::vector& output,
+ std::vector& present_state,
+ int batch_size,
+ int channels,
+ int input_length,
+ int kernel_size,
+ const std::string& activation) {
+ int state_length = kernel_size - 1;
+ int total_virtual_length = state_length + input_length;
+
+ output.resize(batch_size * channels * input_length);
+ present_state.resize(batch_size * channels * state_length);
+
+ for (int b = 0; b < batch_size; ++b) {
+ for (int d = 0; d < channels; ++d) {
+ int bd = b * channels + d;
+
+ // Build virtual input: [conv_state, input]
+ std::vector virtual_input(total_virtual_length, 0.0f);
+ if (conv_state != nullptr) {
+ for (int s = 0; s < state_length; ++s) {
+ virtual_input[s] = (*conv_state)[bd * state_length + s];
+ }
+ }
+ for (int l = 0; l < input_length; ++l) {
+ virtual_input[state_length + l] = input[bd * input_length + l];
+ }
+
+ // Compute depthwise convolution
+ for (int pos = 0; pos < input_length; ++pos) {
+ float acc = 0.0f;
+ for (int j = 0; j < kernel_size; ++j) {
+ float val = virtual_input[pos + j];
+ float w = weight[d * kernel_size + j];
+ acc += val * w;
+ }
+ // Add bias
+ if (bias != nullptr) {
+ acc += (*bias)[d];
+ }
+ // Apply activation
+ if (activation == "silu" || activation == "swish") {
+ acc = acc / (1.0f + std::exp(-acc));
+ }
+ output[bd * input_length + pos] = acc;
+ }
+
+ // Compute present_state: last state_length values from virtual input
+ for (int s = 0; s < state_length; ++s) {
+ present_state[bd * state_length + s] =
+ virtual_input[input_length + s];
+ }
+ }
+ }
+}
+
+// Returns a WebGPU EP if it is available and has the CausalConvWithState kernel registered,
+// or nullptr otherwise.
+std::unique_ptr TryGetEpWithCausalConvWithState() {
+ auto ep = DefaultWebGpuExecutionProvider();
+ if (!ep) {
+ ep = DefaultCpuExecutionProvider();
+ }
+
+ auto kernel_registry = ep->GetKernelRegistry();
+ if (kernel_registry) {
+ const KernelCreateInfo* info = nullptr;
+ KernelRegistry::TypeConstraintMap type_constraints;
+ auto status = kernel_registry->TryFindKernel(
+ ep->Type(), "CausalConvWithState", kMSDomain, 1,
+ type_constraints, DefaultLoggingManager().DefaultLogger(), &info);
+ if (!status.IsOK()) return nullptr;
+ }
+ return ep;
+}
+
+} // anonymous namespace
+
+static void RunCausalConvWithStateTest(
+ const std::vector& input_data,
+ const std::vector& weight_data,
+ const std::vector* bias_data,
+ const std::vector* conv_state_data,
+ const std::vector& expected_output,
+ const std::vector& expected_state,
+ int batch_size,
+ int channels,
+ int input_length,
+ int kernel_size,
+ const std::string& activation,
+ TensorType tensor_type) {
+ auto ep = TryGetEpWithCausalConvWithState();
+ if (!ep) {
+ GTEST_SKIP() << "CausalConvWithState kernel not registered";
+ return;
+ }
+
+ int state_length = kernel_size - 1;
+
+ std::vector input_shape = {batch_size, channels, input_length};
+ std::vector weight_shape = {channels, 1, kernel_size};
+ std::vector bias_shape = {channels};
+ std::vector state_shape = {batch_size, channels, state_length};
+ std::vector output_shape = {batch_size, channels, input_length};
+
+ {
+ OpTester test("CausalConvWithState", 1, onnxruntime::kMSDomain);
+ test.AddAttribute("activation", activation);
+
+ if (tensor_type == TensorType::kFloat) {
+ test.AddInput("input", input_shape, input_data);
+ test.AddInput("weight", weight_shape, weight_data);
+
+ if (bias_data != nullptr) {
+ test.AddInput("bias", bias_shape, *bias_data);
+ } else {
+ test.AddOptionalInputEdge();
+ }
+
+ if (conv_state_data != nullptr) {
+ test.AddInput("past_state", state_shape, *conv_state_data);
+ } else {
+ test.AddOptionalInputEdge();
+ }
+
+ test.AddOutput("output", output_shape, expected_output);
+ test.AddOutput("present_state", state_shape, expected_state);
+ } else {
+ test.AddInput("input", input_shape, ToFloat16(input_data));
+ test.AddInput("weight", weight_shape, ToFloat16(weight_data));
+
+ if (bias_data != nullptr) {
+ test.AddInput("bias", bias_shape, ToFloat16(*bias_data));
+ } else {
+ test.AddOptionalInputEdge();
+ }
+
+ if (conv_state_data != nullptr) {
+ test.AddInput("past_state", state_shape, ToFloat16(*conv_state_data));
+ } else {
+ test.AddOptionalInputEdge();
+ }
+
+ test.AddOutput("output", output_shape, ToFloat16(expected_output));
+ test.AddOutput("present_state", state_shape, ToFloat16(expected_state));
+ }
+
+ test.SetOutputAbsErr("output", 0.01f);
+ test.SetOutputAbsErr("present_state", 0.01f);
+
+ std::vector> execution_providers;
+ execution_providers.push_back(std::move(ep));
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+ }
+}
+
+static void RunCausalConvWithStateTests(
+ const std::vector& input_data,
+ const std::vector& weight_data,
+ const std::vector* bias_data,
+ const std::vector* conv_state_data,
+ int batch_size,
+ int channels,
+ int input_length,
+ int kernel_size,
+ const std::string& activation = "silu") {
+ // Compute expected output using reference implementation
+ std::vector expected_output;
+ std::vector expected_state;
+ CausalConvWithStateReference(
+ input_data, weight_data, bias_data, conv_state_data,
+ expected_output, expected_state,
+ batch_size, channels, input_length, kernel_size, activation);
+
+ // FP32 test
+ RunCausalConvWithStateTest(
+ input_data, weight_data, bias_data, conv_state_data,
+ expected_output, expected_state,
+ batch_size, channels, input_length, kernel_size, activation,
+ TensorType::kFloat);
+
+ // FP16 test
+ RunCausalConvWithStateTest(
+ input_data, weight_data, bias_data, conv_state_data,
+ expected_output, expected_state,
+ batch_size, channels, input_length, kernel_size, activation,
+ TensorType::kFloat16);
+}
+
+// =============================================================================
+// Basic tests - simple cases
+// =============================================================================
+
+TEST(CausalConvWithStateTest, BasicNoStateNoBias) {
+ // B=1, D=2, L=4, K=3, activation=none
+ int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3;
+
+ // Input: (1, 2, 4)
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f, 4.0f, // channel 0
+ 0.5f, 1.5f, 2.5f, 3.5f}; // channel 1
+
+ // Weight: (2, 1, 3)
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f, // channel 0 kernel
+ 0.4f, 0.5f, 0.6f}; // channel 1 kernel
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, nullptr,
+ batch_size, channels, input_length, kernel_size, "none");
+}
+
+TEST(CausalConvWithStateTest, BasicWithBias) {
+ // B=1, D=2, L=4, K=3, activation=none
+ int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 0.5f, 1.5f, 2.5f, 3.5f};
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+ std::vector bias_data = {0.1f, -0.2f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, &bias_data, nullptr,
+ batch_size, channels, input_length, kernel_size, "none");
+}
+
+TEST(CausalConvWithStateTest, BasicWithState) {
+ // B=1, D=2, L=3, K=3, activation=none
+ int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f,
+ 0.5f, 1.5f, 2.5f};
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+ // State: (1, 2, 2) - kernel_size - 1 = 2
+ std::vector conv_state_data = {
+ -1.0f, 0.5f, // channel 0 state
+ 0.3f, -0.7f}; // channel 1 state
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "none");
+}
+
+TEST(CausalConvWithStateTest, WithStateAndBias) {
+ // B=1, D=2, L=3, K=3, activation=none
+ int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f,
+ 0.5f, 1.5f, 2.5f};
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+ std::vector bias_data = {0.1f, -0.2f};
+ std::vector conv_state_data = {
+ -1.0f, 0.5f,
+ 0.3f, -0.7f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, &bias_data, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "none");
+}
+
+// =============================================================================
+// SiLU activation tests
+// =============================================================================
+
+TEST(CausalConvWithStateTest, SiluActivationNoState) {
+ int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 0.5f, 1.5f, 2.5f, 3.5f};
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, nullptr,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+TEST(CausalConvWithStateTest, SiluActivationWithState) {
+ int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f,
+ 0.5f, 1.5f, 2.5f};
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+ std::vector conv_state_data = {
+ -1.0f, 0.5f,
+ 0.3f, -0.7f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+TEST(CausalConvWithStateTest, SiluActivationWithBiasAndState) {
+ int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 0.5f, 1.5f, 2.5f, 3.5f};
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+ std::vector bias_data = {0.1f, -0.2f};
+ std::vector conv_state_data = {
+ -1.0f, 0.5f,
+ 0.3f, -0.7f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, &bias_data, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+// =============================================================================
+// Kernel size variations
+// =============================================================================
+
+TEST(CausalConvWithStateTest, KernelSize2) {
+ int batch_size = 1, channels = 2, input_length = 4, kernel_size = 2;
+
+ std::vector input_data = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 0.5f, 1.5f, 2.5f, 3.5f};
+ std::vector weight_data = {
+ 0.3f, 0.7f,
+ 0.4f, 0.6f};
+ // State: (1, 2, 1) - kernel_size - 1 = 1
+ std::vector conv_state_data = {0.5f, -0.3f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+TEST(CausalConvWithStateTest, KernelSize4) {
+ int batch_size = 1, channels = 1, input_length = 5, kernel_size = 4;
+
+ std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
+ std::vector weight_data = {0.1f, 0.2f, 0.3f, 0.4f};
+ // State: (1, 1, 3)
+ std::vector conv_state_data = {-1.0f, 0.0f, 0.5f};
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "none");
+}
+
+// =============================================================================
+// Batch size > 1
+// =============================================================================
+
+TEST(CausalConvWithStateTest, MultiBatch) {
+ int batch_size = 2, channels = 2, input_length = 3, kernel_size = 3;
+
+ // Input: (2, 2, 3)
+ std::vector input_data = {
+ // Batch 0
+ 1.0f, 2.0f, 3.0f, // ch 0
+ 0.5f, 1.5f, 2.5f, // ch 1
+ // Batch 1
+ -1.0f, 0.0f, 1.0f, // ch 0
+ 0.2f, 0.4f, 0.6f}; // ch 1
+
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+
+ std::vector bias_data = {0.1f, -0.1f};
+
+ // State: (2, 2, 2)
+ std::vector conv_state_data = {
+ // Batch 0
+ -0.5f, 0.5f, // ch 0
+ 0.3f, -0.3f, // ch 1
+ // Batch 1
+ 0.1f, -0.1f, // ch 0
+ 0.7f, 0.8f}; // ch 1
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, &bias_data, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+// =============================================================================
+// Single token decode (L=1) - the primary use case for incremental decoding
+// =============================================================================
+
+TEST(CausalConvWithStateTest, SingleTokenDecode) {
+ int batch_size = 1, channels = 4, input_length = 1, kernel_size = 4;
+
+ // Input: (1, 4, 1)
+ std::vector input_data = {0.5f, -0.3f, 1.2f, 0.8f};
+
+ // Weight: (4, 1, 4)
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f, 0.4f,
+ 0.5f, 0.6f, 0.7f, 0.8f,
+ -0.1f, -0.2f, 0.1f, 0.2f,
+ 0.3f, 0.3f, 0.3f, 0.3f};
+
+ std::vector bias_data = {0.0f, 0.1f, -0.1f, 0.0f};
+
+ // State: (1, 4, 3) - carrying the last 3 values per channel
+ std::vector conv_state_data = {
+ 1.0f, 2.0f, 3.0f, // ch 0
+ -1.0f, 0.0f, 1.0f, // ch 1
+ 0.5f, 0.5f, 0.5f, // ch 2
+ -0.2f, 0.4f, -0.6f}; // ch 3
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, &bias_data, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+TEST(CausalConvWithStateTest, SingleTokenDecodeMultiBatch) {
+ int batch_size = 2, channels = 2, input_length = 1, kernel_size = 3;
+
+ // Input: (2, 2, 1)
+ std::vector input_data = {
+ 0.5f, // B0, ch 0
+ -0.3f, // B0, ch 1
+ 1.2f, // B1, ch 0
+ 0.8f}; // B1, ch 1
+
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+
+ // State: (2, 2, 2)
+ std::vector conv_state_data = {
+ 1.0f, 2.0f, // B0, ch 0
+ -1.0f, 0.0f, // B0, ch 1
+ 0.5f, 0.5f, // B1, ch 0
+ -0.2f, 0.4f}; // B1, ch 1
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, nullptr, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+// =============================================================================
+// State continuity test: verify that present_state from one call can be used
+// as conv_state for the next call (simulating autoregressive decode)
+// =============================================================================
+
+TEST(CausalConvWithStateTest, StateContinuity) {
+ // Process a sequence of single tokens and verify state propagation
+ int batch_size = 1, channels = 1, kernel_size = 3;
+ int input_length = 1;
+
+ std::vector weight_data = {0.2f, 0.3f, 0.5f};
+ std::vector bias_data = {0.1f};
+
+ // Initial state: zeros
+ std::vector conv_state = {0.0f, 0.0f};
+
+ // First token
+ std::vector input1 = {1.0f};
+ std::vector expected_output1;
+ std::vector expected_state1;
+ CausalConvWithStateReference(input1, weight_data, &bias_data, &conv_state,
+ expected_output1, expected_state1,
+ batch_size, channels, input_length, kernel_size, "none");
+
+ RunCausalConvWithStateTest(input1, weight_data, &bias_data, &conv_state,
+ expected_output1, expected_state1,
+ batch_size, channels, input_length, kernel_size, "none",
+ TensorType::kFloat);
+
+ // Second token, using present_state from first as conv_state
+ std::vector input2 = {2.0f};
+ std::vector expected_output2;
+ std::vector expected_state2;
+ CausalConvWithStateReference(input2, weight_data, &bias_data, &expected_state1,
+ expected_output2, expected_state2,
+ batch_size, channels, input_length, kernel_size, "none");
+
+ RunCausalConvWithStateTest(input2, weight_data, &bias_data, &expected_state1,
+ expected_output2, expected_state2,
+ batch_size, channels, input_length, kernel_size, "none",
+ TensorType::kFloat);
+
+ // Third token
+ std::vector input3 = {3.0f};
+ std::vector expected_output3;
+ std::vector expected_state3;
+ CausalConvWithStateReference(input3, weight_data, &bias_data, &expected_state2,
+ expected_output3, expected_state3,
+ batch_size, channels, input_length, kernel_size, "none");
+
+ RunCausalConvWithStateTest(input3, weight_data, &bias_data, &expected_state2,
+ expected_output3, expected_state3,
+ batch_size, channels, input_length, kernel_size, "none",
+ TensorType::kFloat);
+
+ // The present_state after processing [1, 2, 3] should be [2, 3]
+ EXPECT_NEAR(expected_state3[0], 2.0f, 1e-5f);
+ EXPECT_NEAR(expected_state3[1], 3.0f, 1e-5f);
+}
+
+// =============================================================================
+// Equivalence test: sequence processing should match token-by-token with state
+// =============================================================================
+
+TEST(CausalConvWithStateTest, SequenceVsTokenByToken) {
+ int batch_size = 1, channels = 2, kernel_size = 3;
+
+ std::vector weight_data = {
+ 0.1f, 0.2f, 0.3f,
+ 0.4f, 0.5f, 0.6f};
+ std::vector bias_data = {0.05f, -0.05f};
+
+ // Initial state: zeros
+ std::vector conv_state = {0.0f, 0.0f, 0.0f, 0.0f}; // (1, 2, 2)
+
+ // Full sequence: length 4
+ std::vector full_input = {
+ 1.0f, 2.0f, 3.0f, 4.0f, // ch 0
+ 0.5f, 1.5f, 2.5f, 3.5f}; // ch 1
+
+ // Process full sequence at once
+ std::vector full_output;
+ std::vector full_final_state;
+ CausalConvWithStateReference(full_input, weight_data, &bias_data, &conv_state,
+ full_output, full_final_state,
+ batch_size, channels, 4, kernel_size, "none");
+
+ // Process token by token
+ std::vector current_state = conv_state;
+ std::vector token_outputs;
+
+ for (int t = 0; t < 4; ++t) {
+ // Extract single token: (1, 2, 1)
+ std::vector token_input = {
+ full_input[0 * 4 + t], // ch 0
+ full_input[1 * 4 + t]}; // ch 1
+
+ std::vector token_output;
+ std::vector next_state;
+ CausalConvWithStateReference(token_input, weight_data, &bias_data, ¤t_state,
+ token_output, next_state,
+ batch_size, channels, 1, kernel_size, "none");
+
+ // Collect outputs
+ for (int d = 0; d < channels; ++d) {
+ token_outputs.push_back(token_output[d]);
+ }
+ current_state = next_state;
+ }
+
+ // Rearrange token_outputs from (T, D) to (D, T) layout for comparison
+ std::vector token_outputs_dlt(channels * 4);
+ for (int t = 0; t < 4; ++t) {
+ for (int d = 0; d < channels; ++d) {
+ token_outputs_dlt[d * 4 + t] = token_outputs[t * channels + d];
+ }
+ }
+
+ // Compare outputs
+ for (int i = 0; i < channels * 4; ++i) {
+ EXPECT_NEAR(full_output[i], token_outputs_dlt[i], 1e-5f)
+ << "Mismatch at index " << i;
+ }
+
+ // Compare final states
+ for (int i = 0; i < channels * 2; ++i) {
+ EXPECT_NEAR(full_final_state[i], current_state[i], 1e-5f)
+ << "State mismatch at index " << i;
+ }
+}
+
+// =============================================================================
+// Larger dimension test with realistic sizes
+// =============================================================================
+
+TEST(CausalConvWithStateTest, LargerDimensions) {
+ int batch_size = 2, channels = 8, input_length = 16, kernel_size = 4;
+
+ // Generate test data with a simple pattern
+ std::vector input_data(batch_size * channels * input_length);
+ for (int i = 0; i < static_cast(input_data.size()); ++i) {
+ input_data[i] = std::sin(static_cast(i) * 0.1f);
+ }
+
+ std::vector weight_data(channels * kernel_size);
+ for (int i = 0; i < static_cast(weight_data.size()); ++i) {
+ weight_data[i] = std::cos(static_cast(i) * 0.2f) * 0.5f;
+ }
+
+ std::vector bias_data(channels);
+ for (int i = 0; i < channels; ++i) {
+ bias_data[i] = 0.01f * static_cast(i);
+ }
+
+ int state_length = kernel_size - 1;
+ std::vector conv_state_data(batch_size * channels * state_length);
+ for (int i = 0; i < static_cast(conv_state_data.size()); ++i) {
+ conv_state_data[i] = std::sin(static_cast(i) * 0.3f) * 0.5f;
+ }
+
+ RunCausalConvWithStateTests(
+ input_data, weight_data, &bias_data, &conv_state_data,
+ batch_size, channels, input_length, kernel_size, "silu");
+}
+
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc
new file mode 100644
index 0000000000000..fd2b648c8badf
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc
@@ -0,0 +1,1201 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "core/common/logging/logging.h"
+#include "core/framework/kernel_registry.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/util/include/default_providers.h"
+
+using namespace onnxruntime::test;
+
+namespace onnxruntime {
+namespace test {
+
+namespace {
+
+// Reference implementation of the linear attention recurrence.
+// Processes all tokens sequentially and returns output + final_state.
+void LinearAttentionReference(
+ const std::string& update_rule,
+ int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v,
+ float scale,
+ const std::vector& query,
+ const std::vector& key,
+ const std::vector& value,
+ const std::vector* initial_state,
+ const std::vector* decay,
+ const std::vector* beta,
+ std::vector& output,
+ std::vector& final_state) {
+ int bht = batch_size * num_heads * seq_length;
+ bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht);
+
+ // State: (B, H, dk, dv)
+ final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f);
+ output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f);
+
+ // Initialize state from initial_state if provided
+ if (initial_state != nullptr) {
+ final_state = *initial_state;
+ }
+
+ for (int b = 0; b < batch_size; b++) {
+ for (int h = 0; h < num_heads; h++) {
+ // State for this (b, h): dk x dv
+ auto state_offset = [&](int k, int v) {
+ return ((b * num_heads + h) * head_dim_k + k) * head_dim_v + v;
+ };
+
+ for (int t = 0; t < seq_length; t++) {
+ auto qkv_offset = [&](int dim) {
+ return ((b * num_heads + h) * seq_length + t) * dim;
+ };
+
+ // Load q, k for this token
+ std::vector q_vec(head_dim_k), k_vec(head_dim_k), v_vec(head_dim_v);
+ for (int i = 0; i < head_dim_k; i++) {
+ q_vec[i] = query[qkv_offset(head_dim_k) + i];
+ k_vec[i] = key[qkv_offset(head_dim_k) + i];
+ }
+ for (int i = 0; i < head_dim_v; i++) {
+ v_vec[i] = value[qkv_offset(head_dim_v) + i];
+ }
+
+ // Step 1: Apply decay (gated, gated_delta)
+ if (update_rule == "gated" || update_rule == "gated_delta") {
+ for (int k = 0; k < head_dim_k; k++) {
+ float exp_g;
+ if (decay_broadcast_dk) {
+ int decay_idx = (b * num_heads + h) * seq_length + t;
+ exp_g = std::exp((*decay)[decay_idx]);
+ } else {
+ int decay_idx = ((b * num_heads + h) * seq_length + t) * head_dim_k + k;
+ exp_g = std::exp((*decay)[decay_idx]);
+ }
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ final_state[state_offset(k, v_idx)] *= exp_g;
+ }
+ }
+ }
+
+ // Step 2: Compute state update
+ if (update_rule == "delta" || update_rule == "gated_delta") {
+ // retrieved = S^T @ k (for each v dimension)
+ std::vector retrieved(head_dim_v, 0.0f);
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ for (int k = 0; k < head_dim_k; k++) {
+ retrieved[v_idx] += final_state[state_offset(k, v_idx)] * k_vec[k];
+ }
+ }
+
+ // delta = beta * (v - retrieved)
+ int beta_idx = (b * num_heads + h) * seq_length + t;
+ float beta_val = (*beta)[beta_idx];
+ std::vector delta(head_dim_v);
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ delta[v_idx] = beta_val * (v_vec[v_idx] - retrieved[v_idx]);
+ }
+
+ // S += k ⊗ delta
+ for (int k = 0; k < head_dim_k; k++) {
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ final_state[state_offset(k, v_idx)] += k_vec[k] * delta[v_idx];
+ }
+ }
+ } else {
+ // linear, gated: S += k ⊗ v
+ for (int k = 0; k < head_dim_k; k++) {
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ final_state[state_offset(k, v_idx)] += k_vec[k] * v_vec[v_idx];
+ }
+ }
+ }
+
+ // Step 3: Compute output = scale * S^T @ q
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ float sum = 0.0f;
+ for (int k = 0; k < head_dim_k; k++) {
+ sum += final_state[state_offset(k, v_idx)] * q_vec[k];
+ }
+ int out_idx = ((b * num_heads + h) * seq_length + t) * head_dim_v + v_idx;
+ output[out_idx] = scale * sum;
+ }
+ }
+ }
+ }
+}
+
+// GQA-aware reference implementation.
+// Q has q_num_heads heads, K has n_k_heads heads, V/state have kv_num_heads heads.
+// Standard GQA: q_num_heads >= kv_num_heads, heads_per_group = q_num_heads / kv_num_heads.
+// K-to-KV sharing: kv_per_k_head = kv_num_heads / n_k_heads.
+void LinearAttentionGQAReference(
+ const std::string& update_rule,
+ int batch_size, int q_num_heads, int kv_num_heads, int n_k_heads,
+ int seq_length, int head_dim_k, int head_dim_v,
+ float scale,
+ const std::vector& query, // (B, q_num_heads, T, dk)
+ const std::vector& key, // (B, n_k_heads, T, dk)
+ const std::vector& value, // (B, kv_num_heads, T, dv)
+ const std::vector* initial_state, // (B, kv_num_heads, dk, dv)
+ const std::vector* decay, // (B, kv_num_heads, T[, dk])
+ const std::vector* beta, // (B, kv_num_heads, T)
+ std::vector& output, // (B, kv_num_heads, T, dv)
+ std::vector& final_state) { // (B, kv_num_heads, dk, dv)
+ int bht_kv = batch_size * kv_num_heads * seq_length;
+ bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht_kv);
+ int kv_per_k_head = kv_num_heads / n_k_heads;
+ bool inverse_gqa = q_num_heads < kv_num_heads;
+ int heads_per_group = inverse_gqa ? 0 : q_num_heads / kv_num_heads;
+
+ final_state.resize(batch_size * kv_num_heads * head_dim_k * head_dim_v, 0.0f);
+ // Output always indexed by kv_num_heads (matches schema: output_dim == V_dim)
+ output.resize(batch_size * kv_num_heads * seq_length * head_dim_v, 0.0f);
+
+ if (initial_state != nullptr) {
+ final_state = *initial_state;
+ }
+
+ for (int b = 0; b < batch_size; b++) {
+ for (int kv_h = 0; kv_h < kv_num_heads; kv_h++) {
+ int k_head = kv_h / kv_per_k_head;
+
+ auto state_offset = [&](int k, int v) {
+ return ((b * kv_num_heads + kv_h) * head_dim_k + k) * head_dim_v + v;
+ };
+
+ for (int t = 0; t < seq_length; t++) {
+ // Load k from the K-head that this KV-head maps to
+ std::vector k_vec(head_dim_k), v_vec(head_dim_v);
+ int k_base = ((b * n_k_heads + k_head) * seq_length + t) * head_dim_k;
+ for (int i = 0; i < head_dim_k; i++) k_vec[i] = key[k_base + i];
+ int v_base = ((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_v;
+ for (int i = 0; i < head_dim_v; i++) v_vec[i] = value[v_base + i];
+
+ // Step 1: Apply decay
+ if (update_rule == "gated" || update_rule == "gated_delta") {
+ for (int k = 0; k < head_dim_k; k++) {
+ float exp_g;
+ if (decay_broadcast_dk) {
+ exp_g = std::exp((*decay)[(b * kv_num_heads + kv_h) * seq_length + t]);
+ } else {
+ exp_g = std::exp((*decay)[((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_k + k]);
+ }
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ final_state[state_offset(k, v_idx)] *= exp_g;
+ }
+ }
+ }
+
+ // Step 2: Update state
+ if (update_rule == "delta" || update_rule == "gated_delta") {
+ std::vector retrieved(head_dim_v, 0.0f);
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ for (int k = 0; k < head_dim_k; k++) {
+ retrieved[v_idx] += final_state[state_offset(k, v_idx)] * k_vec[k];
+ }
+ }
+ int beta_idx = (b * kv_num_heads + kv_h) * seq_length + t;
+ float beta_val = (*beta)[beta_idx];
+ std::vector delta(head_dim_v);
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ delta[v_idx] = beta_val * (v_vec[v_idx] - retrieved[v_idx]);
+ }
+ for (int k = 0; k < head_dim_k; k++) {
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ final_state[state_offset(k, v_idx)] += k_vec[k] * delta[v_idx];
+ }
+ }
+ } else {
+ for (int k = 0; k < head_dim_k; k++) {
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ final_state[state_offset(k, v_idx)] += k_vec[k] * v_vec[v_idx];
+ }
+ }
+ }
+
+ // Step 3: Compute output
+ if (!inverse_gqa) {
+ // Standard GQA/MHA: one output per Q head
+ for (int g = 0; g < heads_per_group; g++) {
+ int q_h = kv_h * heads_per_group + g;
+ int q_base = ((b * q_num_heads + q_h) * seq_length + t) * head_dim_k;
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ float sum = 0.0f;
+ for (int k = 0; k < head_dim_k; k++) {
+ sum += final_state[state_offset(k, v_idx)] * query[q_base + k];
+ }
+ // For standard, output head == q head; since q==kv per schema, also == kv_h index
+ int out_idx = ((b * kv_num_heads + (kv_h * heads_per_group + g)) * seq_length + t) * head_dim_v + v_idx;
+ output[out_idx] = scale * sum;
+ }
+ }
+ } else {
+ // Inverse GQA: output indexed by kv_head, Q broadcast
+ int q_h = kv_h * q_num_heads / kv_num_heads;
+ int q_base = ((b * q_num_heads + q_h) * seq_length + t) * head_dim_k;
+ for (int v_idx = 0; v_idx < head_dim_v; v_idx++) {
+ float sum = 0.0f;
+ for (int k = 0; k < head_dim_k; k++) {
+ sum += final_state[state_offset(k, v_idx)] * query[q_base + k];
+ }
+ int out_idx = ((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_v + v_idx;
+ output[out_idx] = scale * sum;
+ }
+ }
+ }
+ }
+ }
+}
+
+// Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout
+std::vector PackBHTD_to_BTHD(const std::vector& data_4d,
+ int B, int H, int T, int D) {
+ std::vector packed(B * T * H * D);
+ for (int b = 0; b < B; b++) {
+ for (int h = 0; h < H; h++) {
+ for (int t = 0; t < T; t++) {
+ for (int d = 0; d < D; d++) {
+ int src_idx = ((b * H + h) * T + t) * D + d;
+ int dst_idx = (b * T + t) * (H * D) + h * D + d;
+ packed[dst_idx] = data_4d[src_idx];
+ }
+ }
+ }
+ }
+ return packed;
+}
+
+// Convert decay/beta from (B,H,T) layout to (B,T,H) layout
+std::vector TransposeBHT_to_BTH(const std::vector& data,
+ int B, int H, int T) {
+ std::vector transposed(B * T * H);
+ for (int b = 0; b < B; b++) {
+ for (int h = 0; h < H; h++) {
+ for (int t = 0; t < T; t++) {
+ int src_idx = (b * H + h) * T + t;
+ int dst_idx = (b * T + t) * H + h;
+ transposed[dst_idx] = data[src_idx];
+ }
+ }
+ }
+ return transposed;
+}
+
+// Returns a WebGPU EP if it is available and has the LinearAttention kernel registered,
+// or nullptr otherwise.
+std::unique_ptr TryGetEpWithLinearAttention() {
+ auto ep = DefaultWebGpuExecutionProvider();
+ if (!ep) {
+ ep = DefaultCpuExecutionProvider();
+ }
+
+ auto kernel_registry = ep->GetKernelRegistry();
+ if (kernel_registry) {
+ const KernelCreateInfo* info = nullptr;
+ KernelRegistry::TypeConstraintMap type_constraints;
+ auto status = kernel_registry->TryFindKernel(
+ ep->Type(), "LinearAttention", kMSDomain, 1,
+ type_constraints, DefaultLoggingManager().DefaultLogger(), &info);
+ if (!status.IsOK()) return nullptr;
+ }
+ return ep;
+}
+
+void RunLinearAttentionTest(
+ const std::string& update_rule,
+ int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v,
+ float scale,
+ const std::vector& query,
+ const std::vector& key,
+ const std::vector& value,
+ const std::vector* initial_state,
+ const std::vector* decay,
+ const std::vector* beta_data) {
+ auto ep = TryGetEpWithLinearAttention();
+ if (!ep) {
+ GTEST_SKIP() << "LinearAttention kernel not registered";
+ return;
+ }
+
+ // Compute reference output (reference works in 4D layout)
+ std::vector expected_output_4d, expected_state;
+ LinearAttentionReference(update_rule, batch_size, num_heads, seq_length,
+ head_dim_k, head_dim_v, scale,
+ query, key, value, initial_state, decay, beta_data,
+ expected_output_4d, expected_state);
+
+ int bht = batch_size * num_heads * seq_length;
+ bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht);
+
+ // Convert from 4D (B,H,T,D) to 3D packed (B,T,H*D) for OpTester
+ auto query_3d = PackBHTD_to_BTHD(query, batch_size, num_heads, seq_length, head_dim_k);
+ auto key_3d = PackBHTD_to_BTHD(key, batch_size, num_heads, seq_length, head_dim_k);
+ auto value_3d = PackBHTD_to_BTHD(value, batch_size, num_heads, seq_length, head_dim_v);
+ auto output_3d = PackBHTD_to_BTHD(expected_output_4d, batch_size, num_heads, seq_length, head_dim_v);
+
+ OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain);
+ tester.AddAttribute("update_rule", update_rule);
+ tester.AddAttribute("scale", scale);
+ tester.AddAttribute("q_num_heads", static_cast(num_heads));
+ tester.AddAttribute("kv_num_heads", static_cast(num_heads));
+
+ // Add required inputs — 3D packed (B, T, H*D)
+ std::vector qk_dims = {batch_size, seq_length, num_heads * head_dim_k};
+ std::vector v_dims = {batch_size, seq_length, num_heads * head_dim_v};
+ tester.AddInput("query", qk_dims, query_3d);
+ tester.AddInput("key", qk_dims, key_3d);
+ tester.AddInput("value", v_dims, value_3d);
+
+ // Optional: past_state (4D, same format as before)
+ if (initial_state != nullptr) {
+ std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v};
+ tester.AddInput("past_state", state_dims, *initial_state);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ // Optional: decay — convert from (B,H,T[,dk]) to (B,T,H[*dk])
+ if (decay != nullptr) {
+ if (decay_broadcast_dk) {
+ // (B,H,T) → (B,T,H)
+ auto decay_3d = TransposeBHT_to_BTH(*decay, batch_size, num_heads, seq_length);
+ std::vector decay_dims = {batch_size, seq_length, num_heads};
+ tester.AddInput("decay", decay_dims, decay_3d);
+ } else {
+ // (B,H,T,dk) → (B,T,H*dk)
+ auto decay_3d = PackBHTD_to_BTHD(*decay, batch_size, num_heads, seq_length, head_dim_k);
+ std::vector decay_dims = {batch_size, seq_length, num_heads * head_dim_k};
+ tester.AddInput("decay", decay_dims, decay_3d);
+ }
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ // Optional: beta — convert from (B*H*T) flat to (B,T,H)
+ if (beta_data != nullptr) {
+ auto beta_3d = TransposeBHT_to_BTH(*beta_data, batch_size, num_heads, seq_length);
+ std::vector beta_dims = {batch_size, seq_length, num_heads};
+ tester.AddInput("beta", beta_dims, beta_3d);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ // Add outputs — output is 3D packed, state is 4D
+ std::vector out_dims = {batch_size, seq_length, num_heads * head_dim_v};
+ std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v};
+ tester.AddOutput("output", out_dims, output_3d, false, 0.005f, 0.005f);
+ tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f);
+
+ std::vector> execution_providers;
+ execution_providers.push_back(std::move(ep));
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+// GQA-aware test harness.
+// Q: (B, q_num_heads, T, dk), K: (B, n_k_heads, T, dk), V: (B, kv_num_heads, T, dv)
+void RunLinearAttentionGQATest(
+ const std::string& update_rule,
+ int batch_size, int q_num_heads, int kv_num_heads, int n_k_heads,
+ int seq_length, int head_dim_k, int head_dim_v,
+ float scale,
+ const std::vector& query,
+ const std::vector& key,
+ const std::vector& value,
+ const std::vector* initial_state,
+ const std::vector* decay,
+ const std::vector* beta_data) {
+ auto ep = TryGetEpWithLinearAttention();
+ if (!ep) {
+ GTEST_SKIP() << "LinearAttention kernel not registered";
+ return;
+ }
+
+ std::vector expected_output_4d, expected_state;
+ LinearAttentionGQAReference(update_rule, batch_size, q_num_heads, kv_num_heads, n_k_heads,
+ seq_length, head_dim_k, head_dim_v, scale,
+ query, key, value, initial_state, decay, beta_data,
+ expected_output_4d, expected_state);
+
+ int bht_kv = batch_size * kv_num_heads * seq_length;
+ bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht_kv);
+
+ // Pack to 3D — each tensor uses its own head count
+ auto query_3d = PackBHTD_to_BTHD(query, batch_size, q_num_heads, seq_length, head_dim_k);
+ auto key_3d = PackBHTD_to_BTHD(key, batch_size, n_k_heads, seq_length, head_dim_k);
+ auto value_3d = PackBHTD_to_BTHD(value, batch_size, kv_num_heads, seq_length, head_dim_v);
+ // Output always indexed by kv_num_heads (matches schema)
+ auto output_3d = PackBHTD_to_BTHD(expected_output_4d, batch_size, kv_num_heads, seq_length, head_dim_v);
+
+ OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain);
+ tester.AddAttribute("update_rule", update_rule);
+ tester.AddAttribute("scale", scale);
+ tester.AddAttribute("q_num_heads", static_cast(q_num_heads));
+ tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads));
+
+ tester.AddInput("query", {batch_size, seq_length, q_num_heads * head_dim_k}, query_3d);
+ tester.AddInput("key", {batch_size, seq_length, n_k_heads * head_dim_k}, key_3d);
+ tester.AddInput("value", {batch_size, seq_length, kv_num_heads * head_dim_v}, value_3d);
+
+ if (initial_state != nullptr) {
+ tester.AddInput("past_state", {batch_size, kv_num_heads, head_dim_k, head_dim_v}, *initial_state);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ if (decay != nullptr) {
+ if (decay_broadcast_dk) {
+ auto decay_3d = TransposeBHT_to_BTH(*decay, batch_size, kv_num_heads, seq_length);
+ tester.AddInput("decay", {batch_size, seq_length, kv_num_heads}, decay_3d);
+ } else {
+ auto decay_3d = PackBHTD_to_BTHD(*decay, batch_size, kv_num_heads, seq_length, head_dim_k);
+ tester.AddInput("decay", {batch_size, seq_length, kv_num_heads * head_dim_k}, decay_3d);
+ }
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ if (beta_data != nullptr) {
+ auto beta_3d = TransposeBHT_to_BTH(*beta_data, batch_size, kv_num_heads, seq_length);
+ tester.AddInput("beta", {batch_size, seq_length, kv_num_heads}, beta_3d);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ tester.AddOutput("output", {batch_size, seq_length, kv_num_heads * head_dim_v},
+ output_3d, false, 0.005f, 0.005f);
+ tester.AddOutput("present_state", {batch_size, kv_num_heads, head_dim_k, head_dim_v},
+ expected_state, false, 0.005f, 0.005f);
+
+ std::vector> execution_providers;
+ execution_providers.push_back(std::move(ep));
+ tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
+}
+
+} // namespace
+// ===========================================================================
+TEST(ContribOpLinearAttentionTest, LinearRule_SingleToken) {
+ const int B = 1, H = 1, T = 1, dk = 4, dv = 4;
+ float scale = 1.0f / std::sqrt(static_cast(dk));
+
+ std::vector query = {1.0f, 0.0f, 0.5f, -0.5f};
+ std::vector key = {0.5f, 0.5f, 0.0f, 1.0f};
+ std::vector value = {1.0f, 2.0f, 3.0f, 4.0f};
+
+ RunLinearAttentionTest("linear", B, H, T, dk, dv, scale,
+ query, key, value,
+ nullptr, nullptr, nullptr);
+}
+
+TEST(ContribOpLinearAttentionTest, LinearRule_MultiToken) {
+ const int B = 1, H = 1, T = 3, dk = 4, dv = 4;
+ float scale = 1.0f / std::sqrt(static_cast(dk));
+
+ std::vector query = {
+ 1.0f, 0.0f, 0.5f, -0.5f,
+ 0.5f, 1.0f, -0.5f, 0.0f,
+ 0.0f, -1.0f, 1.0f, 0.5f};
+ std::vector key = {
+ 0.5f, 0.5f, 0.0f, 1.0f,
+ 1.0f, 0.0f, 1.0f, 0.5f,
+ -0.5f, 1.0f, 0.5f, 0.0f};
+ std::vector value = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 2.0f, 1.0f, 0.0f, 3.0f,
+ 3.0f, 0.0f, 1.0f, 2.0f};
+
+ RunLinearAttentionTest("linear", B, H, T, dk, dv, scale,
+ query, key, value,
+ nullptr, nullptr, nullptr);
+}
+
+TEST(ContribOpLinearAttentionTest, LinearRule_WithInitialState) {
+ const int B = 1, H = 1, T = 2, dk = 4, dv = 4;
+ float scale = 1.0f / std::sqrt(static_cast(dk));
+
+ std::vector query = {
+ 1.0f, 0.0f, 0.5f, -0.5f,
+ 0.5f, 1.0f, -0.5f, 0.0f};
+ std::vector key = {
+ 0.5f, 0.5f, 0.0f, 1.0f,
+ 1.0f, 0.0f, 1.0f, 0.5f};
+ std::vector value = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 2.0f, 1.0f, 0.0f, 3.0f};
+
+ // Non-zero initial state
+ std::vector initial_state(dk * dv, 0.1f);
+
+ RunLinearAttentionTest("linear", B, H, T, dk, dv, scale,
+ query, key, value,
+ &initial_state, nullptr, nullptr);
+}
+
+// ===========================================================================
+// Test: Gated update rule (decay, no beta)
+// ===========================================================================
+TEST(ContribOpLinearAttentionTest, GatedRule_SingleToken) {
+ const int B = 1, H = 1, T = 1, dk = 4, dv = 4;
+ float scale = 1.0f / std::sqrt(static_cast(dk));
+
+ std::vector query = {1.0f, 0.0f, 0.5f, -0.5f};
+ std::vector key = {0.5f, 0.5f, 0.0f, 1.0f};
+ std::vector value = {1.0f, 2.0f, 3.0f, 4.0f};
+
+ // Decay in log-space (small negative values for slight decay)
+ std::vector decay = {-0.1f, -0.2f, -0.05f, -0.15f};
+
+ // Initial state (needed to see decay effect)
+ std::vector initial_state(dk * dv, 1.0f);
+
+ RunLinearAttentionTest("gated", B, H, T, dk, dv, scale,
+ query, key, value,
+ &initial_state, &decay, nullptr);
+}
+
+TEST(ContribOpLinearAttentionTest, GatedRule_MultiToken) {
+ const int B = 1, H = 1, T = 3, dk = 4, dv = 4;
+ float scale = 1.0f / std::sqrt(static_cast(dk));
+
+ std::vector query = {
+ 1.0f, 0.0f, 0.5f, -0.5f,
+ 0.5f, 1.0f, -0.5f, 0.0f,
+ 0.0f, -1.0f, 1.0f, 0.5f};
+ std::vector key = {
+ 0.5f, 0.5f, 0.0f, 1.0f,
+ 1.0f, 0.0f, 1.0f, 0.5f,
+ -0.5f, 1.0f, 0.5f, 0.0f};
+ std::vector value = {
+ 1.0f, 2.0f, 3.0f, 4.0f,
+ 2.0f, 1.0f, 0.0f, 3.0f,
+ 3.0f, 0.0f, 1.0f, 2.0f};
+ std::vector decay = {
+ -0.1f, -0.2f, -0.05f, -0.15f,
+ -0.2f, -0.1f, -0.3f, -0.05f,
+ -0.05f, -0.15f, -0.1f, -0.2f};
+
+ std::vector initial_state(dk * dv, 0.5f);
+
+ RunLinearAttentionTest("gated", B, H, T, dk, dv, scale,
+ query, key, value,
+ &initial_state, &decay, nullptr);
+}
+
+// ===========================================================================
+// Test: Delta update rule (no decay, uses beta)
+// ===========================================================================
+TEST(ContribOpLinearAttentionTest, DeltaRule_SingleToken) {
+ const int B = 1, H = 1, T = 1, dk = 4, dv = 4;
+ float scale = 1.0f / std::sqrt(static_cast(dk));
+
+ std::vector query = {1.0f, 0.0f, 0.5f, -0.5f};
+ std::vector key = {0.5f, 0.5f, 0.0f, 1.0f};
+ std::vector