diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 88f6a5a2b1be6..b8b7fa06d7f12 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -564,6 +564,7 @@ Do not modify directly.* |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| +|CausalConvWithState|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* past_state:**T**
*out* output:**T**
*out* present_state:**T**|1+|**T** = tensor(float)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |CropAndResize|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*in* crop_size:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int32)| |DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| @@ -584,6 +585,7 @@ Do not modify directly.* |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|LinearAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_state:**S**
*in* decay:**T**
*in* beta:**T**
*out* output:**T**
*out* present_state:**S**|1+|**T** = tensor(float)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| @@ -1032,6 +1034,7 @@ Do not modify directly.* |BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BitmaskBiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| |BitmaskDropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T3**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)
**T3** = tensor(uint32)| +|CausalConvWithState|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* past_state:**T**
*out* output:**T**
*out* present_state:**T**|1+|**T** = tensor(float), tensor(float16)| |ComplexMul|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| @@ -1056,6 +1059,7 @@ Do not modify directly.* |GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|LinearAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_state:**S**
*in* decay:**T**
*in* beta:**T**
*out* output:**T**
*out* present_state:**S**|1+|**T** = tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc new file mode 100644 index 0000000000000..72a97c60df84f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/causal_conv_with_state.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/common/safeint.h" +#include "core/platform/threadpool.h" + +#include +#include +#include + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +// Note: Only float is registered for CPU. The op schema allows float16/bfloat16 +// for CUDA compatibility, but the CPU kernel computes in float32 internally. +// MLFloat16 CPU support would require input/output conversion buffers +// (MlasConvertHalfToFloatBuffer / MlasConvertFloatToHalfBuffer). +// +// MLAS usage: No MLAS kernels are used currently. The depthwise causal conv +// is implemented with scalar loops. Potential future optimization: use +// MlasConv1D or vectorized MLAS routines for the 1D convolution. +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + CausalConvWithState, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + CausalConvWithState); + +REGISTER_KERNEL_TYPED(float) + +template +CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) : OpKernel(info) { + int64_t ndim = info.GetAttrOrDefault("ndim", 1); + ORT_ENFORCE(ndim == 1, "CPU CausalConvWithState only supports ndim=1"); + ndim_ = static_cast(ndim); + + activation_ = info.GetAttrOrDefault("activation", "none"); + ORT_ENFORCE(activation_ == "none" || activation_ == "silu" || activation_ == "swish", + "activation must be one of: none, silu, swish"); +} + +namespace { + +inline float ApplySilu(float x) { + return x / (1.0f + std::exp(-x)); +} + +template +inline void ProcessChannelDecodeFixedK( + const float* past_row, + const float* input_val, + const float* w, + float bias_val, + bool apply_silu, + float* out_val, + float* present_row) { + constexpr int pad = K - 1; + float sum = bias_val; + if (past_row != nullptr) { + for (int k = 0; k < pad; ++k) { + sum += w[k] * past_row[k]; + } + } + sum += w[pad] * input_val[0]; + + if (apply_silu) { + sum = ApplySilu(sum); + } + out_val[0] = sum; + + if constexpr (pad > 0) { + if (past_row != nullptr) { + if constexpr (pad > 1) { + std::memcpy(present_row, past_row + 1, static_cast(pad - 1) * sizeof(float)); + } + } else { + if constexpr (pad > 1) { + std::memset(present_row, 0, static_cast(pad - 1) * sizeof(float)); + } + } + present_row[pad - 1] = input_val[0]; + } +} + +// Decode fast-path: L=1, no padded buffer needed. +// The "visible window" for position 0 is [past_state(K-1 values), input(1 value)] = K values. +// Compute dot(weight, window), shift state left by 1, append new input. +void ProcessChannelDecode( + const float* past_row, // past_state for this (b,c): [K-1] or nullptr + const float* input_val, // &input[b,c,0] — single value + const float* w, // weight for this channel: [K] + float bias_val, + bool apply_silu, + float* out_val, // &output[b,c,0] — single value + float* present_row, // present_state for this (b,c): [K-1] + int64_t K) { + int64_t pad = K - 1; + + // Dot product over the window: [past_state..., input] + float sum = bias_val; + // First K-1 elements come from past_state + if (past_row != nullptr) { + for (int64_t k = 0; k < pad; ++k) { + sum += w[k] * past_row[k]; + } + } + // Last element is the current input + sum += w[pad] * input_val[0]; + + if (apply_silu) { + sum = ApplySilu(sum); + } + out_val[0] = sum; + + // Update present_state: shift past_state left by 1, append input + if (pad > 0) { + if (past_row != nullptr && pad > 1) { + std::memcpy(present_row, past_row + 1, static_cast(pad - 1) * sizeof(float)); + } else if (pad > 1) { + std::memset(present_row, 0, static_cast(pad - 1) * sizeof(float)); + } + present_row[pad - 1] = input_val[0]; + } +} + +// Prefill path: L>1, uses padded buffer for the convolution window. +void ProcessChannelPrefill( + const float* past_row, // past_state for this (b,c): [K-1] or nullptr + const float* in_row, // input for this (b,c): [L] + const float* w, // weight for this channel: [K] + float bias_val, + bool apply_silu, + float* out_row, // output for this (b,c): [L] + float* present_row, // present_state for this (b,c): [K-1] + float* padded_row, // scratch buffer: [K-1 + L] + int64_t L, + int64_t K) { + int64_t pad = K - 1; + int64_t padded_len = pad + L; + + // Build padded window: [past_state | input] + if (past_row != nullptr) { + std::memcpy(padded_row, past_row, static_cast(pad) * sizeof(float)); + } else { + std::memset(padded_row, 0, static_cast(pad) * sizeof(float)); + } + std::memcpy(padded_row + pad, in_row, static_cast(L) * sizeof(float)); + + // Depthwise 1D convolution + for (int64_t l = 0; l < L; ++l) { + float sum = bias_val; + for (int64_t k = 0; k < K; ++k) { + sum += w[k] * padded_row[l + k]; + } + if (apply_silu) { + sum = ApplySilu(sum); + } + out_row[l] = sum; + } + + // Save present_state: last K-1 elements of (past_state | input) + std::memcpy(present_row, padded_row + padded_len - pad, static_cast(pad) * sizeof(float)); +} + +} // anonymous namespace + +template +Status CausalConvWithState::Compute(OpKernelContext* context) const { + const Tensor* input_tensor = context->Input(0); + const Tensor* weight_tensor = context->Input(1); + const Tensor* bias_tensor = context->Input(2); // optional + const Tensor* past_state_tensor = context->Input(3); // optional + + ORT_RETURN_IF_NOT(input_tensor != nullptr, "input is required"); + ORT_RETURN_IF_NOT(weight_tensor != nullptr, "weight is required"); + + const auto& input_shape = input_tensor->Shape(); + const auto& weight_shape = weight_tensor->Shape(); + + ORT_RETURN_IF_NOT(static_cast(input_shape.NumDimensions()) == 2 + ndim_, + "input must have ", 2 + ndim_, " dimensions for ndim=", ndim_); + ORT_RETURN_IF_NOT(static_cast(weight_shape.NumDimensions()) == 2 + ndim_, + "weight must have ", 2 + ndim_, " dimensions for ndim=", ndim_); + + const int64_t batch_size = input_shape[0]; + const int64_t channels = input_shape[1]; + + ORT_RETURN_IF_NOT(weight_shape[0] == channels, "weight channels must match input channels"); + ORT_RETURN_IF_NOT(weight_shape[1] == 1, "weight must be depthwise (group=1)"); + + if (bias_tensor != nullptr) { + ORT_RETURN_IF_NOT(bias_tensor->Shape().NumDimensions() == 1 && + bias_tensor->Shape()[0] == channels, + "bias must be 1D with size C"); + } + + // ==== ndim=1 implementation: (B, C, L) with kernel (C, 1, K) ==== + if (ndim_ == 1) { + const int64_t L = input_shape[2]; + const int64_t K = weight_shape[2]; + const int64_t pad = K - 1; + + if (past_state_tensor != nullptr) { + const auto& ps_shape = past_state_tensor->Shape(); + ORT_RETURN_IF_NOT(ps_shape.NumDimensions() == 3 && + ps_shape[0] == batch_size && + ps_shape[1] == channels && + ps_shape[2] == pad, + "past_state must be (B, C, K-1)"); + } + + // ==== Allocate outputs ==== + Tensor* output_tensor = context->Output(0, input_shape); + float* output_data = output_tensor->MutableData(); + + TensorShape state_shape({batch_size, channels, pad}); + Tensor* present_state_tensor = context->Output(1, state_shape); + float* present_data = present_state_tensor->MutableData(); + + const float* input_data = input_tensor->Data(); + const float* weight_data = weight_tensor->Data(); + const float* bias_data = bias_tensor ? bias_tensor->Data() : nullptr; + const float* past_data = past_state_tensor ? past_state_tensor->Data() : nullptr; + bool apply_silu = (activation_ == "silu" || activation_ == "swish"); + + // ==== Thread-parallel over (batch, channel) pairs ==== + // Depthwise conv: each channel is fully independent. + int64_t total_tasks = batch_size * channels; + double cost_per_task = static_cast(L * K); // FLOPs per channel + + auto* tp = context->GetOperatorThreadPool(); + + if (L == 1) { + // ==== Decode fast-path: no padded buffer needed ==== + ThreadPool::TryParallelFor( + tp, + static_cast(total_tasks), + cost_per_task, + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + for (std::ptrdiff_t task = first; task < last; ++task) { + int64_t b = task / channels; + int64_t c = task % channels; + + const float* past_row = past_data + ? past_data + (b * channels + c) * pad + : nullptr; + const float* input_val = input_data + (b * channels + c) * L; + const float* w = weight_data + c * K; + float bias_val = bias_data ? bias_data[c] : 0.0f; + float* out_val = output_data + (b * channels + c) * L; + float* present_row = present_data + (b * channels + c) * pad; + switch (K) { + case 2: + ProcessChannelDecodeFixedK<2>(past_row, input_val, w, bias_val, apply_silu, + out_val, present_row); + break; + case 3: + ProcessChannelDecodeFixedK<3>(past_row, input_val, w, bias_val, apply_silu, + out_val, present_row); + break; + case 4: + ProcessChannelDecodeFixedK<4>(past_row, input_val, w, bias_val, apply_silu, + out_val, present_row); + break; + case 5: + ProcessChannelDecodeFixedK<5>(past_row, input_val, w, bias_val, apply_silu, + out_val, present_row); + break; + default: + ProcessChannelDecode(past_row, input_val, w, bias_val, apply_silu, + out_val, present_row, K); + break; + } + } + }); + } else { + // ==== Prefill path: uses per-thread scratch buffer ==== + ThreadPool::TryParallelFor( + tp, + static_cast(total_tasks), + cost_per_task, + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + // Per-thread scratch buffer for padded input + std::vector padded_buf(static_cast(pad + L)); + + for (std::ptrdiff_t task = first; task < last; ++task) { + int64_t b = task / channels; + int64_t c = task % channels; + + const float* past_row = past_data + ? past_data + (b * channels + c) * pad + : nullptr; + const float* in_row = input_data + (b * channels + c) * L; + const float* w = weight_data + c * K; + float bias_val = bias_data ? bias_data[c] : 0.0f; + float* out_row = output_data + (b * channels + c) * L; + float* present_row = present_data + (b * channels + c) * pad; + + ProcessChannelPrefill(past_row, in_row, w, bias_val, apply_silu, + out_row, present_row, padded_buf.data(), L, K); + } + }); + } + + return Status::OK(); + } + + // ==== ndim=2 or ndim=3: not yet implemented ==== + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "CausalConvWithState with ndim=", ndim_, + " is not yet implemented. " + "Currently only ndim=1 is supported."); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h new file mode 100644 index 0000000000000..e859f69677e80 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +#include + +namespace onnxruntime { +namespace contrib { + +template +class CausalConvWithState final : public OpKernel { + public: + CausalConvWithState(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + int ndim_; + std::string activation_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/cpu/bert/linear_attention.cc new file mode 100644 index 0000000000000..052e7df8bda14 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention.cc @@ -0,0 +1,509 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/linear_attention.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/common/safeint.h" +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" + +#include +#include + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +// Note: Only float is registered for CPU. The op schema allows float16/bfloat16 +// for CUDA compatibility, but the CPU kernel computes in float32 internally. +// MLFloat16 CPU support would require input/output conversion buffers +// (MlasConvertHalfToFloatBuffer / MlasConvertFloatToHalfBuffer). +// +// MLAS usage: MlasGemm is used for retrieval (S^T @ k), state update (k ⊗ delta), +// and query readout (S^T @ q) when d_k * d_v >= 4096. Smaller dimensions use +// scalar loops to avoid MLAS overhead. +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + LinearAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + LinearAttention); + +REGISTER_KERNEL_TYPED(float) + +template +LinearAttention::LinearAttention(const OpKernelInfo& info) : OpKernel(info) { + int64_t q_num_heads = 0; + ORT_ENFORCE(info.GetAttr("q_num_heads", &q_num_heads).IsOK() && q_num_heads > 0, + "q_num_heads must be a positive integer"); + q_num_heads_ = static_cast(q_num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0, + "kv_num_heads must be a positive integer"); + kv_num_heads_ = static_cast(kv_num_heads); + + update_rule_ = info.GetAttrOrDefault("update_rule", "gated_delta"); + ORT_ENFORCE(update_rule_ == "linear" || update_rule_ == "gated" || + update_rule_ == "delta" || update_rule_ == "gated_delta", + "update_rule must be one of: linear, gated, delta, gated_delta"); + + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + int64_t chunk_size = info.GetAttrOrDefault("chunk_size", 64); + // chunk_size_ reserved for future chunk-parallel prefill algorithm; not yet used. + chunk_size_ = static_cast(chunk_size); +} + +namespace { + +// Process a single (batch, kv_head) pair across all timesteps. +// This is the hot inner loop — called once per (b, h_kv) combination, +// potentially from different threads. +void ProcessHead( + float* S, // State matrix [d_k, d_v], in-place updated + const float* q_data, // Packed Q: (B, T, H_q*d_k) + const float* k_data, // Packed K: (B, T, n_k*d_k) + const float* v_data, // Packed V: (B, T, H_kv*d_v) + const float* decay_data, // Decay gates (may be nullptr) + const float* beta_data, // Beta rates (may be nullptr) + float* output_data, // Output: (B, T, H_q*d_v) + float* retrieved_buf, // Pre-allocated scratch buffer [d_v] + int64_t batch_idx, + int h_kv, + int h_k, // Key head index (may differ from h_kv when n_k != kv_num_heads) + int64_t seq_len, + int64_t d_k, + int64_t d_v, + int q_num_heads, + int kv_num_heads, + int n_k_heads, + int heads_per_group, + int64_t output_hidden, + float scale, + bool needs_decay, + bool decay_per_key_dim, + bool needs_beta, + bool beta_per_head, + bool needs_retrieval) { + const size_t dk = static_cast(d_k); + const size_t dv = static_cast(d_v); + const bool use_mlas = (d_k * d_v >= 4096); + + for (int64_t t = 0; t < seq_len; ++t) { + // Pointers into packed 3D tensors at position [batch_idx, t, head*d] + const float* kt = k_data + (batch_idx * seq_len + t) * (n_k_heads * d_k) + h_k * d_k; + const float* vt = v_data + (batch_idx * seq_len + t) * (kv_num_heads * d_v) + h_kv * d_v; + + // ---- Step 1: Apply decay S *= exp(g_t) ---- + if (needs_decay) { + if (decay_per_key_dim) { + const float* gt = decay_data + (batch_idx * seq_len + t) * (kv_num_heads * d_k) + h_kv * d_k; + for (int64_t i = 0; i < d_k; ++i) { + float exp_g = std::exp(gt[i]); + // Scale row i of S by exp_g + for (int64_t j = 0; j < d_v; ++j) { + S[i * d_v + j] *= exp_g; + } + } + } else { + const float* gt = decay_data + (batch_idx * seq_len + t) * kv_num_heads + h_kv; + float exp_g = std::exp(gt[0]); + for (int64_t i = 0; i < d_k * d_v; ++i) { + S[i] *= exp_g; + } + } + } + + // ---- Step 2: Retrieval = S^T @ k_t ---- + if (needs_retrieval) { + if (use_mlas) { + MlasGemm( + CblasNoTrans, + CblasNoTrans, + 1, + dv, + dk, + 1.0f, + kt, + dk, + S, + dv, + 0.0f, + retrieved_buf, + dv, + nullptr, + nullptr); + } else { + for (int64_t j = 0; j < d_v; ++j) { + float acc = 0.0f; + for (int64_t i = 0; i < d_k; ++i) { + acc += S[i * d_v + j] * kt[i]; + } + retrieved_buf[static_cast(j)] = acc; + } + } + } + + // ---- Step 3: State update ---- + if (needs_beta) { + float bt; + if (beta_per_head) { + bt = beta_data[(batch_idx * seq_len + t) * kv_num_heads + h_kv]; + } else { + bt = beta_data[(batch_idx * seq_len + t) * 1]; + } + // Compute delta = beta * (v_t - retrieved) in-place into retrieved_buf + for (size_t j = 0; j < dv; ++j) { + retrieved_buf[j] = bt * (vt[j] - retrieved_buf[j]); + } + // S += k_t outer delta + if (use_mlas) { + MlasGemm( + CblasNoTrans, + CblasNoTrans, + dk, + dv, + 1, + 1.0f, + kt, + 1, + retrieved_buf, + dv, + 1.0f, + S, + dv, + nullptr, + nullptr); + } else { + for (int64_t i = 0; i < d_k; ++i) { + float* s_row = S + i * d_v; + const float ki = kt[i]; + for (int64_t j = 0; j < d_v; ++j) { + s_row[j] += ki * retrieved_buf[static_cast(j)]; + } + } + } + } else { + // linear/gated: S += k_t outer v_t + if (use_mlas) { + MlasGemm( + CblasNoTrans, + CblasNoTrans, + dk, + dv, + 1, + 1.0f, + kt, + 1, + vt, + dv, + 1.0f, + S, + dv, + nullptr, + nullptr); + } else { + for (int64_t i = 0; i < d_k; ++i) { + float* s_row = S + i * d_v; + const float ki = kt[i]; + for (int64_t j = 0; j < d_v; ++j) { + s_row[j] += ki * vt[j]; + } + } + } + } + + // ---- Step 4: Query readout for each q head in this kv group ---- + // o_t = scale * q_t^T @ S -> [1, d_v] + // + // Standard GQA (heads_per_group > 0): multiple Q heads share this KV state. + // Inverse GQA (heads_per_group == 0): multiple KV states share one Q head. + if (heads_per_group > 0) { + for (int g = 0; g < heads_per_group; ++g) { + int h_q = h_kv * heads_per_group + g; + const float* qt = q_data + (batch_idx * seq_len + t) * (q_num_heads * d_k) + h_q * d_k; + float* ot = output_data + (batch_idx * seq_len + t) * output_hidden + h_q * d_v; + + if (use_mlas) { + // Use alpha=1.0 to hit the MLAS M=1 gemv fast path, then scale output. + MlasGemm( + CblasNoTrans, + CblasNoTrans, + 1, + dv, + dk, + 1.0f, + qt, + dk, + S, + dv, + 0.0f, + ot, + dv, + nullptr, + nullptr); + if (scale != 1.0f) { + for (size_t j = 0; j < dv; ++j) { + ot[j] *= scale; + } + } + } else { + for (int64_t j = 0; j < d_v; ++j) { + float acc = 0.0f; + for (int64_t i = 0; i < d_k; ++i) { + acc += qt[i] * S[i * d_v + j]; + } + ot[j] = scale * acc; + } + } + } + } else { + // Inverse GQA: this KV head's Q is determined by h_kv * q_num / kv_num + int h_q = h_kv * q_num_heads / kv_num_heads; + const float* qt = q_data + (batch_idx * seq_len + t) * (q_num_heads * d_k) + h_q * d_k; + float* ot = output_data + (batch_idx * seq_len + t) * output_hidden + h_kv * d_v; + + if (use_mlas) { + // Use alpha=1.0 to hit the MLAS M=1 gemv fast path, then scale output. + MlasGemm( + CblasNoTrans, + CblasNoTrans, + 1, + dv, + dk, + 1.0f, + qt, + dk, + S, + dv, + 0.0f, + ot, + dv, + nullptr, + nullptr); + if (scale != 1.0f) { + for (size_t j = 0; j < dv; ++j) { + ot[j] *= scale; + } + } + } else { + for (int64_t j = 0; j < d_v; ++j) { + float acc = 0.0f; + for (int64_t i = 0; i < d_k; ++i) { + acc += qt[i] * S[i * d_v + j]; + } + ot[j] = scale * acc; + } + } + } + } +} + +} // anonymous namespace + +template +Status LinearAttention::Compute(OpKernelContext* context) const { + // ==== Input Retrieval ==== + const Tensor* query_tensor = context->Input(0); + const Tensor* key_tensor = context->Input(1); // optional + const Tensor* value_tensor = context->Input(2); // optional + const Tensor* past_state_tensor = context->Input(3); // optional + const Tensor* decay_tensor = context->Input(4); // optional + const Tensor* beta_tensor = context->Input(5); // optional + + ORT_RETURN_IF_NOT(query_tensor != nullptr, "query input is required"); + + const auto& query_shape = query_tensor->Shape(); + ORT_RETURN_IF_NOT(query_shape.NumDimensions() == 3, + "query must be 3D [B, T, H*D], got ", query_shape.NumDimensions(), "D"); + + const int64_t batch_size = query_shape[0]; + const int64_t seq_len = query_shape[1]; + const int64_t query_hidden = query_shape[2]; + + // ==== Determine d_k and d_v ==== + ORT_RETURN_IF_NOT(key_tensor != nullptr && value_tensor != nullptr, + "key and value inputs are required"); + + int64_t d_k, d_v; + int n_k_heads; + const float* q_data; + const float* k_data; + const float* v_data; + + { + const auto& key_shape = key_tensor->Shape(); + const auto& value_shape = value_tensor->Shape(); + ORT_RETURN_IF_NOT(key_shape.NumDimensions() == 3 && value_shape.NumDimensions() == 3, + "key and value must be 3D"); + ORT_RETURN_IF_NOT(key_shape[0] == batch_size && value_shape[0] == batch_size, + "batch size mismatch"); + ORT_RETURN_IF_NOT(key_shape[1] == seq_len && value_shape[1] == seq_len, + "sequence length mismatch"); + + d_k = query_hidden / q_num_heads_; + ORT_RETURN_IF_NOT(query_hidden == q_num_heads_ * d_k, + "query hidden size must be divisible by q_num_heads"); + ORT_RETURN_IF_NOT(key_shape[2] % d_k == 0, + "key hidden size must be divisible by d_k"); + n_k_heads = static_cast(key_shape[2] / d_k); + d_v = value_shape[2] / kv_num_heads_; + ORT_RETURN_IF_NOT(value_shape[2] == kv_num_heads_ * d_v, + "value hidden size must be divisible by kv_num_heads"); + + q_data = query_tensor->Data(); + k_data = key_tensor->Data(); + v_data = value_tensor->Data(); + } + + // ==== Determine scale ==== + float s = scale_; + if (s == 0.0f) { + s = 1.0f / std::sqrt(static_cast(d_k)); + } + + // ==== Validate optional inputs based on update_rule ==== + bool needs_decay = (update_rule_ == "gated" || update_rule_ == "gated_delta"); + bool needs_beta = (update_rule_ == "delta" || update_rule_ == "gated_delta"); + bool needs_retrieval = (update_rule_ == "delta" || update_rule_ == "gated_delta"); + + ORT_RETURN_IF_NOT(!needs_decay || decay_tensor != nullptr, + "decay input is required for update_rule=", update_rule_); + ORT_RETURN_IF_NOT(!needs_beta || beta_tensor != nullptr, + "beta input is required for update_rule=", update_rule_); + + const float* decay_data = decay_tensor ? decay_tensor->Data() : nullptr; + const float* beta_data = beta_tensor ? beta_tensor->Data() : nullptr; + + bool decay_per_key_dim = false; + if (decay_tensor != nullptr) { + const auto& decay_shape = decay_tensor->Shape(); + ORT_RETURN_IF_NOT(decay_shape.NumDimensions() == 3, + "decay must be rank 3 (B, T, ...), got rank ", decay_shape.NumDimensions()); + ORT_RETURN_IF_NOT(decay_shape[0] == batch_size && decay_shape[1] == seq_len, + "decay dims 0/1 must match (B=", batch_size, ", T=", seq_len, + "), got (", decay_shape[0], ", ", decay_shape[1], ")"); + int64_t decay_last = decay_shape[2]; + if (decay_last == kv_num_heads_ * d_k) { + decay_per_key_dim = true; + } else { + ORT_RETURN_IF_NOT(decay_last == kv_num_heads_, + "decay last dim must be H_kv or H_kv*d_k"); + } + } + + bool beta_per_head = false; + if (beta_tensor != nullptr) { + const auto& beta_shape = beta_tensor->Shape(); + ORT_RETURN_IF_NOT(beta_shape.NumDimensions() == 3, + "beta must be rank 3 (B, T, ...), got rank ", beta_shape.NumDimensions()); + ORT_RETURN_IF_NOT(beta_shape[0] == batch_size && beta_shape[1] == seq_len, + "beta dims 0/1 must match (B=", batch_size, ", T=", seq_len, + "), got (", beta_shape[0], ", ", beta_shape[1], ")"); + int64_t beta_last = beta_shape[2]; + if (beta_last == kv_num_heads_) { + beta_per_head = true; + } else { + ORT_RETURN_IF_NOT(beta_last == 1, "beta last dim must be H_kv or 1"); + } + } + + // ==== Initialize state: write directly into output present_state ==== + // present_state: (B, H_kv, d_k, d_v) + TensorShape state_shape({batch_size, static_cast(kv_num_heads_), d_k, d_v}); + Tensor* present_state_tensor = context->Output(1, state_shape); + float* state_data = present_state_tensor->MutableData(); + int64_t state_per_head = d_k * d_v; + int64_t total_state = batch_size * kv_num_heads_ * state_per_head; + + if (past_state_tensor != nullptr) { + const auto& ps_shape = past_state_tensor->Shape(); + ORT_RETURN_IF_NOT(ps_shape.NumDimensions() == 4 && + ps_shape[0] == batch_size && + ps_shape[1] == kv_num_heads_ && + ps_shape[2] == d_k && + ps_shape[3] == d_v, + "past_state must be (B, H_kv, d_k, d_v)"); + const float* ps_data = past_state_tensor->Data(); + std::memcpy(state_data, ps_data, static_cast(total_state) * sizeof(float)); + } else { + std::memset(state_data, 0, static_cast(total_state) * sizeof(float)); + } + + // ==== Allocate output ==== + // Output hidden dim: max(q_num_heads, kv_num_heads) * d_v + // Standard GQA: q_num_heads * d_v; Inverse GQA: kv_num_heads * d_v + int64_t output_hidden = std::max(q_num_heads_, kv_num_heads_) * d_v; + TensorShape output_shape({batch_size, seq_len, output_hidden}); + Tensor* output_tensor = context->Output(0, output_shape); + float* output_data = output_tensor->MutableData(); + + // ==== GQA head mapping ==== + // Standard GQA: q_num_heads >= kv_num_heads, multiple Q heads per KV group. + // Inverse GQA: q_num_heads < kv_num_heads (e.g., Qwen3.5 9B: n_k=16, n_kv=32). + // Also n_k_heads may differ from both (K has its own head count). + int heads_per_group; // Q heads per KV group (0 if inverse GQA) + if (q_num_heads_ >= kv_num_heads_) { + ORT_RETURN_IF_NOT(q_num_heads_ % kv_num_heads_ == 0, + "q_num_heads must be divisible by kv_num_heads"); + heads_per_group = q_num_heads_ / kv_num_heads_; + } else { + ORT_RETURN_IF_NOT(kv_num_heads_ % q_num_heads_ == 0, + "kv_num_heads must be divisible by q_num_heads (inverse GQA)"); + heads_per_group = 0; // signals inverse GQA to ProcessHead + } + + // K-to-KV head mapping: when n_k < kv_num_heads, multiple KV heads share one K head + ORT_RETURN_IF_NOT(kv_num_heads_ % n_k_heads == 0, + "kv_num_heads must be divisible by n_k_heads"); + int kv_per_k_head = kv_num_heads_ / n_k_heads; + + // ==== Thread-parallel over (batch, kv_head) pairs ==== + // Each (b, h_kv) pair is fully independent — the state matrix for each + // head is disjoint, and the sequential token dependency is within a + // single head only. This gives us batch_size * kv_num_heads parallel tasks. + int64_t total_tasks = batch_size * kv_num_heads_; + + // Cost estimate: per task processes seq_len tokens, each doing ~3*d_k*d_v FLOPs + double cost_per_task = static_cast(seq_len) * static_cast(d_k * d_v) * 3.0; + + auto* tp = context->GetOperatorThreadPool(); + + ThreadPool::TryParallelFor( + tp, + static_cast(total_tasks), + cost_per_task, + [&](std::ptrdiff_t first, std::ptrdiff_t last) { + // Pre-allocate scratch buffer per thread-batch to avoid malloc in hot loop + std::vector retrieved_buf(static_cast(d_v)); + + for (std::ptrdiff_t task = first; task < last; ++task) { + int64_t b = task / kv_num_heads_; + int h_kv = static_cast(task % kv_num_heads_); + int h_k = h_kv / kv_per_k_head; // map KV head to K head + + float* S = state_data + (b * kv_num_heads_ + h_kv) * state_per_head; + + ProcessHead( + S, q_data, k_data, v_data, decay_data, beta_data, output_data, + retrieved_buf.data(), + b, h_kv, h_k, seq_len, d_k, d_v, + q_num_heads_, kv_num_heads_, n_k_heads, heads_per_group, output_hidden, + s, needs_decay, decay_per_key_dim, needs_beta, beta_per_head, + needs_retrieval); + } + }); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/linear_attention.h b/onnxruntime/contrib_ops/cpu/bert/linear_attention.h new file mode 100644 index 0000000000000..9aaa9f80a3369 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/linear_attention.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +#include + +namespace onnxruntime { +namespace contrib { + +template +class LinearAttention final : public OpKernel { + public: + LinearAttention(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + int q_num_heads_; + int kv_num_heads_; + std::string update_rule_; + float scale_; + int chunk_size_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0949ad6c36f58..8588449a3895c 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -29,6 +29,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SparseAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SparseAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, LinearAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CausalConvWithState); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); #if !defined(DISABLE_GENERATION_OPS) @@ -319,6 +321,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, #if !defined(DISABLE_GENERATION_OPS) diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc new file mode 100644 index 0000000000000..e60a87afe5f25 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/causal_conv_with_state.h" +#include "contrib_ops/cuda/bert/causal_conv_with_state_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; // CudaKernel, Stream, GetDeviceProp, ToCudaType + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + CausalConvWithState, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + CausalConvWithState); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) : CudaKernel(info) { + int64_t ndim = info.GetAttrOrDefault("ndim", 1); + ORT_ENFORCE(ndim == 1, "CUDA CausalConvWithState only supports ndim=1"); + ndim_ = static_cast(ndim); + + activation_ = info.GetAttrOrDefault("activation", "none"); + ORT_ENFORCE(activation_ == "none" || activation_ == "silu" || activation_ == "swish", + "activation must be one of: none, silu, swish"); +} + +template +Status CausalConvWithState::ComputeInternal(OpKernelContext* context) const { + const Tensor* input_tensor = context->Input(0); + const Tensor* weight_tensor = context->Input(1); + const Tensor* bias_tensor = context->Input(2); // optional + const Tensor* past_state_tensor = context->Input(3); // optional + + ORT_RETURN_IF_NOT(input_tensor != nullptr, "input is required"); + ORT_RETURN_IF_NOT(weight_tensor != nullptr, "weight is required"); + + const auto& input_shape = input_tensor->Shape(); + const auto& weight_shape = weight_tensor->Shape(); + + // Validate input rank and weight rank + ORT_RETURN_IF_NOT(input_shape.NumDimensions() == 3, + "input must be rank 3 (batch, channels, length), got rank ", input_shape.NumDimensions()); + ORT_RETURN_IF_NOT(weight_shape.NumDimensions() == 3, + "weight must be rank 3 (channels, 1, kernel_size), got rank ", weight_shape.NumDimensions()); + + const int batch_size = static_cast(input_shape[0]); + const int channels = static_cast(input_shape[1]); + const int L = static_cast(input_shape[2]); + const int K = static_cast(weight_shape[2]); + const int pad = K - 1; + + // Validate weight shape compatibility + ORT_RETURN_IF_NOT(weight_shape[0] == channels, + "weight[0] (", weight_shape[0], ") must match input channels (", channels, ")"); + ORT_RETURN_IF_NOT(weight_shape[1] == 1, + "weight[1] must be 1 for depthwise convolution, got ", weight_shape[1]); + + // Validate optional bias shape + if (bias_tensor != nullptr) { + const auto& bias_shape = bias_tensor->Shape(); + ORT_RETURN_IF_NOT(bias_shape.NumDimensions() == 1 && bias_shape[0] == channels, + "bias must have shape (", channels, "), got ", bias_shape.ToString()); + } + + // Validate optional past_state shape + if (past_state_tensor != nullptr) { + const auto& past_shape = past_state_tensor->Shape(); + ORT_RETURN_IF_NOT(past_shape.NumDimensions() == 3, + "past_state must be rank 3 (batch, channels, kernel_size-1), got rank ", past_shape.NumDimensions()); + ORT_RETURN_IF_NOT(past_shape[0] == batch_size && past_shape[1] == channels && past_shape[2] == pad, + "past_state shape mismatch: expected (", batch_size, ", ", channels, ", ", pad, + "), got (", past_shape[0], ", ", past_shape[1], ", ", past_shape[2], ")"); + } + + // Allocate outputs + Tensor* output_tensor = context->Output(0, input_shape); + TensorShape state_shape({batch_size, channels, pad}); + Tensor* present_state_tensor = context->Output(1, state_shape); + + // Note: no need to zero-initialize present_state — the kernel writes all + // positions unconditionally. When past_state is null, the kernel uses + // zeros for the padding region internally. + // Note: past_state pointer is passed to kernel; kernel reads it directly + + bool apply_silu = (activation_ == "silu" || activation_ == "swish"); + + typedef typename OrtToCudaType::type CudaT; + + return LaunchCausalConvWithStateKernel( + Stream(context), + reinterpret_cast(input_tensor->Data()), + reinterpret_cast(weight_tensor->Data()), + bias_tensor ? reinterpret_cast(bias_tensor->Data()) : nullptr, + past_state_tensor ? reinterpret_cast(past_state_tensor->Data()) : nullptr, + reinterpret_cast(output_tensor->MutableData()), + reinterpret_cast(present_state_tensor->MutableData()), + batch_size, + channels, + L, + K, + apply_silu, + GetDeviceProp().maxThreadsPerBlock); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h new file mode 100644 index 0000000000000..37a9c29b5e749 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +class CausalConvWithState final : public onnxruntime::cuda::CudaKernel { + public: + CausalConvWithState(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int ndim_; + std::string activation_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu new file mode 100644 index 0000000000000..87774f7205bc9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu @@ -0,0 +1,459 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Fused causal depthwise conv1d CUDA kernel with stateful carry and optional SiLU activation. +// +// Design: One thread block per (batch, channel). Two execution paths: +// +// 1. Decode (L=1): The convolution window is [past_state(K-1), input(1)]. +// Load K values into registers, compute a single dot product, shift state. +// One thread block does the entire operation — zero shared memory needed. +// +// 2. Prefill (L>1): Load past_state + input into shared memory as a padded buffer, +// then each thread computes one output position's convolution. +// +// State is stored in type T to match the op schema convention. + +#include +#include +#include +#include "contrib_ops/cuda/bert/causal_conv_with_state_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +namespace { + +template +__device__ __forceinline__ float to_float(T val); + +template <> +__device__ __forceinline__ float to_float(float val) { return val; } + +template <> +__device__ __forceinline__ float to_float(half val) { return __half2float(val); } + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ __forceinline__ float to_float(__nv_bfloat16 val) { return __bfloat162float(val); } +#endif + +template +__device__ __forceinline__ T from_float(float val); + +template <> +__device__ __forceinline__ float from_float(float val) { return val; } + +template <> +__device__ __forceinline__ half from_float(float val) { return __float2half(val); } + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ __forceinline__ __nv_bfloat16 from_float(float val) { return __float2bfloat16(val); } +#endif + +__device__ __forceinline__ float silu_fn(float x) { + return x / (1.0f + expf(-x)); +} + +// ============================================================================= +// Decode kernel: L=1, one dot product per (batch, channel) +// Grid: (batch_size * channels, 1, 1) +// Block: (1, 1, 1) — one thread per (batch, channel) +// No shared memory needed. +// ============================================================================= +template +__global__ void CausalConvDecodeKernel( + const T* __restrict__ input, // [B, C, 1] + const T* __restrict__ weight, // [C, 1, K] + const T* __restrict__ bias, // [C] or nullptr + const T* __restrict__ past_state, // [B, C, K-1] or nullptr + T* __restrict__ output, // [B, C, 1] + T* __restrict__ present_state, // [B, C, K-1] + int batch_channels, // = batch_size * channels (actual element count) + int channels, + int kernel_size, + bool apply_silu) { + const int bc = blockIdx.x * blockDim.x + threadIdx.x; + if (bc >= batch_channels) return; + const int b = bc / channels; + const int c = bc % channels; + + const int pad = kernel_size - 1; + + // Cache input value in register — avoids redundant global reads + const float input_val = to_float(input[(int64_t)b * channels + c]); + + // Cache past_state base pointer for this (b, c) + const T* ps_in = (past_state != nullptr) + ? past_state + (int64_t)b * channels * pad + (int64_t)c * pad + : nullptr; + + // Load weight for this channel: [K] values + // weight layout: [C, 1, K], so channel c starts at c * K + float sum = (bias != nullptr) ? to_float(bias[c]) : 0.0f; + + // Convolution window: [past_state[0..K-2], input[0]] + for (int k = 0; k < pad; ++k) { + float wk = to_float(weight[c * kernel_size + k]); + float xk = (ps_in != nullptr) ? to_float(ps_in[k]) : 0.0f; + sum += wk * xk; + } + // Last element of window is current input + sum += to_float(weight[c * kernel_size + pad]) * input_val; + + if (apply_silu) { + sum = silu_fn(sum); + } + output[(int64_t)b * channels + c] = from_float(sum); + + // Update present_state: shift left by 1, append input + T* ps_out = present_state + (int64_t)b * channels * pad + (int64_t)c * pad; + for (int k = 0; k < pad - 1; ++k) { + ps_out[k] = (ps_in != nullptr) ? ps_in[k + 1] : from_float(0.0f); + } + if (pad > 0) { + ps_out[pad - 1] = from_float(input_val); + } +} + +template +__global__ void CausalConvDecodeKernelFixedK( + const T* __restrict__ input, + const T* __restrict__ weight, + const T* __restrict__ bias, + const T* __restrict__ past_state, + T* __restrict__ output, + T* __restrict__ present_state, + int batch_channels, + int channels, + bool apply_silu) { + const int bc = blockIdx.x * blockDim.x + threadIdx.x; + if (bc >= batch_channels) return; + + const int b = bc / channels; + const int c = bc % channels; + constexpr int pad = K - 1; + + float sum = (bias != nullptr) ? to_float(bias[c]) : 0.0f; + const T* w = weight + static_cast(c) * K; + const T* ps_in = (past_state != nullptr) + ? past_state + static_cast(b) * channels * pad + static_cast(c) * pad + : nullptr; + + if (ps_in != nullptr) { +#pragma unroll + for (int k = 0; k < pad; ++k) { + sum += to_float(w[k]) * to_float(ps_in[k]); + } + } + sum += to_float(w[pad]) * to_float(input[static_cast(b) * channels + c]); + + if (apply_silu) { + sum = silu_fn(sum); + } + output[static_cast(b) * channels + c] = from_float(sum); + + T* ps_out = present_state + static_cast(b) * channels * pad + static_cast(c) * pad; + if constexpr (pad > 0) { +#pragma unroll + for (int k = 0; k < pad - 1; ++k) { + ps_out[k] = (ps_in != nullptr) ? ps_in[k + 1] : from_float(0.0f); + } + ps_out[pad - 1] = input[static_cast(b) * channels + c]; + } +} + +// ============================================================================= +// Prefill kernel: L>1, one thread per output position within a (batch, channel) +// Grid: (batch_size, channels, 1) +// Block: (min(L, max_threads), 1, 1) +// Shared memory: padded input buffer [K-1 + L] floats + weight [K] floats +// ============================================================================= +template +__global__ void CausalConvPrefillKernel( + const T* __restrict__ input, // [B, C, L] + const T* __restrict__ weight, // [C, 1, K] + const T* __restrict__ bias, // [C] or nullptr + const T* __restrict__ past_state, // [B, C, K-1] or nullptr + T* __restrict__ output, // [B, C, L] + T* __restrict__ present_state, // [B, C, K-1] + int seq_len, + int channels, + int kernel_size, + bool apply_silu) { + const int b = blockIdx.x; + const int c = blockIdx.y; + const int tid = threadIdx.x; + + const int pad = kernel_size - 1; + const int padded_len = pad + seq_len; + + // Shared memory: padded input [pad + L] floats + weight [K] floats + extern __shared__ float smem[]; + float* s_padded = smem; + float* s_weight = smem + padded_len; + + // Cooperatively load padded input into shared memory + // Past state portion: [0..pad-1] + for (int i = tid; i < pad; i += blockDim.x) { + if (past_state != nullptr) { + s_padded[i] = to_float(past_state[(int64_t)b * channels * pad + (int64_t)c * pad + i]); + } else { + s_padded[i] = 0.0f; + } + } + // Current input portion: [pad..pad+L-1] + for (int i = tid; i < seq_len; i += blockDim.x) { + s_padded[pad + i] = to_float(input[((int64_t)b * channels + c) * seq_len + i]); + } + // Load weight into shared memory + for (int i = tid; i < kernel_size; i += blockDim.x) { + s_weight[i] = to_float(weight[(int64_t)c * kernel_size + i]); + } + __syncthreads(); + + // Each thread computes one output position + float bias_val = (bias != nullptr) ? to_float(bias[c]) : 0.0f; + for (int l = tid; l < seq_len; l += blockDim.x) { + float sum = bias_val; + for (int k = 0; k < kernel_size; ++k) { + sum += s_weight[k] * s_padded[l + k]; + } + if (apply_silu) { + sum = silu_fn(sum); + } + output[((int64_t)b * channels + c) * seq_len + l] = from_float(sum); + } + + // Save present_state: last K-1 elements of padded input + __syncthreads(); + T* ps = present_state + (int64_t)b * channels * pad + (int64_t)c * pad; + for (int i = tid; i < pad; i += blockDim.x) { + ps[i] = from_float(s_padded[padded_len - pad + i]); + } +} + +// ============================================================================= +// Batched prefill kernel: processes CHANNELS_PER_BLOCK channels per block +// to improve occupancy when per-channel work is small (short sequences). +// +// Grid: (batch_size, ceil(channels / CPB), 1) +// Block: (threads, 1, 1) — threads are split across CPB channels +// +// Each channel gets (blockDim.x / CPB) threads. Weight is loaded into +// registers (small K), input+state goes through shared memory. +// ============================================================================= +template +__global__ void CausalConvPrefillKernelBatched( + const T* __restrict__ input, // [B, C, L] + const T* __restrict__ weight, // [C, 1, K] + const T* __restrict__ bias, // [C] or nullptr + const T* __restrict__ past_state, // [B, C, K-1] or nullptr + T* __restrict__ output, // [B, C, L] + T* __restrict__ present_state, // [B, C, K-1] + int seq_len, + int channels, + int kernel_size, + bool apply_silu) { + const int b = blockIdx.x; + const int c_base = blockIdx.y * CPB; + const int tid = threadIdx.x; + + const int pad = kernel_size - 1; + const int padded_len = pad + seq_len; + + // Which channel within this block's CPB group does this thread serve? + const int threads_per_channel = blockDim.x / CPB; + const int local_ch = tid / threads_per_channel; // 0..CPB-1 + const int local_tid = tid % threads_per_channel; // thread index within channel + const int c = c_base + local_ch; + + // Shared memory: CPB * (padded_len + kernel_size) floats + extern __shared__ float smem[]; + const int smem_per_ch = padded_len + kernel_size; + float* s_padded = smem + local_ch * smem_per_ch; + float* s_weight = s_padded + padded_len; + + if (c < channels) { + // Load past state + for (int i = local_tid; i < pad; i += threads_per_channel) { + if (past_state != nullptr) { + s_padded[i] = to_float(past_state[(int64_t)b * channels * pad + (int64_t)c * pad + i]); + } else { + s_padded[i] = 0.0f; + } + } + // Load input + for (int i = local_tid; i < seq_len; i += threads_per_channel) { + s_padded[pad + i] = to_float(input[((int64_t)b * channels + c) * seq_len + i]); + } + // Load weight + for (int i = local_tid; i < kernel_size; i += threads_per_channel) { + s_weight[i] = to_float(weight[(int64_t)c * kernel_size + i]); + } + } + __syncthreads(); + + if (c < channels) { + float bias_val = (bias != nullptr) ? to_float(bias[c]) : 0.0f; + for (int l = local_tid; l < seq_len; l += threads_per_channel) { + float sum = bias_val; + for (int k = 0; k < kernel_size; ++k) { + sum += s_weight[k] * s_padded[l + k]; + } + if (apply_silu) { + sum = silu_fn(sum); + } + output[((int64_t)b * channels + c) * seq_len + l] = from_float(sum); + } + } + + // Unconditional barrier — s_padded is read-only after the cooperative load, + // so this is safe even when c >= channels. Hoisted out of the conditional + // to avoid divergent __syncthreads() (undefined behavior in CUDA). + __syncthreads(); + + if (c < channels) { + // Save present state + T* ps = present_state + (int64_t)b * channels * pad + (int64_t)c * pad; + for (int i = local_tid; i < pad; i += threads_per_channel) { + ps[i] = from_float(s_padded[padded_len - pad + i]); + } + } +} + +} // anonymous namespace + +template +Status LaunchCausalConvWithStateKernel( + cudaStream_t stream, + const T* input, + const T* weight, + const T* bias, + const T* past_state, + T* output, + T* present_state, + int batch_size, + int channels, + int seq_len, + int kernel_size, + bool apply_silu, + int max_threads_per_block) { + if (seq_len == 1) { + // Decode fast-path: one thread per (batch, channel) + int total = batch_size * channels; + int threads = 256; + int blocks = (total + threads - 1) / threads; + switch (kernel_size) { + case 2: + CausalConvDecodeKernelFixedK<<>>( + input, weight, bias, past_state, output, present_state, + total, channels, apply_silu); + break; + case 3: + CausalConvDecodeKernelFixedK<<>>( + input, weight, bias, past_state, output, present_state, + total, channels, apply_silu); + break; + case 4: + CausalConvDecodeKernelFixedK<<>>( + input, weight, bias, past_state, output, present_state, + total, channels, apply_silu); + break; + case 5: + CausalConvDecodeKernelFixedK<<>>( + input, weight, bias, past_state, output, present_state, + total, channels, apply_silu); + break; + default: + CausalConvDecodeKernel<<>>( + input, weight, bias, past_state, output, present_state, + total, channels, kernel_size, apply_silu); + break; + } + } else { + // Prefill path: choose between batched (short seq) or single-channel (long seq) kernel + int pad = kernel_size - 1; + + // For short sequences, batch multiple channels per block to improve occupancy. + // CPB=4: each block handles 4 channels, reducing block count by 4x. + // Threshold: use batched when seq_len <= 128 (small per-channel work). + constexpr int CPB = 4; + if (seq_len <= 128 && channels >= CPB) { + int channel_blocks = (channels + CPB - 1) / CPB; + const dim3 grid(batch_size, channel_blocks, 1); + // Each channel gets threads/CPB threads + int threads_per_ch = std::min(seq_len, max_threads_per_block / CPB); + threads_per_ch = ((threads_per_ch + 31) / 32) * 32; + if (threads_per_ch < 32) threads_per_ch = 32; + int total_threads = threads_per_ch * CPB; + if (total_threads > max_threads_per_block) { + total_threads = (max_threads_per_block / CPB) * CPB; // round down to multiple of CPB + } + const dim3 block(total_threads, 1, 1); + size_t smem_size = static_cast(CPB) * (static_cast(pad + seq_len) + kernel_size) * sizeof(float); + + // Request extended shared memory if needed (default limit is 48 KB) + if (smem_size > 48 * 1024) { + cudaError_t attr_err = cudaFuncSetAttribute( + CausalConvPrefillKernelBatched, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(smem_size)); + if (attr_err != cudaSuccess) { + return CUDA_CALL(attr_err); + } + } + + CausalConvPrefillKernelBatched<<>>( + input, weight, bias, past_state, output, present_state, + seq_len, channels, kernel_size, apply_silu); + } else { + // Original single-channel-per-block path for long sequences + const dim3 grid(batch_size, channels, 1); + int threads = std::min(seq_len, max_threads_per_block); + threads = ((threads + 31) / 32) * 32; // round to warp + if (threads > max_threads_per_block) threads = max_threads_per_block; + const dim3 block(threads, 1, 1); + + size_t smem_size = (static_cast(pad + seq_len) + kernel_size) * sizeof(float); + + // Request extended shared memory if needed (default limit is 48 KB) + if (smem_size > 48 * 1024) { + cudaError_t attr_err = cudaFuncSetAttribute( + CausalConvPrefillKernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(smem_size)); + if (attr_err != cudaSuccess) { + return CUDA_CALL(attr_err); + } + } + + CausalConvPrefillKernel<<>>( + input, weight, bias, past_state, output, present_state, + seq_len, channels, kernel_size, apply_silu); + } + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Explicit instantiations +template Status LaunchCausalConvWithStateKernel( + cudaStream_t, const float*, const float*, const float*, const float*, + float*, float*, int, int, int, int, bool, int); + +template Status LaunchCausalConvWithStateKernel( + cudaStream_t, const half*, const half*, const half*, const half*, + half*, half*, int, int, int, int, bool, int); + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template Status LaunchCausalConvWithStateKernel<__nv_bfloat16>( + cudaStream_t, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, + __nv_bfloat16*, __nv_bfloat16*, int, int, int, int, bool, int); +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h new file mode 100644 index 0000000000000..4427a1df1fd6d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Fused causal depthwise conv1d + activation + state management. +// One thread block per (batch, channel). For decode (L=1), this is a simple +// dot product from shared memory. For prefill (L>1), each thread handles +// one output position. +template +Status LaunchCausalConvWithStateKernel( + cudaStream_t stream, + const T* input, // [B, C, L] + const T* weight, // [C, 1, K] + const T* bias, // [C] or nullptr + const T* past_state, // [B, C, K-1] or nullptr + T* output, // [B, C, L] + T* present_state, // [B, C, K-1] + int batch_size, + int channels, + int seq_len, + int kernel_size, + bool apply_silu, + int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention.cc b/onnxruntime/contrib_ops/cuda/bert/linear_attention.cc new file mode 100644 index 0000000000000..c8f460b0ca002 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention.cc @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/linear_attention.h" +#include "contrib_ops/cuda/bert/linear_attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; // CudaKernel, Stream, GetDeviceProp, ToCudaType + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + LinearAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + LinearAttention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +LinearAttention::LinearAttention(const OpKernelInfo& info) : CudaKernel(info) { + int64_t q_num_heads = 0; + ORT_ENFORCE(info.GetAttr("q_num_heads", &q_num_heads).IsOK() && q_num_heads > 0); + q_num_heads_ = static_cast(q_num_heads); + + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0); + kv_num_heads_ = static_cast(kv_num_heads); + + update_rule_ = info.GetAttrOrDefault("update_rule", "gated_delta"); + ORT_ENFORCE(update_rule_ == "linear" || update_rule_ == "gated" || + update_rule_ == "delta" || update_rule_ == "gated_delta", + "update_rule must be one of: linear, gated, delta, gated_delta"); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + + int64_t chunk_size = info.GetAttrOrDefault("chunk_size", 64); + // chunk_size_ reserved for future chunk-parallel prefill algorithm; not yet used. + chunk_size_ = static_cast(chunk_size); +} + +template +Status LinearAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query_tensor = context->Input(0); + const Tensor* key_tensor = context->Input(1); // optional + const Tensor* value_tensor = context->Input(2); // optional + const Tensor* past_state_tensor = context->Input(3); // optional + const Tensor* decay_tensor = context->Input(4); // optional + const Tensor* beta_tensor = context->Input(5); // optional + + ORT_RETURN_IF_NOT(query_tensor != nullptr, "query input is required"); + + const auto& query_shape = query_tensor->Shape(); + ORT_RETURN_IF_NOT(query_shape.NumDimensions() == 3, "query must be 3D"); + + const int batch_size = static_cast(query_shape[0]); + const int seq_len = static_cast(query_shape[1]); + const int query_hidden = static_cast(query_shape[2]); + + ORT_RETURN_IF_NOT(key_tensor != nullptr && value_tensor != nullptr, "key and value inputs are required"); + + const auto& key_shape = key_tensor->Shape(); + const auto& value_shape = value_tensor->Shape(); + + int d_k = query_hidden / q_num_heads_; + int d_v = static_cast(value_shape[2]) / kv_num_heads_; + ORT_ENFORCE(static_cast(key_shape[2]) % d_k == 0, + "key last dim (", key_shape[2], ") must be divisible by d_k (", d_k, ")"); + int n_k_heads = static_cast(key_shape[2]) / d_k; + + // GQA head mapping validations + if (q_num_heads_ >= kv_num_heads_) { + ORT_ENFORCE(q_num_heads_ % kv_num_heads_ == 0, + "q_num_heads must be divisible by kv_num_heads"); + } else { + ORT_ENFORCE(kv_num_heads_ % q_num_heads_ == 0, + "kv_num_heads must be divisible by q_num_heads (inverse GQA)"); + } + ORT_ENFORCE(kv_num_heads_ % n_k_heads == 0, + "kv_num_heads must be divisible by n_k_heads"); + + float s = scale_; + if (s == 0.0f) { + s = 1.0f / std::sqrt(static_cast(d_k)); + } + + bool needs_decay = (update_rule_ == "gated" || update_rule_ == "gated_delta"); + bool needs_beta = (update_rule_ == "delta" || update_rule_ == "gated_delta"); + bool needs_retrieval = (update_rule_ == "delta" || update_rule_ == "gated_delta"); + + ORT_ENFORCE(!needs_decay || decay_tensor != nullptr, + "decay input is required for update_rule=", update_rule_); + ORT_ENFORCE(!needs_beta || beta_tensor != nullptr, + "beta input is required for update_rule=", update_rule_); + + bool decay_per_key_dim = false; + if (decay_tensor != nullptr) { + int64_t decay_last = decay_tensor->Shape()[2]; + decay_per_key_dim = (decay_last == kv_num_heads_ * d_k); + } + + bool beta_per_head = false; + if (beta_tensor != nullptr) { + int64_t beta_last = beta_tensor->Shape()[2]; + beta_per_head = (beta_last == kv_num_heads_); + } + + // Allocate outputs + int output_hidden = std::max(q_num_heads_, kv_num_heads_) * d_v; + TensorShape output_shape({batch_size, seq_len, output_hidden}); + Tensor* output_tensor = context->Output(0, output_shape); + + TensorShape state_shape({batch_size, kv_num_heads_, d_k, d_v}); + Tensor* present_state_tensor = context->Output(1, state_shape); + + // If past_state is nullptr, zero-init present_state on device + if (past_state_tensor == nullptr) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + present_state_tensor->MutableData(), 0, + static_cast(batch_size) * kv_num_heads_ * d_k * d_v * sizeof(T), + Stream(context))); + } else { + // Validate past_state shape matches expected (B, H_kv, d_k, d_v) + const auto& past_shape = past_state_tensor->Shape(); + ORT_ENFORCE(past_shape.NumDimensions() == 4, + "past_state must be rank 4 (B, H_kv, d_k, d_v), got rank ", past_shape.NumDimensions()); + ORT_ENFORCE(past_shape[0] == batch_size && past_shape[1] == kv_num_heads_ && + past_shape[2] == d_k && past_shape[3] == d_v, + "past_state shape mismatch: expected (", batch_size, ", ", kv_num_heads_, ", ", d_k, ", ", d_v, + "), got (", past_shape[0], ", ", past_shape[1], ", ", past_shape[2], ", ", past_shape[3], ")"); + // Copy past_state -> present_state (will be updated in-place by kernel) + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_state_tensor->MutableData(), + past_state_tensor->Data(), + static_cast(batch_size) * kv_num_heads_ * d_k * d_v * sizeof(T), + cudaMemcpyDeviceToDevice, + Stream(context))); + } + + typedef typename OrtToCudaType::type CudaT; + + return LaunchLinearAttentionKernel( + Stream(context), + reinterpret_cast(query_tensor->Data()), + reinterpret_cast(key_tensor->Data()), + reinterpret_cast(value_tensor->Data()), + decay_tensor ? reinterpret_cast(decay_tensor->Data()) : nullptr, + beta_tensor ? reinterpret_cast(beta_tensor->Data()) : nullptr, + reinterpret_cast(output_tensor->MutableData()), + reinterpret_cast(present_state_tensor->MutableData()), + batch_size, + seq_len, + q_num_heads_, + kv_num_heads_, + n_k_heads, + d_k, + d_v, + s, + needs_decay, + decay_per_key_dim, + needs_beta, + beta_per_head, + needs_retrieval, + GetDeviceProp().maxThreadsPerBlock); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention.h b/onnxruntime/contrib_ops/cuda/bert/linear_attention.h new file mode 100644 index 0000000000000..ed398218771d0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +class LinearAttention final : public onnxruntime::cuda::CudaKernel { + public: + LinearAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int q_num_heads_; + int kv_num_heads_; + std::string update_rule_; + float scale_; + int chunk_size_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu new file mode 100644 index 0000000000000..9137fcf9b25b9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu @@ -0,0 +1,781 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Fused recurrent linear attention CUDA kernel for gated_delta / delta / gated / linear update rules. +// +// Design: One thread block per (batch, kv_head). The state matrix [d_k, d_v] is loaded into +// shared memory at the start and kept there for the entire token loop. Each token's +// decay → retrieval → delta → update → readout sequence runs without global memory +// round-trips for the state. This matches the FLA (flash-linear-attention) kernel design. +// +// State tiles: For d_k=128, d_v=128, fp32 state = 64 KB shared memory. On SM80+ GPUs with +// 164 KB shared memory per SM, this fits with room for scratch. Requires +// cudaFuncSetAttribute to opt into extended shared memory (>48 KB). +// +// Thread mapping: num_threads = max(d_k, d_v) rounded to warp boundary. Each thread +// participates in both row operations (decay/update: tid < d_k handles row tid) and +// column operations (retrieval/readout: tid < d_v computes column tid's dot product). +// +// Reductions: Matrix-vector products (S^T @ k, S^T @ q) use column-per-thread dot products +// instead of atomicAdd, eliminating contention. Each thread tid computes +// sum_i(S[i, tid] * scalar[i]) by reading shared memory column-wise (bank-conflict-free +// when d_v is a multiple of 32). + +#include +#include +#include +#include +#include "contrib_ops/cuda/bert/linear_attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +namespace { + +// Convert half/bfloat16 to float +template +__device__ __forceinline__ float to_float(T val); + +template <> +__device__ __forceinline__ float to_float(float val) { return val; } + +template <> +__device__ __forceinline__ float to_float(half val) { return __half2float(val); } + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ __forceinline__ float to_float(__nv_bfloat16 val) { return __bfloat162float(val); } +#endif + +// Convert float to half/bfloat16/float +template +__device__ __forceinline__ T from_float(float val); + +template <> +__device__ __forceinline__ float from_float(float val) { return val; } + +template <> +__device__ __forceinline__ half from_float(float val) { return __float2half(val); } + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ __forceinline__ __nv_bfloat16 from_float(float val) { return __float2bfloat16(val); } +#endif + +// ============================================================================= +// Fused recurrent linear attention kernel +// +// Grid: (batch_size, kv_num_heads, 1) +// Block: (max(d_k, d_v) rounded to warp, 1, 1) +// +// Shared memory layout (dynamic): +// float S_smem[d_k * d_v] — recurrent state matrix (fp32) +// float s_scratch[max(d_k, d_v)] — broadcast/retrieval/delta buffer +// +// State is stored as type T in global memory but computed in fp32 in shared +// memory for numerical stability. +// ============================================================================= +template +__global__ void LinearAttentionRecurrentKernel( + const T* __restrict__ query, // [B, T, H_q * d_k] + const T* __restrict__ key, // [B, T, n_k * d_k] + const T* __restrict__ value, // [B, T, H_kv * d_v] + T* __restrict__ state, // [B, H_kv, d_k, d_v] — in-place updated + const T* __restrict__ decay, // [B, T, H_kv] or [B, T, H_kv*d_k] or nullptr + const T* __restrict__ beta_in, // [B, T, H_kv] or [B, T, 1] or nullptr + T* __restrict__ output, // [B, T, max(H_q, H_kv) * d_v] + int seq_len, + int q_num_heads, + int kv_num_heads, + int n_k_heads, + int d_k, + int d_v, + int output_hidden, + float scale, + bool needs_decay, + bool decay_per_key_dim, + bool needs_beta, + bool beta_per_head, + bool needs_retrieval) { + const int b = blockIdx.x; + const int h_kv = blockIdx.y; + const int tid = threadIdx.x; + const int num_threads = blockDim.x; + const int kv_per_k = kv_num_heads / n_k_heads; + const int h_k = h_kv / kv_per_k; + + // Global state pointer for this (batch, head): [d_k, d_v] + T* S_global = state + ((int64_t)b * kv_num_heads + h_kv) * d_k * d_v; + + // Shared memory layout + extern __shared__ float smem[]; + float* S_smem = smem; // [d_k * d_v] + float* k_buf = smem + d_k * d_v; // [d_k] + float* s_scratch = smem + d_k * d_v + d_k; // [max(d_k, d_v)] + + // Load state from global memory (type T) into shared memory (fp32) + for (int idx = tid; idx < d_k * d_v; idx += num_threads) { + S_smem[idx] = to_float(S_global[idx]); + } + __syncthreads(); + + // ---- Token loop ---- + for (int t = 0; t < seq_len; ++t) { + // Load k_t[tid] into register (each thread loads one element) + float kt_val = 0.0f; + if (tid < d_k) { + int k_offset = ((int64_t)b * seq_len + t) * (n_k_heads * d_k) + h_k * d_k + tid; + kt_val = to_float(key[k_offset]); + } + + // Steps 1+2: Decay + Retrieval (fused for scalar per-head decay) + bool fused_decay_update = false; + float fused_exp_g = 1.0f; + + if (needs_decay && needs_retrieval && !decay_per_key_dim) { + if (tid < d_k) { + k_buf[tid] = kt_val; + } + if (tid == 0) { + int g_offset = ((int64_t)b * seq_len + t) * kv_num_heads + h_kv; + s_scratch[0] = expf(to_float(decay[g_offset])); + } + __syncthreads(); + + fused_exp_g = s_scratch[0]; + + if (tid < d_v) { + float acc = 0.0f; + for (int i = 0; i < d_k; ++i) { + acc += S_smem[i * d_v + tid] * k_buf[i]; + } + s_scratch[tid] = fused_exp_g * acc; + } + __syncthreads(); + + fused_decay_update = true; + + } else { + // Non-fused path: separate decay and retrieval steps + if (needs_decay) { + if (!decay_per_key_dim) { + if (tid == 0) { + int g_offset = ((int64_t)b * seq_len + t) * kv_num_heads + h_kv; + s_scratch[0] = expf(to_float(decay[g_offset])); + } + __syncthreads(); + } + if (tid < d_k) { + float exp_g; + if (decay_per_key_dim) { + int g_offset = ((int64_t)b * seq_len + t) * (kv_num_heads * d_k) + h_kv * d_k + tid; + exp_g = expf(to_float(decay[g_offset])); + } else { + exp_g = s_scratch[0]; + } + for (int j = 0; j < d_v; ++j) { + S_smem[tid * d_v + j] *= exp_g; + } + } + __syncthreads(); + } + + if (needs_retrieval) { + // Store k in k_buf (not s_scratch) to avoid inter-warp race when + // d_k > 32: retrieval overwrites s_scratch[tid] while other warps + // may still be reading s_scratch[i] in the dot product loop. + if (tid < d_k) { + k_buf[tid] = kt_val; + } + __syncthreads(); + + if (tid < d_v) { + float acc = 0.0f; + for (int i = 0; i < d_k; ++i) { + acc += S_smem[i * d_v + tid] * k_buf[i]; + } + s_scratch[tid] = acc; + } + __syncthreads(); + } + } + + // Step 3: State update — S += k_t ⊗ delta (or k_t ⊗ v_t for linear) + // When fused_decay_update, applies: S = exp_g * S + k * delta + if (needs_beta) { + float bt; + if (beta_per_head) { + bt = to_float(beta_in[((int64_t)b * seq_len + t) * kv_num_heads + h_kv]); + } else { + bt = to_float(beta_in[((int64_t)b * seq_len + t)]); + } + + if (tid < d_v) { + int v_base = ((int64_t)b * seq_len + t) * (kv_num_heads * d_v) + h_kv * d_v; + float vj = to_float(value[v_base + tid]); + s_scratch[tid] = bt * (vj - s_scratch[tid]); + } + __syncthreads(); + + if (tid < d_k) { + if (fused_decay_update) { + for (int j = 0; j < d_v; ++j) { + S_smem[tid * d_v + j] = fused_exp_g * S_smem[tid * d_v + j] + kt_val * s_scratch[j]; + } + } else { + for (int j = 0; j < d_v; ++j) { + S_smem[tid * d_v + j] += kt_val * s_scratch[j]; + } + } + } + } else { + if (tid < d_v) { + int v_base = ((int64_t)b * seq_len + t) * (kv_num_heads * d_v) + h_kv * d_v; + s_scratch[tid] = to_float(value[v_base + tid]); + } + __syncthreads(); + + if (tid < d_k) { + if (fused_decay_update) { + for (int j = 0; j < d_v; ++j) { + S_smem[tid * d_v + j] = fused_exp_g * S_smem[tid * d_v + j] + kt_val * s_scratch[j]; + } + } else { + for (int j = 0; j < d_v; ++j) { + S_smem[tid * d_v + j] += kt_val * s_scratch[j]; + } + } + } + } + __syncthreads(); + + // Step 4: Query readout — output = S^T @ q_t (standard GQA or inverse GQA) + if (q_num_heads >= kv_num_heads) { + int heads_per_group = q_num_heads / kv_num_heads; + for (int g = 0; g < heads_per_group; ++g) { + if (g > 0) { + __syncthreads(); + } + + int h_q = h_kv * heads_per_group + g; + if (tid < d_k) { + int q_offset = ((int64_t)b * seq_len + t) * (q_num_heads * d_k) + h_q * d_k + tid; + s_scratch[tid] = to_float(query[q_offset]); + } + __syncthreads(); + + if (tid < d_v) { + float acc = 0.0f; + for (int i = 0; i < d_k; ++i) { + acc += S_smem[i * d_v + tid] * s_scratch[i]; + } + int out_offset = ((int64_t)b * seq_len + t) * output_hidden + h_q * d_v + tid; + output[out_offset] = from_float(scale * acc); + } + } + } else { + int h_q = h_kv * q_num_heads / kv_num_heads; + if (tid < d_k) { + int q_offset = ((int64_t)b * seq_len + t) * (q_num_heads * d_k) + h_q * d_k + tid; + s_scratch[tid] = to_float(query[q_offset]); + } + __syncthreads(); + + if (tid < d_v) { + float acc = 0.0f; + for (int i = 0; i < d_k; ++i) { + acc += S_smem[i * d_v + tid] * s_scratch[i]; + } + int out_offset = ((int64_t)b * seq_len + t) * output_hidden + h_kv * d_v + tid; + output[out_offset] = from_float(scale * acc); + } + } + + __syncthreads(); + } + + // Write back state from shared memory (fp32) to global memory (type T) + for (int idx = tid; idx < d_k * d_v; idx += num_threads) { + S_global[idx] = from_float(S_smem[idx]); + } +} + +// Compile-time specialized variant for common (d_k, d_v) pairs. +// Optimizations over the generic kernel: +// 1. #pragma unroll on all inner loops for better ILP +// 2. float4 vectorized row operations (decay, state update) — 4x fewer shared memory transactions +// 3. Fused decay+retrieval for scalar per-head decay — eliminates one state pass and one __syncthreads() +// 4. Dedicated k_buf in shared memory avoids scratch aliasing during fused path +template +__global__ void LinearAttentionRecurrentKernelFixedShape( + const T* __restrict__ query, + const T* __restrict__ key, + const T* __restrict__ value, + T* __restrict__ state, + const T* __restrict__ decay, + const T* __restrict__ beta_in, + T* __restrict__ output, + int seq_len, + int q_num_heads, + int kv_num_heads, + int n_k_heads, + int output_hidden, + float scale, + bool needs_decay, + bool decay_per_key_dim, + bool needs_beta, + bool beta_per_head, + bool needs_retrieval) { + static_assert(DV % 4 == 0 && DK % 4 == 0, "DK and DV must be multiples of 4 for float4 optimization"); + constexpr int DV4 = DV / 4; + + const int b = blockIdx.x; + const int h_kv = blockIdx.y; + const int tid = threadIdx.x; + const int kv_per_k = kv_num_heads / n_k_heads; + const int h_k = h_kv / kv_per_k; + + T* S_global = state + ((int64_t)b * kv_num_heads + h_kv) * DK * DV; + + // Shared memory layout: + // S_smem[DK * DV] — recurrent state matrix (fp32) + // k_buf[DK] — persistent key broadcast buffer + // s_scratch[max(DK, DV)] — general scratch (retrieval, delta, query broadcast) + extern __shared__ float smem[]; + float* S_smem = smem; // [DK * DV] + float* k_buf = smem + DK * DV; // [DK] + float* s_scratch = smem + DK * DV + DK; // [max(DK, DV)] + + // Load state from global memory (type T) into shared memory (fp32) — vectorized + if constexpr (sizeof(T) == 2 && DV % 2 == 0) { + // half/bf16: load 2 elements at a time via uint32 + const uint32_t* S_global_u32 = reinterpret_cast(S_global); + int half_pairs = (DK * DV) / 2; + for (int idx = tid; idx < half_pairs; idx += blockDim.x) { + uint32_t packed = S_global_u32[idx]; + T lo, hi; + memcpy(&lo, &packed, sizeof(T)); + memcpy(&hi, reinterpret_cast(&packed) + sizeof(T), sizeof(T)); + S_smem[idx * 2] = to_float(lo); + S_smem[idx * 2 + 1] = to_float(hi); + } + } else { + for (int idx = tid; idx < DK * DV; idx += blockDim.x) { + S_smem[idx] = to_float(S_global[idx]); + } + } + __syncthreads(); + + // Precompute per-batch strides to avoid repeated int64 multiplications in the token loop + const int64_t b_seq = (int64_t)b * seq_len; + const int k_hidden = n_k_heads * DK; + const int kv_v_hidden = kv_num_heads * DV; + const int q_hidden = q_num_heads * DK; + const int kv_dk_hidden = kv_num_heads * DK; + + for (int t = 0; t < seq_len; ++t) { + const int64_t bt = b_seq + t; + + float kt_val = 0.0f; + if (tid < DK) { + kt_val = to_float(key[bt * k_hidden + h_k * DK + tid]); + } + + // ================================================================== + // Steps 1+2: Decay + Retrieval + // ================================================================== + // For the fused scalar-decay + gated_delta path, we also fuse the + // state update (step 3) to avoid a separate decay pass entirely: + // retrieval = exp_g * (S^T @ k) [on old S] + // delta = beta * (v - retrieval) + // S = exp_g * S + k ⊗ delta [single fused pass] + // This reduces 3 state passes (decay, retrieval, update) to 2 (retrieval, fused update). + bool fused_decay_update = false; + float fused_exp_g = 1.0f; + + if (needs_decay && needs_retrieval && !decay_per_key_dim) { + // --- FUSED path: scalar per-head decay + retrieval --- + if (tid < DK) { + k_buf[tid] = kt_val; + } + if (tid == 0) { + s_scratch[0] = expf(to_float(decay[bt * kv_num_heads + h_kv])); + } + __syncthreads(); + + fused_exp_g = s_scratch[0]; + + // Retrieval on old state, pre-scaled by exp_g + if (tid < DV) { + float acc = 0.0f; +#pragma unroll + for (int i = 0; i < DK; ++i) { + acc += S_smem[i * DV + tid] * k_buf[i]; + } + s_scratch[tid] = fused_exp_g * acc; + } + __syncthreads(); + + // Decay is deferred to the update step (fused_decay_update = true) + fused_decay_update = true; + + } else if (needs_decay && needs_retrieval) { + // --- Per-key-dim decay then retrieval (cannot fuse — exp_g differs per row) --- + if (tid < DK) { + k_buf[tid] = kt_val; + float exp_g = expf(to_float(decay[bt * kv_dk_hidden + h_kv * DK + tid])); + float4* row = reinterpret_cast(S_smem + tid * DV); +#pragma unroll + for (int j = 0; j < DV4; ++j) { + float4 v = row[j]; + v.x *= exp_g; + v.y *= exp_g; + v.z *= exp_g; + v.w *= exp_g; + row[j] = v; + } + } + __syncthreads(); // decay done, k_buf visible + + if (tid < DV) { + float acc = 0.0f; +#pragma unroll + for (int i = 0; i < DK; ++i) { + acc += S_smem[i * DV + tid] * k_buf[i]; + } + s_scratch[tid] = acc; + } + __syncthreads(); // retrieval done + + } else { + // --- Decay only, retrieval only, or neither --- + if (needs_decay) { + if (!decay_per_key_dim) { + if (tid == 0) { + s_scratch[0] = expf(to_float(decay[bt * kv_num_heads + h_kv])); + } + __syncthreads(); + } + if (tid < DK) { + float exp_g; + if (decay_per_key_dim) { + exp_g = expf(to_float(decay[bt * kv_dk_hidden + h_kv * DK + tid])); + } else { + exp_g = s_scratch[0]; + } + float4* row = reinterpret_cast(S_smem + tid * DV); +#pragma unroll + for (int j = 0; j < DV4; ++j) { + float4 v = row[j]; + v.x *= exp_g; + v.y *= exp_g; + v.z *= exp_g; + v.w *= exp_g; + row[j] = v; + } + } + __syncthreads(); + } + + if (needs_retrieval) { + if (tid < DK) { + k_buf[tid] = kt_val; + } + __syncthreads(); // k_buf visible + + if (tid < DV) { + float acc = 0.0f; +#pragma unroll + for (int i = 0; i < DK; ++i) { + acc += S_smem[i * DV + tid] * k_buf[i]; + } + s_scratch[tid] = acc; + } + __syncthreads(); // retrieval done + } + } + + // ================================================================== + // Step 3: State update with float4 vectorization + // When fused_decay_update is true, decay is applied here: + // S[i,j] = exp_g * S[i,j] + k[i] * delta[j] + // ================================================================== + if (needs_beta) { + float beta_t; + if (beta_per_head) { + beta_t = to_float(beta_in[bt * kv_num_heads + h_kv]); + } else { + beta_t = to_float(beta_in[bt]); + } + + if (tid < DV) { + float vj = to_float(value[bt * kv_v_hidden + h_kv * DV + tid]); + s_scratch[tid] = beta_t * (vj - s_scratch[tid]); + } + __syncthreads(); + + if (tid < DK) { + float4* row = reinterpret_cast(S_smem + tid * DV); + const float4* delta4 = reinterpret_cast(s_scratch); + if (fused_decay_update) { + // Fused: S = exp_g * S + k * delta (single pass, no separate decay) +#pragma unroll + for (int j = 0; j < DV4; ++j) { + float4 s = row[j]; + float4 d = delta4[j]; + s.x = fused_exp_g * s.x + kt_val * d.x; + s.y = fused_exp_g * s.y + kt_val * d.y; + s.z = fused_exp_g * s.z + kt_val * d.z; + s.w = fused_exp_g * s.w + kt_val * d.w; + row[j] = s; + } + } else { +#pragma unroll + for (int j = 0; j < DV4; ++j) { + float4 s = row[j]; + float4 d = delta4[j]; + s.x += kt_val * d.x; + s.y += kt_val * d.y; + s.z += kt_val * d.z; + s.w += kt_val * d.w; + row[j] = s; + } + } + } + } else { + if (tid < DV) { + s_scratch[tid] = to_float(value[bt * kv_v_hidden + h_kv * DV + tid]); + } + __syncthreads(); + + if (tid < DK) { + float4* row = reinterpret_cast(S_smem + tid * DV); + const float4* v4 = reinterpret_cast(s_scratch); + if (fused_decay_update) { +#pragma unroll + for (int j = 0; j < DV4; ++j) { + float4 s = row[j]; + float4 v = v4[j]; + s.x = fused_exp_g * s.x + kt_val * v.x; + s.y = fused_exp_g * s.y + kt_val * v.y; + s.z = fused_exp_g * s.z + kt_val * v.z; + s.w = fused_exp_g * s.w + kt_val * v.w; + row[j] = s; + } + } else { +#pragma unroll + for (int j = 0; j < DV4; ++j) { + float4 s = row[j]; + float4 v = v4[j]; + s.x += kt_val * v.x; + s.y += kt_val * v.y; + s.z += kt_val * v.z; + s.w += kt_val * v.w; + row[j] = s; + } + } + } + } + __syncthreads(); + + // ================================================================== + // Step 4: Query readout (column dot products — not float4-vectorizable) + // ================================================================== + if (q_num_heads >= kv_num_heads) { + int heads_per_group = q_num_heads / kv_num_heads; + for (int g = 0; g < heads_per_group; ++g) { + if (g > 0) { + __syncthreads(); + } + + int h_q = h_kv * heads_per_group + g; + if (tid < DK) { + s_scratch[tid] = to_float(query[bt * q_hidden + h_q * DK + tid]); + } + __syncthreads(); + + if (tid < DV) { + float acc = 0.0f; +#pragma unroll + for (int i = 0; i < DK; ++i) { + acc += S_smem[i * DV + tid] * s_scratch[i]; + } + output[bt * output_hidden + h_q * DV + tid] = from_float(scale * acc); + } + } + } else { + int h_q = h_kv * q_num_heads / kv_num_heads; + if (tid < DK) { + s_scratch[tid] = to_float(query[bt * q_hidden + h_q * DK + tid]); + } + __syncthreads(); + + if (tid < DV) { + float acc = 0.0f; +#pragma unroll + for (int i = 0; i < DK; ++i) { + acc += S_smem[i * DV + tid] * s_scratch[i]; + } + output[bt * output_hidden + h_kv * DV + tid] = from_float(scale * acc); + } + } + + __syncthreads(); + } + + // Write back state from shared memory (fp32) to global memory (type T) — vectorized + if constexpr (sizeof(T) == 2 && DV % 2 == 0) { + uint32_t* S_global_u32 = reinterpret_cast(S_global); + int half_pairs = (DK * DV) / 2; + for (int idx = tid; idx < half_pairs; idx += blockDim.x) { + T lo = from_float(S_smem[idx * 2]); + T hi = from_float(S_smem[idx * 2 + 1]); + uint32_t packed; + memcpy(&packed, &lo, sizeof(T)); + memcpy(reinterpret_cast(&packed) + sizeof(T), &hi, sizeof(T)); + S_global_u32[idx] = packed; + } + } else if constexpr (sizeof(T) == 4 && DV % 4 == 0) { + float4* S_global_f4 = reinterpret_cast(S_global); + int quads = (DK * DV) / 4; + for (int idx = tid; idx < quads; idx += blockDim.x) { + float4 v; + v.x = S_smem[idx * 4]; + v.y = S_smem[idx * 4 + 1]; + v.z = S_smem[idx * 4 + 2]; + v.w = S_smem[idx * 4 + 3]; + S_global_f4[idx] = v; + } + } else { + for (int idx = tid; idx < DK * DV; idx += blockDim.x) { + S_global[idx] = from_float(S_smem[idx]); + } + } +} + +} // anonymous namespace + +template +Status LaunchLinearAttentionKernel( + cudaStream_t stream, + const T* query, + const T* key, + const T* value, + const T* decay, + const T* beta, + T* output, + T* present_state, + int batch_size, + int seq_len, + int q_num_heads, + int kv_num_heads, + int n_k_heads, + int d_k, + int d_v, + float scale, + bool needs_decay, + bool decay_per_key_dim, + bool needs_beta, + bool beta_per_head, + bool needs_retrieval, + int max_threads_per_block) { + // Grid: one block per (batch, kv_head) + const dim3 grid(batch_size, kv_num_heads, 1); + + int output_hidden = std::max(q_num_heads, kv_num_heads) * d_v; + + auto launch_fixed = [&](auto dk_tag, auto dv_tag) -> Status { + constexpr int DK = decltype(dk_tag)::value; + constexpr int DV = decltype(dv_tag)::value; + constexpr int max_dim = (DK > DV) ? DK : DV; + // Layout: S_smem[DK*DV] + k_buf[DK] + s_scratch[max(DK,DV)] + const size_t fixed_smem_size = (static_cast(DK) * DV + DK + max_dim) * sizeof(float); + const dim3 fixed_block(max_dim, 1, 1); + + if (fixed_smem_size > 48 * 1024) { + cudaError_t attr_err = cudaFuncSetAttribute( + LinearAttentionRecurrentKernelFixedShape, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(fixed_smem_size)); + if (attr_err != cudaSuccess) { + return CUDA_CALL(attr_err); + } + } + + LinearAttentionRecurrentKernelFixedShape<<>>( + query, key, value, present_state, decay, beta, output, + seq_len, q_num_heads, kv_num_heads, n_k_heads, output_hidden, scale, + needs_decay, decay_per_key_dim, needs_beta, beta_per_head, needs_retrieval); + + return CUDA_CALL(cudaGetLastError()); + }; + + // Fast paths for common (d_k, d_v) pairs + if (d_k == 64 && d_v == 64 && max_threads_per_block >= 64) { + return launch_fixed(std::integral_constant{}, std::integral_constant{}); + } + if (d_k == 128 && d_v == 128 && max_threads_per_block >= 128) { + return launch_fixed(std::integral_constant{}, std::integral_constant{}); + } + if (d_k == 128 && d_v == 64 && max_threads_per_block >= 128) { + return launch_fixed(std::integral_constant{}, std::integral_constant{}); + } + if (d_k == 64 && d_v == 128 && max_threads_per_block >= 128) { + return launch_fixed(std::integral_constant{}, std::integral_constant{}); + } + + // Generic fallback + // Block: max(d_k, d_v) threads, rounded up to warp boundary + int threads = ((std::max(d_k, d_v) + 31) / 32) * 32; + if (threads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "LinearAttention: max(d_k=", d_k, ", d_v=", d_v, + ") exceeds max threads per block (", max_threads_per_block, + "). Use a model with smaller head dimensions."); + } + const dim3 block(threads, 1, 1); + + // Shared memory: state[d_k*d_v] + k_buf[d_k] + scratch[max(d_k,d_v)] + size_t smem_size = (static_cast(d_k) * d_v + d_k + std::max(d_k, d_v)) * sizeof(float); + + // Request extended shared memory if needed (default limit is 48 KB) + if (smem_size > 48 * 1024) { + cudaError_t attr_err = cudaFuncSetAttribute( + LinearAttentionRecurrentKernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(smem_size)); + if (attr_err != cudaSuccess) { + return CUDA_CALL(attr_err); + } + } + + LinearAttentionRecurrentKernel<<>>( + query, key, value, present_state, decay, beta, output, + seq_len, q_num_heads, kv_num_heads, n_k_heads, d_k, d_v, output_hidden, scale, + needs_decay, decay_per_key_dim, needs_beta, beta_per_head, needs_retrieval); + + return CUDA_CALL(cudaGetLastError()); +} + +// Explicit instantiations +template Status LaunchLinearAttentionKernel( + cudaStream_t, const float*, const float*, const float*, + const float*, const float*, float*, float*, + int, int, int, int, int, int, int, float, bool, bool, bool, bool, bool, int); + +template Status LaunchLinearAttentionKernel( + cudaStream_t, const half*, const half*, const half*, + const half*, const half*, half*, half*, + int, int, int, int, int, int, int, float, bool, bool, bool, bool, bool, int); + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template Status LaunchLinearAttentionKernel<__nv_bfloat16>( + cudaStream_t, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, + int, int, int, int, int, int, int, float, bool, bool, bool, bool, bool, int); +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.h new file mode 100644 index 0000000000000..c30e081e6cd2f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Fused recurrent kernel for gated_delta update rule (decode path, T tokens sequentially). +// Processes all heads in parallel on GPU; each (batch, kv_head) gets one thread block. +// State is kept in shared memory for the entire token loop to avoid global memory round-trips. +template +Status LaunchLinearAttentionKernel( + cudaStream_t stream, + const T* query, // [B, T, H_q * d_k] + const T* key, // [B, T, n_k * d_k] + const T* value, // [B, T, H_kv * d_v] + const T* decay, // [B, T, H_kv] or [B, T, H_kv * d_k] or nullptr + const T* beta, // [B, T, H_kv] or [B, T, 1] or nullptr + T* output, // [B, T, max(H_q, H_kv) * d_v] + T* present_state, // [B, H_kv, d_k, d_v] -- in-place (caller pre-fills from past) + int batch_size, + int seq_len, + int q_num_heads, + int kv_num_heads, + int n_k_heads, + int d_k, + int d_v, + float scale, + bool needs_decay, + bool decay_per_key_dim, + bool needs_beta, + bool beta_per_head, + bool needs_retrieval, + int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index a1dcb1a203cc8..da7ef35d25052 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -147,6 +147,10 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RotaryEmbedding); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RotaryEmbedding); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, RotaryEmbedding); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GemmaRotaryEmbedding); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, LinearAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, LinearAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, CausalConvWithState); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, CausalConvWithState); #if !defined(DISABLE_GENERATION_OPS) class CUDA_MS_OP_CLASS_NAME(1, Sampling); #endif @@ -404,6 +408,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_GENERATION_OPS) BuildKernelCreateInfo, #endif diff --git a/onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py b/onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py new file mode 100644 index 0000000000000..d3b053f1a094a --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py @@ -0,0 +1,564 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import torch +from onnx import TensorProto, checker, helper + +import onnxruntime + +onnxruntime.preload_dlls() + + +def _has_cuda_ep() -> bool: + return "CUDAExecutionProvider" in onnxruntime.get_available_providers() + + +def _run_onnx(model_bytes: bytes, inputs: dict[str, np.ndarray], provider: str) -> list[np.ndarray]: + session = onnxruntime.InferenceSession(model_bytes, providers=[provider]) + return session.run(None, inputs) + + +def _torch_linear_attention_reference( + query: np.ndarray, + key: np.ndarray, + value: np.ndarray, + past_state: np.ndarray, + decay: np.ndarray | None, + beta: np.ndarray | None, + q_num_heads: int, + kv_num_heads: int, + d_k: int, + d_v: int, + scale: float, + update_rule: str = "gated_delta", +) -> tuple[np.ndarray, np.ndarray]: + q = torch.from_numpy(query) + k = torch.from_numpy(key) + v = torch.from_numpy(value) + s = torch.from_numpy(past_state).clone() + g = torch.from_numpy(decay) if decay is not None else None + b = torch.from_numpy(beta) if beta is not None else None + + batch, seq_len, _ = query.shape + output_heads = max(q_num_heads, kv_num_heads) + output = torch.empty((batch, seq_len, output_heads * d_v), dtype=torch.float32) + + # Detect per-key-dim decay from shape + decay_per_key_dim = g is not None and g.shape[-1] == kv_num_heads * d_k + + for bi in range(batch): + for hk in range(kv_num_heads): + state = s[bi, hk] + for t in range(seq_len): + kt = k[bi, t, hk * d_k : (hk + 1) * d_k] + vt = v[bi, t, hk * d_v : (hk + 1) * d_v] + + if update_rule == "linear": + # S_t = S_{t-1} + k_t ⊗ v_t + state = state + torch.outer(kt, vt) + + elif update_rule == "gated": + # S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t + if decay_per_key_dim: + exp_g = torch.exp(g[bi, t, hk * d_k : (hk + 1) * d_k]) # [d_k] + state = state * exp_g.unsqueeze(1) + torch.outer(kt, vt) + else: + exp_g = torch.exp(g[bi, t, hk]) + state = state * exp_g + torch.outer(kt, vt) + + elif update_rule == "delta": + # S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t) + retrieved = torch.matmul(kt, state) + bt = b[bi, t, 0] + delta = bt * (vt - retrieved) + state = state + torch.outer(kt, delta) + + elif update_rule == "gated_delta": + # S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t) + if decay_per_key_dim: + exp_g = torch.exp(g[bi, t, hk * d_k : (hk + 1) * d_k]) # [d_k] + state = state * exp_g.unsqueeze(1) + else: + exp_g = torch.exp(g[bi, t, hk]) + state = state * exp_g + retrieved = torch.matmul(kt, state) + bt = b[bi, t, 0] + delta = bt * (vt - retrieved) + state = state + torch.outer(kt, delta) + + else: + raise ValueError(f"Unknown update_rule: {update_rule}") + + # Query readout: standard GQA or inverse GQA + if q_num_heads >= kv_num_heads: + heads_per_group = q_num_heads // kv_num_heads + for hg in range(heads_per_group): + hq = hk * heads_per_group + hg + qt = q[bi, t, hq * d_k : (hq + 1) * d_k] + ot = scale * torch.matmul(qt, state) + output[bi, t, hq * d_v : (hq + 1) * d_v] = ot + else: + # Inverse GQA: multiple kv heads map to fewer q heads + hq = hk * q_num_heads // kv_num_heads + qt = q[bi, t, hq * d_k : (hq + 1) * d_k] + ot = scale * torch.matmul(qt, state) + output[bi, t, hk * d_v : (hk + 1) * d_v] = ot + + s[bi, hk] = state + + return output.numpy(), s.numpy() + + +def _torch_causal_conv_reference( + x: np.ndarray, + weight: np.ndarray, + bias: np.ndarray, + past_state: np.ndarray, + activation: str, +) -> tuple[np.ndarray, np.ndarray]: + xt = torch.from_numpy(x) + wt = torch.from_numpy(weight) + bt = torch.from_numpy(bias) + pst = torch.from_numpy(past_state) + + pad = wt.shape[2] - 1 + padded = torch.cat([pst, xt], dim=2) + out = torch.nn.functional.conv1d(padded, wt, bias=bt, stride=1, padding=0, groups=xt.shape[1]) + + if activation in ("silu", "swish"): + out = torch.nn.functional.silu(out) + + present = padded[:, :, -pad:] if pad > 0 else torch.empty((xt.shape[0], xt.shape[1], 0), dtype=torch.float32) + return out.numpy(), present.numpy() + + +def _build_linear_attention_model( + q_num_heads: int, + kv_num_heads: int, + update_rule: str, + scale: float, + decay_per_key_dim: bool = False, +) -> bytes: + node = helper.make_node( + "LinearAttention", + ["query", "key", "value", "past_state", "decay", "beta"], + ["output", "present_state"], + domain="com.microsoft", + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + update_rule=update_rule, + scale=scale, + ) + + # Decay shape: [B, T, H_kv * d_k] for per-key-dim, [B, T, H_kv] for per-head + decay_shape = ["B", "T", "DH"] if decay_per_key_dim else ["B", "T", "H"] + + graph = helper.make_graph( + [node], + "LinearAttentionParity", + [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, ["B", "T", "QH"]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, ["B", "T", "KH"]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, ["B", "T", "VH"]), + helper.make_tensor_value_info("past_state", TensorProto.FLOAT, ["B", "H", "DK", "DV"]), + helper.make_tensor_value_info("decay", TensorProto.FLOAT, decay_shape), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["B", "T", 1]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["B", "T", "OH"]), + helper.make_tensor_value_info("present_state", TensorProto.FLOAT, ["B", "H", "DK", "DV"]), + ], + ) + + model = helper.make_model( + graph, + opset_imports=[ + helper.make_opsetid("", 17), + helper.make_opsetid("com.microsoft", 1), + ], + ir_version=8, + ) + checker.check_model(model) + return model.SerializeToString() + + +def _build_causal_conv_model(activation: str) -> bytes: + node = helper.make_node( + "CausalConvWithState", + ["input", "weight", "bias", "past_state"], + ["output", "present_state"], + domain="com.microsoft", + ndim=1, + activation=activation, + ) + + graph = helper.make_graph( + [node], + "CausalConvWithStateParity", + [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, ["B", "C", "L"]), + helper.make_tensor_value_info("weight", TensorProto.FLOAT, ["C", 1, "K"]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("past_state", TensorProto.FLOAT, ["B", "C", "P"]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["B", "C", "L"]), + helper.make_tensor_value_info("present_state", TensorProto.FLOAT, ["B", "C", "P"]), + ], + ) + + model = helper.make_model( + graph, + opset_imports=[ + helper.make_opsetid("", 17), + helper.make_opsetid("com.microsoft", 1), + ], + ir_version=8, + ) + checker.check_model(model) + return model.SerializeToString() + + +@unittest.skipUnless(_has_cuda_ep(), "CUDAExecutionProvider is required for parity tests") +class TestLinearAttentionCausalConvPyTorchParity(unittest.TestCase): + def _run_linear_attention_test( + self, + update_rule, + q_num_heads, + kv_num_heads, + d_k, + d_v, + batch, + seq_lens, + provider="CUDAExecutionProvider", + decay_per_key_dim=False, + ): + rng = np.random.default_rng(0) + scale = 1.0 / np.sqrt(float(d_k)) + + needs_decay = update_rule in ("gated", "gated_delta") + needs_beta = update_rule in ("delta", "gated_delta") + + model = _build_linear_attention_model( + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + update_rule=update_rule, + scale=scale, + decay_per_key_dim=decay_per_key_dim, + ) + + decay_dim = kv_num_heads * d_k if decay_per_key_dim else kv_num_heads + + for seq_len in seq_lens: + inputs = { + "query": rng.standard_normal((batch, seq_len, q_num_heads * d_k), dtype=np.float32), + "key": rng.standard_normal((batch, seq_len, kv_num_heads * d_k), dtype=np.float32), + "value": rng.standard_normal((batch, seq_len, kv_num_heads * d_v), dtype=np.float32), + "past_state": rng.standard_normal((batch, kv_num_heads, d_k, d_v), dtype=np.float32), + "decay": rng.standard_normal((batch, seq_len, decay_dim), dtype=np.float32) + if needs_decay + else np.zeros((batch, seq_len, decay_dim), dtype=np.float32), + "beta": rng.uniform(0.0, 1.0, size=(batch, seq_len, 1)).astype(np.float32) + if needs_beta + else np.zeros((batch, seq_len, 1), dtype=np.float32), + } + + ort_output, ort_state = _run_onnx(model, inputs, provider) + ref_output, ref_state = _torch_linear_attention_reference( + query=inputs["query"], + key=inputs["key"], + value=inputs["value"], + past_state=inputs["past_state"], + decay=inputs["decay"] if needs_decay else None, + beta=inputs["beta"] if needs_beta else None, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + d_k=d_k, + d_v=d_v, + scale=scale, + update_rule=update_rule, + ) + + output_max_diff = np.max(np.abs(ort_output - ref_output)) + state_max_diff = np.max(np.abs(ort_state - ref_state)) + print( + f"LinearAttention parity ({update_rule}, seq_len={seq_len}, " + f"q={q_num_heads}, kv={kv_num_heads}): " + f"output_max_diff={output_max_diff:.6e}, state_max_diff={state_max_diff:.6e}" + ) + + # Tolerances: gated/gated_delta with multi-step sequences accumulate + # floating-point differences due to exp(g) amplification in the recurrence. + # Use relaxed tolerances for seq_len > 1 with decay-based rules. + if update_rule in ("gated", "gated_delta") and seq_len > 1: + rtol, atol = 2e-2, 5e-2 + else: + rtol, atol = 2e-4, 3e-4 + + np.testing.assert_allclose(ort_output, ref_output, rtol=rtol, atol=atol) + np.testing.assert_allclose(ort_state, ref_state, rtol=rtol, atol=atol) + + def test_linear_attention_gated_delta(self): + """Original test — gated_delta update rule.""" + self._run_linear_attention_test( + "gated_delta", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_linear(self): + """Linear update rule: S_t = S_{t-1} + k_t ⊗ v_t.""" + self._run_linear_attention_test( + "linear", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_gated(self): + """Gated update rule: S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t.""" + self._run_linear_attention_test( + "gated", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_delta(self): + """Delta update rule: S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S^T k_t).""" + self._run_linear_attention_test( + "delta", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_inverse_gqa(self): + """Inverse GQA: kv_num_heads > q_num_heads — exercises the q_num_heads < kv_num_heads readout path.""" + self._run_linear_attention_test( + "gated_delta", q_num_heads=8, kv_num_heads=16, d_k=64, d_v=64, batch=1, seq_lens=(1, 3) + ) + + def test_linear_attention_d128(self): + """d_k=d_v=128 — exercises the 128-thread fixed template path.""" + self._run_linear_attention_test( + "gated_delta", q_num_heads=4, kv_num_heads=2, d_k=128, d_v=128, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_per_key_dim_decay(self): + """Per-key-dim decay: decay shape [B, T, H_kv * d_k] — exercises per-row expf path.""" + self._run_linear_attention_test( + "gated_delta", + q_num_heads=4, + kv_num_heads=2, + d_k=64, + d_v=64, + batch=2, + seq_lens=(1, 5), + decay_per_key_dim=True, + ) + + def test_causal_conv_with_state_pytorch_parity(self): + rng = np.random.default_rng(1) + + batch = 2 + channels = 16 + kernel = 4 + pad = kernel - 1 + + model = _build_causal_conv_model(activation="silu") + + for seq_len in (1, 7): + with self.subTest(seq_len=seq_len): + inputs = { + "input": rng.standard_normal((batch, channels, seq_len), dtype=np.float32), + "weight": rng.standard_normal((channels, 1, kernel), dtype=np.float32), + "bias": rng.standard_normal((channels,), dtype=np.float32), + "past_state": rng.standard_normal((batch, channels, pad), dtype=np.float32), + } + + cuda_output, cuda_state = _run_onnx(model, inputs, "CUDAExecutionProvider") + ref_output, ref_state = _torch_causal_conv_reference( + x=inputs["input"], + weight=inputs["weight"], + bias=inputs["bias"], + past_state=inputs["past_state"], + activation="silu", + ) + + output_max_diff = np.max(np.abs(cuda_output - ref_output)) + state_max_diff = np.max(np.abs(cuda_state - ref_state)) + print( + "CausalConvWithState parity " + f"(seq_len={seq_len}): output_max_diff={output_max_diff:.6e}, " + f"state_max_diff={state_max_diff:.6e}" + ) + + np.testing.assert_allclose(cuda_output, ref_output, rtol=1e-4, atol=2e-4) + np.testing.assert_allclose(cuda_state, ref_state, rtol=1e-5, atol=1e-5) + + def test_causal_conv_with_state_kernel_1(self): + """K=1 edge case: pad=0, zero-size state tensors.""" + rng = np.random.default_rng(2) + batch = 2 + channels = 16 + kernel = 1 + pad = kernel - 1 # = 0 + + model = _build_causal_conv_model(activation="silu") + + for seq_len in (1, 5): + with self.subTest(seq_len=seq_len): + inputs = { + "input": rng.standard_normal((batch, channels, seq_len), dtype=np.float32), + "weight": rng.standard_normal((channels, 1, kernel), dtype=np.float32), + "bias": rng.standard_normal((channels,), dtype=np.float32), + "past_state": np.zeros((batch, channels, pad), dtype=np.float32), + } + + cuda_output, cuda_state = _run_onnx(model, inputs, "CUDAExecutionProvider") + ref_output, ref_state = _torch_causal_conv_reference( + x=inputs["input"], + weight=inputs["weight"], + bias=inputs["bias"], + past_state=inputs["past_state"], + activation="silu", + ) + + np.testing.assert_allclose(cuda_output, ref_output, rtol=1e-4, atol=2e-4) + self.assertEqual(cuda_state.shape[2], 0) # zero-size state + + +class TestLinearAttentionCausalConvCPUParity(unittest.TestCase): + def _run_linear_attention_test(self, update_rule, q_num_heads, kv_num_heads, d_k, d_v, batch, seq_lens): + rng = np.random.default_rng(0) + scale = 1.0 / np.sqrt(float(d_k)) + + needs_decay = update_rule in ("gated", "gated_delta") + needs_beta = update_rule in ("delta", "gated_delta") + + model = _build_linear_attention_model( + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + update_rule=update_rule, + scale=scale, + ) + + for seq_len in seq_lens: + inputs = { + "query": rng.standard_normal((batch, seq_len, q_num_heads * d_k), dtype=np.float32), + "key": rng.standard_normal((batch, seq_len, kv_num_heads * d_k), dtype=np.float32), + "value": rng.standard_normal((batch, seq_len, kv_num_heads * d_v), dtype=np.float32), + "past_state": rng.standard_normal((batch, kv_num_heads, d_k, d_v), dtype=np.float32), + "decay": rng.standard_normal((batch, seq_len, kv_num_heads), dtype=np.float32) + if needs_decay + else np.zeros((batch, seq_len, kv_num_heads), dtype=np.float32), + "beta": rng.uniform(0.0, 1.0, size=(batch, seq_len, 1)).astype(np.float32) + if needs_beta + else np.zeros((batch, seq_len, 1), dtype=np.float32), + } + + cpu_output, cpu_state = _run_onnx(model, inputs, "CPUExecutionProvider") + ref_output, ref_state = _torch_linear_attention_reference( + query=inputs["query"], + key=inputs["key"], + value=inputs["value"], + past_state=inputs["past_state"], + decay=inputs["decay"] if needs_decay else None, + beta=inputs["beta"] if needs_beta else None, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + d_k=d_k, + d_v=d_v, + scale=scale, + update_rule=update_rule, + ) + + np.testing.assert_allclose(cpu_output, ref_output, rtol=2e-4, atol=3e-4) + np.testing.assert_allclose(cpu_state, ref_state, rtol=2e-4, atol=3e-4) + + def test_linear_attention_cpu_gated_delta(self): + self._run_linear_attention_test( + "gated_delta", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_cpu_linear(self): + self._run_linear_attention_test( + "linear", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_cpu_gated(self): + self._run_linear_attention_test( + "gated", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_linear_attention_cpu_delta(self): + self._run_linear_attention_test( + "delta", q_num_heads=4, kv_num_heads=2, d_k=64, d_v=64, batch=2, seq_lens=(1, 5) + ) + + def test_causal_conv_with_state_cpu_pytorch_parity(self): + rng = np.random.default_rng(1) + + batch = 2 + channels = 16 + kernel = 4 + pad = kernel - 1 + + model = _build_causal_conv_model(activation="silu") + + for seq_len in (1, 7): + with self.subTest(seq_len=seq_len): + inputs = { + "input": rng.standard_normal((batch, channels, seq_len), dtype=np.float32), + "weight": rng.standard_normal((channels, 1, kernel), dtype=np.float32), + "bias": rng.standard_normal((channels,), dtype=np.float32), + "past_state": rng.standard_normal((batch, channels, pad), dtype=np.float32), + } + + cpu_output, cpu_state = _run_onnx(model, inputs, "CPUExecutionProvider") + ref_output, ref_state = _torch_causal_conv_reference( + x=inputs["input"], + weight=inputs["weight"], + bias=inputs["bias"], + past_state=inputs["past_state"], + activation="silu", + ) + + output_max_diff = np.max(np.abs(cpu_output - ref_output)) + state_max_diff = np.max(np.abs(cpu_state - ref_state)) + print( + "CausalConvWithState CPU parity " + f"(seq_len={seq_len}): output_max_diff={output_max_diff:.6e}, " + f"state_max_diff={state_max_diff:.6e}" + ) + + np.testing.assert_allclose(cpu_output, ref_output, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(cpu_state, ref_state, rtol=1e-5, atol=1e-5) + + def test_causal_conv_with_state_cpu_kernel_1(self): + """K=1 edge case on CPU: pad=0, zero-size state tensors.""" + rng = np.random.default_rng(2) + batch = 2 + channels = 16 + kernel = 1 + pad = kernel - 1 # = 0 + + model = _build_causal_conv_model(activation="silu") + + for seq_len in (1, 5): + with self.subTest(seq_len=seq_len): + inputs = { + "input": rng.standard_normal((batch, channels, seq_len), dtype=np.float32), + "weight": rng.standard_normal((channels, 1, kernel), dtype=np.float32), + "bias": rng.standard_normal((channels,), dtype=np.float32), + "past_state": np.zeros((batch, channels, pad), dtype=np.float32), + } + + cpu_output, cpu_state = _run_onnx(model, inputs, "CPUExecutionProvider") + ref_output, ref_state = _torch_causal_conv_reference( + x=inputs["input"], + weight=inputs["weight"], + bias=inputs["bias"], + past_state=inputs["past_state"], + activation="silu", + ) + + np.testing.assert_allclose(cpu_output, ref_output, rtol=1e-5, atol=1e-5) + self.assertEqual(cpu_state.shape[2], 0) # zero-size state + + +if __name__ == "__main__": + unittest.main()