diff --git a/onnxruntime/contrib_ops/cpu/bert/causal_conv1d_with_state.cc b/onnxruntime/contrib_ops/cpu/bert/causal_conv1d_with_state.cc new file mode 100644 index 0000000000000..799642e22da97 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/causal_conv1d_with_state.cc @@ -0,0 +1,167 @@ +#include "contrib_ops/cpu/bert/causal_conv1d_with_state.h" + +#include +#include +#include + +#include "core/util/math.h" +#include "core/providers/common.h" + +using namespace ::onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +namespace { + +inline float ToFloat(float v) { return v; } +inline float ToFloat(MLFloat16 v) { return v.ToFloat(); } +inline float ToFloat(BFloat16 v) { return v.ToFloat(); } + +inline void StoreFloat(float val, float& out) { out = val; } +inline void StoreFloat(float val, MLFloat16& out) { out = MLFloat16(val); } +inline void StoreFloat(float val, BFloat16& out) { out = BFloat16(val); } + +inline float ApplySiLU(float x) { + return x / (1.0f + expf(-x)); +} + +} + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + CausalConv1DWithState, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + CausalConv1DWithState); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +CausalConv1DWithState::CausalConv1DWithState(const OpKernelInfo& info) + : OpKernel(info) { + activation_str_ = info.GetAttrOrDefault("activation", "silu"); + if (activation_str_ == "silu" || activation_str_ == "swish") { + activation_ = CausalConv1DActivation::kSiLU; + } else if (activation_str_ == "none") { + activation_ = CausalConv1DActivation::kNone; + } else { + ORT_THROW("CausalConv1DWithState: unknown activation '", activation_str_, "'"); + } +} + +template +Status CausalConv1DWithState::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); // (B, D, T) + const Tensor* weight = context->Input(1); // (D, 1, K) + const Tensor* bias = context->Input(2); // (D,) optional + const Tensor* conv_state = context->Input(3); // (B, D, K-1) optional + + ORT_RETURN_IF_NOT(input != nullptr, "input is required"); + ORT_RETURN_IF_NOT(weight != nullptr, "weight is required"); + + const auto& in_shape = input->Shape(); + const auto& wt_shape = weight->Shape(); + + ORT_RETURN_IF_NOT(in_shape.NumDimensions() == 3, "input must be 3D (B,D,T)"); + ORT_RETURN_IF_NOT(wt_shape.NumDimensions() == 3, "weight must be 3D (D,1,K)"); + + const int batch_size = static_cast(in_shape[0]); + const int channels = static_cast(in_shape[1]); + const int seq_len = static_cast(in_shape[2]); + const int kernel_size = static_cast(wt_shape[2]); + const int state_len = kernel_size - 1; + + ORT_RETURN_IF_NOT(wt_shape[0] == channels, "weight dim 0 must equal channels"); + ORT_RETURN_IF_NOT(wt_shape[1] == 1, "weight dim 1 must be 1 (depthwise)"); + ORT_RETURN_IF_NOT(kernel_size <= 32, "kernel_size must be <= 32"); + + if (bias != nullptr) { + ORT_RETURN_IF_NOT(bias->Shape().NumDimensions() == 1, "bias must be 1D"); + ORT_RETURN_IF_NOT(bias->Shape()[0] == channels, "bias length must equal channels"); + } + + if (conv_state != nullptr) { + const auto& cs = conv_state->Shape(); + ORT_RETURN_IF_NOT(cs.NumDimensions() == 3, "conv_state must be 3D (B,D,K-1)"); + ORT_RETURN_IF_NOT(cs[0] == batch_size, "conv_state batch size must match input"); + ORT_RETURN_IF_NOT(cs[1] == channels, "conv_state channels must match input"); + ORT_RETURN_IF_NOT(cs[2] == state_len, "conv_state dim 2 must be K-1"); + } + + Tensor* output = context->Output(0, TensorShape({batch_size, channels, seq_len})); + Tensor* present_state = context->Output(1, TensorShape({batch_size, channels, state_len})); + + const T* in_data = input->Data(); + const T* wt_data = weight->Data(); + T* out_data = output->MutableData(); + T* ps_data = present_state->MutableData(); + + for (int b = 0; b < batch_size; b++) { + for (int d = 0; d < channels; d++) { + const int bd = b * channels + d; + + float w[32]; + for (int k = 0; k < kernel_size; k++) { + w[k] = ToFloat(wt_data[d * kernel_size + k]); + } + + float bias_val = (bias != nullptr) ? ToFloat(bias->Data()[d]) : 0.0f; + + for (int t = 0; t < seq_len; t++) { + float sum = bias_val; + + for (int k = 0; k < kernel_size; k++) { + const int src_t = t - state_len + k; + + float input_val; + if (src_t >= 0) { + input_val = ToFloat(in_data[bd * seq_len + src_t]); + } else { + const int state_idx = state_len + src_t; + if (conv_state != nullptr && state_idx >= 0) { + input_val = ToFloat(conv_state->Data()[bd * state_len + state_idx]); + } else { + input_val = 0.0f; + } + } + + sum += w[k] * input_val; + } + + if (activation_ == CausalConv1DActivation::kSiLU) { + sum = ApplySiLU(sum); + } + + StoreFloat(sum, out_data[bd * seq_len + t]); + } + + for (int k = 0; k < state_len; k++) { + const int src_t = seq_len - state_len + k; + float val; + if (src_t >= 0) { + val = ToFloat(in_data[bd * seq_len + src_t]); + } else { + const int state_idx = state_len + src_t; + if (conv_state != nullptr && state_idx >= 0) { + val = ToFloat(conv_state->Data()[bd * state_len + state_idx]); + } else { + val = 0.0f; + } + } + StoreFloat(val, ps_data[bd * state_len + k]); + } + } + } + + return Status::OK(); +} + +} +} diff --git a/onnxruntime/contrib_ops/cpu/bert/causal_conv1d_with_state.h b/onnxruntime/contrib_ops/cpu/bert/causal_conv1d_with_state.h new file mode 100644 index 0000000000000..beac7af959977 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/causal_conv1d_with_state.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +enum class CausalConv1DActivation { + kNone, + kSiLU, +}; + +template +class CausalConv1DWithState final : public OpKernel { + public: + explicit CausalConv1DWithState(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + std::string activation_str_; + CausalConv1DActivation activation_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention_chunk_parallel.cc b/onnxruntime/contrib_ops/cpu/bert/linear_attention_chunk_parallel.cc new file mode 100644 index 0000000000000..56a73256ef8bf --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention_chunk_parallel.cc @@ -0,0 +1,243 @@ +#include "contrib_ops/cpu/bert/linear_attention_chunk_parallel.h" + +#include +#include +#include + +#include "core/util/math.h" +#include "core/providers/common.h" + +using namespace ::onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +namespace { + +inline float ToFloat(float v) { return v; } +inline float ToFloat(MLFloat16 v) { return v.ToFloat(); } +inline float ToFloat(BFloat16 v) { return v.ToFloat(); } + +inline void StoreFloat(float val, float& out) { out = val; } +inline void StoreFloat(float val, MLFloat16& out) { out = MLFloat16(val); } +inline void StoreFloat(float val, BFloat16& out) { out = BFloat16(val); } + +LinearAttentionUpdateRule ParseUpdateRule(const std::string& s) { + if (s == "linear") return LinearAttentionUpdateRule::kLinear; + if (s == "gated") return LinearAttentionUpdateRule::kGated; + if (s == "delta") return LinearAttentionUpdateRule::kDelta; + if (s == "gated_delta") return LinearAttentionUpdateRule::kGatedDelta; + ORT_THROW("Unknown linear attention update_rule: ", s); +} + +} + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + LinearAttentionChunkParallel, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + LinearAttentionChunkParallel); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +LinearAttentionChunkParallel::LinearAttentionChunkParallel(const OpKernelInfo& info) + : OpKernel(info) { + std::string rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); + update_rule_ = ParseUpdateRule(rule_str); + chunk_size_ = static_cast(info.GetAttrOrDefault("chunk_size", 64)); + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template +void LinearAttentionChunkParallel::StepSingleHead( + const float* q, const float* k, const float* v, + float* state, + const float* decay, float beta_val, + float* output, + int d_k, int d_v, float scale) const { + + std::vector retrieved(d_v, 0.0f); + if (update_rule_ == LinearAttentionUpdateRule::kDelta || + update_rule_ == LinearAttentionUpdateRule::kGatedDelta) { + for (int i = 0; i < d_k; i++) { + float gi = (update_rule_ == LinearAttentionUpdateRule::kGatedDelta) ? decay[i] : 1.0f; + for (int j = 0; j < d_v; j++) { + retrieved[j] += gi * state[i * d_v + j] * k[i]; + } + } + } + + std::fill(output, output + d_v, 0.0f); + + for (int i = 0; i < d_k; i++) { + for (int j = 0; j < d_v; j++) { + float s = state[i * d_v + j]; + float new_s = 0.0f; + + switch (update_rule_) { + case LinearAttentionUpdateRule::kLinear: + new_s = s + k[i] * v[j]; + break; + case LinearAttentionUpdateRule::kGated: + new_s = decay[i] * s + k[i] * v[j]; + break; + case LinearAttentionUpdateRule::kDelta: { + float delta = v[j] - retrieved[j]; + new_s = s + beta_val * k[i] * delta; + break; + } + case LinearAttentionUpdateRule::kGatedDelta: { + float delta = v[j] - retrieved[j]; + new_s = decay[i] * s + beta_val * k[i] * delta; + break; + } + } + + state[i * d_v + j] = new_s; + output[j] += q[i] * new_s; + } + } + + for (int j = 0; j < d_v; j++) { + output[j] *= scale; + } +} + +template +Status LinearAttentionChunkParallel::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* initial_state = context->Input(3); // optional + const Tensor* decay = context->Input(4); // optional + const Tensor* beta = context->Input(5); // optional + + ORT_RETURN_IF_NOT(query != nullptr, "query input is required"); + ORT_RETURN_IF_NOT(key != nullptr, "key input is required"); + ORT_RETURN_IF_NOT(value != nullptr, "value input is required"); + + const auto& q_shape = query->Shape(); + const auto& v_shape = value->Shape(); + + ORT_RETURN_IF_NOT(q_shape.NumDimensions() == 4, "query must be 4D (B,H,T,d_k)"); + ORT_RETURN_IF_NOT(v_shape.NumDimensions() == 4, "value must be 4D (B,H,T,d_v)"); + + const int batch_size = static_cast(q_shape[0]); + const int num_heads = static_cast(q_shape[1]); + const int seq_len = static_cast(q_shape[2]); + const int d_k = static_cast(q_shape[3]); + const int d_v = static_cast(v_shape[3]); + + ORT_RETURN_IF_NOT(key->Shape()[2] == seq_len, "key sequence length must match query"); + ORT_RETURN_IF_NOT(v_shape[2] == seq_len, "value sequence length must match query"); + + if (initial_state != nullptr) { + const auto& s = initial_state->Shape(); + ORT_RETURN_IF_NOT(s.NumDimensions() == 4 && + s[0] == batch_size && s[1] == num_heads && + s[2] == d_k && s[3] == d_v, + "initial_state shape must be (B,H,d_k,d_v)"); + } + + const bool needs_decay = (update_rule_ == LinearAttentionUpdateRule::kGated || + update_rule_ == LinearAttentionUpdateRule::kGatedDelta); + const bool needs_beta = (update_rule_ == LinearAttentionUpdateRule::kDelta || + update_rule_ == LinearAttentionUpdateRule::kGatedDelta); + + ORT_RETURN_IF_NOT(!needs_decay || decay != nullptr, + "decay is required for gated/gated_delta update rules"); + ORT_RETURN_IF_NOT(!needs_beta || beta != nullptr, + "beta is required for delta/gated_delta update rules"); + + bool decay_broadcasted = false; + if (decay != nullptr) { + ORT_RETURN_IF_NOT(decay->Shape().NumDimensions() == 4, "decay must be 4D"); + decay_broadcasted = (decay->Shape()[3] == d_k); + } + + const float scale = (scale_ == 0.0f) ? (1.0f / sqrtf(static_cast(d_k))) : scale_; + + Tensor* output = context->Output(0, TensorShape({batch_size, num_heads, seq_len, d_v})); + Tensor* final_state = context->Output(1, TensorShape({batch_size, num_heads, d_k, d_v})); + + const T* q_data = query->Data(); + const T* k_data = key->Data(); + const T* v_data = value->Data(); + T* out_data = output->MutableData(); + T* fs_data = final_state->MutableData(); + + const int state_elems = batch_size * num_heads * d_k * d_v; + std::vector state_f(state_elems, 0.0f); + + if (initial_state != nullptr) { + const T* is_data = initial_state->Data(); + for (int i = 0; i < state_elems; i++) { + state_f[i] = ToFloat(is_data[i]); + } + } + + std::vector q_f(d_k), k_f(d_k), v_f(d_v); + std::vector decay_f(d_k, 1.0f); + std::vector out_f(d_v); + + for (int t = 0; t < seq_len; t++) { + for (int b = 0; b < batch_size; b++) { + for (int h = 0; h < num_heads; h++) { + const int bh = b * num_heads + h; + + for (int i = 0; i < d_k; i++) { + q_f[i] = ToFloat(q_data[bh * seq_len * d_k + t * d_k + i]); + k_f[i] = ToFloat(k_data[bh * seq_len * d_k + t * d_k + i]); + } + for (int j = 0; j < d_v; j++) { + v_f[j] = ToFloat(v_data[bh * seq_len * d_v + t * d_v + j]); + } + + if (decay != nullptr) { + const T* decay_data = decay->Data(); + if (decay_broadcasted) { + for (int i = 0; i < d_k; i++) { + decay_f[i] = expf(ToFloat(decay_data[bh * seq_len * d_k + t * d_k + i])); + } + } else { + const float scalar = expf(ToFloat(decay_data[bh * seq_len + t])); + std::fill(decay_f.begin(), decay_f.end(), scalar); + } + } + + float beta_val = 0.0f; + if (beta != nullptr) { + beta_val = ToFloat(beta->Data()[bh * seq_len + t]); + } + + StepSingleHead( + q_f.data(), k_f.data(), v_f.data(), + state_f.data() + bh * d_k * d_v, + decay_f.data(), beta_val, + out_f.data(), + d_k, d_v, scale); + + for (int j = 0; j < d_v; j++) { + StoreFloat(out_f[j], out_data[bh * seq_len * d_v + t * d_v + j]); + } + } + } + } + + for (int i = 0; i < state_elems; i++) { + StoreFloat(state_f[i], fs_data[i]); + } + + return Status::OK(); +} + +} +} diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention_chunk_parallel.h b/onnxruntime/contrib_ops/cpu/bert/linear_attention_chunk_parallel.h new file mode 100644 index 0000000000000..ce21b1651e23c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention_chunk_parallel.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/bert/linear_attention_recurrent.h" // for LinearAttentionUpdateRule + +namespace onnxruntime { +namespace contrib { + +template +class LinearAttentionChunkParallel final : public OpKernel { + public: + explicit LinearAttentionChunkParallel(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + LinearAttentionUpdateRule update_rule_; + int chunk_size_; + float scale_; + + // Apply one recurrent step (token t) to the float state for head (b,h). + // Identical math to LinearAttentionRecurrent::ComputeSingleHead. + void StepSingleHead( + const float* q, // [d_k] + const float* k, // [d_k] + const float* v, // [d_v] + float* state, // [d_k * d_v], updated in-place + const float* decay, // [d_k] — already exp(·), or nullptr + float beta_val, + float* output, // [d_v] + int d_k, int d_v, float scale) const; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention_recurrent.cc b/onnxruntime/contrib_ops/cpu/bert/linear_attention_recurrent.cc new file mode 100644 index 0000000000000..93a6cb7a13cfa --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention_recurrent.cc @@ -0,0 +1,244 @@ +#include "contrib_ops/cpu/bert/linear_attention_recurrent.h" + +#include +#include + +#include "core/util/math.h" +#include "core/mlas/inc/mlas.h" +#include "core/providers/common.h" + +using namespace ::onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +namespace { + +inline float ToFloat(float v) { return v; } +inline float ToFloat(MLFloat16 v) { return v.ToFloat(); } +inline float ToFloat(BFloat16 v) { return v.ToFloat(); } + +inline void StoreFloat(float val, float& out) { out = val; } +inline void StoreFloat(float val, MLFloat16& out) { out = MLFloat16(val); } +inline void StoreFloat(float val, BFloat16& out) { out = BFloat16(val); } + +LinearAttentionUpdateRule ParseUpdateRule(const std::string& s) { + if (s == "linear") return LinearAttentionUpdateRule::kLinear; + if (s == "gated") return LinearAttentionUpdateRule::kGated; + if (s == "delta") return LinearAttentionUpdateRule::kDelta; + if (s == "gated_delta") return LinearAttentionUpdateRule::kGatedDelta; + ORT_THROW("Unknown linear attention update_rule: ", s); +} + +} + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + LinearAttentionRecurrent, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + LinearAttentionRecurrent); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +LinearAttentionRecurrent::LinearAttentionRecurrent(const OpKernelInfo& info) + : OpKernel(info) { + std::string rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); + update_rule_ = ParseUpdateRule(rule_str); + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template +void LinearAttentionRecurrent::ComputeSingleHead( + const float* q, const float* k, const float* v, + float* state, + const float* decay, float beta_val, + float* output, + int d_k, int d_v, float scale) const { + + // Step 1: Compute retrieved = (decay * S)^T k (needed for delta modes) + std::vector retrieved(d_v, 0.0f); + if (update_rule_ == LinearAttentionUpdateRule::kDelta || + update_rule_ == LinearAttentionUpdateRule::kGatedDelta) { + for (int i = 0; i < d_k; i++) { + // For gated_delta the state is decayed before retrieval (same as CUDA impl) + float gi = (update_rule_ == LinearAttentionUpdateRule::kGatedDelta) ? decay[i] : 1.0f; + for (int j = 0; j < d_v; j++) { + retrieved[j] += gi * state[i * d_v + j] * k[i]; + } + } + } + + // Step 2 + 3: Update state and accumulate output in a single pass + std::fill(output, output + d_v, 0.0f); + + for (int i = 0; i < d_k; i++) { + for (int j = 0; j < d_v; j++) { + float s = state[i * d_v + j]; + float new_s = 0.0f; + + switch (update_rule_) { + case LinearAttentionUpdateRule::kLinear: + new_s = s + k[i] * v[j]; + break; + + case LinearAttentionUpdateRule::kGated: + new_s = decay[i] * s + k[i] * v[j]; + break; + + case LinearAttentionUpdateRule::kDelta: { + float delta = v[j] - retrieved[j]; + new_s = s + beta_val * k[i] * delta; + break; + } + + case LinearAttentionUpdateRule::kGatedDelta: { + float delta = v[j] - retrieved[j]; + new_s = decay[i] * s + beta_val * k[i] * delta; + break; + } + } + + state[i * d_v + j] = new_s; + output[j] += q[i] * new_s; + } + } + + // Scale output + for (int j = 0; j < d_v; j++) { + output[j] *= scale; + } +} + +template +Status LinearAttentionRecurrent::Compute(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* past_state = context->Input(3); + const Tensor* decay = context->Input(4); // optional + const Tensor* beta = context->Input(5); // optional + + ORT_RETURN_IF_NOT(query != nullptr, "query input is required"); + ORT_RETURN_IF_NOT(key != nullptr, "key input is required"); + ORT_RETURN_IF_NOT(value != nullptr, "value input is required"); + ORT_RETURN_IF_NOT(past_state != nullptr, "past_state input is required"); + + const auto& q_shape = query->Shape(); + const auto& k_shape = key->Shape(); + const auto& v_shape = value->Shape(); + const auto& s_shape = past_state->Shape(); + + ORT_RETURN_IF_NOT(q_shape.NumDimensions() == 4, "query must be 4D (B,H,1,d_k)"); + ORT_RETURN_IF_NOT(k_shape.NumDimensions() == 4, "key must be 4D (B,H,1,d_k)"); + ORT_RETURN_IF_NOT(v_shape.NumDimensions() == 4, "value must be 4D (B,H,1,d_v)"); + ORT_RETURN_IF_NOT(s_shape.NumDimensions() == 4, "past_state must be 4D (B,H,d_k,d_v)"); + + ORT_RETURN_IF_NOT(q_shape[2] == 1, "query sequence length must be 1 (recurrent mode)"); + ORT_RETURN_IF_NOT(k_shape[2] == 1, "key sequence length must be 1 (recurrent mode)"); + ORT_RETURN_IF_NOT(v_shape[2] == 1, "value sequence length must be 1 (recurrent mode)"); + + const int batch_size = static_cast(q_shape[0]); + const int num_heads = static_cast(q_shape[1]); + const int d_k = static_cast(q_shape[3]); + const int d_v = static_cast(v_shape[3]); + + ORT_RETURN_IF_NOT(s_shape[0] == batch_size && s_shape[1] == num_heads && + s_shape[2] == d_k && s_shape[3] == d_v, + "past_state shape must be (B,H,d_k,d_v)"); + + const bool needs_decay = (update_rule_ == LinearAttentionUpdateRule::kGated || + update_rule_ == LinearAttentionUpdateRule::kGatedDelta); + const bool needs_beta = (update_rule_ == LinearAttentionUpdateRule::kDelta || + update_rule_ == LinearAttentionUpdateRule::kGatedDelta); + + ORT_RETURN_IF_NOT(!needs_decay || decay != nullptr, + "decay is required for gated/gated_delta update rules"); + ORT_RETURN_IF_NOT(!needs_beta || beta != nullptr, + "beta is required for delta/gated_delta update rules"); + + bool decay_broadcasted = false; + if (decay != nullptr) { + ORT_RETURN_IF_NOT(decay->Shape().NumDimensions() == 4, "decay must be 4D"); + decay_broadcasted = (decay->Shape()[3] == d_k); + } + + float scale = (scale_ == 0.0f) ? (1.0f / sqrtf(static_cast(d_k))) : scale_; + + Tensor* output = context->Output(0, TensorShape({batch_size, num_heads, 1, d_v})); + Tensor* present_state = context->Output(1, s_shape); + + const T* q_data = query->Data(); + const T* k_data = key->Data(); + const T* v_data = value->Data(); + const T* s_data = past_state->Data(); + T* out_data = output->MutableData(); + T* pstate_data = present_state->MutableData(); + + const int state_elems = batch_size * num_heads * d_k * d_v; + std::vector state_f(state_elems); + for (int i = 0; i < state_elems; i++) { + state_f[i] = ToFloat(s_data[i]); + } + + for (int b = 0; b < batch_size; b++) { + for (int h = 0; h < num_heads; h++) { + const int bh = b * num_heads + h; + + std::vector q_f(d_k), k_f(d_k), v_f(d_v); + for (int i = 0; i < d_k; i++) { + q_f[i] = ToFloat(q_data[bh * d_k + i]); + k_f[i] = ToFloat(k_data[bh * d_k + i]); + } + for (int j = 0; j < d_v; j++) { + v_f[j] = ToFloat(v_data[bh * d_v + j]); + } + + std::vector decay_f(d_k, 1.0f); + if (decay != nullptr) { + const T* decay_data = decay->Data(); + if (decay_broadcasted) { + for (int i = 0; i < d_k; i++) { + decay_f[i] = expf(ToFloat(decay_data[bh * d_k + i])); + } + } else { + const float scalar = expf(ToFloat(decay_data[bh])); + std::fill(decay_f.begin(), decay_f.end(), scalar); + } + } + + float beta_val = 0.0f; + if (beta != nullptr) { + beta_val = ToFloat(beta->Data()[bh]); + } + + std::vector out_f(d_v, 0.0f); + ComputeSingleHead( + q_f.data(), k_f.data(), v_f.data(), + state_f.data() + bh * d_k * d_v, + decay_f.data(), beta_val, + out_f.data(), + d_k, d_v, scale); + + for (int j = 0; j < d_v; j++) { + StoreFloat(out_f[j], out_data[bh * d_v + j]); + } + } + } + + for (int i = 0; i < state_elems; i++) { + StoreFloat(state_f[i], pstate_data[i]); + } + + return Status::OK(); +} + +} +} diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention_recurrent.h b/onnxruntime/contrib_ops/cpu/bert/linear_attention_recurrent.h new file mode 100644 index 0000000000000..26aa2405a0f60 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention_recurrent.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +// Mirrors LinearAttentionUpdateRule from the CUDA header — +// kept in the CPU namespace so both can coexist in the same build. +enum class LinearAttentionUpdateRule { + kLinear, // S_t = S_{t-1} + k_t ⊗ v_t + kGated, // S_t = exp(g_t) · S_{t-1} + k_t ⊗ v_t + kDelta, // S_t = S_{t-1} + β_t · k_t ⊗ (v_t − S_{t-1}^T k_t) + kGatedDelta, // S_t = exp(g_t)·S_{t-1} + β_t·k_t ⊗ (v_t − exp(g_t)·S_{t-1}^T k_t) +}; + +template +class LinearAttentionRecurrent final : public OpKernel { + public: + explicit LinearAttentionRecurrent(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + LinearAttentionUpdateRule update_rule_; + float scale_; + + // Compute one (batch, head) recurrent step entirely in float32. + // state is updated in-place; output receives the query readout. + void ComputeSingleHead( + const float* q, // [d_k] + const float* k, // [d_k] + const float* v, // [d_v] + float* state, // [d_k * d_v], updated in-place + const float* decay, // [d_k] — already exp(·), or nullptr for linear/delta + float beta_val, // scalar beta, 0 for linear/gated + float* output, // [d_v] + int d_k, int d_v, float scale) const; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0949ad6c36f58..7922a4c9b2b33 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -163,6 +163,15 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inver class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinearAttentionRecurrent); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, LinearAttentionRecurrent); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BFloat16, LinearAttentionRecurrent); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinearAttentionChunkParallel); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, LinearAttentionChunkParallel); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BFloat16, LinearAttentionChunkParallel); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CausalConv1DWithState); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, CausalConv1DWithState); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BFloat16, CausalConv1DWithState); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -390,6 +399,15 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo,