From 634e15cfb7881851418c76518b3870031ec88a22 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 22 Jan 2026 22:27:36 -0800 Subject: [PATCH 1/2] Fix GQA Parity (#27108) Fix [#27079](https://github.com/microsoft/onnxruntime/issues/27079) - Qwen3 model quality regression on CUDA backend. The parity issue was caused by **buffer pointer misconfiguration** in the GQA (Group Query Attention) QKV preprocessing pipeline. The original implementation used multiple separate kernels for: 1. Unpacking packed QKV tensor 2. Applying RoPE (Rotary Position Embedding) to Q and K 3. Appending K/V to cache This multi-kernel approach created opportunities for misconfiguration: - Buffers were allocated but not properly used - Pointers could reference memory that was not yet allocated or initialized - Buffer sharing logic was fragmented across different code paths Consolidate QKV preprocessing into a **single fused kernel** (`UnpackRoPEAppend`) that performs all operations in one pass: 1. **Unified kernel design**: A single kernel handles unpacking, RoPE application, and cache append operations 2. **Simplified buffer management**: The new `PrepareQKV` function clearly manages buffer allocation and ensures proper initialization 3. **Explicit past-to-present cache copy**: When `past_present_share_buffer` is false, explicitly copy past KV cache to present buffer before appending new tokens 4. **Zero-initialization for non-shared buffers**: Clear present KV buffers when not sharing with past to ensure deterministic output | File | Changes | |------|---------| | [group_query_attention_qkv.cuh](cci:7://file:///home/tlwu/onnxruntime/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh:0:0-0:0) | New fused `UnpackRoPEAppend` kernel with shared memory optimization for non-interleaved RoPE | | `group_query_attention_impl.cu` | New `PrepareQKV` helper function that orchestrates buffer setup and kernel launch | | `group_query_attention.cc` | Simplified operator logic by delegating QKV prep to unified helper | | `test_gqa.py` | Enhanced test coverage for various QKV configurations | - **Reduced kernel launches**: From 4-5 separate kernel calls to a single fused kernel - **Better memory safety**: All buffer pointers are validated in a single location - **Improved RoPE handling**: Uses shared memory for efficient non-interleaved RoPE computation - **Deterministic output**: Explicit buffer initialization ensures consistent results across runs - **Compatible with quantized KV cache**: The new preprocessing kernel design supports future quantization work - All existing GQA unit tests pass - Verified Qwen3 model no longer produces gibberish output - Tested both fp16/bf16 and various head configurations --- .../cpu/bert/attention_parameters.h | 18 +- .../contrib_ops/cpu/utils/debug_macros.h | 18 +- .../contrib_ops/cuda/bert/attention_data.h | 23 +- .../cuda/bert/flash_attention/flash_api.cc | 1 - .../cuda/bert/flash_attention/flash_api.h | 7 +- .../cuda/bert/group_query_attention.cc | 126 +-- .../cuda/bert/group_query_attention.h | 1 - .../cuda/bert/group_query_attention_impl.cu | 818 ++++-------------- .../cuda/bert/group_query_attention_impl.h | 108 +-- .../cuda/bert/group_query_attention_qkv.cuh | 248 ++++++ .../cuda/bert/rotary_embedding_impl.cu | 51 +- .../test/python/transformers/test_gqa.py | 432 +++++---- 12 files changed, 836 insertions(+), 1015 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f237b24b899a0..4ad11dce7e093 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -25,15 +25,16 @@ struct AttentionParameters { int num_splits; // number of splits for splitkv int rotary_dim = 0; // rotary embedding dimension int beam_width; - bool is_unidirectional; - bool past_present_share_buffer; + bool is_unidirectional = false; + bool past_present_share_buffer = false; bool is_packed_qkv = false; // whether qkv is packed - bool do_rotary; - bool broadcast_attn_bias_dim_0; - bool broadcast_attn_bias_dim_1; + bool do_rotary = false; + bool broadcast_attn_bias_dim_0 = false; + bool broadcast_attn_bias_dim_1 = false; float mask_filter_value; float scale; - bool use_tf32; + bool use_tf32 = false; + bool is_output_bnsh = false; // whether the output format is BNSH AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -87,9 +88,8 @@ struct GroupQueryAttentionParameters : AttentionParameters { int seqlen_past_kv_cache; // sequence length of past kv tensor int seqlen_present_kv_cache; // sequence length of present kv tensor int local_window_size; // Mask out tokens prior to total_sequence_length - local_window_size - bool kv_share_buffer; - bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 - bool is_first_prompt; // indicates whether this is first decoding step + bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt; // indicates whether this is first decoding step bool rotary_interleaved; bool use_smooth_softmax; float softcap; diff --git a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h index 47d0fc5e4008c..415612582ee4b 100644 --- a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h +++ b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h @@ -1,12 +1,9 @@ #pragma once +#include #include "core/common/make_string.h" -// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) - -#ifdef DEBUG_GENERATION -#define DUMP_TENSOR_LEVEL 2 -#else -#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation. +#if !defined(DUMP_TENSOR_LEVEL) +#define DUMP_TENSOR_LEVEL 0 #endif #define DUMP_CPU_TENSOR_LEVEL DUMP_TENSOR_LEVEL @@ -48,3 +45,12 @@ #else #define DUMP_TENSOR_D(...) #endif + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG) +#define DEBUG_PRINTF(fmt, ...) \ + std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__) +#else +#define DEBUG_PRINTF(fmt, ...) \ + do { \ + } while (0) +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 2344b425ed263..1622bb6622412 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -179,35 +179,20 @@ struct GroupQueryAttentionData { // Memory Efficient buffers T* fmha_buffer = nullptr; - T* unpacked_qkv_buffer = nullptr; - T* rotary_buffer = nullptr; - int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs + T* qkv_buffer = nullptr; + T* k = nullptr; T* v = nullptr; -#ifndef NDEBUG - // Buffer size tracking for debug validation - // Allocated sizes are set during buffer allocation in group_query_attention.cc - // Max used sizes are updated during kernel calls in group_query_attention_impl.cu - // Validated before operator returns to ensure usage exactly matches allocation - size_t unpacked_qkv_buffer_size = 0; // Allocated size - size_t rotary_buffer_size = 0; // Allocated size - size_t position_ids_buffer_size = 0; // Allocated size - mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels) - mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels) - mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels) -#endif - // Output Tensors T* output = nullptr; - T* present_key = nullptr; - T* present_value = nullptr; + void* present_key = nullptr; + void* present_value = nullptr; // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; bool use_flash_attention_fast_decode = false; - bool disable_fused_kv = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 91cac731054e6..fcc470b19a7b4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -533,7 +533,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; if (seqlens_k_ != nullptr) { params.cu_seqlens_k = static_cast(seqlens_k_); - params.seqused_k = static_cast(seqlens_k_); } if (rotary_cos != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 22b075d8533f9..83f94a31d1786 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -132,9 +132,12 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); size_t get_softmax_lse_size(size_t token_count, size_t num_heads); +size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q); +size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, + size_t seqlen_q, size_t head_size_rounded); -std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, - size_t head_size, size_t num_SMs); +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, + size_t num_heads, size_t head_size, size_t num_SMs); bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index c99db85f93421..29ef660e562e0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" @@ -39,8 +40,17 @@ REGISTER_KERNEL_TYPED(MLFloat16) REGISTER_KERNEL_TYPED(BFloat16) constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; -constexpr const char* kDisableFusedKv = "ORT_DISABLE_FUSED_KV"; +// Group Query Attention (GQA) Operator +// +// This operator implements Group Query Attention, a variation of Multi-Head Attention (MHA) +// where the number of key/value heads is smaller than the number of query heads. +// It supports: +// - Rotary Positional Embeddings (RoPE) +// - KV Cache (past/present key/value) +// - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData +// - Flash Attention and Memory Efficient Attention backends +// template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : CudaKernel(info) { @@ -63,7 +73,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); - // Memory efficient attention supports float and float16. BFloat16 support is added for SM80+ via cutlass kernels. + // Memory efficient attention supports float and float16. BFloat16 support added for SM80+. disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); if (!disable_flash_attention_) { @@ -72,9 +82,23 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) } disable_flash_decode_ = ParseEnvironmentVariableWithDefault(kDisableFlashDecode, false); - disable_fused_kv_ = ParseEnvironmentVariableWithDefault(kDisableFusedKv, false); } +// ComputeInternal executes the GQA kernel. +// +// Inputs: +// 0. query (Tensor) [batch, sequence_length, hidden_size] +// 1. key (Tensor) [batch, sequence_length, kv_hidden_size] (Optional) +// 2. value (Tensor) [batch, sequence_length, kv_hidden_size] (Optional) +// 3. past_key (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional) +// 4. past_value (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional) +// 5. seqlens_k (Tensor) [batch] - Total sequence length minus 1 (for historical compatibility) +// 6. total_seqlen (Tensor) - Max total sequence length +// 7. cos_cache (Tensor) - Precomputed cosine table for RoPE +// 8. sin_cache (Tensor) - Precomputed sine table for RoPE +// 9. position_ids (Tensor) - Position indices for RoPE +// 10. attention_bias (Tensor) - Not supported in this kernel +// 11. head_sink (Tensor) - Attention sink for GPT-OSS template Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* query = context->Input(0); @@ -162,7 +186,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr k_buffer; IAllocatorUniquePtr v_buffer; IAllocatorUniquePtr rotary_buffer; - IAllocatorUniquePtr position_ids_buffer; IAllocatorUniquePtr fmha_buffer; IAllocatorUniquePtr unpacked_qkv_buffer; IAllocatorUniquePtr seq_lens_buffer; @@ -185,24 +208,39 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); data.present_value = reinterpret_cast(context->Output(2)->MutableData()); + // Compute past_present_share_buffer early since it's needed for flash attention path selection. + // This compares the final pointer values after quantization handling. + parameters.past_present_share_buffer = (data.past_key == data.present_key); + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, parameters.head_size, parameters.num_heads, parameters.kv_num_heads); - data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_share_buffer; - if (use_flash_attention) { - data.use_flash_attention = true; - data.use_memory_efficient_attention = false; + data.use_flash_attention = use_flash_attention; + data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer; + + if (use_flash_attention) { // Allocate Flash specific buffers (Softmax LSE, Accum) size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + + int num_heads_for_split = data.use_flash_attention_fast_decode ? parameters.kv_num_heads : parameters.num_heads; auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, num_heads_for_split, parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = static_cast(num_splits); + if (data.use_flash_attention_fast_decode && num_splits > 1) { + // The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel. + // However, the LSE and Accum buffers must store results for ALL num_heads. + softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32)); + } + softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); @@ -214,11 +252,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { #endif if (data.use_flash_attention_fast_decode && parameters.sequence_length == 1) { - // FlashAttentionDecoding Fast Path: + // FlashDecoding Fast Path: // - Uses Flash Attention's internal KV append logic, so total_seq_lens and padded_seq_lens are not needed. - // - Past_seq_lens is passed as seqlens_k to Flash Attention, which uses it to: - // 1. Determine where to append new K/V in the cache - // 2. Apply correct causal masking (attention only to positions [0, past_seq_len]) // - The input seqlens_k from ONNX graph is (total_len - 1), which equals past_seq_len when seq_len == 1. // - This optimization avoids launching GetSequenceLengths kernel for single-token decoding. data.past_seq_lens = const_cast(total_seq_lens_minus_one->Data()); @@ -239,16 +274,20 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.is_first_prompt, cuda_stream, device_prop.maxThreadsPerBlock)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("total_seq_lens", data.total_seq_lens, parameters.batch_size, 1); + DUMP_TENSOR("past_seq_lens", data.past_seq_lens, parameters.batch_size, 1); + DUMP_TENSOR("padded_seq_lens", data.padded_seq_lens, parameters.batch_size, 1); } - if (!use_flash_attention) { - // Fall back to memory efficient attention. #if USE_MEMORY_EFFICIENT_ATTENTION + if (!data.use_flash_attention) { + // Fall back to memory efficient attention. int sm = (device_prop.major * 10) + device_prop.minor; bool use_memory_efficient_attention = - !use_flash_attention && !disable_memory_efficient_attention_ && has_memory_efficient_attention(sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.head_size); + data.use_memory_efficient_attention = use_memory_efficient_attention; // KV buffer for head expansion (when num_heads != kv_num_heads) size_t kv_buffer_bytes = (use_memory_efficient_attention && (parameters.num_heads != parameters.kv_num_heads)) @@ -262,49 +301,30 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); -#else - constexpr bool use_memory_efficient_attention = false; -#endif - - data.use_memory_efficient_attention = use_memory_efficient_attention; - data.use_flash_attention = false; data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); - data.disable_fused_kv = disable_fused_kv_; } +#endif + // ------------- // Centralized scratch buffer allocation using GQABufferRequirements // This ensures allocation logic stays in sync with kernel usage auto buffer_req = GQABufferRequirements::Compute( parameters, - use_flash_attention, + data.use_flash_attention, data.use_flash_attention_fast_decode, data.use_memory_efficient_attention); - if (buffer_req.unpacked_qkv_bytes > 0) { - unpacked_qkv_buffer = GetScratchBuffer(buffer_req.unpacked_qkv_bytes, context->GetComputeStream()); - data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); - } - if (buffer_req.rotary_buffer_bytes > 0) { - rotary_buffer = GetScratchBuffer(buffer_req.rotary_buffer_bytes, context->GetComputeStream()); - data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); + if (buffer_req.qkv_buffer_bytes > 0) { + unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream()); + data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } - if (buffer_req.position_ids_bytes > 0) { - position_ids_buffer = GetScratchBuffer(buffer_req.position_ids_bytes, context->GetComputeStream()); - data.position_ids_buffer = reinterpret_cast(position_ids_buffer.get()); - } -#ifndef NDEBUG - // Track allocated sizes for validation - data.unpacked_qkv_buffer_size = buffer_req.unpacked_qkv_bytes; - data.rotary_buffer_size = buffer_req.rotary_buffer_bytes; - data.position_ids_buffer_size = buffer_req.position_ids_bytes; -#endif if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; - debug_info.use_flash_attention = use_flash_attention; + debug_info.use_flash_attention = data.use_flash_attention; debug_info.use_efficient_attention = data.use_memory_efficient_attention; debug_info.Print("GroupQueryAttention", @@ -313,12 +333,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { std::is_same::value); } - if (data.past_key == data.present_key) { - parameters.kv_share_buffer = true; - ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when kv_share_buffer is true"); + // Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup) + if (parameters.past_present_share_buffer) { + ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); } else { - parameters.kv_share_buffer = false; - ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when kv_share_buffer is false"); + ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); } data.output = reinterpret_cast(output->MutableData()); @@ -337,19 +356,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(QkvToContext( device_prop, cublas, context->GetComputeStream(), parameters, data)); -#ifndef NDEBUG - // Validate buffer usage matches allocation exactly - ORT_ENFORCE(data.unpacked_qkv_max_used == data.unpacked_qkv_buffer_size, - "unpacked_qkv_buffer: used ", data.unpacked_qkv_max_used, - " bytes but allocated ", data.unpacked_qkv_buffer_size); - ORT_ENFORCE(data.rotary_max_used == data.rotary_buffer_size, - "rotary_buffer: used ", data.rotary_max_used, - " bytes but allocated ", data.rotary_buffer_size); - ORT_ENFORCE(data.position_ids_max_used == data.position_ids_buffer_size, - "position_ids_buffer: used ", data.position_ids_max_used, - " bytes but allocated ", data.position_ids_buffer_size); -#endif - return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 5bf26e8c6edac..2536da9fe1379 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -35,7 +35,6 @@ class GroupQueryAttention final : public CudaKernel { bool disable_flash_attention_; bool disable_memory_efficient_attention_; bool disable_flash_decode_; - bool disable_fused_kv_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index c8a1629f21bce..0b6da63b31af6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -27,11 +27,11 @@ limitations under the License. #include #include -#include // For getenv #include #include +#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/utils/debug_macros.h" #include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -40,14 +40,16 @@ limitations under the License. #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include "contrib_ops/cuda/bert/rotary_common.cuh" #include "contrib_ops/cuda/bert/transformer_common.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "core/providers/cuda/cuda_common.h" + #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; @@ -59,6 +61,100 @@ namespace onnxruntime { namespace contrib { namespace cuda { +// ============================================================================ +// QKV Preprocessing Helpers +// ============================================================================ + +// Internal helper to get Q, K, V pointers, handling packed input +// +// This function orchestrates the preparation of Q, K, and V tensors for attention kernels. +// It performs: +// 1. Handling packed vs. unpacked QKV inputs. +// 2. Managing KV cache updates (appending new tokens). +// 3. Ensuring synchronization between past and present KV caches when necessary. +// 4. Launching the UnpackRoPEQuantizeAppend kernel to unpack, apply RoPE, and update caches. +// 5. Returning strict Q, K, V pointers ready for the core attention operation. +template +Status PrepareQKV( + cudaStream_t stream, + const int max_threads_per_block, + const GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + const T*& q, + const T*& k, + const T*& v) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + using CudaT = typename ToCudaType::MappedType; + CudaT* q_out = data.qkv_buffer; + + if (!parameters.is_packed_qkv && !parameters.do_rotary) { + q_out = nullptr; + } + + CudaT* k_final_ptr = reinterpret_cast(data.present_key); + CudaT* v_final_ptr = reinterpret_cast(data.present_value); + int final_max_seqlen = parameters.seqlen_present_kv_cache; + bool final_is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + + if (!parameters.past_present_share_buffer) { + size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * final_max_seqlen * head_size * sizeof(CudaT); + CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream)); + CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream)); + } + + if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) { + bool is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + if (is_bnsh) { + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaT); + size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaT); + size_t width = src_pitch; + size_t height = (size_t)batch_size * kv_num_heads; + + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + } else { + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaT); + size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaT); + size_t width = src_pitch; + size_t height = (size_t)batch_size; + + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + } + } + + ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppendKV( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_out, k_final_ptr, v_final_ptr, + num_heads, kv_num_heads, head_size, sequence_length, batch_size, + final_max_seqlen, data.past_seq_lens, + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, + final_is_bnsh, + stream, max_threads_per_block)); + + if (q_out != nullptr) { + q = reinterpret_cast(q_out); + } else { + q = reinterpret_cast(data.query); + } + k = reinterpret_cast(k_final_ptr); + v = reinterpret_cast(v_final_ptr); + return Status::OK(); +} + ////////// Auxiliary Kernels for KV prep // Concat new to past in present. Supports past BSNH or past BNSH @@ -393,267 +489,6 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache -// This eliminates 4 kernel launches: Unpack -> Rotate Q -> Rotate K -> Append K -> Append V -// Becomes: Single kernel that does all operations in one pass -// -// Bounds Safety: -// - cache_s = past_seq_len + s is guaranteed < max_seqlen by the caller (group_query_attention.cc) -// because present_sequence_length = max(past + new_seq_len) across batches, and the present -// buffer is allocated with seqlen_present_kv_cache >= total_seq_lens[b] for all b. -// - The kernel processes exactly batch_size * sequence_length * (Q+K+V hidden) elements, -// which matches the packed_qkv input size allocated by the model. -// -// RoPE Contiguity Requirement: -// - packed_qkv MUST be strictly contiguous with layout [B, S, (H_q + 2*H_kv) * D] -// - The half-split RoPE logic (RotaryDispatcher::apply) fetches pair elements at offset -// (h + rotary_dim/2) relative to the start of each head -// - If strided/non-contiguous inputs are ever supported, this pointer arithmetic must change -// -// Performance Optimization: -// Uses 3D grid layout to eliminate expensive integer divisions: -// - blockIdx.z = batch index (b) -// - blockIdx.y = sequence index (s) -// - blockIdx.x * blockDim.x + threadIdx.x = offset within QKV hidden dimension -// This removes 4 divisions (/, %) per thread that would otherwise be needed. -template -__global__ void UnpackQKVWithRoPEAndAppendKV( - const T* packed_qkv, // Input: packed QKV [B, S, (Q+K+V) hidden] - T* unpacked_q, // Output: rotated Q [B, S, Q_heads, H] (BSNH) - T* k_cache, // Output: K cache [B, N, MaxS, H] or [B, MaxS, N, H] - T* v_cache, // Output: V cache [B, N, MaxS, H] or [B, MaxS, N, H] - const int num_heads, - const int kv_num_heads, - const int head_size, - const int d, // QKV hidden stride = (num_heads + 2*kv_num_heads) * head_size - const int max_seqlen, // KV cache max sequence length - const int* past_seq_lens, - // RoPE params - const T* cos_cache, - const T* sin_cache, - const int rotary_dim, - const int64_t* position_ids, - const bool interleaved, - const bool is_cache_bnsh) { - // Vectorized load/store using float4 (16 bytes) - using LoadT = float4; - constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); - - // 3D grid layout eliminates integer division: - // - blockIdx.z = batch index (b) - obtained from grid dimension, no division needed - // - blockIdx.y = sequence index (s) - obtained from grid dimension, no division needed - // - linear thread index within (b, s) gives offset directly - const int b = blockIdx.z; - const int s = blockIdx.y; - const int offset_vec_idx = blockIdx.x * blockDim.x + threadIdx.x; // Vector index within d - const int offset = offset_vec_idx * elements_per_thread; // Element offset within d - - // Bounds check: offset must be within the QKV hidden dimension - if (offset >= d) return; - - const int q_hidden = num_heads * head_size; - const int k_hidden = kv_num_heads * head_size; - const int sequence_length = gridDim.y; // Get from grid dimension - - // Calculate linear index for packed_qkv load - const int64_t packed_idx = static_cast(b) * sequence_length * d + - static_cast(s) * d + offset; - - // Load vector from packed buffer - LoadT val_vec = reinterpret_cast(packed_qkv)[packed_idx / elements_per_thread]; - - // Common RoPE Calculations - const int past_seq_len = past_seq_lens[b]; - int pos_id = 0; - if (position_ids != nullptr) { - pos_id = static_cast(position_ids[b * sequence_length + s]); - } else { - pos_id = past_seq_len + s; - } - - // Determine Q, K, or V based on offset - if (offset < q_hidden) { - // Q: Apply RoPE and write to unpacked_q buffer (BSNH format) - const int q_head_idx = offset / head_size; - const int h = offset % head_size; - const int h_idx = h / elements_per_thread; - - if (cos_cache != nullptr && rotary_dim > 0 && h < rotary_dim) { - // For half-split RoPE, pair values should be read relative to the START of the current Q head. - // Calculate offset to head start: (b, s, q_head_n, 0) in packed QKV. - const int64_t q_head_start_in_packed = static_cast(b) * sequence_length * d + - static_cast(s) * d + - static_cast(q_head_idx) * head_size; - RotaryDispatcher::apply(val_vec, - reinterpret_cast(cos_cache), - reinterpret_cast(sin_cache), - rotary_dim, h_idx, pos_id, interleaved, - reinterpret_cast(packed_qkv), - q_head_start_in_packed / elements_per_thread); - } - - const int64_t q_idx = static_cast(b) * sequence_length * num_heads * head_size + - static_cast(s) * num_heads * head_size + offset; - // Vector store to unpacked_q - reinterpret_cast(unpacked_q)[q_idx / elements_per_thread] = val_vec; - - } else if (offset < q_hidden + k_hidden) { - // K: Apply RoPE and write DIRECTLY to K cache - const int k_offset = offset - q_hidden; - const int n = k_offset / head_size; - const int h = k_offset % head_size; - const int h_idx = h / elements_per_thread; - - if (cos_cache != nullptr && rotary_dim > 0 && h < rotary_dim) { - // For half-split RoPE, pair values should be read relative to the START of the current K head. - // Calculate offset to head start: (b, s, k_head_n, 0) in packed QKV. - const int64_t k_head_start_in_packed = static_cast(b) * sequence_length * d + - static_cast(s) * d + - q_hidden + - static_cast(n) * head_size; - RotaryDispatcher::apply(val_vec, - reinterpret_cast(cos_cache), - reinterpret_cast(sin_cache), - rotary_dim, h_idx, pos_id, interleaved, - reinterpret_cast(packed_qkv), - k_head_start_in_packed / elements_per_thread); - } - - const int cache_s = past_seq_len + s; - int64_t cache_idx; - if (is_cache_bnsh) { - cache_idx = static_cast(b) * kv_num_heads * max_seqlen * head_size + - static_cast(n) * max_seqlen * head_size + - static_cast(cache_s) * head_size + h; - } else { // BSNH - cache_idx = static_cast(b) * max_seqlen * kv_num_heads * head_size + - static_cast(cache_s) * kv_num_heads * head_size + - static_cast(n) * head_size + h; - } - // Vector store to k_cache - reinterpret_cast(k_cache)[cache_idx / elements_per_thread] = val_vec; - - } else { - // V: Write DIRECTLY to V cache (no rotation) - const int v_offset = offset - q_hidden - k_hidden; - const int n = v_offset / head_size; - const int h = v_offset % head_size; - - const int cache_s = past_seq_len + s; - int64_t cache_idx; - if (is_cache_bnsh) { - cache_idx = static_cast(b) * kv_num_heads * max_seqlen * head_size + - static_cast(n) * max_seqlen * head_size + - static_cast(cache_s) * head_size + h; - } else { // BSNH - cache_idx = static_cast(b) * max_seqlen * kv_num_heads * head_size + - static_cast(cache_s) * kv_num_heads * head_size + - static_cast(n) * head_size + h; - } - // Vector store to v_cache - reinterpret_cast(v_cache)[cache_idx / elements_per_thread] = val_vec; - } -} - -// Launcher for fused UnpackQKV + RoPE + KV Append -template -Status LaunchUnpackQKVWithRoPEAndAppendKV( - const T* packed_qkv, - T* unpacked_q, - T* k_cache, - T* v_cache, - const int num_heads, - const int kv_num_heads, - const int head_size, - const int sequence_length, - const int batch_size, - const int max_seqlen, - const int* past_seq_lens, - const T* cos_cache, - const T* sin_cache, - const int rotary_dim, - const int64_t* position_ids, - const bool interleaved, - const bool is_cache_bnsh, - cudaStream_t stream, - const int max_threads_per_block) { - // Determine vectorization factor (float4 is 16 bytes) - constexpr int vector_bytes = sizeof(float4); - constexpr int element_bytes = sizeof(T); - constexpr int elements_per_vector = vector_bytes / element_bytes; - - // Validate head_size alignment - if (head_size % elements_per_vector != 0) { - // If strict alignment is not met (unlikely given GQA constraints), we should fall back or fail. - // Typically GQA enforces head_size % 8 == 0. - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by ", elements_per_vector, " for vectorized GQA kernel."); - } - - // Validate grid dimensions - CUDA limits gridDim.y to 65535 - if (sequence_length > 65535) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Sequence length ", sequence_length, - " exceeds CUDA grid dimension limit (65535) for fused UnpackQKV kernel."); - } - -#ifndef NDEBUG - // Debug-mode alignment assertions for vectorized memory access - assert(reinterpret_cast(packed_qkv) % 16 == 0 && "packed_qkv must be 16-byte aligned"); - assert(reinterpret_cast(unpacked_q) % 16 == 0 && "unpacked_q must be 16-byte aligned"); - assert(reinterpret_cast(k_cache) % 16 == 0 && "k_cache must be 16-byte aligned"); - assert(reinterpret_cast(v_cache) % 16 == 0 && "v_cache must be 16-byte aligned"); - if (cos_cache != nullptr) { - assert(reinterpret_cast(cos_cache) % 16 == 0 && "cos_cache must be 16-byte aligned"); - assert(reinterpret_cast(sin_cache) % 16 == 0 && "sin_cache must be 16-byte aligned"); - } -#endif - - // QKV hidden dimension stride - const int d = (num_heads + 2 * kv_num_heads) * head_size; - const int d_vectors = d / elements_per_vector; // Number of vectors per (b, s) - - // 3D grid layout for eliminating integer divisions in kernel: - // grid.x = number of blocks needed to cover d_vectors with threads_per_block threads - // grid.y = sequence_length - // grid.z = batch_size - const int threads_per_block = std::min(max_threads_per_block, d_vectors); - const int blocks_x = (d_vectors + threads_per_block - 1) / threads_per_block; - const dim3 grid(blocks_x, sequence_length, batch_size); - const dim3 block(threads_per_block); - - UnpackQKVWithRoPEAndAppendKV<<>>( - packed_qkv, - unpacked_q, - k_cache, - v_cache, - num_heads, - kv_num_heads, - head_size, - d, - max_seqlen, - past_seq_lens, - cos_cache, - sin_cache, - rotary_dim, - position_ids, - interleaved, - is_cache_bnsh); - - return CUDA_CALL(cudaGetLastError()); -} - -// Explicit template instantiations -template Status LaunchUnpackQKVWithRoPEAndAppendKV( - const half*, half*, half*, half*, - int, int, int, int, int, int, const int*, - const half*, const half*, int, const int64_t*, bool, bool, - cudaStream_t, int); - -template Status LaunchUnpackQKVWithRoPEAndAppendKV( - const BFloat16*, BFloat16*, BFloat16*, BFloat16*, - int, int, int, int, int, int, const int*, - const BFloat16*, const BFloat16*, int, const int64_t*, bool, bool, - cudaStream_t, int); - // ============================================================================ // GetSequenceLengths Kernel // ============================================================================ @@ -697,6 +532,7 @@ __global__ void GetSequenceLengths(const int* total_seq_lens_minus_one, padded_seq_lens[i] = sequence_length; } else { past_seq_lens[i] = total_len - sequence_length; + padded_seq_lens[i] = 0; } } } @@ -716,20 +552,32 @@ Status LaunchGetSequenceLengths( return CUDA_CALL(cudaGetLastError()); } -////////// Kernels (supports right padding but not left padding) +// Trace function for debugging +#define ORT_GQA_TRACE(func_name) \ + DEBUG_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \ + func_name, \ + static_cast(parameters.is_packed_qkv), \ + static_cast(parameters.is_first_prompt), \ + static_cast(parameters.is_subsequent_prompt), \ + static_cast(parameters.past_present_share_buffer)); +////////// Kernels (supports right padding but not left padding) +// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. +// Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. #if USE_FLASH_ATTENTION // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. template -Status FlashAttentionDecoding( +Status FlashDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { - assert(!parameters.is_first_prompt && parameters.kv_share_buffer); + assert(!parameters.is_first_prompt && parameters.past_present_share_buffer); + + ORT_GQA_TRACE("FlashDecoding"); const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; @@ -757,8 +605,8 @@ Status FlashAttentionDecoding( void* seqlens_k = reinterpret_cast(data.past_seq_lens); - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* present_key = data.present_key; + void* present_value = data.present_value; void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); void* head_sink = reinterpret_cast(const_cast(data.head_sink)); @@ -773,7 +621,8 @@ Status FlashAttentionDecoding( parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - parameters.local_window_size - 1, parameters.rotary_interleaved, parameters.is_packed_qkv)); + parameters.local_window_size - 1, parameters.rotary_interleaved, parameters.is_packed_qkv, + 0, 1)); return Status::OK(); } @@ -799,242 +648,21 @@ Status FlashAttention( bool is_causal = parameters.is_unidirectional; bool is_bf16 = std::is_same::value; - void* query = reinterpret_cast(const_cast(data.query)); - void* key; - void* value; - - if (!parameters.is_packed_qkv) { - key = reinterpret_cast(const_cast(data.key)); - value = reinterpret_cast(const_cast(data.value)); - } else { - const size_t key_offset = static_cast(num_heads * head_size); - const size_t value_offset = static_cast(kv_num_heads * head_size); - key = reinterpret_cast(query) + key_offset; - value = reinterpret_cast(key) + value_offset; - } - -#if DUMP_TENSOR_LEVEL > 0 - printf("[GQA FlashAttention] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, kv_share_buffer: %d\n", - static_cast(parameters.is_packed_qkv), - static_cast(parameters.is_first_prompt), - static_cast(parameters.is_subsequent_prompt), - static_cast(parameters.kv_share_buffer)); -#endif DUMP_TENSOR_INIT(); - // Track whether we keep packed QKV for FA kernels - bool use_packed_for_fa = parameters.is_packed_qkv; - - // Track if we used the fully fused path (packed + share_buffer + rotary) - bool used_fused_packed_path = false; - - // ========================================================================= - // Handle Packed QKV Input - // ========================================================================= - if (parameters.is_packed_qkv) { - T* unpacked_buffer = reinterpret_cast(data.unpacked_qkv_buffer); - if (unpacked_buffer != nullptr) { - T* unpacked_q = unpacked_buffer; - - // Check if we can use the fully fused path - if (parameters.kv_share_buffer && parameters.do_rotary && !data.disable_fused_kv) { - // FULLY FUSED PATH: Unpack + RoPE Q + RoPE K + Append KV in single kernel - // This eliminates 4 kernel launches! - ORT_RETURN_IF_ERROR(LaunchUnpackQKVWithRoPEAndAppendKV( - reinterpret_cast(data.query), // packed QKV - unpacked_q, // Q output buffer (rotated) - data.present_key, // K cache (direct write) - data.present_value, // V cache (direct write) - num_heads, - kv_num_heads, - head_size, - sequence_length, - batch_size, - parameters.seqlen_present_kv_cache, - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - parameters.rotary_dim, - data.position_ids, - parameters.rotary_interleaved, - !past_bsnh, // is_cache_bnsh - stream, - max_threads_per_block)); - - // Update query to point to rotated Q - query = unpacked_q; - use_packed_for_fa = false; - used_fused_packed_path = true; - - // Track buffer usage: Only Q is stored in unpacked_qkv_buffer (fused path writes K/V to cache) - size_t q_bytes = static_cast(batch_size) * sequence_length * num_heads * head_size * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, q_bytes); - - // K and V are already in cache - no need to set key/value pointers + const T* q_prep = nullptr; + const T* k_prep = nullptr; + const T* v_prep = nullptr; + ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep)); - } else { - // Standard path: Unpack first, then process K/V separately - size_t q_size = static_cast(batch_size) * sequence_length * num_heads * head_size; - T* unpacked_k = unpacked_buffer + q_size; + void* query = const_cast(q_prep); + (void)k_prep; // Key/value are now processed by PrepareQKV + (void)v_prep; - size_t k_size = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - T* unpacked_v = unpacked_k + k_size; + bool use_packed_for_fa = false; - // If we need Q rotation, we MUST unpack Q as well. - T* q_dst = parameters.do_rotary ? unpacked_q : nullptr; - - // Always unpack to BSNH as LaunchConcatNewToPastKV expects contiguous BSNH input - ORT_RETURN_IF_ERROR((LaunchUnpackQKV(reinterpret_cast(data.query), q_dst, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size, stream, max_threads_per_block))); - - // Update key/value to point to unpacked buffers - key = unpacked_k; - value = unpacked_v; - - if (parameters.do_rotary) { - query = unpacked_q; - use_packed_for_fa = false; - } - - // Track buffer usage: Q+K+V unpacked - size_t total_bytes = (q_size + 2 * k_size) * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, total_bytes); - } - } - } - // ========================================================================= - // Handle Unpacked Q, K, V Input (with optional RoPE) - // ========================================================================= - else { - if (parameters.do_rotary) { - // For unpacked input, we need to rotate Q and K. - // The rotated Q and K will be stored in unpacked_qkv_buffer with layout [Q (B*S*H*D), K (B*S*H_kv*D)]. - T* unpacked_buffer = reinterpret_cast(data.unpacked_qkv_buffer); - if (unpacked_buffer != nullptr) { - query = unpacked_buffer; - // Do not update key here for Unpacked path. - // key must remain pointing to data.key (Input) for Explicit K Rotation (k_src). - // k_dst will be calculated from unpacked_buffer explicitly. - } - } - } - - const int64_t* position_ids = data.position_ids; - - // Explicit Q Rotation (skip if fused path already applied RoPE) - if (parameters.do_rotary && !used_fused_packed_path) { - // Rotate Q - // Q ptr is already set to the destination buffer (unpacked_buffer) above. - // Input for Rotation: - // If packed: we unpacked into `query` buffer. So Input==Output (In-place). - // If unpacked: we set `query = unpacked_buffer`. But Input is `data.query`. - const T* q_input_for_rope = parameters.is_packed_qkv ? reinterpret_cast(query) : reinterpret_cast(data.query); - T* q_output_for_rope = reinterpret_cast(query); // Destination - - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, - q_output_for_rope, - q_input_for_rope, - nullptr, // position_ids unused for format 2/3 - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - batch_size, - sequence_length, - num_heads, - head_size, - parameters.rotary_dim, - parameters.max_sequence_length, - 2, // position_ids_format = 2 (Implicit: past_seq_lens[b] + s) - parameters.rotary_interleaved, - max_threads_per_block, - false // is_input_bnsh_format (Q is BSNH) - )); - DUMP_TENSOR("Rotated Q", q_output_for_rope, batch_size, sequence_length, num_heads, head_size); - - // Rotate K will be done later in fused kernel. - } - - // Skip KV append if we used the fully fused path (KV already in cache) - if (!used_fused_packed_path) { - if (parameters.kv_share_buffer && !parameters.is_first_prompt) { - constexpr bool is_new_kv_bnsh_format = false; - if (parameters.do_rotary) { - // Explicit K Rotation (replacing internal RoPE in fused kernel) - size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; - T* k_dst = reinterpret_cast(data.unpacked_qkv_buffer) + q_elements; - const T* k_src = reinterpret_cast(key); - - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, - k_dst, - k_src, - position_ids, - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - batch_size, - sequence_length, - kv_num_heads, - head_size, - parameters.rotary_dim, - parameters.max_sequence_length, - position_ids != nullptr ? 1 : 2, - parameters.rotary_interleaved, - max_threads_per_block, - false)); - - if (!data.disable_fused_kv) { - // Use fused kernel for K (rotated) + V append - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlaceFused( - batch_size, - kv_num_heads, - head_size, - parameters.seqlen_present_kv_cache, - data.past_seq_lens, - data.total_seq_lens, - sequence_length, - k_dst, - reinterpret_cast(data.value), - data.present_key, - data.present_value, - !past_bsnh, - is_new_kv_bnsh_format, - stream, - max_threads_per_block)); - } else { - // Unfused Fallback: LaunchConcatKVInPlace - // We must pass the ROTATED K (k_dst) to it. - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, k_dst, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } - - // Track buffer usage: Q + K rotated in unpacked_qkv_buffer - size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - size_t total_bytes = (q_elements + k_elements) * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, total_bytes); - } else { - // No RoPE - use original kernel - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } - } else { - // ORT MUST perform the append (using unpacked data for packed case) - bool skip_new_append = false; - // FUSED ROPE: Pass RoPE params to ConcatKV (applies RoPE to K as it is appended) - // IMPORTANT: For Fused RoPE with unpacked input, we must pass data.key (the original input), - // not the scratch buffer 'key' which is empty since explicit rotation was skipped. - const void* key_for_concat = parameters.is_packed_qkv ? key : data.key; - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKVHelper(parameters, data, key_for_concat, value, stream, max_threads_per_block, skip_new_append, - data.cos_cache, data.sin_cache, parameters.rotary_dim, nullptr, parameters.rotary_interleaved)); - } - } - - DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1); - DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1); - DUMP_TENSOR("Present Key", data.present_key, batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - DUMP_TENSOR("Present Value", data.present_value, batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* present_key = data.present_key; + void* present_value = data.present_value; // Disable internal RoPE in Flash Attention (pass nullptr) void* cos_cache = nullptr; @@ -1047,7 +675,6 @@ Status FlashAttention( void* kernel_new_v = nullptr; // Use padded seq lens for first prompt since mha_fwd_kvcache assumes uniform seqlen_q. - // The causal mask offset (seqlen_k - seqlen_q) becomes negative when seqlen_k < seqlen_q, causing incorrect masking. int* seq_lens = parameters.is_first_prompt ? data.padded_seq_lens : data.total_seq_lens; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( @@ -1057,12 +684,16 @@ Status FlashAttention( /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, head_sink, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, - parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, + 0, // rotary_dim = 0 as it is already rotated + scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size - 1, parameters.rotary_interleaved, use_packed_for_fa, 0, 1)); + DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1); + DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1); + return Status::OK(); } #endif @@ -1084,164 +715,17 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if DUMP_TENSOR_LEVEL > 0 - printf("[GQA EfficientAttention] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, kv_share_buffer: %d\n", - static_cast(parameters.is_packed_qkv), - static_cast(parameters.is_first_prompt), - static_cast(parameters.is_subsequent_prompt), - static_cast(parameters.kv_share_buffer)); -#endif - - const void* query; - const void* key; - const void* value; - - if (!parameters.is_packed_qkv) { - query = reinterpret_cast(data.query); - key = reinterpret_cast(data.key); - value = reinterpret_cast(data.value); - } else { - size_t q_size = static_cast(batch_size) * sequence_length * num_heads * head_size; - size_t k_size = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - auto q = reinterpret_cast(data.unpacked_qkv_buffer); - auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); - auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); - - Status status = LaunchUnpackQKV( - reinterpret_cast(data.query), q, k, v, num_heads, kv_num_heads, - head_size, sequence_length, batch_size, stream, max_threads_per_block); - if (status != Status::OK()) { - return status; - } - - query = reinterpret_cast(q); - key = reinterpret_cast(k); - value = reinterpret_cast(v); - - // Track buffer usage: Q+K+V unpacked - size_t total_bytes = (q_size + 2 * k_size) * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, total_bytes); - } - - const int64_t* position_ids = data.position_ids; - if (parameters.do_rotary) { - auto q_buffer = reinterpret_cast(data.rotary_buffer); - - // Launch rotary embedding kernel for Q - if (position_ids != nullptr) { - // User provided explicit position_ids - Use Format 1 - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, q_buffer, reinterpret_cast(query), - position_ids, nullptr /*past_seq_lens not used in format 1*/, - data.cos_cache, data.sin_cache, - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.max_sequence_length, - 1, // Format 1: Explicit position_ids - parameters.rotary_interleaved, - max_threads_per_block, - false)); - } else { - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, q_buffer, reinterpret_cast(query), - nullptr, data.past_seq_lens, - data.cos_cache, data.sin_cache, - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.max_sequence_length, - 2, // Format 2: Implicit (past_seq_lens[b] + s) - parameters.rotary_interleaved, - max_threads_per_block, - false)); - } - query = reinterpret_cast(q_buffer); - - // For kv_share_buffer path, we use Fused RoPE in LaunchConcatKVInPlaceWithRoPE. - // For non-share-buffer path, we use Fused RoPE in LaunchConcatNewToPastKVHelper. - // No explicit K rotation needed here - handled by fused kernels. - - // key remains pointing to original source for use in fused kernel below - - // Track rotary buffer usage: Q rotated (K rotation is fused in KV append) - size_t q_bytes = static_cast(batch_size) * sequence_length * num_heads * head_size * sizeof(T); - size_t k_bytes = static_cast(batch_size) * sequence_length * kv_num_heads * head_size * sizeof(T); - // Note: rotary_buffer layout is [Q_rotated, K_rotated] - no position_ids here - UpdateRotaryMaxUsed(data, q_bytes + k_bytes); + ORT_GQA_TRACE("EfficientAttention"); - // Track position_ids_buffer usage - size_t pos_ids_bytes = static_cast(batch_size) * sequence_length * sizeof(int64_t); - UpdatePositionIdsMaxUsed(data, pos_ids_bytes); - } + const T* q_prep = nullptr; + const T* k_prep = nullptr; + const T* v_prep = nullptr; + ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep)); - if (parameters.kv_share_buffer) { - // Concatenate new kv in place - constexpr bool is_new_kv_bnsh_format = false; - - if (parameters.do_rotary) { - // Explicit K Rotation - size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; - T* k_dst = reinterpret_cast(data.rotary_buffer) + q_elements; - const T* k_src = reinterpret_cast(key); - - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, - k_dst, - k_src, - position_ids, - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - batch_size, - sequence_length, - parameters.kv_num_heads, - parameters.head_size, - parameters.rotary_dim, - parameters.max_sequence_length, - position_ids != nullptr ? 1 : 2, - parameters.rotary_interleaved, - max_threads_per_block, - false)); - - if (!data.disable_fused_kv) { - // Use truly fused kernel for K (already rotated) + V append in single kernel - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlaceFused( - batch_size, - parameters.kv_num_heads, - parameters.head_size, - parameters.seqlen_present_kv_cache, - data.past_seq_lens, - data.total_seq_lens, - parameters.sequence_length, - k_dst, - reinterpret_cast(value), - data.present_key, - data.present_value, - past_kv_format != AttentionQkvFormat::Q_K_V_BSNH, // is_past_kv_bnsh_format - is_new_kv_bnsh_format, - stream, - max_threads_per_block)); - } else { - // Unfused Fallback - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, k_dst, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } + const void* query = reinterpret_cast(q_prep); + const void* key = reinterpret_cast(k_prep); + const void* value = reinterpret_cast(v_prep); - // Track rotary buffer usage: Q + K rotated (no position_ids in rotary_buffer) - size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - UpdateRotaryMaxUsed(data, (q_elements + k_elements) * sizeof(T)); - } else { - // No RoPE - use original kernel - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } - } else { - // Copy past and concat new KV to present buffer - // FUSED ROPE: Pass RoPE params to ConcatKV - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKVHelper(parameters, data, key, value, stream, max_threads_per_block, false, - data.cos_cache, data.sin_cache, parameters.rotary_dim, nullptr, parameters.rotary_interleaved)); - } - - // Ungroup if grouped, otherwise use present kv directly const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; if (num_heads == kv_num_heads) { // Use present kv directly if not grouped @@ -1309,7 +793,7 @@ Status QkvToContext( #if USE_FLASH_ATTENTION if (data.use_flash_attention_fast_decode) { - return FlashAttentionDecoding(device_prop, stream, parameters, data, scale); + return FlashDecoding(device_prop, stream, parameters, data, scale); } if (data.use_flash_attention) { @@ -1327,6 +811,7 @@ Status QkvToContext( } template struct GroupQueryAttentionData; +template struct GroupQueryAttentionData; template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -1335,24 +820,15 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); -template Status LaunchUnpackQKV( - const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, - const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, - cudaStream_t stream, const int max_threads_per_block); - -template struct GroupQueryAttentionData; - template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* ort_stream, - GroupQueryAttentionParameters& parameters, + contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); -template Status LaunchUnpackQKV( - const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, - const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, - cudaStream_t stream, const int max_threads_per_block); +template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); +template Status LaunchUnpackQKV(const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index c42fe53e4b625..4ad71c5003e0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -39,12 +39,9 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp // auto req = GQABufferRequirements::Compute(params, use_flash, fast_decode, use_mea, disable_fused); // unpacked_qkv_buffer = GetScratchBuffer(req.unpacked_qkv_bytes, ...); // rotary_buffer = GetScratchBuffer(req.rotary_buffer_bytes, ...); -// position_ids_buffer = GetScratchBuffer(req.position_ids_bytes, ...); // ============================================================================ struct GQABufferRequirements { - size_t unpacked_qkv_bytes = 0; - size_t rotary_buffer_bytes = 0; - size_t position_ids_bytes = 0; + size_t qkv_buffer_bytes = 0; template static GQABufferRequirements Compute( @@ -53,6 +50,9 @@ struct GQABufferRequirements { bool use_flash_attention_fast_decode, bool use_memory_efficient_attention) { GQABufferRequirements req; + if (use_flash_attention_fast_decode) { + return req; // All zeros - no scratch buffers needed + } const size_t elem_size = sizeof(T); const size_t batch_size = static_cast(params.batch_size); @@ -61,49 +61,36 @@ struct GQABufferRequirements { const size_t kv_num_heads = static_cast(params.kv_num_heads); const size_t head_size = static_cast(params.head_size); - // Fast decode path: Flash Attention handles everything internally - if (use_flash_attention_fast_decode) { - return req; // All zeros - no scratch buffers needed - } - - // Q, K, V element counts + // Base requirements for all paths const size_t q_elements = batch_size * seq_len * num_heads * head_size; const size_t k_elements = batch_size * seq_len * kv_num_heads * head_size; const size_t v_elements = k_elements; if (use_flash_attention) { // Flash Attention path: - // - unpacked_qkv_buffer is used for: - // 1. Unpacking packed QKV input - // 2. Storing rotated Q (and K for non-fused path) - // - rotary_buffer is NOT used (rotations go to unpacked_qkv_buffer) - // - position_ids_buffer is NOT used (flash attention uses implicit position IDs) - - if (params.is_packed_qkv) { - // Need full Q+K+V for unpacking - req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements + v_elements); - } else if (params.do_rotary) { - // Unpacked input with RoPE: need Q+K for rotation output - req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements); + // qkv_buffer is used for: + // 1. Unpacking packed Q (and K/V if needed) + // 2. Storing rotated Q + // + // Logic: + // - we generally only need Q buffer (for rotary Q) if we can write K/V directly to cache/output. + + if (params.do_rotary || params.is_packed_qkv) { + // Just Q buffer needed for rotation/unpacking. + // K and V are written directly to present_key/value (unpacked/rotated/quantized/appended). + req.qkv_buffer_bytes = elem_size * q_elements; } - // Note: unpacked + no-RoPE case does NOT need unpacked_qkv_buffer - } else if (use_memory_efficient_attention) { // Memory Efficient Attention path: - // - unpacked_qkv_buffer: for unpacking packed QKV - // - rotary_buffer: for Q and K rotation output (separate from unpack buffer) - // - position_ids_buffer: for explicit position IDs if needed + // - qkv_buffer: for unpacking packed QKV or Q rotation + // MEA path usually needs Q, and also K, V if they need unpacking. + // Current MEA implementation can handle separate K/V, but if packed, we unpack all. if (params.is_packed_qkv) { - req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements + v_elements); - } - - if (params.do_rotary) { + req.qkv_buffer_bytes = elem_size * (q_elements + k_elements + v_elements); + } else if (params.do_rotary) { // Q rotation + K rotation - // Note: K uses kv_num_heads which may be less than num_heads - req.rotary_buffer_bytes = elem_size * (q_elements + k_elements); - // Position IDs space (always allocated for MEA + RoPE path) - req.position_ids_bytes = sizeof(int64_t) * batch_size * seq_len; + req.qkv_buffer_bytes = elem_size * (q_elements + k_elements); } } @@ -111,47 +98,6 @@ struct GQABufferRequirements { } }; -// ============================================================================ -// Debug helper for tracking buffer usage -// ============================================================================ -// Call these after buffer access to record the maximum offset used. -// In release builds, these are no-ops. -// -// Example: -// T* unpacked_q = data.unpacked_qkv_buffer; -// // ... kernel writes to unpacked_q[0..Q_size-1] ... -// UpdateUnpackedQkvMaxUsed(data, Q_size * sizeof(T)); -// ============================================================================ -#ifndef NDEBUG -template -inline void UpdateUnpackedQkvMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { - if (bytes_used > data.unpacked_qkv_max_used) { - data.unpacked_qkv_max_used = bytes_used; - } -} - -template -inline void UpdateRotaryMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { - if (bytes_used > data.rotary_max_used) { - data.rotary_max_used = bytes_used; - } -} - -template -inline void UpdatePositionIdsMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { - if (bytes_used > data.position_ids_max_used) { - data.position_ids_max_used = bytes_used; - } -} -#else -template -inline void UpdateUnpackedQkvMaxUsed(GroupQueryAttentionData&, size_t) {} -template -inline void UpdateRotaryMaxUsed(GroupQueryAttentionData&, size_t) {} -template -inline void UpdatePositionIdsMaxUsed(GroupQueryAttentionData&, size_t) {} -#endif - Status LaunchGetSequenceLengths( const int* total_seq_lens_minus_one, int* past_seq_lens, @@ -163,6 +109,16 @@ Status LaunchGetSequenceLengths( cudaStream_t stream, const int max_threads_per_block); +template +Status LaunchUnpackRoPEAppendKV( + const T* packed_qkv, const T* query, const T* key, const T* value, + T* unpacked_q, T* k_cache, T* v_cache, + const int num_heads, const int kv_num_heads, const int head_size, + const int sequence_length, const int batch_size, const int max_seqlen, + const int* past_seq_lens, const T* cos_cache, const T* sin_cache, + const int rotary_dim, const int64_t* position_ids, const bool interleaved, + const bool is_cache_bnsh, cudaStream_t stream, const int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh new file mode 100644 index 0000000000000..ddf24aff27442 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include + +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/rotary_common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache +// +// OPTIMIZATION: This version uses Shared Memory to store the current head being processed. +// Shared memory allows RoPE dispatcher to access paired elements in non-interleaved mode +// (element i pairs with i ± rotary_dim/2) without global memory gathers. +// +// Alignment Note: This kernel assumes that base pointers (packed_qkv, query, etc.) +// are 16-byte aligned and that head_size is a multiple of elements_per_thread. +// +// Grid Layout: +// blockIdx.x: sequence index (s) -> Max 2^31-1 (Supports very long context) +// blockIdx.y: head index (head_idx) -> Max 65535 +// blockIdx.z: batch index (b) -> Max 65535 +template +__global__ void UnpackRoPEAppend( + const T* packed_qkv, + const T* query, + const T* key, + const T* value, + T* unpacked_q, + T* k_cache, + T* v_cache, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int d, // packed QKV hidden stride = (num_heads + 2*kv_num_heads) * head_size + const int max_seqlen, // KV cache max sequence length + const int* past_seq_lens, + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved, + const bool is_cache_bnsh) { + using LoadT = float4; + constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); + + const int s = blockIdx.x; + const int head_idx = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const int h = tid * elements_per_thread; + + // Guard work with 'valid' instead of early return to ensure all threads reach __syncthreads() + const bool valid = (h < head_size); + + const int q_hidden = num_heads * head_size; + const int k_hidden = kv_num_heads * head_size; + const int sequence_length = gridDim.x; + + __shared__ T shared_head[MAX_HEAD_SIZE]; + + // Determine Head Type and Offset within hidden dimension + enum HeadType { QUERY, + KEY, + VALUE }; + HeadType head_type; + int n; // Index within its specific type + int offset_in_hidden; + + if (head_idx < num_heads) { + head_type = QUERY; + n = head_idx; + offset_in_hidden = n * head_size; + } else if (head_idx < num_heads + kv_num_heads) { + head_type = KEY; + n = head_idx - num_heads; + offset_in_hidden = q_hidden + n * head_size; + } else { + head_type = VALUE; + n = head_idx - (num_heads + kv_num_heads); + offset_in_hidden = q_hidden + k_hidden + n * head_size; + } + + // 1. Load data into Registers + T vals[elements_per_thread]; + if (valid) { + if (packed_qkv != nullptr) { + const int64_t packed_idx = static_cast(b) * sequence_length * d + + static_cast(s) * d + + static_cast(offset_in_hidden) + h; + *reinterpret_cast(vals) = reinterpret_cast(packed_qkv)[packed_idx / elements_per_thread]; + } else { + if (head_type == QUERY) { + const int64_t q_idx = static_cast(b) * sequence_length * q_hidden + + static_cast(s) * q_hidden + + static_cast(n) * head_size + h; + *reinterpret_cast(vals) = reinterpret_cast(query)[q_idx / elements_per_thread]; + } else if (head_type == KEY) { + const int64_t k_idx = static_cast(b) * sequence_length * k_hidden + + static_cast(s) * k_hidden + + static_cast(n) * head_size + h; + *reinterpret_cast(vals) = reinterpret_cast(key)[k_idx / elements_per_thread]; + } else { + const int64_t v_idx = static_cast(b) * sequence_length * k_hidden + + static_cast(s) * k_hidden + + static_cast(n) * head_size + h; + *reinterpret_cast(vals) = reinterpret_cast(value)[v_idx / elements_per_thread]; + } + } + } + + // 2. Process RoPE + // Optimization: Only use shared memory for non-interleaved mode + const bool is_qk = (head_type == QUERY || head_type == KEY); + if (valid && rotary_dim > 0 && is_qk && !interleaved) { + T* shared_ptr = &shared_head[h]; + *reinterpret_cast(shared_ptr) = *reinterpret_cast(vals); + } + + // CRITICAL: Barrier must be outside the 'if(valid)' and 'if(is_qk)' blocks + // to ensure every thread in the block participates. + __syncthreads(); + + if (valid && rotary_dim > 0 && is_qk) { + const int past_seq_len = past_seq_lens[b]; + const int64_t pos_base = static_cast(b) * sequence_length; + int pos_id = (position_ids != nullptr) ? static_cast(position_ids[pos_base + s]) : (past_seq_len + s); + const int h_idx = h / elements_per_thread; + + onnxruntime::contrib::cuda::RotaryDispatcher::apply( + *reinterpret_cast(vals), + reinterpret_cast(cos_cache), + reinterpret_cast(sin_cache), + rotary_dim, h_idx, pos_id, interleaved, + reinterpret_cast(shared_head), + 0); + } + + // 3. Store results back to Global Memory + if (valid) { + if (head_type == QUERY) { + if (unpacked_q != nullptr) { + const int64_t q_out_idx = static_cast(b) * sequence_length * q_hidden + + static_cast(s) * q_hidden + + static_cast(n) * head_size + h; + reinterpret_cast(unpacked_q)[q_out_idx / elements_per_thread] = *reinterpret_cast(vals); + } + } else { + const int cache_s = past_seq_lens[b] + s; + if (cache_s < max_seqlen) { + T* cache_ptr = (head_type == KEY) ? k_cache : v_cache; + if (cache_ptr != nullptr) { + int64_t cache_idx = is_cache_bnsh ? (static_cast(b) * kv_num_heads * max_seqlen * head_size + static_cast(n) * max_seqlen * head_size + static_cast(cache_s) * head_size + h) : (static_cast(b) * max_seqlen * kv_num_heads * head_size + static_cast(cache_s) * kv_num_heads * head_size + static_cast(n) * head_size + h); + reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals); + } + } + } + } +} + +template +Status LaunchUnpackRoPEAppendKV( + const T* packed_qkv, const T* query, const T* key, const T* value, + T* unpacked_q, T* k_cache, T* v_cache, + const int num_heads, const int kv_num_heads, const int head_size, + const int sequence_length, const int batch_size, const int max_seqlen, + const int* past_seq_lens, const T* cos_cache, const T* sin_cache, + const int rotary_dim, const int64_t* position_ids, const bool interleaved, + const bool is_cache_bnsh, cudaStream_t stream, const int max_threads_per_block) { + constexpr int elements_per_vector = sizeof(float4) / sizeof(T); + + if (head_size % elements_per_vector != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by vector size (16 bytes)."); + } + + // rotary_dim <= head_size check to prevent out-of-bounds in shared memory + if (rotary_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "rotary_dim (", rotary_dim, ") cannot exceed head_size (", head_size, ")."); + } + + if (!interleaved && rotary_dim % 2 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Non-interleaved RoPE requires even rotary_dim."); + } + + const int total_heads = num_heads + 2 * kv_num_heads; + const int d = total_heads * head_size; + + const int threads_per_block = (head_size + elements_per_vector - 1) / elements_per_vector; + if (threads_per_block > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size too large for current block configuration."); + } + + if (total_heads > 65535) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Total heads (", total_heads, ") exceeds CUDA grid limit (65535)."); + } + if (batch_size > 65535) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "batch_size (", batch_size, ") exceeds CUDA grid limit (65535)."); + } + + const dim3 grid(sequence_length, total_heads, batch_size); + const dim3 block(threads_per_block); + + // Dynamic dispatch for MAX_HEAD_SIZE templates to improve occupancy for common LLM head sizes + if (head_size <= 64) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); + } else if (head_size <= 128) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); + } else if (head_size <= 256) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (256)."); + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Explicit template instantiations +template Status LaunchUnpackRoPEAppendKV( + const half*, const half*, const half*, const half*, half*, half*, half*, + int, int, int, int, int, int, const int*, const half*, const half*, int, const int64_t*, bool, bool, + cudaStream_t, int); + +template Status LaunchUnpackRoPEAppendKV( + const BFloat16*, const BFloat16*, const BFloat16*, const BFloat16*, BFloat16*, BFloat16*, BFloat16*, + int, int, int, int, int, int, const int*, const BFloat16*, const BFloat16*, int, const int64_t*, bool, bool, + cudaStream_t, int); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index ce6b4724af705..0c1e346503194 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -17,7 +17,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const T* input, // BxSxNxH const T* cos_cache, // Mx(H/2) @@ -38,15 +38,30 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNx const int i = threadIdx.x; + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; + + [[maybe_unused]] extern __shared__ char smem_[]; + [[maybe_unused]] T* smem = reinterpret_cast(smem_); + + if constexpr (use_smem) { + // Load to shared memory for safe in-place update + if (i < head_size) { + smem[i] = input_data[i]; + } + __syncthreads(); + } + if (i >= head_size) { return; } - const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; - T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; - if (i >= rotary_embedding_dim) { - output_data[i] = input_data[i]; + if constexpr (use_smem) { + output_data[i] = smem[i]; + } else { + output_data[i] = input_data[i]; + } return; } @@ -79,7 +94,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNx sign = (i < half_rotary_embedding_dim) ? -1 : 1; j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + + // Use values from shared memory + if constexpr (use_smem) { + output_data[i] = smem[i] * cos_data[cache_idx] + sign * smem[j] * sin_data[cache_idx]; + } else { + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } } template @@ -137,9 +158,21 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu const dim3 grid(sequence_length, batch_size, num_heads); assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, - num_heads, head_size, rotary_embedding_dim, position_ids_format, - interleaved, in_strides, out_strides); + + if (output == input) { + // In-place operation: use shared memory to avoid read-after-write hazards + size_t smem_size = head_size * sizeof(T); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, in_strides, out_strides); + } else { + // Separate buffers: no shared memory needed + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, in_strides, out_strides); + } return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index b3a5c15718ffb..9cbe2a01698ae 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -22,21 +22,36 @@ from onnx import TensorProto, helper from parameterized import parameterized -from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_build_info +from onnxruntime import ( + InferenceSession, + SessionOptions, + get_available_providers, + get_build_info, +) # Set seed for reproducibility torch.manual_seed(0) random.seed(69) +try: + from rotary_flash import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + # Reduces number of tests to run for faster pipeline checks pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" # Number of values per parameter (compared to pipeline mode) param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 -# When quick build is used, flash attention only supports fp16 and head_size=128 -quick_build = ", quick-build=1, " in get_build_info() +# When quick build is used, flash attention only supports head_size=128 +quick_build = ", quick-build=" in get_build_info() + +enable_debug_print = quick_build + +enable_deterministic_check = True +enable_quantized_kv_tests = True # ################################################################################################# # Configuration and Helper Classes # ################################################################################################# @@ -52,6 +67,14 @@ "int4": TensorProto.UINT8, } +TORCH_DTYPE_TO_ONNX_MAP = { + torch.float32: TensorProto.FLOAT, + torch.float16: TensorProto.FLOAT16, + torch.bfloat16: TensorProto.BFLOAT16, + torch.int32: TensorProto.INT32, + torch.int8: TensorProto.INT8, +} + TORCH_DTYPE_MAP = { "float32": torch.float32, "float16": torch.float16, @@ -156,8 +179,8 @@ def forward(self, x, cos, sin, pos, interleaved): # Triton-based implementation for CUDA def rotary_embedding_cuda(*args, **kwargs): - from rotary_flash import apply_rotary_emb # noqa: PLC0415 - + if apply_rotary_emb is None: + raise ImportError("rotary_flash not found") return apply_rotary_emb(*args, **kwargs) @@ -262,7 +285,10 @@ def create_gqa_node_and_io( helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), ] - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if not config.packed: graph_input.extend( @@ -431,12 +457,19 @@ def gqa_prompt_func( bind_tensor(io_binding, "key", new_k, device, ort_type) bind_tensor(io_binding, "value", new_v, device, ort_type) - # 3. Bind 'past_key', 'past_value' (if share_buffer and passed as k/v) + # 3. Bind 'past_key', 'past_value' if share_buffer: # cache_ort_type corresponds to config.kv_cache_type - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - bind_tensor(io_binding, "past_key", k, device, cache_ort_type) - bind_tensor(io_binding, "past_value", v, device, cache_ort_type) + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + k_to_bind = k if share_buffer else k[:, :, :0, :] + v_to_bind = v if share_buffer else v[:, :, :0, :] + bind_tensor(io_binding, "past_key", k_to_bind, device, cache_ort_type) + bind_tensor(io_binding, "past_value", v_to_bind, device, cache_ort_type) + + # Scales are bound below in section 6 # 4. Bind scalars/1D tensors # seqlens_k is INT32 @@ -487,8 +520,16 @@ def gqa_prompt_func( # Determine dtype for cache tensors cache_dtype = out_dtype cache_ort_type = ort_type - if config.kv_cache_type in ONNX_TENSOR_TYPE_MAP: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + is_valid_type = config.kv_cache_type in TORCH_DTYPE_TO_ONNX_MAP + else: + is_valid_type = config.kv_cache_type in ONNX_TENSOR_TYPE_MAP + + if is_valid_type: + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # We bind output to the input buffer 'k' / 'v' (in-place update) @@ -559,7 +600,10 @@ def gqa_past_func( # 3. Bind 'past_key', 'past_value' # These are required inputs for past_func # cache_ort_type corresponds to config.kv_cache_type - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # If sharing buffer, we bind 'past_key' to the large buffer 'k' @@ -615,14 +659,22 @@ def gqa_past_func( if share_buffer: present_seqlen = config.buffer_sequence_length else: - present_seqlen = total_seq_len + present_seqlen = total_seq_len # For past_func, total seq len is accumulated present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] cache_dtype = out_dtype cache_ort_type = ort_type - if config.kv_cache_type in ONNX_TENSOR_TYPE_MAP: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + is_valid_type = config.kv_cache_type in TORCH_DTYPE_TO_ONNX_MAP + else: + is_valid_type = config.kv_cache_type in ONNX_TENSOR_TYPE_MAP + + if is_valid_type: + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # In-place update to k/v buffers @@ -754,9 +806,9 @@ def parity_check_gqa_prompt( causal, rtol, atol, + std=0.2, ): torch.manual_seed(0) - std = 0.02 q = ( torch.randn( config.batch_size, @@ -873,24 +925,56 @@ def parity_check_gqa_prompt( # seqlens_k for GQA op is past_seq_len + seq_len - 1 ort_seqlens = cache_seqlens - 1 - out, present_k, present_v = gqa_prompt_func( - q=q_ort, - k=k_ort, - v=v_ort, - config=config, - new_k=new_k_ort, - new_v=new_v_ort, - cos=cos, - sin=sin, - seqlens_k=ort_seqlens, - position_ids=position_ids, - attention_bias=attention_bias, - head_sink=head_sink, - ep=ep, - device=device, - share_buffer=config.share_buffer, - ort_type=ort_type, - ) + num_runs = 2 if enable_deterministic_check else 1 + for i in range(num_runs): + out, present_k, present_v = gqa_prompt_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens, + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) + if i == 0: + first_out = out.clone() + first_present_k = present_k.clone() if present_k is not None else None + first_present_v = present_v.clone() if present_v is not None else None + else: + if present_k is not None: + try: + torch.testing.assert_close( + present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" + ) + except AssertionError as e: + print(e) + raise e + if present_v is not None: + try: + torch.testing.assert_close( + present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" + ) + except AssertionError as e: + print(e) + raise e + try: + torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + except AssertionError as e: + max_diff = (out - first_out).abs().max().item() + print(f"Output mismatch max diff: {max_diff}") + with open("/tmp/gqa_diff_info.txt", "w") as f: + f.write(f"Max Diff: {max_diff}\n") + print(e) + raise e out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -917,9 +1001,12 @@ def parity_check_gqa_prompt( k_cache_ref_np = k_cache_ref_np[:, :, : config.kv_sequence_length, :] v_cache_ref_np = v_cache_ref_np[:, :, : config.kv_sequence_length, :] + print_diff_statistics(torch.tensor(present_k_np - k_cache_ref_np), "present_k") numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(present_v_np - v_cache_ref_np), "present_v") numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) @@ -932,6 +1019,7 @@ def parity_check_gqa_past( causal, rtol, atol, + std=0.2, ): if ort_type == TensorProto.FLOAT16: torch_type = torch.float16 @@ -940,7 +1028,6 @@ def parity_check_gqa_past( else: torch_type = torch.float32 torch.manual_seed(0) - std = 0.02 # --- Test Data Generation --- q = ( torch.randn( @@ -966,10 +1053,10 @@ def parity_check_gqa_past( ) v = torch.randn_like(k) * std - # Random past sequence lengths. This tests paddings in decoding. + # past cache sequence length is in [1, past_kv_sequence_length] cache_seqlens = torch.randint( - 0, - config.past_kv_sequence_length - config.q_sequence_length + 1, + 1, + config.past_kv_sequence_length + 1, (config.batch_size,), device=device, dtype=torch.long, @@ -1056,27 +1143,50 @@ def parity_check_gqa_past( new_k_ort, new_v_ort = None, None ort_seqlens = cache_seqlens + config.q_sequence_length - 1 - out, present_k, present_v = gqa_past_func( - q=q_ort, - k=k, - v=v, - config=config, - new_k=new_k_ort, - new_v=new_v_ort, - cos=cos, - sin=sin, - seqlens_k=ort_seqlens.int(), - position_ids=position_ids, - attention_bias=attention_bias, - head_sink=head_sink, - ep=ep, - device=device, - share_buffer=config.share_buffer, - ort_type=ort_type, - ) + num_runs = 2 if enable_deterministic_check else 1 + for i in range(num_runs): + out, present_k, present_v = gqa_past_func( + q=q_ort, + k=k, + v=v, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens.int(), + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) + if i == 0: + first_out = out.clone() + first_present_k = present_k.clone() if present_k is not None else None + first_present_v = present_v.clone() if present_v is not None else None + else: + torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + if present_k is not None: + torch.testing.assert_close( + present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" + ) + if present_v is not None: + torch.testing.assert_close( + present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" + ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() + if enable_debug_print: + print(f"[DEBUG] out_np non-zeros: {numpy.count_nonzero(out_np)} / {out_np.size}") + print(f"[DEBUG] out_ref_np non-zeros: {numpy.count_nonzero(out_ref_np)} / {out_ref_np.size}") + + if numpy.count_nonzero(out_ref_np) > 0 and numpy.count_nonzero(out_np) == 0: + raise RuntimeError("Output is all zeros") + # --- Comparison --- # Compare KV cache # Transpose reference back to BNSH to match ORT output @@ -1090,9 +1200,12 @@ def parity_check_gqa_past( k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] + print_diff_statistics(torch.tensor(present_k_np - k_cache_ref_np), "present_k") numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(present_v_np - v_cache_ref_np), "present_v") numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) @@ -1240,6 +1353,51 @@ def parity_test_gqa_padding_prompt(): torch.testing.assert_close(out_ort, out_ref, rtol=1e-2, atol=1e-2) +# ################################################################################################# +# Test Utilities +# ################################################################################################# + + +def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""): + """ + Print percentile statistics (75%, 95%, 99%) for a difference tensor. + This helps assess parity quality beyond just max difference. + + Args: + diff_tensor: Tensor containing absolute differences between expected and actual outputs. + prefix: Optional prefix string for the output message. + """ + if not enable_debug_print: + return + + diff_flat = diff_tensor.flatten().float() + if diff_flat.numel() == 0: + print(f"{prefix}Diff statistics: empty tensor") + return + + # Compute percentiles + sorted_diff, _ = torch.sort(diff_flat) + n = sorted_diff.numel() + + p75_idx = min(int(n * 0.75), n - 1) + p90_idx = min(int(n * 0.90), n - 1) + p95_idx = min(int(n * 0.95), n - 1) + p99_idx = min(int(n * 0.99), n - 1) + p999_idx = min(int(n * 0.999), n - 1) + + p75 = sorted_diff[p75_idx].item() + p90 = sorted_diff[p90_idx].item() + p95 = sorted_diff[p95_idx].item() + p99 = sorted_diff[p99_idx].item() + p999 = sorted_diff[p999_idx].item() + max_val = sorted_diff[-1].item() + mean_val = diff_flat.mean().item() + + print( + f"{prefix} Diff stats - mean: {mean_val:.6f}, p75: {p75:.6f}, p90: {p90:.6f}, p95: {p95:.6f}, p99: {p99:.6f}, p999: {p999:.6f}, max: {max_val:.6f}" + ) + + # ################################################################################################# # Test Case Generators # ################################################################################################# @@ -1260,11 +1418,11 @@ def get_softmax_options(allow_head_sink: bool = True): return options -def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): +def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True, allow_local: bool = True): batches = [3, 1, 5] seqs = [(35, 35), (1, 1), (64, 64), (128, 128), (240, 240), (2000, 2000)] heads = [(6, 3), (3, 1), (32, 8)] - h_sizes = [128] if quick_build else [128, 32, 64, 256] + h_sizes = [128] if quick_build else [128, 32, 64, 80, 160, 256] smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) rotary_opts = list(get_cuda_rotary_options()) @@ -1291,7 +1449,7 @@ def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): b = batches[combo_index % len(batches)] sq, skv = seqs[combo_index % len(seqs)] n, n2 = heads[combo_index % len(heads)] - lws_opts = [-1, random.randint(1, skv)] + lws_opts = [-1, random.randint(1, skv)] if allow_local else [-1] lws = lws_opts[combo_index % len(lws_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ @@ -1327,19 +1485,21 @@ def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): yield name, config -def gqa_cuda_past_test_cases(allow_head_sink: bool = True): +def gqa_cuda_past_test_cases( + allow_head_sink: bool = True, allow_local: bool = True, enforce_share_buffer: bool = False +): batches = [2, 1, 3] - # s: new sequence length, s2: past sequence length + # s: new sequence length, s2: past sequence length`` seqs = [(1, 1), (1, 128), (1, 2048), (1, 5000)] subsequent_prompt_seqs = [(3, 256)] heads = [(32, 8), (6, 3), (9, 9)] - h_sizes = [128] if quick_build else [128, 40, 64, 256] + h_sizes = [128] if quick_build else [128, 40, 64, 80, 256] smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) rotary_opts = list(get_cuda_rotary_options()) packed_opts = [False, True] # For past test: pipeline tests share_buffer=True only, comprehensive tests both - share_buffer_opts = [True] if pipeline_mode else [True, False] + share_buffer_opts = [True] if pipeline_mode or enforce_share_buffer else [True, False] softcap_opts = [0.0, 50.0] # Use new strategy for both modes: iterate over key code path parameters @@ -1367,7 +1527,7 @@ def gqa_cuda_past_test_cases(allow_head_sink: bool = True): b = 1 # Force batch=1 for subsequent prompt n, n2 = heads[combo_index % len(heads)] - lws_opts = [-1, random.randint(1, s2)] + lws_opts = [-1, random.randint(1, s2)] if allow_local else [-1] lws = lws_opts[combo_index % len(lws_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ @@ -1419,9 +1579,7 @@ def has_cuda_device(min_capability: int = 80): return major * 10 + minor >= min_capability -def has_flash_attention(bf16: bool = False): - if bf16 and quick_build: - return False +def has_flash_attention(): return has_cuda_device(80) @@ -1460,7 +1618,7 @@ def test_gqa_past_flash_attention(self, name, config): ) -@unittest.skipIf(not has_flash_attention(bf16=True), "Flash Attention is not available, skipping tests.") +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFlashGQABF16(unittest.TestCase): @parameterized.expand(gqa_cuda_prompt_test_cases()) def test_gqa_prompt_flash_attention_bf16(self, name, config): @@ -1561,102 +1719,10 @@ def test_gqa_padding_prompt_memory_efficient_attention(self): parity_test_gqa_padding_prompt() -# ################################################################################################# -# Fused Kernel Parity Tests (ORT_DISABLE_FUSED_KV and ORT_DISABLE_FLASH_DECODE) -# ################################################################################################# - - -def fused_kernel_test_cases(): - """Test cases specifically for fused vs unfused kernel parity.""" - configs = [ - # Decoding with RoPE and shared buffer - GQAConfig( - batch_size=2, - q_sequence_length=1, - kv_sequence_length=1, - num_heads=16, - kv_num_heads=4, - head_size=128, - past_kv_sequence_length=128, - buffer_sequence_length=256, - rotary=True, - packed=False, - share_buffer=True, - ), - # Packed QKV decoding with RoPE - GQAConfig( - batch_size=2, - q_sequence_length=1, - kv_sequence_length=1, - num_heads=8, - kv_num_heads=2, - head_size=128, - past_kv_sequence_length=64, - buffer_sequence_length=128, - rotary=True, - packed=True, - share_buffer=True, - ), - # Subsequent prompt with RoPE - GQAConfig( - batch_size=1, - q_sequence_length=4, - kv_sequence_length=4, - num_heads=8, - kv_num_heads=4, - head_size=128, - past_kv_sequence_length=32, - buffer_sequence_length=64, - rotary=True, - packed=False, - share_buffer=True, - ), - ] - for i, config in enumerate(configs): - yield f"fused_config_{i}", config - - @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFusedKernelParity(unittest.TestCase): """Tests that verify fused kernels produce the same results as unfused kernels.""" - @parameterized.expand(fused_kernel_test_cases()) - def test_fused_kv_parity(self, name, config): - """Test ORT_DISABLE_FUSED_KV: fused vs unfused KV append kernels.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - # Run with fused kernels (default) - if "ORT_DISABLE_FUSED_KV" in os.environ: - del os.environ["ORT_DISABLE_FUSED_KV"] - - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) - - # Run with unfused kernels - os.environ["ORT_DISABLE_FUSED_KV"] = "1" - - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) - - # Clean up - del os.environ["ORT_DISABLE_FUSED_KV"] - def test_flash_decode_parity(self): """Test ORT_DISABLE_FLASH_DECODE: fast decode vs standard path.""" os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" @@ -1709,5 +1775,49 @@ def test_flash_decode_parity(self): del os.environ["ORT_DISABLE_FLASH_DECODE"] +class TestGQARegressions(unittest.TestCase): + """Specific regression tests for historical bugs.""" + + def test_gqa_rope_separate_qkv_bug(self): + """ + Regression test for separate QKV + RoPE + FlashAttention bug. + The bug caused q_out to be nullptr when unpacking separate QKV with only Q rotation (standard GQA), + leading to unrotated Q being used in Attention. + """ + if "CUDAExecutionProvider" not in get_available_providers(): + self.skipTest("CUDA required") + + # Config that triggers the path: Prompt phase, Separate QKV inputs, RoPE enabled + config = GQAConfig( + batch_size=1, + num_heads=4, + kv_num_heads=4, + head_size=128, + q_sequence_length=16, + kv_sequence_length=16, + past_kv_sequence_length=0, + buffer_sequence_length=16, + rotary=True, + rotary_interleaved=False, + share_buffer=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=1e-3, + atol=1e-3, + std=1.0, + ) + + if __name__ == "__main__": unittest.main() From 5dc32a4a31a50902d565f3bd4c82763250f3cbac Mon Sep 17 00:00:00 2001 From: Stephan Seitz Date: Thu, 22 Jan 2026 23:09:57 +0100 Subject: [PATCH 2/2] Linux device discovery for TRT-RTX Ep (#26210) ### Description This change adds PCIe bus_id to the properties detected during Linux device discovery. This property is used to enable device discovery on Linux for the TRT-RTX execution provider. ### Motivation and Context I want to use device discovery for TRT-EP also on Linux. This changes have already been tested with the newly added inference samples https://github.com/microsoft/onnxruntime-inference-examples/pull/529 . @gedoensmax for visibilty --- .../core/platform/linux/device_discovery.cc | 29 +++++++++++++++++++ .../nv_tensorrt_rtx/nv_provider_factory.cc | 20 +++++++++++++ .../nv_tensorrt_rtx/version_script.lds | 2 ++ .../nv_tensorrt_rtx/nv_basic_test.cc | 8 ++--- .../nv_tensorrt_rtx/nv_ep_context_test.cc | 4 --- .../test_nv_trt_rtx_ep_util.cc | 2 -- 6 files changed, 55 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc index e9c45a6966ef8..db6ac73996863 100644 --- a/onnxruntime/core/platform/linux/device_discovery.cc +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "core/common/common.h" @@ -114,6 +115,28 @@ std::optional IsGpuDiscrete(uint16_t vendor_id, uint16_t device_id) { return std::nullopt; } +Status GetPciBusId(const std::filesystem::path& sysfs_path, std::optional& pci_bus_id) { + constexpr const char* regex_pattern{R"([0-9a-f]+:[0-9a-f]+:[0-9a-f]+[.][0-9a-f]+)"}; + static const std::regex pci_bus_id_regex(regex_pattern); + + std::error_code error_code; + auto pci_bus_id_path = std::filesystem::canonical(sysfs_path / "device", error_code); // resolves symlink to PCI bus id, e.g. 0000:65:00.0 + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code)); + + auto pci_bus_id_filename = pci_bus_id_path.filename(); + if (std::regex_match(pci_bus_id_filename.string(), pci_bus_id_regex)) { + pci_bus_id = pci_bus_id_filename.string(); + } else { + pci_bus_id = {}; + LOGS_DEFAULT(WARNING) << MakeString("Skipping pci_bus_id for PCI path at \"", + pci_bus_id_path.string(), + "\" because filename \"", pci_bus_id_filename, "\" dit not match expected pattern of ", + regex_pattern); + }; + + return Status::OK(); +} + Status GetGpuDeviceFromSysfs(const GpuSysfsPathInfo& path_info, OrtHardwareDevice& gpu_device_out) { OrtHardwareDevice gpu_device{}; const auto& sysfs_path = path_info.path; @@ -140,6 +163,12 @@ Status GetGpuDeviceFromSysfs(const GpuSysfsPathInfo& path_info, OrtHardwareDevic gpu_device.metadata.Add("Discrete", (*is_gpu_discrete ? "1" : "0")); } + std::optional pci_bus_id; + ORT_RETURN_IF_ERROR(GetPciBusId(sysfs_path, pci_bus_id)); + if (pci_bus_id) { + gpu_device.metadata.Add("pci_bus_id", std::move(*pci_bus_id)); + } + gpu_device.type = OrtHardwareDeviceType_GPU; gpu_device_out = std::move(gpu_device); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e5015e705958d..9955e73bf69ad 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -584,6 +584,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { * @return True if the device is a supported NVIDIA GPU, false otherwise. */ bool IsOrtHardwareDeviceSupported(const OrtHardwareDevice& device) { +#if _WIN32 const auto& metadata_entries = device.metadata.Entries(); const auto it = metadata_entries.find("LUID"); if (it == metadata_entries.end()) { @@ -625,6 +626,25 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { } return false; +#else + const auto& metadata_entries = device.metadata.Entries(); + const auto it = metadata_entries.find("pci_bus_id"); + if (it == metadata_entries.end()) { + return false; + } + auto& target_id = it->second; + int cuda_device_idx = 0; + if (cudaDeviceGetByPCIBusId(&cuda_device_idx, target_id.c_str()) != cudaSuccess) { + return false; + } + + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, cuda_device_idx) != cudaSuccess) { + return false; + } + // Ampere architecture or newer is required. + return prop.major >= 8; +#endif } // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds b/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds index 094abb3329781..251e39e089275 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds @@ -2,6 +2,8 @@ VERS_1.0 { global: GetProvider; + CreateEpFactories; + ReleaseEpFactory; # Hide everything else. local: diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 1a987ab4f411a..f017c86824df6 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -275,7 +275,6 @@ INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ), [](const testing::TestParamInfo& info) { return getTypeAsName(info.param); }); -#ifdef _WIN32 static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -292,7 +291,6 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { } // Tests autoEP feature to automatically select an EP that supports the GPU. -// Currently only works on Windows. TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("nv_execution_provider_auto_ep.onnx"); std::string graph_name = "test"; @@ -302,7 +300,11 @@ TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { CreateBaseModel(model_name, graph_name, dims); { +#if _WIN32 ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); +#else + ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("libonnxruntime_providers_nv_tensorrt_rtx.so")); +#endif Ort::SessionOptions so; so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); @@ -599,7 +601,5 @@ TEST(NvExecutionProviderTest, FP4CustomOpModel) { LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP4 dynamic quantize model run completed successfully"; } -#endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc index ac24dcb70c1dd..bcdfd18407ca8 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc @@ -14,7 +14,6 @@ namespace test { RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, std::unordered_map& option_map) { RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; -#ifdef _WIN32 /// Since this test runs after other tests that use registration interface this test has to use it as well /// windows as otherwise the kernel registry inside the EP will not be populated. The legacy APis ony call the initialize once. Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); @@ -26,9 +25,6 @@ RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, } } session_options.AppendExecutionProvider_V2(*ort_env, {selected_device}, option_map); -#else - session_options.AppendExecutionProvider(onnxruntime::kNvTensorRTRTXExecutionProvider, option_map); -#endif return nv_tensorrt_rtx_ep; } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index 47127399b4646..de028bf613a27 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -24,7 +24,6 @@ namespace onnxruntime { namespace test { -#ifdef _WIN32 Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; @@ -61,7 +60,6 @@ void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniqu c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); }); } -#endif // _WIN32 void CreateBaseModel(const PathString& model_name, std::string graph_name,