Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ struct AttentionParameters {
int num_splits; // number of splits for splitkv
int rotary_dim = 0; // rotary embedding dimension
int beam_width;
bool is_unidirectional;
bool past_present_share_buffer;
bool is_unidirectional = false;
bool past_present_share_buffer = false;
bool is_packed_qkv = false; // whether qkv is packed
bool do_rotary;
bool broadcast_attn_bias_dim_0;
bool broadcast_attn_bias_dim_1;
bool do_rotary = false;
bool broadcast_attn_bias_dim_0 = false;
bool broadcast_attn_bias_dim_1 = false;
float mask_filter_value;
float scale;
bool use_tf32;
bool use_tf32 = false;
bool is_output_bnsh = false; // whether the output format is BNSH
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand Down Expand Up @@ -87,9 +88,8 @@ struct GroupQueryAttentionParameters : AttentionParameters {
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
int local_window_size; // Mask out tokens prior to total_sequence_length - local_window_size
bool kv_share_buffer;
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
bool is_first_prompt; // indicates whether this is first decoding step
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
bool is_first_prompt; // indicates whether this is first decoding step
bool rotary_interleaved;
bool use_smooth_softmax;
float softcap;
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/contrib_ops/cpu/utils/debug_macros.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#pragma once
#include <cstdio>
#include "core/common/make_string.h"

// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)

#ifdef DEBUG_GENERATION
#define DUMP_TENSOR_LEVEL 2
#else
#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
#if !defined(DUMP_TENSOR_LEVEL)
#define DUMP_TENSOR_LEVEL 0
#endif

#define DUMP_CPU_TENSOR_LEVEL DUMP_TENSOR_LEVEL
Expand Down Expand Up @@ -48,3 +45,12 @@
#else
#define DUMP_TENSOR_D(...)
#endif

#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG)
#define DEBUG_PRINTF(fmt, ...) \
std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__)
#else
#define DEBUG_PRINTF(fmt, ...) \
do { \
} while (0)
#endif
23 changes: 4 additions & 19 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,35 +179,20 @@ struct GroupQueryAttentionData {

// Memory Efficient buffers
T* fmha_buffer = nullptr;
T* unpacked_qkv_buffer = nullptr;
T* rotary_buffer = nullptr;
int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs
T* qkv_buffer = nullptr;

T* k = nullptr;
T* v = nullptr;

#ifndef NDEBUG
// Buffer size tracking for debug validation
// Allocated sizes are set during buffer allocation in group_query_attention.cc
// Max used sizes are updated during kernel calls in group_query_attention_impl.cu
// Validated before operator returns to ensure usage exactly matches allocation
size_t unpacked_qkv_buffer_size = 0; // Allocated size
size_t rotary_buffer_size = 0; // Allocated size
size_t position_ids_buffer_size = 0; // Allocated size
mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels)
mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels)
mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels)
#endif

// Output Tensors
T* output = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
void* present_key = nullptr;
void* present_value = nullptr;

// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
bool use_flash_attention_fast_decode = false;
bool disable_fused_kv = false;
};

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.is_seqlens_k_cumulative = seqlens_k_ == nullptr;
if (seqlens_k_ != nullptr) {
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
params.seqused_k = static_cast<int*>(seqlens_k_);
}

if (rotary_cos != nullptr) {
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,

size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads);
size_t get_softmax_lse_size(size_t token_count, size_t num_heads);
size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q);
size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads,
size_t seqlen_q, size_t head_size_rounded);

std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads,
size_t head_size, size_t num_SMs);
std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k,
size_t num_heads, size_t head_size, size_t num_SMs);

bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k);

Expand Down
126 changes: 66 additions & 60 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include <vector>
#include <algorithm>
#include "core/providers/cuda/cuda_common.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
Expand Down Expand Up @@ -39,8 +40,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE";
constexpr const char* kDisableFusedKv = "ORT_DISABLE_FUSED_KV";

// Group Query Attention (GQA) Operator
//
// This operator implements Group Query Attention, a variation of Multi-Head Attention (MHA)
// where the number of key/value heads is smaller than the number of query heads.
// It supports:
// - Rotary Positional Embeddings (RoPE)
// - KV Cache (past/present key/value)
// - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData
// - Flash Attention and Memory Efficient Attention backends
//
template <typename T>
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
: CudaKernel(info) {
Expand All @@ -63,7 +73,7 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)

disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();

// Memory efficient attention supports float and float16. BFloat16 support is added for SM80+ via cutlass kernels.
// Memory efficient attention supports float and float16. BFloat16 support added for SM80+.
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();

if (!disable_flash_attention_) {
Expand All @@ -72,9 +82,23 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
}

disable_flash_decode_ = ParseEnvironmentVariableWithDefault<bool>(kDisableFlashDecode, false);
disable_fused_kv_ = ParseEnvironmentVariableWithDefault<bool>(kDisableFusedKv, false);
}

// ComputeInternal executes the GQA kernel.
//
// Inputs:
// 0. query (Tensor) [batch, sequence_length, hidden_size]
// 1. key (Tensor) [batch, sequence_length, kv_hidden_size] (Optional)
// 2. value (Tensor) [batch, sequence_length, kv_hidden_size] (Optional)
// 3. past_key (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional)
// 4. past_value (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional)
// 5. seqlens_k (Tensor) [batch] - Total sequence length minus 1 (for historical compatibility)
// 6. total_seqlen (Tensor) - Max total sequence length
// 7. cos_cache (Tensor) - Precomputed cosine table for RoPE
// 8. sin_cache (Tensor) - Precomputed sine table for RoPE
// 9. position_ids (Tensor) - Position indices for RoPE
// 10. attention_bias (Tensor) - Not supported in this kernel
// 11. head_sink (Tensor) - Attention sink for GPT-OSS
template <typename T>
Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);
Expand Down Expand Up @@ -162,7 +186,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
IAllocatorUniquePtr<void> k_buffer;
IAllocatorUniquePtr<void> v_buffer;
IAllocatorUniquePtr<void> rotary_buffer;
IAllocatorUniquePtr<void> position_ids_buffer;
IAllocatorUniquePtr<void> fmha_buffer;
IAllocatorUniquePtr<void> unpacked_qkv_buffer;
IAllocatorUniquePtr<int> seq_lens_buffer;
Expand All @@ -185,24 +208,39 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.present_value = reinterpret_cast<CudaT*>(context->Output<Tensor>(2)->MutableData<T>());

// Compute past_present_share_buffer early since it's needed for flash attention path selection.
// This compares the final pointer values after quantization handling.
parameters.past_present_share_buffer = (data.past_key == data.present_key);

#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported<T>(device_prop,
parameters.head_size,
parameters.num_heads,
parameters.kv_num_heads);
data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_share_buffer;
if (use_flash_attention) {
data.use_flash_attention = true;
data.use_memory_efficient_attention = false;

data.use_flash_attention = use_flash_attention;
data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer;

if (use_flash_attention) {
// Allocate Flash specific buffers (Softmax LSE, Accum)
size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads);

int num_heads_for_split = data.use_flash_attention_fast_decode ? parameters.kv_num_heads : parameters.num_heads;
auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, num_heads_for_split,
parameters.head_size, device_prop.multiProcessorCount);

parameters.num_splits = static_cast<int>(num_splits);

if (data.use_flash_attention_fast_decode && num_splits > 1) {
// The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel.
// However, the LSE and Accum buffers must store results for ALL num_heads.
softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length);
auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; };
out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32));
}

softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
Expand All @@ -214,11 +252,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#endif

if (data.use_flash_attention_fast_decode && parameters.sequence_length == 1) {
// FlashAttentionDecoding Fast Path:
// FlashDecoding Fast Path:
// - Uses Flash Attention's internal KV append logic, so total_seq_lens and padded_seq_lens are not needed.
// - Past_seq_lens is passed as seqlens_k to Flash Attention, which uses it to:
// 1. Determine where to append new K/V in the cache
// 2. Apply correct causal masking (attention only to positions [0, past_seq_len])
// - The input seqlens_k from ONNX graph is (total_len - 1), which equals past_seq_len when seq_len == 1.
// - This optimization avoids launching GetSequenceLengths kernel for single-token decoding.
data.past_seq_lens = const_cast<int*>(total_seq_lens_minus_one->Data<int>());
Expand All @@ -239,16 +274,20 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.is_first_prompt,
cuda_stream,
device_prop.maxThreadsPerBlock));
DUMP_TENSOR_INIT();
DUMP_TENSOR("total_seq_lens", data.total_seq_lens, parameters.batch_size, 1);
DUMP_TENSOR("past_seq_lens", data.past_seq_lens, parameters.batch_size, 1);
DUMP_TENSOR("padded_seq_lens", data.padded_seq_lens, parameters.batch_size, 1);
}

