diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 88f6a5a2b1be6..b8b7fa06d7f12 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -564,6 +564,7 @@ Do not modify directly.*
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
+|CausalConvWithState|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* past_state:**T**
*out* output:**T**
*out* present_state:**T**|1+|**T** = tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)|
|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)|
@@ -584,6 +585,7 @@ Do not modify directly.*
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|LinearAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_state:**S**
*in* decay:**T**
*in* beta:**T**
*out* output:**T**
*out* present_state:**S**|1+|**T** = tensor(float)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)|
@@ -1032,6 +1034,7 @@ Do not modify directly.*
|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BitmaskBiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)|
|BitmaskDropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)|
+|CausalConvWithState|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* past_state:**T**
*out* output:**T**
*out* present_state:**T**|1+|**T** = tensor(float), tensor(float16)|
|ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
@@ -1056,6 +1059,7 @@ Do not modify directly.*
|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|LinearAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_state:**S**
*in* decay:**T**
*in* beta:**T**
*out* output:**T**
*out* present_state:**S**|1+|**T** = tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc
new file mode 100644
index 0000000000000..72a97c60df84f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc
@@ -0,0 +1,326 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/bert/causal_conv_with_state.h"
+
+#include "core/framework/tensorprotoutils.h"
+#include "core/common/safeint.h"
+#include "core/platform/threadpool.h"
+
+#include
+#include
+#include
+
+using onnxruntime::concurrency::ThreadPool;
+
+namespace onnxruntime {
+namespace contrib {
+
+// These ops are internal-only, so register outside of onnx
+// Note: Only float is registered for CPU. The op schema allows float16/bfloat16
+// for CUDA compatibility, but the CPU kernel computes in float32 internally.
+// MLFloat16 CPU support would require input/output conversion buffers
+// (MlasConvertHalfToFloatBuffer / MlasConvertFloatToHalfBuffer).
+//
+// MLAS usage: No MLAS kernels are used currently. The depthwise causal conv
+// is implemented with scalar loops. Potential future optimization: use
+// MlasConv1D or vectorized MLAS routines for the 1D convolution.
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ CausalConvWithState, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ CausalConvWithState);
+
+REGISTER_KERNEL_TYPED(float)
+
+template
+CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) : OpKernel(info) {
+ int64_t ndim = info.GetAttrOrDefault("ndim", 1);
+ ORT_ENFORCE(ndim == 1, "CPU CausalConvWithState only supports ndim=1");
+ ndim_ = static_cast(ndim);
+
+ activation_ = info.GetAttrOrDefault("activation", "none");
+ ORT_ENFORCE(activation_ == "none" || activation_ == "silu" || activation_ == "swish",
+ "activation must be one of: none, silu, swish");
+}
+
+namespace {
+
+inline float ApplySilu(float x) {
+ return x / (1.0f + std::exp(-x));
+}
+
+template
+inline void ProcessChannelDecodeFixedK(
+ const float* past_row,
+ const float* input_val,
+ const float* w,
+ float bias_val,
+ bool apply_silu,
+ float* out_val,
+ float* present_row) {
+ constexpr int pad = K - 1;
+ float sum = bias_val;
+ if (past_row != nullptr) {
+ for (int k = 0; k < pad; ++k) {
+ sum += w[k] * past_row[k];
+ }
+ }
+ sum += w[pad] * input_val[0];
+
+ if (apply_silu) {
+ sum = ApplySilu(sum);
+ }
+ out_val[0] = sum;
+
+ if constexpr (pad > 0) {
+ if (past_row != nullptr) {
+ if constexpr (pad > 1) {
+ std::memcpy(present_row, past_row + 1, static_cast(pad - 1) * sizeof(float));
+ }
+ } else {
+ if constexpr (pad > 1) {
+ std::memset(present_row, 0, static_cast(pad - 1) * sizeof(float));
+ }
+ }
+ present_row[pad - 1] = input_val[0];
+ }
+}
+
+// Decode fast-path: L=1, no padded buffer needed.
+// The "visible window" for position 0 is [past_state(K-1 values), input(1 value)] = K values.
+// Compute dot(weight, window), shift state left by 1, append new input.
+void ProcessChannelDecode(
+ const float* past_row, // past_state for this (b,c): [K-1] or nullptr
+ const float* input_val, // &input[b,c,0] — single value
+ const float* w, // weight for this channel: [K]
+ float bias_val,
+ bool apply_silu,
+ float* out_val, // &output[b,c,0] — single value
+ float* present_row, // present_state for this (b,c): [K-1]
+ int64_t K) {
+ int64_t pad = K - 1;
+
+ // Dot product over the window: [past_state..., input]
+ float sum = bias_val;
+ // First K-1 elements come from past_state
+ if (past_row != nullptr) {
+ for (int64_t k = 0; k < pad; ++k) {
+ sum += w[k] * past_row[k];
+ }
+ }
+ // Last element is the current input
+ sum += w[pad] * input_val[0];
+
+ if (apply_silu) {
+ sum = ApplySilu(sum);
+ }
+ out_val[0] = sum;
+
+ // Update present_state: shift past_state left by 1, append input
+ if (pad > 0) {
+ if (past_row != nullptr && pad > 1) {
+ std::memcpy(present_row, past_row + 1, static_cast(pad - 1) * sizeof(float));
+ } else if (pad > 1) {
+ std::memset(present_row, 0, static_cast(pad - 1) * sizeof(float));
+ }
+ present_row[pad - 1] = input_val[0];
+ }
+}
+
+// Prefill path: L>1, uses padded buffer for the convolution window.
+void ProcessChannelPrefill(
+ const float* past_row, // past_state for this (b,c): [K-1] or nullptr
+ const float* in_row, // input for this (b,c): [L]
+ const float* w, // weight for this channel: [K]
+ float bias_val,
+ bool apply_silu,
+ float* out_row, // output for this (b,c): [L]
+ float* present_row, // present_state for this (b,c): [K-1]
+ float* padded_row, // scratch buffer: [K-1 + L]
+ int64_t L,
+ int64_t K) {
+ int64_t pad = K - 1;
+ int64_t padded_len = pad + L;
+
+ // Build padded window: [past_state | input]
+ if (past_row != nullptr) {
+ std::memcpy(padded_row, past_row, static_cast(pad) * sizeof(float));
+ } else {
+ std::memset(padded_row, 0, static_cast(pad) * sizeof(float));
+ }
+ std::memcpy(padded_row + pad, in_row, static_cast(L) * sizeof(float));
+
+ // Depthwise 1D convolution
+ for (int64_t l = 0; l < L; ++l) {
+ float sum = bias_val;
+ for (int64_t k = 0; k < K; ++k) {
+ sum += w[k] * padded_row[l + k];
+ }
+ if (apply_silu) {
+ sum = ApplySilu(sum);
+ }
+ out_row[l] = sum;
+ }
+
+ // Save present_state: last K-1 elements of (past_state | input)
+ std::memcpy(present_row, padded_row + padded_len - pad, static_cast(pad) * sizeof(float));
+}
+
+} // anonymous namespace
+
+template
+Status CausalConvWithState::Compute(OpKernelContext* context) const {
+ const Tensor* input_tensor = context->Input(0);
+ const Tensor* weight_tensor = context->Input(1);
+ const Tensor* bias_tensor = context->Input(2); // optional
+ const Tensor* past_state_tensor = context->Input(3); // optional
+
+ ORT_RETURN_IF_NOT(input_tensor != nullptr, "input is required");
+ ORT_RETURN_IF_NOT(weight_tensor != nullptr, "weight is required");
+
+ const auto& input_shape = input_tensor->Shape();
+ const auto& weight_shape = weight_tensor->Shape();
+
+ ORT_RETURN_IF_NOT(static_cast(input_shape.NumDimensions()) == 2 + ndim_,
+ "input must have ", 2 + ndim_, " dimensions for ndim=", ndim_);
+ ORT_RETURN_IF_NOT(static_cast(weight_shape.NumDimensions()) == 2 + ndim_,
+ "weight must have ", 2 + ndim_, " dimensions for ndim=", ndim_);
+
+ const int64_t batch_size = input_shape[0];
+ const int64_t channels = input_shape[1];
+
+ ORT_RETURN_IF_NOT(weight_shape[0] == channels, "weight channels must match input channels");
+ ORT_RETURN_IF_NOT(weight_shape[1] == 1, "weight must be depthwise (group=1)");
+
+ if (bias_tensor != nullptr) {
+ ORT_RETURN_IF_NOT(bias_tensor->Shape().NumDimensions() == 1 &&
+ bias_tensor->Shape()[0] == channels,
+ "bias must be 1D with size C");
+ }
+
+ // ==== ndim=1 implementation: (B, C, L) with kernel (C, 1, K) ====
+ if (ndim_ == 1) {
+ const int64_t L = input_shape[2];
+ const int64_t K = weight_shape[2];
+ const int64_t pad = K - 1;
+
+ if (past_state_tensor != nullptr) {
+ const auto& ps_shape = past_state_tensor->Shape();
+ ORT_RETURN_IF_NOT(ps_shape.NumDimensions() == 3 &&
+ ps_shape[0] == batch_size &&
+ ps_shape[1] == channels &&
+ ps_shape[2] == pad,
+ "past_state must be (B, C, K-1)");
+ }
+
+ // ==== Allocate outputs ====
+ Tensor* output_tensor = context->Output(0, input_shape);
+ float* output_data = output_tensor->MutableData();
+
+ TensorShape state_shape({batch_size, channels, pad});
+ Tensor* present_state_tensor = context->Output(1, state_shape);
+ float* present_data = present_state_tensor->MutableData();
+
+ const float* input_data = input_tensor->Data();
+ const float* weight_data = weight_tensor->Data();
+ const float* bias_data = bias_tensor ? bias_tensor->Data() : nullptr;
+ const float* past_data = past_state_tensor ? past_state_tensor->Data() : nullptr;
+ bool apply_silu = (activation_ == "silu" || activation_ == "swish");
+
+ // ==== Thread-parallel over (batch, channel) pairs ====
+ // Depthwise conv: each channel is fully independent.
+ int64_t total_tasks = batch_size * channels;
+ double cost_per_task = static_cast(L * K); // FLOPs per channel
+
+ auto* tp = context->GetOperatorThreadPool();
+
+ if (L == 1) {
+ // ==== Decode fast-path: no padded buffer needed ====
+ ThreadPool::TryParallelFor(
+ tp,
+ static_cast(total_tasks),
+ cost_per_task,
+ [&](std::ptrdiff_t first, std::ptrdiff_t last) {
+ for (std::ptrdiff_t task = first; task < last; ++task) {
+ int64_t b = task / channels;
+ int64_t c = task % channels;
+
+ const float* past_row = past_data
+ ? past_data + (b * channels + c) * pad
+ : nullptr;
+ const float* input_val = input_data + (b * channels + c) * L;
+ const float* w = weight_data + c * K;
+ float bias_val = bias_data ? bias_data[c] : 0.0f;
+ float* out_val = output_data + (b * channels + c) * L;
+ float* present_row = present_data + (b * channels + c) * pad;
+ switch (K) {
+ case 2:
+ ProcessChannelDecodeFixedK<2>(past_row, input_val, w, bias_val, apply_silu,
+ out_val, present_row);
+ break;
+ case 3:
+ ProcessChannelDecodeFixedK<3>(past_row, input_val, w, bias_val, apply_silu,
+ out_val, present_row);
+ break;
+ case 4:
+ ProcessChannelDecodeFixedK<4>(past_row, input_val, w, bias_val, apply_silu,
+ out_val, present_row);
+ break;
+ case 5:
+ ProcessChannelDecodeFixedK<5>(past_row, input_val, w, bias_val, apply_silu,
+ out_val, present_row);
+ break;
+ default:
+ ProcessChannelDecode(past_row, input_val, w, bias_val, apply_silu,
+ out_val, present_row, K);
+ break;
+ }
+ }
+ });
+ } else {
+ // ==== Prefill path: uses per-thread scratch buffer ====
+ ThreadPool::TryParallelFor(
+ tp,
+ static_cast(total_tasks),
+ cost_per_task,
+ [&](std::ptrdiff_t first, std::ptrdiff_t last) {
+ // Per-thread scratch buffer for padded input
+ std::vector padded_buf(static_cast(pad + L));
+
+ for (std::ptrdiff_t task = first; task < last; ++task) {
+ int64_t b = task / channels;
+ int64_t c = task % channels;
+
+ const float* past_row = past_data
+ ? past_data + (b * channels + c) * pad
+ : nullptr;
+ const float* in_row = input_data + (b * channels + c) * L;
+ const float* w = weight_data + c * K;
+ float bias_val = bias_data ? bias_data[c] : 0.0f;
+ float* out_row = output_data + (b * channels + c) * L;
+ float* present_row = present_data + (b * channels + c) * pad;
+
+ ProcessChannelPrefill(past_row, in_row, w, bias_val, apply_silu,
+ out_row, present_row, padded_buf.data(), L, K);
+ }
+ });
+ }
+
+ return Status::OK();
+ }
+
+ // ==== ndim=2 or ndim=3: not yet implemented ====
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "CausalConvWithState with ndim=", ndim_,
+ " is not yet implemented. "
+ "Currently only ndim=1 is supported.");
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h
new file mode 100644
index 0000000000000..e859f69677e80
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h
@@ -0,0 +1,26 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class CausalConvWithState final : public OpKernel {
+ public:
+ CausalConvWithState(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int ndim_;
+ std::string activation_;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/cpu/bert/linear_attention.cc
new file mode 100644
index 0000000000000..052e7df8bda14
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention.cc
@@ -0,0 +1,509 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/bert/linear_attention.h"
+
+#include "core/framework/tensorprotoutils.h"
+#include "core/common/safeint.h"
+#include "core/mlas/inc/mlas.h"
+#include "core/platform/threadpool.h"
+
+#include
+#include
+
+using onnxruntime::concurrency::ThreadPool;
+
+namespace onnxruntime {
+namespace contrib {
+
+// These ops are internal-only, so register outside of onnx
+// Note: Only float is registered for CPU. The op schema allows float16/bfloat16
+// for CUDA compatibility, but the CPU kernel computes in float32 internally.
+// MLFloat16 CPU support would require input/output conversion buffers
+// (MlasConvertHalfToFloatBuffer / MlasConvertFloatToHalfBuffer).
+//
+// MLAS usage: MlasGemm is used for retrieval (S^T @ k), state update (k ⊗ delta),
+// and query readout (S^T @ q) when d_k * d_v >= 4096. Smaller dimensions use
+// scalar loops to avoid MLAS overhead.
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ LinearAttention, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ LinearAttention);
+
+REGISTER_KERNEL_TYPED(float)
+
+template
+LinearAttention::LinearAttention(const OpKernelInfo& info) : OpKernel(info) {
+ int64_t q_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("q_num_heads", &q_num_heads).IsOK() && q_num_heads > 0,
+ "q_num_heads must be a positive integer");
+ q_num_heads_ = static_cast(q_num_heads);
+
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0,
+ "kv_num_heads must be a positive integer");
+ kv_num_heads_ = static_cast(kv_num_heads);
+
+ update_rule_ = info.GetAttrOrDefault("update_rule", "gated_delta");
+ ORT_ENFORCE(update_rule_ == "linear" || update_rule_ == "gated" ||
+ update_rule_ == "delta" || update_rule_ == "gated_delta",
+ "update_rule must be one of: linear, gated, delta, gated_delta");
+
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+
+ int64_t chunk_size = info.GetAttrOrDefault("chunk_size", 64);
+ // chunk_size_ reserved for future chunk-parallel prefill algorithm; not yet used.
+ chunk_size_ = static_cast(chunk_size);
+}
+
+namespace {
+
+// Process a single (batch, kv_head) pair across all timesteps.
+// This is the hot inner loop — called once per (b, h_kv) combination,
+// potentially from different threads.
+void ProcessHead(
+ float* S, // State matrix [d_k, d_v], in-place updated
+ const float* q_data, // Packed Q: (B, T, H_q*d_k)
+ const float* k_data, // Packed K: (B, T, n_k*d_k)
+ const float* v_data, // Packed V: (B, T, H_kv*d_v)
+ const float* decay_data, // Decay gates (may be nullptr)
+ const float* beta_data, // Beta rates (may be nullptr)
+ float* output_data, // Output: (B, T, H_q*d_v)
+ float* retrieved_buf, // Pre-allocated scratch buffer [d_v]
+ int64_t batch_idx,
+ int h_kv,
+ int h_k, // Key head index (may differ from h_kv when n_k != kv_num_heads)
+ int64_t seq_len,
+ int64_t d_k,
+ int64_t d_v,
+ int q_num_heads,
+ int kv_num_heads,
+ int n_k_heads,
+ int heads_per_group,
+ int64_t output_hidden,
+ float scale,
+ bool needs_decay,
+ bool decay_per_key_dim,
+ bool needs_beta,
+ bool beta_per_head,
+ bool needs_retrieval) {
+ const size_t dk = static_cast(d_k);
+ const size_t dv = static_cast(d_v);
+ const bool use_mlas = (d_k * d_v >= 4096);
+
+ for (int64_t t = 0; t < seq_len; ++t) {
+ // Pointers into packed 3D tensors at position [batch_idx, t, head*d]
+ const float* kt = k_data + (batch_idx * seq_len + t) * (n_k_heads * d_k) + h_k * d_k;
+ const float* vt = v_data + (batch_idx * seq_len + t) * (kv_num_heads * d_v) + h_kv * d_v;
+
+ // ---- Step 1: Apply decay S *= exp(g_t) ----
+ if (needs_decay) {
+ if (decay_per_key_dim) {
+ const float* gt = decay_data + (batch_idx * seq_len + t) * (kv_num_heads * d_k) + h_kv * d_k;
+ for (int64_t i = 0; i < d_k; ++i) {
+ float exp_g = std::exp(gt[i]);
+ // Scale row i of S by exp_g
+ for (int64_t j = 0; j < d_v; ++j) {
+ S[i * d_v + j] *= exp_g;
+ }
+ }
+ } else {
+ const float* gt = decay_data + (batch_idx * seq_len + t) * kv_num_heads + h_kv;
+ float exp_g = std::exp(gt[0]);
+ for (int64_t i = 0; i < d_k * d_v; ++i) {
+ S[i] *= exp_g;
+ }
+ }
+ }
+
+ // ---- Step 2: Retrieval = S^T @ k_t ----
+ if (needs_retrieval) {
+ if (use_mlas) {
+ MlasGemm(
+ CblasNoTrans,
+ CblasNoTrans,
+ 1,
+ dv,
+ dk,
+ 1.0f,
+ kt,
+ dk,
+ S,
+ dv,
+ 0.0f,
+ retrieved_buf,
+ dv,
+ nullptr,
+ nullptr);
+ } else {
+ for (int64_t j = 0; j < d_v; ++j) {
+ float acc = 0.0f;
+ for (int64_t i = 0; i < d_k; ++i) {
+ acc += S[i * d_v + j] * kt[i];
+ }
+ retrieved_buf[static_cast(j)] = acc;
+ }
+ }
+ }
+
+ // ---- Step 3: State update ----
+ if (needs_beta) {
+ float bt;
+ if (beta_per_head) {
+ bt = beta_data[(batch_idx * seq_len + t) * kv_num_heads + h_kv];
+ } else {
+ bt = beta_data[(batch_idx * seq_len + t) * 1];
+ }
+ // Compute delta = beta * (v_t - retrieved) in-place into retrieved_buf
+ for (size_t j = 0; j < dv; ++j) {
+ retrieved_buf[j] = bt * (vt[j] - retrieved_buf[j]);
+ }
+ // S += k_t outer delta
+ if (use_mlas) {
+ MlasGemm(
+ CblasNoTrans,
+ CblasNoTrans,
+ dk,
+ dv,
+ 1,
+ 1.0f,
+ kt,
+ 1,
+ retrieved_buf,
+ dv,
+ 1.0f,
+ S,
+ dv,
+ nullptr,
+ nullptr);
+ } else {
+ for (int64_t i = 0; i < d_k; ++i) {
+ float* s_row = S + i * d_v;
+ const float ki = kt[i];
+ for (int64_t j = 0; j < d_v; ++j) {
+ s_row[j] += ki * retrieved_buf[static_cast(j)];
+ }
+ }
+ }
+ } else {
+ // linear/gated: S += k_t outer v_t
+ if (use_mlas) {
+ MlasGemm(
+ CblasNoTrans,
+ CblasNoTrans,
+ dk,
+ dv,
+ 1,
+ 1.0f,
+ kt,
+ 1,
+ vt,
+ dv,
+ 1.0f,
+ S,
+ dv,
+ nullptr,
+ nullptr);
+ } else {
+ for (int64_t i = 0; i < d_k; ++i) {
+ float* s_row = S + i * d_v;
+ const float ki = kt[i];
+ for (int64_t j = 0; j < d_v; ++j) {
+ s_row[j] += ki * vt[j];
+ }
+ }
+ }
+ }
+
+ // ---- Step 4: Query readout for each q head in this kv group ----
+ // o_t = scale * q_t^T @ S -> [1, d_v]
+ //
+ // Standard GQA (heads_per_group > 0): multiple Q heads share this KV state.
+ // Inverse GQA (heads_per_group == 0): multiple KV states share one Q head.
+ if (heads_per_group > 0) {
+ for (int g = 0; g < heads_per_group; ++g) {
+ int h_q = h_kv * heads_per_group + g;
+ const float* qt = q_data + (batch_idx * seq_len + t) * (q_num_heads * d_k) + h_q * d_k;
+ float* ot = output_data + (batch_idx * seq_len + t) * output_hidden + h_q * d_v;
+
+ if (use_mlas) {
+ // Use alpha=1.0 to hit the MLAS M=1 gemv fast path, then scale output.
+ MlasGemm(
+ CblasNoTrans,
+ CblasNoTrans,
+ 1,
+ dv,
+ dk,
+ 1.0f,
+ qt,
+ dk,
+ S,
+ dv,
+ 0.0f,
+ ot,
+ dv,
+ nullptr,
+ nullptr);
+ if (scale != 1.0f) {
+ for (size_t j = 0; j < dv; ++j) {
+ ot[j] *= scale;
+ }
+ }
+ } else {
+ for (int64_t j = 0; j < d_v; ++j) {
+ float acc = 0.0f;
+ for (int64_t i = 0; i < d_k; ++i) {
+ acc += qt[i] * S[i * d_v + j];
+ }
+ ot[j] = scale * acc;
+ }
+ }
+ }
+ } else {
+ // Inverse GQA: this KV head's Q is determined by h_kv * q_num / kv_num
+ int h_q = h_kv * q_num_heads / kv_num_heads;
+ const float* qt = q_data + (batch_idx * seq_len + t) * (q_num_heads * d_k) + h_q * d_k;
+ float* ot = output_data + (batch_idx * seq_len + t) * output_hidden + h_kv * d_v;
+
+ if (use_mlas) {
+ // Use alpha=1.0 to hit the MLAS M=1 gemv fast path, then scale output.
+ MlasGemm(
+ CblasNoTrans,
+ CblasNoTrans,
+ 1,
+ dv,
+ dk,
+ 1.0f,
+ qt,
+ dk,
+ S,
+ dv,
+ 0.0f,
+ ot,
+ dv,
+ nullptr,
+ nullptr);
+ if (scale != 1.0f) {
+ for (size_t j = 0; j < dv; ++j) {
+ ot[j] *= scale;
+ }
+ }
+ } else {
+ for (int64_t j = 0; j < d_v; ++j) {
+ float acc = 0.0f;
+ for (int64_t i = 0; i < d_k; ++i) {
+ acc += qt[i] * S[i * d_v + j];
+ }
+ ot[j] = scale * acc;
+ }
+ }
+ }
+ }
+}
+
+} // anonymous namespace
+
+template
+Status LinearAttention::Compute(OpKernelContext* context) const {
+ // ==== Input Retrieval ====
+ const Tensor* query_tensor = context->Input(0);
+ const Tensor* key_tensor = context->Input(1); // optional
+ const Tensor* value_tensor = context->Input(2); // optional
+ const Tensor* past_state_tensor = context->Input(3); // optional
+ const Tensor* decay_tensor = context->Input(4); // optional
+ const Tensor* beta_tensor = context->Input(5); // optional
+
+ ORT_RETURN_IF_NOT(query_tensor != nullptr, "query input is required");
+
+ const auto& query_shape = query_tensor->Shape();
+ ORT_RETURN_IF_NOT(query_shape.NumDimensions() == 3,
+ "query must be 3D [B, T, H*D], got ", query_shape.NumDimensions(), "D");
+
+ const int64_t batch_size = query_shape[0];
+ const int64_t seq_len = query_shape[1];
+ const int64_t query_hidden = query_shape[2];
+
+ // ==== Determine d_k and d_v ====
+ ORT_RETURN_IF_NOT(key_tensor != nullptr && value_tensor != nullptr,
+ "key and value inputs are required");
+
+ int64_t d_k, d_v;
+ int n_k_heads;
+ const float* q_data;
+ const float* k_data;
+ const float* v_data;
+
+ {
+ const auto& key_shape = key_tensor->Shape();
+ const auto& value_shape = value_tensor->Shape();
+ ORT_RETURN_IF_NOT(key_shape.NumDimensions() == 3 && value_shape.NumDimensions() == 3,
+ "key and value must be 3D");
+ ORT_RETURN_IF_NOT(key_shape[0] == batch_size && value_shape[0] == batch_size,
+ "batch size mismatch");
+ ORT_RETURN_IF_NOT(key_shape[1] == seq_len && value_shape[1] == seq_len,
+ "sequence length mismatch");
+
+ d_k = query_hidden / q_num_heads_;
+ ORT_RETURN_IF_NOT(query_hidden == q_num_heads_ * d_k,
+ "query hidden size must be divisible by q_num_heads");
+ ORT_RETURN_IF_NOT(key_shape[2] % d_k == 0,
+ "key hidden size must be divisible by d_k");
+ n_k_heads = static_cast(key_shape[2] / d_k);
+ d_v = value_shape[2] / kv_num_heads_;
+ ORT_RETURN_IF_NOT(value_shape[2] == kv_num_heads_ * d_v,
+ "value hidden size must be divisible by kv_num_heads");
+
+ q_data = query_tensor->Data();
+ k_data = key_tensor->Data();
+ v_data = value_tensor->Data();
+ }
+
+ // ==== Determine scale ====
+ float s = scale_;
+ if (s == 0.0f) {
+ s = 1.0f / std::sqrt(static_cast(d_k));
+ }
+
+ // ==== Validate optional inputs based on update_rule ====
+ bool needs_decay = (update_rule_ == "gated" || update_rule_ == "gated_delta");
+ bool needs_beta = (update_rule_ == "delta" || update_rule_ == "gated_delta");
+ bool needs_retrieval = (update_rule_ == "delta" || update_rule_ == "gated_delta");
+
+ ORT_RETURN_IF_NOT(!needs_decay || decay_tensor != nullptr,
+ "decay input is required for update_rule=", update_rule_);
+ ORT_RETURN_IF_NOT(!needs_beta || beta_tensor != nullptr,
+ "beta input is required for update_rule=", update_rule_);
+
+ const float* decay_data = decay_tensor ? decay_tensor->Data() : nullptr;
+ const float* beta_data = beta_tensor ? beta_tensor->Data() : nullptr;
+
+ bool decay_per_key_dim = false;
+ if (decay_tensor != nullptr) {
+ const auto& decay_shape = decay_tensor->Shape();
+ ORT_RETURN_IF_NOT(decay_shape.NumDimensions() == 3,
+ "decay must be rank 3 (B, T, ...), got rank ", decay_shape.NumDimensions());
+ ORT_RETURN_IF_NOT(decay_shape[0] == batch_size && decay_shape[1] == seq_len,
+ "decay dims 0/1 must match (B=", batch_size, ", T=", seq_len,
+ "), got (", decay_shape[0], ", ", decay_shape[1], ")");
+ int64_t decay_last = decay_shape[2];
+ if (decay_last == kv_num_heads_ * d_k) {
+ decay_per_key_dim = true;
+ } else {
+ ORT_RETURN_IF_NOT(decay_last == kv_num_heads_,
+ "decay last dim must be H_kv or H_kv*d_k");
+ }
+ }
+
+ bool beta_per_head = false;
+ if (beta_tensor != nullptr) {
+ const auto& beta_shape = beta_tensor->Shape();
+ ORT_RETURN_IF_NOT(beta_shape.NumDimensions() == 3,
+ "beta must be rank 3 (B, T, ...), got rank ", beta_shape.NumDimensions());
+ ORT_RETURN_IF_NOT(beta_shape[0] == batch_size && beta_shape[1] == seq_len,
+ "beta dims 0/1 must match (B=", batch_size, ", T=", seq_len,
+ "), got (", beta_shape[0], ", ", beta_shape[1], ")");
+ int64_t beta_last = beta_shape[2];
+ if (beta_last == kv_num_heads_) {
+ beta_per_head = true;
+ } else {
+ ORT_RETURN_IF_NOT(beta_last == 1, "beta last dim must be H_kv or 1");
+ }
+ }
+
+ // ==== Initialize state: write directly into output present_state ====
+ // present_state: (B, H_kv, d_k, d_v)
+ TensorShape state_shape({batch_size, static_cast(kv_num_heads_), d_k, d_v});
+ Tensor* present_state_tensor = context->Output(1, state_shape);
+ float* state_data = present_state_tensor->MutableData();
+ int64_t state_per_head = d_k * d_v;
+ int64_t total_state = batch_size * kv_num_heads_ * state_per_head;
+
+ if (past_state_tensor != nullptr) {
+ const auto& ps_shape = past_state_tensor->Shape();
+ ORT_RETURN_IF_NOT(ps_shape.NumDimensions() == 4 &&
+ ps_shape[0] == batch_size &&
+ ps_shape[1] == kv_num_heads_ &&
+ ps_shape[2] == d_k &&
+ ps_shape[3] == d_v,
+ "past_state must be (B, H_kv, d_k, d_v)");
+ const float* ps_data = past_state_tensor->Data();
+ std::memcpy(state_data, ps_data, static_cast(total_state) * sizeof(float));
+ } else {
+ std::memset(state_data, 0, static_cast(total_state) * sizeof(float));
+ }
+
+ // ==== Allocate output ====
+ // Output hidden dim: max(q_num_heads, kv_num_heads) * d_v
+ // Standard GQA: q_num_heads * d_v; Inverse GQA: kv_num_heads * d_v
+ int64_t output_hidden = std::max(q_num_heads_, kv_num_heads_) * d_v;
+ TensorShape output_shape({batch_size, seq_len, output_hidden});
+ Tensor* output_tensor = context->Output(0, output_shape);
+ float* output_data = output_tensor->MutableData();
+
+ // ==== GQA head mapping ====
+ // Standard GQA: q_num_heads >= kv_num_heads, multiple Q heads per KV group.
+ // Inverse GQA: q_num_heads < kv_num_heads (e.g., Qwen3.5 9B: n_k=16, n_kv=32).
+ // Also n_k_heads may differ from both (K has its own head count).
+ int heads_per_group; // Q heads per KV group (0 if inverse GQA)
+ if (q_num_heads_ >= kv_num_heads_) {
+ ORT_RETURN_IF_NOT(q_num_heads_ % kv_num_heads_ == 0,
+ "q_num_heads must be divisible by kv_num_heads");
+ heads_per_group = q_num_heads_ / kv_num_heads_;
+ } else {
+ ORT_RETURN_IF_NOT(kv_num_heads_ % q_num_heads_ == 0,
+ "kv_num_heads must be divisible by q_num_heads (inverse GQA)");
+ heads_per_group = 0; // signals inverse GQA to ProcessHead
+ }
+
+ // K-to-KV head mapping: when n_k < kv_num_heads, multiple KV heads share one K head
+ ORT_RETURN_IF_NOT(kv_num_heads_ % n_k_heads == 0,
+ "kv_num_heads must be divisible by n_k_heads");
+ int kv_per_k_head = kv_num_heads_ / n_k_heads;
+
+ // ==== Thread-parallel over (batch, kv_head) pairs ====
+ // Each (b, h_kv) pair is fully independent — the state matrix for each
+ // head is disjoint, and the sequential token dependency is within a
+ // single head only. This gives us batch_size * kv_num_heads parallel tasks.
+ int64_t total_tasks = batch_size * kv_num_heads_;
+
+ // Cost estimate: per task processes seq_len tokens, each doing ~3*d_k*d_v FLOPs
+ double cost_per_task = static_cast(seq_len) * static_cast(d_k * d_v) * 3.0;
+
+ auto* tp = context->GetOperatorThreadPool();
+
+ ThreadPool::TryParallelFor(
+ tp,
+ static_cast(total_tasks),
+ cost_per_task,
+ [&](std::ptrdiff_t first, std::ptrdiff_t last) {
+ // Pre-allocate scratch buffer per thread-batch to avoid malloc in hot loop
+ std::vector retrieved_buf(static_cast(d_v));
+
+ for (std::ptrdiff_t task = first; task < last; ++task) {
+ int64_t b = task / kv_num_heads_;
+ int h_kv = static_cast(task % kv_num_heads_);
+ int h_k = h_kv / kv_per_k_head; // map KV head to K head
+
+ float* S = state_data + (b * kv_num_heads_ + h_kv) * state_per_head;
+
+ ProcessHead(
+ S, q_data, k_data, v_data, decay_data, beta_data, output_data,
+ retrieved_buf.data(),
+ b, h_kv, h_k, seq_len, d_k, d_v,
+ q_num_heads_, kv_num_heads_, n_k_heads, heads_per_group, output_hidden,
+ s, needs_decay, decay_per_key_dim, needs_beta, beta_per_head,
+ needs_retrieval);
+ }
+ });
+
+ return Status::OK();
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention.h b/onnxruntime/contrib_ops/cpu/bert/linear_attention.h
new file mode 100644
index 0000000000000..9aaa9f80a3369
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention.h
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class LinearAttention final : public OpKernel {
+ public:
+ LinearAttention(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int q_num_heads_;
+ int kv_num_heads_;
+ std::string update_rule_;
+ float scale_;
+ int chunk_size_;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 0949ad6c36f58..8588449a3895c 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -29,6 +29,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SparseAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinearAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CausalConvWithState);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
#if !defined(DISABLE_GENERATION_OPS)
@@ -319,6 +321,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
#if !defined(DISABLE_GENERATION_OPS)
diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc
new file mode 100644
index 0000000000000..e60a87afe5f25
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc
@@ -0,0 +1,120 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cuda/bert/causal_conv_with_state.h"
+#include "contrib_ops/cuda/bert/causal_conv_with_state_impl.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cuda_type_conversion.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda; // CudaKernel, Stream, GetDeviceProp, ToCudaType
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ CausalConvWithState, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ CausalConvWithState);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+template
+CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) : CudaKernel(info) {
+ int64_t ndim = info.GetAttrOrDefault("ndim", 1);
+ ORT_ENFORCE(ndim == 1, "CUDA CausalConvWithState only supports ndim=1");
+ ndim_ = static_cast(ndim);
+
+ activation_ = info.GetAttrOrDefault("activation", "none");
+ ORT_ENFORCE(activation_ == "none" || activation_ == "silu" || activation_ == "swish",
+ "activation must be one of: none, silu, swish");
+}
+
+template
+Status CausalConvWithState::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* input_tensor = context->Input(0);
+ const Tensor* weight_tensor = context->Input(1);
+ const Tensor* bias_tensor = context->Input(2); // optional
+ const Tensor* past_state_tensor = context->Input(3); // optional
+
+ ORT_RETURN_IF_NOT(input_tensor != nullptr, "input is required");
+ ORT_RETURN_IF_NOT(weight_tensor != nullptr, "weight is required");
+
+ const auto& input_shape = input_tensor->Shape();
+ const auto& weight_shape = weight_tensor->Shape();
+
+ // Validate input rank and weight rank
+ ORT_RETURN_IF_NOT(input_shape.NumDimensions() == 3,
+ "input must be rank 3 (batch, channels, length), got rank ", input_shape.NumDimensions());
+ ORT_RETURN_IF_NOT(weight_shape.NumDimensions() == 3,
+ "weight must be rank 3 (channels, 1, kernel_size), got rank ", weight_shape.NumDimensions());
+
+ const int batch_size = static_cast(input_shape[0]);
+ const int channels = static_cast(input_shape[1]);
+ const int L = static_cast(input_shape[2]);
+ const int K = static_cast(weight_shape[2]);
+ const int pad = K - 1;
+
+ // Validate weight shape compatibility
+ ORT_RETURN_IF_NOT(weight_shape[0] == channels,
+ "weight[0] (", weight_shape[0], ") must match input channels (", channels, ")");
+ ORT_RETURN_IF_NOT(weight_shape[1] == 1,
+ "weight[1] must be 1 for depthwise convolution, got ", weight_shape[1]);
+
+ // Validate optional bias shape
+ if (bias_tensor != nullptr) {
+ const auto& bias_shape = bias_tensor->Shape();
+ ORT_RETURN_IF_NOT(bias_shape.NumDimensions() == 1 && bias_shape[0] == channels,
+ "bias must have shape (", channels, "), got ", bias_shape.ToString());
+ }
+
+ // Validate optional past_state shape
+ if (past_state_tensor != nullptr) {
+ const auto& past_shape = past_state_tensor->Shape();
+ ORT_RETURN_IF_NOT(past_shape.NumDimensions() == 3,
+ "past_state must be rank 3 (batch, channels, kernel_size-1), got rank ", past_shape.NumDimensions());
+ ORT_RETURN_IF_NOT(past_shape[0] == batch_size && past_shape[1] == channels && past_shape[2] == pad,
+ "past_state shape mismatch: expected (", batch_size, ", ", channels, ", ", pad,
+ "), got (", past_shape[0], ", ", past_shape[1], ", ", past_shape[2], ")");
+ }
+
+ // Allocate outputs
+ Tensor* output_tensor = context->Output(0, input_shape);
+ TensorShape state_shape({batch_size, channels, pad});
+ Tensor* present_state_tensor = context->Output(1, state_shape);
+
+ // Note: no need to zero-initialize present_state — the kernel writes all
+ // positions unconditionally. When past_state is null, the kernel uses
+ // zeros for the padding region internally.
+ // Note: past_state pointer is passed to kernel; kernel reads it directly
+
+ bool apply_silu = (activation_ == "silu" || activation_ == "swish");
+
+ typedef typename OrtToCudaType::type CudaT;
+
+ return LaunchCausalConvWithStateKernel(
+ Stream(context),
+ reinterpret_cast(input_tensor->Data()),
+ reinterpret_cast(weight_tensor->Data()),
+ bias_tensor ? reinterpret_cast(bias_tensor->Data()) : nullptr,
+ past_state_tensor ? reinterpret_cast(past_state_tensor->Data()) : nullptr,
+ reinterpret_cast(output_tensor->MutableData()),
+ reinterpret_cast(present_state_tensor->MutableData()),
+ batch_size,
+ channels,
+ L,
+ K,
+ apply_silu,
+ GetDeviceProp().maxThreadsPerBlock);
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h
new file mode 100644
index 0000000000000..37a9c29b5e749
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+#include
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+class CausalConvWithState final : public onnxruntime::cuda::CudaKernel {
+ public:
+ CausalConvWithState(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int ndim_;
+ std::string activation_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu
new file mode 100644
index 0000000000000..87774f7205bc9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu
@@ -0,0 +1,459 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Fused causal depthwise conv1d CUDA kernel with stateful carry and optional SiLU activation.
+//
+// Design: One thread block per (batch, channel). Two execution paths:
+//
+// 1. Decode (L=1): The convolution window is [past_state(K-1), input(1)].
+// Load K values into registers, compute a single dot product, shift state.
+// One thread block does the entire operation — zero shared memory needed.
+//
+// 2. Prefill (L>1): Load past_state + input into shared memory as a padded buffer,
+// then each thread computes one output position's convolution.
+//
+// State is stored in type T to match the op schema convention.
+
+#include
+#include
+#include
+#include "contrib_ops/cuda/bert/causal_conv_with_state_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+namespace {
+
+template
+__device__ __forceinline__ float to_float(T val);
+
+template <>
+__device__ __forceinline__ float to_float(float val) { return val; }
+
+template <>
+__device__ __forceinline__ float to_float(half val) { return __half2float(val); }
+
+#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
+template <>
+__device__ __forceinline__ float to_float(__nv_bfloat16 val) { return __bfloat162float(val); }
+#endif
+
+template
+__device__ __forceinline__ T from_float(float val);
+
+template <>
+__device__ __forceinline__ float from_float(float val) { return val; }
+
+template <>
+__device__ __forceinline__ half from_float(float val) { return __float2half(val); }
+
+#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
+template <>
+__device__ __forceinline__ __nv_bfloat16 from_float(float val) { return __float2bfloat16(val); }
+#endif
+
+__device__ __forceinline__ float silu_fn(float x) {
+ return x / (1.0f + expf(-x));
+}
+
+// =============================================================================
+// Decode kernel: L=1, one dot product per (batch, channel)
+// Grid: (batch_size * channels, 1, 1)
+// Block: (1, 1, 1) — one thread per (batch, channel)
+// No shared memory needed.
+// =============================================================================
+template
+__global__ void CausalConvDecodeKernel(
+ const T* __restrict__ input, // [B, C, 1]
+ const T* __restrict__ weight, // [C, 1, K]
+ const T* __restrict__ bias, // [C] or nullptr
+ const T* __restrict__ past_state, // [B, C, K-1] or nullptr
+ T* __restrict__ output, // [B, C, 1]
+ T* __restrict__ present_state, // [B, C, K-1]
+ int batch_channels, // = batch_size * channels (actual element count)
+ int channels,
+ int kernel_size,
+ bool apply_silu) {
+ const int bc = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bc >= batch_channels) return;
+ const int b = bc / channels;
+ const int c = bc % channels;
+
+ const int pad = kernel_size - 1;
+
+ // Cache input value in register — avoids redundant global reads
+ const float input_val = to_float(input[(int64_t)b * channels + c]);
+
+ // Cache past_state base pointer for this (b, c)
+ const T* ps_in = (past_state != nullptr)
+ ? past_state + (int64_t)b * channels * pad + (int64_t)c * pad
+ : nullptr;
+
+ // Load weight for this channel: [K] values
+ // weight layout: [C, 1, K], so channel c starts at c * K
+ float sum = (bias != nullptr) ? to_float(bias[c]) : 0.0f;
+
+ // Convolution window: [past_state[0..K-2], input[0]]
+ for (int k = 0; k < pad; ++k) {
+ float wk = to_float(weight[c * kernel_size + k]);
+ float xk = (ps_in != nullptr) ? to_float(ps_in[k]) : 0.0f;
+ sum += wk * xk;
+ }
+ // Last element of window is current input
+ sum += to_float(weight[c * kernel_size + pad]) * input_val;
+
+ if (apply_silu) {
+ sum = silu_fn(sum);
+ }
+ output[(int64_t)b * channels + c] = from_float(sum);
+
+ // Update present_state: shift left by 1, append input
+ T* ps_out = present_state + (int64_t)b * channels * pad + (int64_t)c * pad;
+ for (int k = 0; k < pad - 1; ++k) {
+ ps_out[k] = (ps_in != nullptr) ? ps_in[k + 1] : from_float(0.0f);
+ }
+ if (pad > 0) {
+ ps_out[pad - 1] = from_float(input_val);
+ }
+}
+
+template
+__global__ void CausalConvDecodeKernelFixedK(
+ const T* __restrict__ input,
+ const T* __restrict__ weight,
+ const T* __restrict__ bias,
+ const T* __restrict__ past_state,
+ T* __restrict__ output,
+ T* __restrict__ present_state,
+ int batch_channels,
+ int channels,
+ bool apply_silu) {
+ const int bc = blockIdx.x * blockDim.x + threadIdx.x;
+ if (bc >= batch_channels) return;
+
+ const int b = bc / channels;
+ const int c = bc % channels;
+ constexpr int pad = K - 1;
+
+ float sum = (bias != nullptr) ? to_float(bias[c]) : 0.0f;
+ const T* w = weight + static_cast(c) * K;
+ const T* ps_in = (past_state != nullptr)
+ ? past_state + static_cast(b) * channels * pad + static_cast(c) * pad
+ : nullptr;
+
+ if (ps_in != nullptr) {
+#pragma unroll
+ for (int k = 0; k < pad; ++k) {
+ sum += to_float(w[k]) * to_float(ps_in[k]);
+ }
+ }
+ sum += to_float(w[pad]) * to_float(input[static_cast(b) * channels + c]);
+
+ if (apply_silu) {
+ sum = silu_fn(sum);
+ }
+ output[static_cast(b) * channels + c] = from_float(sum);
+
+ T* ps_out = present_state + static_cast(b) * channels * pad + static_cast(c) * pad;
+ if constexpr (pad > 0) {
+#pragma unroll
+ for (int k = 0; k < pad - 1; ++k) {
+ ps_out[k] = (ps_in != nullptr) ? ps_in[k + 1] : from_float(0.0f);
+ }
+ ps_out[pad - 1] = input[static_cast(b) * channels + c];
+ }
+}
+
+// =============================================================================
+// Prefill kernel: L>1, one thread per output position within a (batch, channel)
+// Grid: (batch_size, channels, 1)
+// Block: (min(L, max_threads), 1, 1)
+// Shared memory: padded input buffer [K-1 + L] floats + weight [K] floats
+// =============================================================================
+template
+__global__ void CausalConvPrefillKernel(
+ const T* __restrict__ input, // [B, C, L]
+ const T* __restrict__ weight, // [C, 1, K]
+ const T* __restrict__ bias, // [C] or nullptr
+ const T* __restrict__ past_state, // [B, C, K-1] or nullptr
+ T* __restrict__ output, // [B, C, L]
+ T* __restrict__ present_state, // [B, C, K-1]
+ int seq_len,
+ int channels,
+ int kernel_size,
+ bool apply_silu) {
+ const int b = blockIdx.x;
+ const int c = blockIdx.y;
+ const int tid = threadIdx.x;
+
+ const int pad = kernel_size - 1;
+ const int padded_len = pad + seq_len;
+
+ // Shared memory: padded input [pad + L] floats + weight [K] floats
+ extern __shared__ float smem[];
+ float* s_padded = smem;
+ float* s_weight = smem + padded_len;
+
+ // Cooperatively load padded input into shared memory
+ // Past state portion: [0..pad-1]
+ for (int i = tid; i < pad; i += blockDim.x) {
+ if (past_state != nullptr) {
+ s_padded[i] = to_float(past_state[(int64_t)b * channels * pad + (int64_t)c * pad + i]);
+ } else {
+ s_padded[i] = 0.0f;
+ }
+ }
+ // Current input portion: [pad..pad+L-1]
+ for (int i = tid; i < seq_len; i += blockDim.x) {
+ s_padded[pad + i] = to_float(input[((int64_t)b * channels + c) * seq_len + i]);
+ }
+ // Load weight into shared memory
+ for (int i = tid; i < kernel_size; i += blockDim.x) {
+ s_weight[i] = to_float(weight[(int64_t)c * kernel_size + i]);
+ }
+ __syncthreads();
+
+ // Each thread computes one output position
+ float bias_val = (bias != nullptr) ? to_float(bias[c]) : 0.0f;
+ for (int l = tid; l < seq_len; l += blockDim.x) {
+ float sum = bias_val;
+ for (int k = 0; k < kernel_size; ++k) {
+ sum += s_weight[k] * s_padded[l + k];
+ }
+ if (apply_silu) {
+ sum = silu_fn(sum);
+ }
+ output[((int64_t)b * channels + c) * seq_len + l] = from_float(sum);
+ }
+
+ // Save present_state: last K-1 elements of padded input
+ __syncthreads();
+ T* ps = present_state + (int64_t)b * channels * pad + (int64_t)c * pad;
+ for (int i = tid; i < pad; i += blockDim.x) {
+ ps[i] = from_float(s_padded[padded_len - pad + i]);
+ }
+}
+
+// =============================================================================
+// Batched prefill kernel: processes CHANNELS_PER_BLOCK channels per block
+// to improve occupancy when per-channel work is small (short sequences).
+//
+// Grid: (batch_size, ceil(channels / CPB), 1)
+// Block: (threads, 1, 1) — threads are split across CPB channels
+//
+// Each channel gets (blockDim.x / CPB) threads. Weight is loaded into
+// registers (small K), input+state goes through shared memory.
+// =============================================================================
+template
+__global__ void CausalConvPrefillKernelBatched(
+ const T* __restrict__ input, // [B, C, L]
+ const T* __restrict__ weight, // [C, 1, K]
+ const T* __restrict__ bias, // [C] or nullptr
+ const T* __restrict__ past_state, // [B, C, K-1] or nullptr
+ T* __restrict__ output, // [B, C, L]
+ T* __restrict__ present_state, // [B, C, K-1]
+ int seq_len,
+ int channels,
+ int kernel_size,
+ bool apply_silu) {
+ const int b = blockIdx.x;
+ const int c_base = blockIdx.y * CPB;
+ const int tid = threadIdx.x;
+
+ const int pad = kernel_size - 1;
+ const int padded_len = pad + seq_len;
+
+ // Which channel within this block's CPB group does this thread serve?
+ const int threads_per_channel = blockDim.x / CPB;
+ const int local_ch = tid / threads_per_channel; // 0..CPB-1
+ const int local_tid = tid % threads_per_channel; // thread index within channel
+ const int c = c_base + local_ch;
+
+ // Shared memory: CPB * (padded_len + kernel_size) floats
+ extern __shared__ float smem[];
+ const int smem_per_ch = padded_len + kernel_size;
+ float* s_padded = smem + local_ch * smem_per_ch;
+ float* s_weight = s_padded + padded_len;
+
+ if (c < channels) {
+ // Load past state
+ for (int i = local_tid; i < pad; i += threads_per_channel) {
+ if (past_state != nullptr) {
+ s_padded[i] = to_float(past_state[(int64_t)b * channels * pad + (int64_t)c * pad + i]);
+ } else {
+ s_padded[i] = 0.0f;
+ }
+ }
+ // Load input
+ for (int i = local_tid; i < seq_len; i += threads_per_channel) {
+ s_padded[pad + i] = to_float(input[((int64_t)b * channels + c) * seq_len + i]);
+ }
+ // Load weight
+ for (int i = local_tid; i < kernel_size; i += threads_per_channel) {
+ s_weight[i] = to_float(weight[(int64_t)c * kernel_size + i]);
+ }
+ }
+ __syncthreads();
+
+ if (c < channels) {
+ float bias_val = (bias != nullptr) ? to_float(bias[c]) : 0.0f;
+ for (int l = local_tid; l < seq_len; l += threads_per_channel) {
+ float sum = bias_val;
+ for (int k = 0; k < kernel_size; ++k) {
+ sum += s_weight[k] * s_padded[l + k];
+ }
+ if (apply_silu) {
+ sum = silu_fn(sum);
+ }
+ output[((int64_t)b * channels + c) * seq_len + l] = from_float(sum);
+ }
+ }
+
+ // Unconditional barrier — s_padded is read-only after the cooperative load,
+ // so this is safe even when c >= channels. Hoisted out of the conditional
+ // to avoid divergent __syncthreads() (undefined behavior in CUDA).
+ __syncthreads();
+
+ if (c < channels) {
+ // Save present state
+ T* ps = present_state + (int64_t)b * channels * pad + (int64_t)c * pad;
+ for (int i = local_tid; i < pad; i += threads_per_channel) {
+ ps[i] = from_float(s_padded[padded_len - pad + i]);
+ }
+ }
+}
+
+} // anonymous namespace
+
+template
+Status LaunchCausalConvWithStateKernel(
+ cudaStream_t stream,
+ const T* input,
+ const T* weight,
+ const T* bias,
+ const T* past_state,
+ T* output,
+ T* present_state,
+ int batch_size,
+ int channels,
+ int seq_len,
+ int kernel_size,
+ bool apply_silu,
+ int max_threads_per_block) {
+ if (seq_len == 1) {
+ // Decode fast-path: one thread per (batch, channel)
+ int total = batch_size * channels;
+ int threads = 256;
+ int blocks = (total + threads - 1) / threads;
+ switch (kernel_size) {
+ case 2:
+ CausalConvDecodeKernelFixedK<<>>(
+ input, weight, bias, past_state, output, present_state,
+ total, channels, apply_silu);
+ break;
+ case 3:
+ CausalConvDecodeKernelFixedK<<>>(
+ input, weight, bias, past_state, output, present_state,
+ total, channels, apply_silu);
+ break;
+ case 4:
+ CausalConvDecodeKernelFixedK<<>>(
+ input, weight, bias, past_state, output, present_state,
+ total, channels, apply_silu);
+ break;
+ case 5:
+ CausalConvDecodeKernelFixedK<<>>(
+ input, weight, bias, past_state, output, present_state,
+ total, channels, apply_silu);
+ break;
+ default:
+ CausalConvDecodeKernel<<>>(
+ input, weight, bias, past_state, output, present_state,
+ total, channels, kernel_size, apply_silu);
+ break;
+ }
+ } else {
+ // Prefill path: choose between batched (short seq) or single-channel (long seq) kernel
+ int pad = kernel_size - 1;
+
+ // For short sequences, batch multiple channels per block to improve occupancy.
+ // CPB=4: each block handles 4 channels, reducing block count by 4x.
+ // Threshold: use batched when seq_len <= 128 (small per-channel work).
+ constexpr int CPB = 4;
+ if (seq_len <= 128 && channels >= CPB) {
+ int channel_blocks = (channels + CPB - 1) / CPB;
+ const dim3 grid(batch_size, channel_blocks, 1);
+ // Each channel gets threads/CPB threads
+ int threads_per_ch = std::min(seq_len, max_threads_per_block / CPB);
+ threads_per_ch = ((threads_per_ch + 31) / 32) * 32;
+ if (threads_per_ch < 32) threads_per_ch = 32;
+ int total_threads = threads_per_ch * CPB;
+ if (total_threads > max_threads_per_block) {
+ total_threads = (max_threads_per_block / CPB) * CPB; // round down to multiple of CPB
+ }
+ const dim3 block(total_threads, 1, 1);
+ size_t smem_size = static_cast(CPB) * (static_cast(pad + seq_len) + kernel_size) * sizeof(float);
+
+ // Request extended shared memory if needed (default limit is 48 KB)
+ if (smem_size > 48 * 1024) {
+ cudaError_t attr_err = cudaFuncSetAttribute(
+ CausalConvPrefillKernelBatched,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ static_cast(smem_size));
+ if (attr_err != cudaSuccess) {
+ return CUDA_CALL(attr_err);
+ }
+ }
+
+ CausalConvPrefillKernelBatched<<>>(
+ input, weight, bias, past_state, output, present_state,
+ seq_len, channels, kernel_size, apply_silu);
+ } else {
+ // Original single-channel-per-block path for long sequences
+ const dim3 grid(batch_size, channels, 1);
+ int threads = std::min(seq_len, max_threads_per_block);
+ threads = ((threads + 31) / 32) * 32; // round to warp
+ if (threads > max_threads_per_block) threads = max_threads_per_block;
+ const dim3 block(threads, 1, 1);
+
+ size_t smem_size = (static_cast(pad + seq_len) + kernel_size) * sizeof(float);
+
+ // Request extended shared memory if needed (default limit is 48 KB)
+ if (smem_size > 48 * 1024) {
+ cudaError_t attr_err = cudaFuncSetAttribute(
+ CausalConvPrefillKernel,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ static_cast(smem_size));
+ if (attr_err != cudaSuccess) {
+ return CUDA_CALL(attr_err);
+ }
+ }
+
+ CausalConvPrefillKernel<<>>(
+ input, weight, bias, past_state, output, present_state,
+ seq_len, channels, kernel_size, apply_silu);
+ }
+ }
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+// Explicit instantiations
+template Status LaunchCausalConvWithStateKernel(
+ cudaStream_t, const float*, const float*, const float*, const float*,
+ float*, float*, int, int, int, int, bool, int);
+
+template Status LaunchCausalConvWithStateKernel(
+ cudaStream_t, const half*, const half*, const half*, const half*,
+ half*, half*, int, int, int, int, bool, int);
+
+#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
+template Status LaunchCausalConvWithStateKernel<__nv_bfloat16>(
+ cudaStream_t, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*,
+ __nv_bfloat16*, __nv_bfloat16*, int, int, int, int, bool, int);
+#endif
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h
new file mode 100644
index 0000000000000..4427a1df1fd6d
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include "core/providers/cuda/cuda_common.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+// Fused causal depthwise conv1d + activation + state management.
+// One thread block per (batch, channel). For decode (L=1), this is a simple
+// dot product from shared memory. For prefill (L>1), each thread handles
+// one output position.
+template
+Status LaunchCausalConvWithStateKernel(
+ cudaStream_t stream,
+ const T* input, // [B, C, L]
+ const T* weight, // [C, 1, K]
+ const T* bias, // [C] or nullptr
+ const T* past_state, // [B, C, K-1] or nullptr
+ T* output, // [B, C, L]
+ T* present_state, // [B, C, K-1]
+ int batch_size,
+ int channels,
+ int seq_len,
+ int kernel_size,
+ bool apply_silu,
+ int max_threads_per_block);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention.cc b/onnxruntime/contrib_ops/cuda/bert/linear_attention.cc
new file mode 100644
index 0000000000000..c8f460b0ca002
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention.cc
@@ -0,0 +1,177 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cuda/bert/linear_attention.h"
+#include "contrib_ops/cuda/bert/linear_attention_impl.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/cuda_type_conversion.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda; // CudaKernel, Stream, GetDeviceProp, ToCudaType
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ LinearAttention, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ LinearAttention);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+template
+LinearAttention::LinearAttention(const OpKernelInfo& info) : CudaKernel(info) {
+ int64_t q_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("q_num_heads", &q_num_heads).IsOK() && q_num_heads > 0);
+ q_num_heads_ = static_cast(q_num_heads);
+
+ int64_t kv_num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
+ kv_num_heads_ = static_cast(kv_num_heads);
+
+ update_rule_ = info.GetAttrOrDefault("update_rule", "gated_delta");
+ ORT_ENFORCE(update_rule_ == "linear" || update_rule_ == "gated" ||
+ update_rule_ == "delta" || update_rule_ == "gated_delta",
+ "update_rule must be one of: linear, gated, delta, gated_delta");
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+
+ int64_t chunk_size = info.GetAttrOrDefault("chunk_size", 64);
+ // chunk_size_ reserved for future chunk-parallel prefill algorithm; not yet used.
+ chunk_size_ = static_cast(chunk_size);
+}
+
+template
+Status LinearAttention::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* query_tensor = context->Input(0);
+ const Tensor* key_tensor = context->Input(1); // optional
+ const Tensor* value_tensor = context->Input(2); // optional
+ const Tensor* past_state_tensor = context->Input(3); // optional
+ const Tensor* decay_tensor = context->Input(4); // optional
+ const Tensor* beta_tensor = context->Input(5); // optional
+
+ ORT_RETURN_IF_NOT(query_tensor != nullptr, "query input is required");
+
+ const auto& query_shape = query_tensor->Shape();
+ ORT_RETURN_IF_NOT(query_shape.NumDimensions() == 3, "query must be 3D");
+
+ const int batch_size = static_cast(query_shape[0]);
+ const int seq_len = static_cast(query_shape[1]);
+ const int query_hidden = static_cast(query_shape[2]);
+
+ ORT_RETURN_IF_NOT(key_tensor != nullptr && value_tensor != nullptr, "key and value inputs are required");
+
+ const auto& key_shape = key_tensor->Shape();
+ const auto& value_shape = value_tensor->Shape();
+
+ int d_k = query_hidden / q_num_heads_;
+ int d_v = static_cast(value_shape[2]) / kv_num_heads_;
+ ORT_ENFORCE(static_cast(key_shape[2]) % d_k == 0,
+ "key last dim (", key_shape[2], ") must be divisible by d_k (", d_k, ")");
+ int n_k_heads = static_cast(key_shape[2]) / d_k;
+
+ // GQA head mapping validations
+ if (q_num_heads_ >= kv_num_heads_) {
+ ORT_ENFORCE(q_num_heads_ % kv_num_heads_ == 0,
+ "q_num_heads must be divisible by kv_num_heads");
+ } else {
+ ORT_ENFORCE(kv_num_heads_ % q_num_heads_ == 0,
+ "kv_num_heads must be divisible by q_num_heads (inverse GQA)");
+ }
+ ORT_ENFORCE(kv_num_heads_ % n_k_heads == 0,
+ "kv_num_heads must be divisible by n_k_heads");
+
+ float s = scale_;
+ if (s == 0.0f) {
+ s = 1.0f / std::sqrt(static_cast(d_k));
+ }
+
+ bool needs_decay = (update_rule_ == "gated" || update_rule_ == "gated_delta");
+ bool needs_beta = (update_rule_ == "delta" || update_rule_ == "gated_delta");
+ bool needs_retrieval = (update_rule_ == "delta" || update_rule_ == "gated_delta");
+
+ ORT_ENFORCE(!needs_decay || decay_tensor != nullptr,
+ "decay input is required for update_rule=", update_rule_);
+ ORT_ENFORCE(!needs_beta || beta_tensor != nullptr,
+ "beta input is required for update_rule=", update_rule_);
+
+ bool decay_per_key_dim = false;
+ if (decay_tensor != nullptr) {
+ int64_t decay_last = decay_tensor->Shape()[2];
+ decay_per_key_dim = (decay_last == kv_num_heads_ * d_k);
+ }
+
+ bool beta_per_head = false;
+ if (beta_tensor != nullptr) {
+ int64_t beta_last = beta_tensor->Shape()[2];
+ beta_per_head = (beta_last == kv_num_heads_);
+ }
+
+ // Allocate outputs
+ int output_hidden = std::max(q_num_heads_, kv_num_heads_) * d_v;
+ TensorShape output_shape({batch_size, seq_len, output_hidden});
+ Tensor* output_tensor = context->Output(0, output_shape);
+
+ TensorShape state_shape({batch_size, kv_num_heads_, d_k, d_v});
+ Tensor* present_state_tensor = context->Output(1, state_shape);
+
+ // If past_state is nullptr, zero-init present_state on device
+ if (past_state_tensor == nullptr) {
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(
+ present_state_tensor->MutableData(), 0,
+ static_cast(batch_size) * kv_num_heads_ * d_k * d_v * sizeof(T),
+ Stream(context)));
+ } else {
+ // Validate past_state shape matches expected (B, H_kv, d_k, d_v)
+ const auto& past_shape = past_state_tensor->Shape();
+ ORT_ENFORCE(past_shape.NumDimensions() == 4,
+ "past_state must be rank 4 (B, H_kv, d_k, d_v), got rank ", past_shape.NumDimensions());
+ ORT_ENFORCE(past_shape[0] == batch_size && past_shape[1] == kv_num_heads_ &&
+ past_shape[2] == d_k && past_shape[3] == d_v,
+ "past_state shape mismatch: expected (", batch_size, ", ", kv_num_heads_, ", ", d_k, ", ", d_v,
+ "), got (", past_shape[0], ", ", past_shape[1], ", ", past_shape[2], ", ", past_shape[3], ")");
+ // Copy past_state -> present_state (will be updated in-place by kernel)
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(
+ present_state_tensor->MutableData(),
+ past_state_tensor->Data(),
+ static_cast(batch_size) * kv_num_heads_ * d_k * d_v * sizeof(T),
+ cudaMemcpyDeviceToDevice,
+ Stream(context)));
+ }
+
+ typedef typename OrtToCudaType::type CudaT;
+
+ return LaunchLinearAttentionKernel(
+ Stream(context),
+ reinterpret_cast(query_tensor->Data()),
+ reinterpret_cast(key_tensor->Data()),
+ reinterpret_cast(value_tensor->Data()),
+ decay_tensor ? reinterpret_cast(decay_tensor->Data()) : nullptr,
+ beta_tensor ? reinterpret_cast(beta_tensor->Data()) : nullptr,
+ reinterpret_cast(output_tensor->MutableData()),
+ reinterpret_cast(present_state_tensor->MutableData()),
+ batch_size,
+ seq_len,
+ q_num_heads_,
+ kv_num_heads_,
+ n_k_heads,
+ d_k,
+ d_v,
+ s,
+ needs_decay,
+ decay_per_key_dim,
+ needs_beta,
+ beta_per_head,
+ needs_retrieval,
+ GetDeviceProp().maxThreadsPerBlock);
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention.h b/onnxruntime/contrib_ops/cuda/bert/linear_attention.h
new file mode 100644
index 0000000000000..ed398218771d0
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention.h
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+#include
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+class LinearAttention final : public onnxruntime::cuda::CudaKernel {
+ public:
+ LinearAttention(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ int q_num_heads_;
+ int kv_num_heads_;
+ std::string update_rule_;
+ float scale_;
+ int chunk_size_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu
new file mode 100644
index 0000000000000..9137fcf9b25b9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu
@@ -0,0 +1,781 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// Fused recurrent linear attention CUDA kernel for gated_delta / delta / gated / linear update rules.
+//
+// Design: One thread block per (batch, kv_head). The state matrix [d_k, d_v] is loaded into
+// shared memory at the start and kept there for the entire token loop. Each token's
+// decay → retrieval → delta → update → readout sequence runs without global memory
+// round-trips for the state. This matches the FLA (flash-linear-attention) kernel design.
+//
+// State tiles: For d_k=128, d_v=128, fp32 state = 64 KB shared memory. On SM80+ GPUs with
+// 164 KB shared memory per SM, this fits with room for scratch. Requires
+// cudaFuncSetAttribute to opt into extended shared memory (>48 KB).
+//
+// Thread mapping: num_threads = max(d_k, d_v) rounded to warp boundary. Each thread
+// participates in both row operations (decay/update: tid < d_k handles row tid) and
+// column operations (retrieval/readout: tid < d_v computes column tid's dot product).
+//
+// Reductions: Matrix-vector products (S^T @ k, S^T @ q) use column-per-thread dot products
+// instead of atomicAdd, eliminating contention. Each thread tid computes
+// sum_i(S[i, tid] * scalar[i]) by reading shared memory column-wise (bank-conflict-free
+// when d_v is a multiple of 32).
+
+#include
+#include
+#include
+#include
+#include "contrib_ops/cuda/bert/linear_attention_impl.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+namespace {
+
+// Convert half/bfloat16 to float
+template
+__device__ __forceinline__ float to_float(T val);
+
+template <>
+__device__ __forceinline__ float to_float(float val) { return val; }
+
+template <>
+__device__ __forceinline__ float to_float(half val) { return __half2float(val); }
+
+#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
+template <>
+__device__ __forceinline__ float to_float(__nv_bfloat16 val) { return __bfloat162float(val); }
+#endif
+
+// Convert float to half/bfloat16/float
+template
+__device__ __forceinline__ T from_float(float val);
+
+template <>
+__device__ __forceinline__ float from_float(float val) { return val; }
+
+template <>
+__device__ __forceinline__ half from_float(float val) { return __float2half(val); }
+
+#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
+template <>
+__device__ __forceinline__ __nv_bfloat16 from_float(float val) { return __float2bfloat16(val); }
+#endif
+
+// =============================================================================
+// Fused recurrent linear attention kernel
+//
+// Grid: (batch_size, kv_num_heads, 1)
+// Block: (max(d_k, d_v) rounded to warp, 1, 1)
+//
+// Shared memory layout (dynamic):
+// float S_smem[d_k * d_v] — recurrent state matrix (fp32)
+// float s_scratch[max(d_k, d_v)] — broadcast/retrieval/delta buffer
+//
+// State is stored as type T in global memory but computed in fp32 in shared
+// memory for numerical stability.
+// =============================================================================
+template
+__global__ void LinearAttentionRecurrentKernel(
+ const T* __restrict__ query, // [B, T, H_q * d_k]
+ const T* __restrict__ key, // [B, T, n_k * d_k]
+ const T* __restrict__ value, // [B, T, H_kv * d_v]
+ T* __restrict__ state, // [B, H_kv, d_k, d_v] — in-place updated
+ const T* __restrict__ decay, // [B, T, H_kv] or [B, T, H_kv*d_k] or nullptr
+ const T* __restrict__ beta_in, // [B, T, H_kv] or [B, T, 1] or nullptr
+ T* __restrict__ output, // [B, T, max(H_q, H_kv) * d_v]
+ int seq_len,
+ int q_num_heads,
+ int kv_num_heads,
+ int n_k_heads,
+ int d_k,
+ int d_v,
+ int output_hidden,
+ float scale,
+ bool needs_decay,
+ bool decay_per_key_dim,
+ bool needs_beta,
+ bool beta_per_head,
+ bool needs_retrieval) {
+ const int b = blockIdx.x;
+ const int h_kv = blockIdx.y;
+ const int tid = threadIdx.x;
+ const int num_threads = blockDim.x;
+ const int kv_per_k = kv_num_heads / n_k_heads;
+ const int h_k = h_kv / kv_per_k;
+
+ // Global state pointer for this (batch, head): [d_k, d_v]
+ T* S_global = state + ((int64_t)b * kv_num_heads + h_kv) * d_k * d_v;
+
+ // Shared memory layout
+ extern __shared__ float smem[];
+ float* S_smem = smem; // [d_k * d_v]
+ float* k_buf = smem + d_k * d_v; // [d_k]
+ float* s_scratch = smem + d_k * d_v + d_k; // [max(d_k, d_v)]
+
+ // Load state from global memory (type T) into shared memory (fp32)
+ for (int idx = tid; idx < d_k * d_v; idx += num_threads) {
+ S_smem[idx] = to_float(S_global[idx]);
+ }
+ __syncthreads();
+
+ // ---- Token loop ----
+ for (int t = 0; t < seq_len; ++t) {
+ // Load k_t[tid] into register (each thread loads one element)
+ float kt_val = 0.0f;
+ if (tid < d_k) {
+ int k_offset = ((int64_t)b * seq_len + t) * (n_k_heads * d_k) + h_k * d_k + tid;
+ kt_val = to_float(key[k_offset]);
+ }
+
+ // Steps 1+2: Decay + Retrieval (fused for scalar per-head decay)
+ bool fused_decay_update = false;
+ float fused_exp_g = 1.0f;
+
+ if (needs_decay && needs_retrieval && !decay_per_key_dim) {
+ if (tid < d_k) {
+ k_buf[tid] = kt_val;
+ }
+ if (tid == 0) {
+ int g_offset = ((int64_t)b * seq_len + t) * kv_num_heads + h_kv;
+ s_scratch[0] = expf(to_float(decay[g_offset]));
+ }
+ __syncthreads();
+
+ fused_exp_g = s_scratch[0];
+
+ if (tid < d_v) {
+ float acc = 0.0f;
+ for (int i = 0; i < d_k; ++i) {
+ acc += S_smem[i * d_v + tid] * k_buf[i];
+ }
+ s_scratch[tid] = fused_exp_g * acc;
+ }
+ __syncthreads();
+
+ fused_decay_update = true;
+
+ } else {
+ // Non-fused path: separate decay and retrieval steps
+ if (needs_decay) {
+ if (!decay_per_key_dim) {
+ if (tid == 0) {
+ int g_offset = ((int64_t)b * seq_len + t) * kv_num_heads + h_kv;
+ s_scratch[0] = expf(to_float(decay[g_offset]));
+ }
+ __syncthreads();
+ }
+ if (tid < d_k) {
+ float exp_g;
+ if (decay_per_key_dim) {
+ int g_offset = ((int64_t)b * seq_len + t) * (kv_num_heads * d_k) + h_kv * d_k + tid;
+ exp_g = expf(to_float(decay[g_offset]));
+ } else {
+ exp_g = s_scratch[0];
+ }
+ for (int j = 0; j < d_v; ++j) {
+ S_smem[tid * d_v + j] *= exp_g;
+ }
+ }
+ __syncthreads();
+ }
+
+ if (needs_retrieval) {
+ // Store k in k_buf (not s_scratch) to avoid inter-warp race when
+ // d_k > 32: retrieval overwrites s_scratch[tid] while other warps
+ // may still be reading s_scratch[i] in the dot product loop.
+ if (tid < d_k) {
+ k_buf[tid] = kt_val;
+ }
+ __syncthreads();
+
+ if (tid < d_v) {
+ float acc = 0.0f;
+ for (int i = 0; i < d_k; ++i) {
+ acc += S_smem[i * d_v + tid] * k_buf[i];
+ }
+ s_scratch[tid] = acc;
+ }
+ __syncthreads();
+ }
+ }
+
+ // Step 3: State update — S += k_t ⊗ delta (or k_t ⊗ v_t for linear)
+ // When fused_decay_update, applies: S = exp_g * S + k * delta
+ if (needs_beta) {
+ float bt;
+ if (beta_per_head) {
+ bt = to_float(beta_in[((int64_t)b * seq_len + t) * kv_num_heads + h_kv]);
+ } else {
+ bt = to_float(beta_in[((int64_t)b * seq_len + t)]);
+ }
+
+ if (tid < d_v) {
+ int v_base = ((int64_t)b * seq_len + t) * (kv_num_heads * d_v) + h_kv * d_v;
+ float vj = to_float(value[v_base + tid]);
+ s_scratch[tid] = bt * (vj - s_scratch[tid]);
+ }
+ __syncthreads();
+
+ if (tid < d_k) {
+ if (fused_decay_update) {
+ for (int j = 0; j < d_v; ++j) {
+ S_smem[tid * d_v + j] = fused_exp_g * S_smem[tid * d_v + j] + kt_val * s_scratch[j];
+ }
+ } else {
+ for (int j = 0; j < d_v; ++j) {
+ S_smem[tid * d_v + j] += kt_val * s_scratch[j];
+ }
+ }
+ }
+ } else {
+ if (tid < d_v) {
+ int v_base = ((int64_t)b * seq_len + t) * (kv_num_heads * d_v) + h_kv * d_v;
+ s_scratch[tid] = to_float(value[v_base + tid]);
+ }
+ __syncthreads();
+
+ if (tid < d_k) {
+ if (fused_decay_update) {
+ for (int j = 0; j < d_v; ++j) {
+ S_smem[tid * d_v + j] = fused_exp_g * S_smem[tid * d_v + j] + kt_val * s_scratch[j];
+ }
+ } else {
+ for (int j = 0; j < d_v; ++j) {
+ S_smem[tid * d_v + j] += kt_val * s_scratch[j];
+ }
+ }
+ }
+ }
+ __syncthreads();
+
+ // Step 4: Query readout — output = S^T @ q_t (standard GQA or inverse GQA)
+ if (q_num_heads >= kv_num_heads) {
+ int heads_per_group = q_num_heads / kv_num_heads;
+ for (int g = 0; g < heads_per_group; ++g) {
+ if (g > 0) {
+ __syncthreads();
+ }
+
+ int h_q = h_kv * heads_per_group + g;
+ if (tid < d_k) {
+ int q_offset = ((int64_t)b * seq_len + t) * (q_num_heads * d_k) + h_q * d_k + tid;
+ s_scratch[tid] = to_float(query[q_offset]);
+ }
+ __syncthreads();
+
+ if (tid < d_v) {
+ float acc = 0.0f;
+ for (int i = 0; i < d_k; ++i) {
+ acc += S_smem[i * d_v + tid] * s_scratch[i];
+ }
+ int out_offset = ((int64_t)b * seq_len + t) * output_hidden + h_q * d_v + tid;
+ output[out_offset] = from_float(scale * acc);
+ }
+ }
+ } else {
+ int h_q = h_kv * q_num_heads / kv_num_heads;
+ if (tid < d_k) {
+ int q_offset = ((int64_t)b * seq_len + t) * (q_num_heads * d_k) + h_q * d_k + tid;
+ s_scratch[tid] = to_float(query[q_offset]);
+ }
+ __syncthreads();
+
+ if (tid < d_v) {
+ float acc = 0.0f;
+ for (int i = 0; i < d_k; ++i) {
+ acc += S_smem[i * d_v + tid] * s_scratch[i];
+ }
+ int out_offset = ((int64_t)b * seq_len + t) * output_hidden + h_kv * d_v + tid;
+ output[out_offset] = from_float(scale * acc);
+ }
+ }
+
+ __syncthreads();
+ }
+
+ // Write back state from shared memory (fp32) to global memory (type T)
+ for (int idx = tid; idx < d_k * d_v; idx += num_threads) {
+ S_global[idx] = from_float(S_smem[idx]);
+ }
+}
+
+// Compile-time specialized variant for common (d_k, d_v) pairs.
+// Optimizations over the generic kernel:
+// 1. #pragma unroll on all inner loops for better ILP
+// 2. float4 vectorized row operations (decay, state update) — 4x fewer shared memory transactions
+// 3. Fused decay+retrieval for scalar per-head decay — eliminates one state pass and one __syncthreads()
+// 4. Dedicated k_buf in shared memory avoids scratch aliasing during fused path
+template
+__global__ void LinearAttentionRecurrentKernelFixedShape(
+ const T* __restrict__ query,
+ const T* __restrict__ key,
+ const T* __restrict__ value,
+ T* __restrict__ state,
+ const T* __restrict__ decay,
+ const T* __restrict__ beta_in,
+ T* __restrict__ output,
+ int seq_len,
+ int q_num_heads,
+ int kv_num_heads,
+ int n_k_heads,
+ int output_hidden,
+ float scale,
+ bool needs_decay,
+ bool decay_per_key_dim,
+ bool needs_beta,
+ bool beta_per_head,
+ bool needs_retrieval) {
+ static_assert(DV % 4 == 0 && DK % 4 == 0, "DK and DV must be multiples of 4 for float4 optimization");
+ constexpr int DV4 = DV / 4;
+
+ const int b = blockIdx.x;
+ const int h_kv = blockIdx.y;
+ const int tid = threadIdx.x;
+ const int kv_per_k = kv_num_heads / n_k_heads;
+ const int h_k = h_kv / kv_per_k;
+
+ T* S_global = state + ((int64_t)b * kv_num_heads + h_kv) * DK * DV;
+
+ // Shared memory layout:
+ // S_smem[DK * DV] — recurrent state matrix (fp32)
+ // k_buf[DK] — persistent key broadcast buffer
+ // s_scratch[max(DK, DV)] — general scratch (retrieval, delta, query broadcast)
+ extern __shared__ float smem[];
+ float* S_smem = smem; // [DK * DV]
+ float* k_buf = smem + DK * DV; // [DK]
+ float* s_scratch = smem + DK * DV + DK; // [max(DK, DV)]
+
+ // Load state from global memory (type T) into shared memory (fp32) — vectorized
+ if constexpr (sizeof(T) == 2 && DV % 2 == 0) {
+ // half/bf16: load 2 elements at a time via uint32
+ const uint32_t* S_global_u32 = reinterpret_cast(S_global);
+ int half_pairs = (DK * DV) / 2;
+ for (int idx = tid; idx < half_pairs; idx += blockDim.x) {
+ uint32_t packed = S_global_u32[idx];
+ T lo, hi;
+ memcpy(&lo, &packed, sizeof(T));
+ memcpy(&hi, reinterpret_cast(&packed) + sizeof(T), sizeof(T));
+ S_smem[idx * 2] = to_float(lo);
+ S_smem[idx * 2 + 1] = to_float(hi);
+ }
+ } else {
+ for (int idx = tid; idx < DK * DV; idx += blockDim.x) {
+ S_smem[idx] = to_float(S_global[idx]);
+ }
+ }
+ __syncthreads();
+
+ // Precompute per-batch strides to avoid repeated int64 multiplications in the token loop
+ const int64_t b_seq = (int64_t)b * seq_len;
+ const int k_hidden = n_k_heads * DK;
+ const int kv_v_hidden = kv_num_heads * DV;
+ const int q_hidden = q_num_heads * DK;
+ const int kv_dk_hidden = kv_num_heads * DK;
+
+ for (int t = 0; t < seq_len; ++t) {
+ const int64_t bt = b_seq + t;
+
+ float kt_val = 0.0f;
+ if (tid < DK) {
+ kt_val = to_float(key[bt * k_hidden + h_k * DK + tid]);
+ }
+
+ // ==================================================================
+ // Steps 1+2: Decay + Retrieval
+ // ==================================================================
+ // For the fused scalar-decay + gated_delta path, we also fuse the
+ // state update (step 3) to avoid a separate decay pass entirely:
+ // retrieval = exp_g * (S^T @ k) [on old S]
+ // delta = beta * (v - retrieval)
+ // S = exp_g * S + k ⊗ delta [single fused pass]
+ // This reduces 3 state passes (decay, retrieval, update) to 2 (retrieval, fused update).
+ bool fused_decay_update = false;
+ float fused_exp_g = 1.0f;
+
+ if (needs_decay && needs_retrieval && !decay_per_key_dim) {
+ // --- FUSED path: scalar per-head decay + retrieval ---
+ if (tid < DK) {
+ k_buf[tid] = kt_val;
+ }
+ if (tid == 0) {
+ s_scratch[0] = expf(to_float(decay[bt * kv_num_heads + h_kv]));
+ }
+ __syncthreads();
+
+ fused_exp_g = s_scratch[0];
+
+ // Retrieval on old state, pre-scaled by exp_g
+ if (tid < DV) {
+ float acc = 0.0f;
+#pragma unroll
+ for (int i = 0; i < DK; ++i) {
+ acc += S_smem[i * DV + tid] * k_buf[i];
+ }
+ s_scratch[tid] = fused_exp_g * acc;
+ }
+ __syncthreads();
+
+ // Decay is deferred to the update step (fused_decay_update = true)
+ fused_decay_update = true;
+
+ } else if (needs_decay && needs_retrieval) {
+ // --- Per-key-dim decay then retrieval (cannot fuse — exp_g differs per row) ---
+ if (tid < DK) {
+ k_buf[tid] = kt_val;
+ float exp_g = expf(to_float(decay[bt * kv_dk_hidden + h_kv * DK + tid]));
+ float4* row = reinterpret_cast(S_smem + tid * DV);
+#pragma unroll
+ for (int j = 0; j < DV4; ++j) {
+ float4 v = row[j];
+ v.x *= exp_g;
+ v.y *= exp_g;
+ v.z *= exp_g;
+ v.w *= exp_g;
+ row[j] = v;
+ }
+ }
+ __syncthreads(); // decay done, k_buf visible
+
+ if (tid < DV) {
+ float acc = 0.0f;
+#pragma unroll
+ for (int i = 0; i < DK; ++i) {
+ acc += S_smem[i * DV + tid] * k_buf[i];
+ }
+ s_scratch[tid] = acc;
+ }
+ __syncthreads(); // retrieval done
+
+ } else {
+ // --- Decay only, retrieval only, or neither ---
+ if (needs_decay) {
+ if (!decay_per_key_dim) {
+ if (tid == 0) {
+ s_scratch[0] = expf(to_float(decay[bt * kv_num_heads + h_kv]));
+ }
+ __syncthreads();
+ }
+ if (tid < DK) {
+ float exp_g;
+ if (decay_per_key_dim) {
+ exp_g = expf(to_float(decay[bt * kv_dk_hidden + h_kv * DK + tid]));
+ } else {
+ exp_g = s_scratch[0];
+ }
+ float4* row = reinterpret_cast(S_smem + tid * DV);
+#pragma unroll
+ for (int j = 0; j < DV4; ++j) {
+ float4 v = row[j];
+ v.x *= exp_g;
+ v.y *= exp_g;
+ v.z *= exp_g;
+ v.w *= exp_g;
+ row[j] = v;
+ }
+ }
+ __syncthreads();
+ }
+
+ if (needs_retrieval) {
+ if (tid < DK) {
+ k_buf[tid] = kt_val;
+ }
+ __syncthreads(); // k_buf visible
+
+ if (tid < DV) {
+ float acc = 0.0f;
+#pragma unroll
+ for (int i = 0; i < DK; ++i) {
+ acc += S_smem[i * DV + tid] * k_buf[i];
+ }
+ s_scratch[tid] = acc;
+ }
+ __syncthreads(); // retrieval done
+ }
+ }
+
+ // ==================================================================
+ // Step 3: State update with float4 vectorization
+ // When fused_decay_update is true, decay is applied here:
+ // S[i,j] = exp_g * S[i,j] + k[i] * delta[j]
+ // ==================================================================
+ if (needs_beta) {
+ float beta_t;
+ if (beta_per_head) {
+ beta_t = to_float(beta_in[bt * kv_num_heads + h_kv]);
+ } else {
+ beta_t = to_float(beta_in[bt]);
+ }
+
+ if (tid < DV) {
+ float vj = to_float(value[bt * kv_v_hidden + h_kv * DV + tid]);
+ s_scratch[tid] = beta_t * (vj - s_scratch[tid]);
+ }
+ __syncthreads();
+
+ if (tid < DK) {
+ float4* row = reinterpret_cast(S_smem + tid * DV);
+ const float4* delta4 = reinterpret_cast(s_scratch);
+ if (fused_decay_update) {
+ // Fused: S = exp_g * S + k * delta (single pass, no separate decay)
+#pragma unroll
+ for (int j = 0; j < DV4; ++j) {
+ float4 s = row[j];
+ float4 d = delta4[j];
+ s.x = fused_exp_g * s.x + kt_val * d.x;
+ s.y = fused_exp_g * s.y + kt_val * d.y;
+ s.z = fused_exp_g * s.z + kt_val * d.z;
+ s.w = fused_exp_g * s.w + kt_val * d.w;
+ row[j] = s;
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < DV4; ++j) {
+ float4 s = row[j];
+ float4 d = delta4[j];
+ s.x += kt_val * d.x;
+ s.y += kt_val * d.y;
+ s.z += kt_val * d.z;
+ s.w += kt_val * d.w;
+ row[j] = s;
+ }
+ }
+ }
+ } else {
+ if (tid < DV) {
+ s_scratch[tid] = to_float(value[bt * kv_v_hidden + h_kv * DV + tid]);
+ }
+ __syncthreads();
+
+ if (tid < DK) {
+ float4* row = reinterpret_cast(S_smem + tid * DV);
+ const float4* v4 = reinterpret_cast(s_scratch);
+ if (fused_decay_update) {
+#pragma unroll
+ for (int j = 0; j < DV4; ++j) {
+ float4 s = row[j];
+ float4 v = v4[j];
+ s.x = fused_exp_g * s.x + kt_val * v.x;
+ s.y = fused_exp_g * s.y + kt_val * v.y;
+ s.z = fused_exp_g * s.z + kt_val * v.z;
+ s.w = fused_exp_g * s.w + kt_val * v.w;
+ row[j] = s;
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < DV4; ++j) {
+ float4 s = row[j];
+ float4 v = v4[j];
+ s.x += kt_val * v.x;
+ s.y += kt_val * v.y;
+ s.z += kt_val * v.z;
+ s.w += kt_val * v.w;
+ row[j] = s;
+ }
+ }
+ }
+ }
+ __syncthreads();
+
+ // ==================================================================
+ // Step 4: Query readout (column dot products — not float4-vectorizable)
+ // ==================================================================
+ if (q_num_heads >= kv_num_heads) {
+ int heads_per_group = q_num_heads / kv_num_heads;
+ for (int g = 0; g < heads_per_group; ++g) {
+ if (g > 0) {
+ __syncthreads();
+ }
+
+ int h_q = h_kv * heads_per_group + g;
+ if (tid < DK) {
+ s_scratch[tid] = to_float(query[bt * q_hidden + h_q * DK + tid]);
+ }
+ __syncthreads();
+
+ if (tid < DV) {
+ float acc = 0.0f;
+#pragma unroll
+ for (int i = 0; i < DK; ++i) {
+ acc += S_smem[i * DV + tid] * s_scratch[i];
+ }
+ output[bt * output_hidden + h_q * DV + tid] = from_float(scale * acc);
+ }
+ }
+ } else {
+ int h_q = h_kv * q_num_heads / kv_num_heads;
+ if (tid < DK) {
+ s_scratch[tid] = to_float(query[bt * q_hidden + h_q * DK + tid]);
+ }
+ __syncthreads();
+
+ if (tid < DV) {
+ float acc = 0.0f;
+#pragma unroll
+ for (int i = 0; i < DK; ++i) {
+ acc += S_smem[i * DV + tid] * s_scratch[i];
+ }
+ output[bt * output_hidden + h_kv * DV + tid] = from_float(scale * acc);
+ }
+ }
+
+ __syncthreads();
+ }
+
+ // Write back state from shared memory (fp32) to global memory (type T) — vectorized
+ if constexpr (sizeof(T) == 2 && DV % 2 == 0) {
+ uint32_t* S_global_u32 = reinterpret_cast(S_global);
+ int half_pairs = (DK * DV) / 2;
+ for (int idx = tid; idx < half_pairs; idx += blockDim.x) {
+ T lo = from_float(S_smem[idx * 2]);
+ T hi = from_float(S_smem[idx * 2 + 1]);
+ uint32_t packed;
+ memcpy(&packed, &lo, sizeof(T));
+ memcpy(reinterpret_cast(&packed) + sizeof(T), &hi, sizeof(T));
+ S_global_u32[idx] = packed;
+ }
+ } else if constexpr (sizeof(T) == 4 && DV % 4 == 0) {
+ float4* S_global_f4 = reinterpret_cast(S_global);
+ int quads = (DK * DV) / 4;
+ for (int idx = tid; idx < quads; idx += blockDim.x) {
+ float4 v;
+ v.x = S_smem[idx * 4];
+ v.y = S_smem[idx * 4 + 1];
+ v.z = S_smem[idx * 4 + 2];
+ v.w = S_smem[idx * 4 + 3];
+ S_global_f4[idx] = v;
+ }
+ } else {
+ for (int idx = tid; idx < DK * DV; idx += blockDim.x) {
+ S_global[idx] = from_float(S_smem[idx]);
+ }
+ }
+}
+
+} // anonymous namespace
+
+template
+Status LaunchLinearAttentionKernel(
+ cudaStream_t stream,
+ const T* query,
+ const T* key,
+ const T* value,
+ const T* decay,
+ const T* beta,
+ T* output,
+ T* present_state,
+ int batch_size,
+ int seq_len,
+ int q_num_heads,
+ int kv_num_heads,
+ int n_k_heads,
+ int d_k,
+ int d_v,
+ float scale,
+ bool needs_decay,
+ bool decay_per_key_dim,
+ bool needs_beta,
+ bool beta_per_head,
+ bool needs_retrieval,
+ int max_threads_per_block) {
+ // Grid: one block per (batch, kv_head)
+ const dim3 grid(batch_size, kv_num_heads, 1);
+
+ int output_hidden = std::max(q_num_heads, kv_num_heads) * d_v;
+
+ auto launch_fixed = [&](auto dk_tag, auto dv_tag) -> Status {
+ constexpr int DK = decltype(dk_tag)::value;
+ constexpr int DV = decltype(dv_tag)::value;
+ constexpr int max_dim = (DK > DV) ? DK : DV;
+ // Layout: S_smem[DK*DV] + k_buf[DK] + s_scratch[max(DK,DV)]
+ const size_t fixed_smem_size = (static_cast(DK) * DV + DK + max_dim) * sizeof(float);
+ const dim3 fixed_block(max_dim, 1, 1);
+
+ if (fixed_smem_size > 48 * 1024) {
+ cudaError_t attr_err = cudaFuncSetAttribute(
+ LinearAttentionRecurrentKernelFixedShape