if (!use_flash_attention) {
// Fall back to memory efficient attention.
#if USE_MEMORY_EFFICIENT_ATTENTION
if (!data.use_flash_attention) {
// Fall back to memory efficient attention.
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
has_memory_efficient_attention(sm, std::is_same<T, MLFloat16>::value, std::is_same<T, BFloat16>::value, parameters.head_size, parameters.head_size);
data.use_memory_efficient_attention = use_memory_efficient_attention;

// KV buffer for head expansion (when num_heads != kv_num_heads)
size_t kv_buffer_bytes = (use_memory_efficient_attention && (parameters.num_heads != parameters.kv_num_heads))
Expand All @@ -262,49 +301,30 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
#endif

data.use_memory_efficient_attention = use_memory_efficient_attention;
data.use_flash_attention = false;

data.k = reinterpret_cast<CudaT*>(k_buffer.get());
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
data.disable_fused_kv = disable_fused_kv_;
}
#endif

// -------------
// Centralized scratch buffer allocation using GQABufferRequirements
// This ensures allocation logic stays in sync with kernel usage
auto buffer_req = GQABufferRequirements::Compute<T>(
parameters,
use_flash_attention,
data.use_flash_attention,
data.use_flash_attention_fast_decode,
data.use_memory_efficient_attention);

if (buffer_req.unpacked_qkv_bytes > 0) {
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.unpacked_qkv_bytes, context->GetComputeStream());
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
if (buffer_req.rotary_buffer_bytes > 0) {
rotary_buffer = GetScratchBuffer<void>(buffer_req.rotary_buffer_bytes, context->GetComputeStream());
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
if (buffer_req.qkv_buffer_bytes > 0) {
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.qkv_buffer_bytes, context->GetComputeStream());
data.qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
if (buffer_req.position_ids_bytes > 0) {
position_ids_buffer = GetScratchBuffer<void>(buffer_req.position_ids_bytes, context->GetComputeStream());
data.position_ids_buffer = reinterpret_cast<int64_t*>(position_ids_buffer.get());
}
#ifndef NDEBUG
// Track allocated sizes for validation
data.unpacked_qkv_buffer_size = buffer_req.unpacked_qkv_bytes;
data.rotary_buffer_size = buffer_req.rotary_buffer_bytes;
data.position_ids_buffer_size = buffer_req.position_ids_bytes;
#endif

if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_flash_attention = data.use_flash_attention;
debug_info.use_efficient_attention = data.use_memory_efficient_attention;

debug_info.Print("GroupQueryAttention",
Expand All @@ -313,12 +333,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
std::is_same<T, BFloat16>::value);
}

if (data.past_key == data.present_key) {
parameters.kv_share_buffer = true;
ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when kv_share_buffer is true");
// Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup)
if (parameters.past_present_share_buffer) {
ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true");
} else {
parameters.kv_share_buffer = false;
ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when kv_share_buffer is false");
ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false");
}

data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
Expand All @@ -337,19 +356,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(QkvToContext<CudaT>(
device_prop, cublas, context->GetComputeStream(), parameters, data));

#ifndef NDEBUG
// Validate buffer usage matches allocation exactly
ORT_ENFORCE(data.unpacked_qkv_max_used == data.unpacked_qkv_buffer_size,
"unpacked_qkv_buffer: used ", data.unpacked_qkv_max_used,
" bytes but allocated ", data.unpacked_qkv_buffer_size);
ORT_ENFORCE(data.rotary_max_used == data.rotary_buffer_size,
"rotary_buffer: used ", data.rotary_max_used,
" bytes but allocated ", data.rotary_buffer_size);
ORT_ENFORCE(data.position_ids_max_used == data.position_ids_buffer_size,
"position_ids_buffer: used ", data.position_ids_max_used,
" bytes but allocated ", data.position_ids_buffer_size);
#endif

return Status::OK();
}

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class GroupQueryAttention final : public CudaKernel {
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
bool disable_flash_decode_;
bool disable_fused_kv_;

static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
IAllocatorUniquePtr<int> zeros_;
Expand Down
Loading
Loading