diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 486bf05bd86d5..70e5683c2f974 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -225,11 +225,27 @@ struct PagedAttentionData { // Fused op buffers T* workspace_buffer = nullptr; + // Memory-efficient attention (CUTLASS fMHA) buffers for the unfused fallback path + // taken when FlashAttention is unavailable (SM<80 or ORT_DISABLE_FLASH_ATTENTION). + T* gathered_key = nullptr; // [total_kv_tokens, num_heads, head_size], packed varlen (GQA-expanded) + T* gathered_value = nullptr; // [total_kv_tokens, num_heads, head_size], packed varlen (GQA-expanded) + T* fmha_buffer = nullptr; // CUTLASS fMHA output-accumulator workspace + // Populated by the caller after a D->H sync on cumulative_seqlens_kv[batch_size]. + int total_kv_tokens = 0; + + // Actual max of per-batch new-query lengths (cumulative_seqlens_q[i+1] - cumulative_seqlens_q[i]). + // Populated by the caller via the same D->H sync so the MEA path's rotary grid and MEA's + // grid_x (ceil_div(sequence_length, kQueriesPerBlock)) cover every query token. The previous + // heuristic `token_count - batch_size + 1` underestimates when any batch has 0 new tokens, + // producing silent per-token dropout in MEA and rotary. + int max_query_len = 0; + // Output Tensors T* output = nullptr; // Kernel Flags bool use_flash_attention = false; + bool use_memory_efficient_attention = false; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc index 5df2c8b438771..7fba61270e280 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/paged_attention.h" #include "contrib_ops/cuda/bert/paged_attention_helper.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -50,6 +51,7 @@ PagedAttention::PagedAttention(const OpKernelInfo& info) kernel_options_ = this->GetAttentionKernelOptions(); disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); + disable_memory_efficient_attention_ = sizeof(T) != 2 || !kernel_options_->UseEfficientAttention(); } template @@ -141,31 +143,57 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { "value_cache and value_cache_out must be the same buffer"); } - // Check flash kernel availability and allocate buffers + // Empty query input: output is already shaped [0, hidden_size], and the cache outputs + // alias the input caches (verified above), so no backend kernel or cache update is needed. + if (parameters.token_count == 0) { + return Status::OK(); + } + + // Kernel backend selection — FlashAttention preferred, fall back to MemoryEfficientAttention. #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, parameters.head_size, parameters.num_heads, parameters.kv_num_heads); - size_t softmax_lse_bytes = 0; - if (use_flash_attention) { - softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count, - parameters.num_heads); - } - auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); #else constexpr bool use_flash_attention = false; - auto softmax_lse_buffer = GetScratchBuffer(0, GetComputeStream(context)); // nullptr #endif - if (!use_flash_attention) { +#if USE_MEMORY_EFFICIENT_ATTENTION + const int sm = device_prop.major * 10 + device_prop.minor; + const bool is_half = std::is_same::value; + const bool is_bf16 = std::is_same::value; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + has_memory_efficient_attention(sm, is_half, is_bf16, + parameters.head_size, parameters.head_size); +#else + constexpr bool use_memory_efficient_attention = false; +#endif + + if (!use_flash_attention && !use_memory_efficient_attention) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Currently PagedAttention is only supported through the FlashAttention kernel."); + "PagedAttention requires FlashAttention (sm>=80, fp16/bf16) or " + "MemoryEfficientAttention (fp16 sm>=53, bf16 sm>=80, head_size<=1024 and %8==0) " + "to be available. Check ORT_DISABLE_FLASH_ATTENTION / " + "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION env vars and dtype/head_size."); } + // Scratch buffers common to both backends. + size_t softmax_lse_bytes = 0; +#if USE_FLASH_ATTENTION + if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count, + parameters.num_heads); + } +#endif + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, GetComputeStream(context)); + size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1); auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, GetComputeStream(context)); + int* cumulative_seqlens_kv_ptr = reinterpret_cast(cumulative_seqlens_kv_buffer.get()); size_t workspace_buffer_bytes = 0; if (do_rotary_) { @@ -175,10 +203,91 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { } auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, GetComputeStream(context)); + // Populate cumulative_seqlens_kv for both backends. The MEA path additionally needs + // the last element on the host to size the tight gather buffers, so we D->H sync below. + // + // LaunchGetCumulativeSeqlensKV uses a per-block cub::BlockScan with a block size of 256 + // and launches (batch_size + 255) / 256 blocks, so blocks scan independently. Enforce + // batch_size <= 256 so the cumulative sum is correct; a larger batch would silently + // produce wrong KV offsets. (A future grid-wide scan could lift this limit.) + constexpr int kMaxBatchSizeForCumulativeSeqlensKV = 256; + if (parameters.batch_size > kMaxBatchSizeForCumulativeSeqlensKV) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "PagedAttention currently supports batch_size <= ", + kMaxBatchSizeForCumulativeSeqlensKV, + " (LaunchGetCumulativeSeqlensKV limitation); got batch_size=", + parameters.batch_size, "."); + } + + cudaStream_t cuda_stream = static_cast(ort_stream.get()->GetHandle()); + ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV( + cumulative_seqlens_kv_ptr, + reinterpret_cast(cumulative_seqlens_q->Data()), + reinterpret_cast(past_seqlens->Data()), + parameters.batch_size, cuda_stream)); + + int total_kv_tokens = 0; + int max_query_len = 0; + IAllocatorUniquePtr gathered_key_buffer; + IAllocatorUniquePtr gathered_value_buffer; + IAllocatorUniquePtr fmha_buffer; + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (use_memory_efficient_attention) { + // MEA needs two host-side quantities: + // - total_kv_tokens (= cumulative_seqlens_kv[batch_size]) to size tight gather buffers. + // - max_query_len (= max per-batch new-query length) to size the rotary and MEA grids + // correctly. The heuristic `token_count - batch_size + 1` underestimates when any + // batch has 0 new tokens (valid input), silently dropping query-tokens from those + // larger-than-average batches. + // Both come from cumulative_seqlens_q / cumulative_seqlens_kv, which are tiny (batch+1 + // ints each), so one D->H copy of the full arrays is cheaper than issuing an extra + // reduction kernel and avoids a second sync. + const int kCumulativeCount = parameters.batch_size + 1; + auto cum_q_pinned = this->AllocateBufferOnCPUPinned(kCumulativeCount); + auto cum_kv_pinned = this->AllocateBufferOnCPUPinned(kCumulativeCount); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cum_q_pinned.get(), + reinterpret_cast(cumulative_seqlens_q->Data()), + sizeof(int) * kCumulativeCount, cudaMemcpyDeviceToHost, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cum_kv_pinned.get(), cumulative_seqlens_kv_ptr, + sizeof(int) * kCumulativeCount, cudaMemcpyDeviceToHost, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + total_kv_tokens = cum_kv_pinned.get()[parameters.batch_size]; + for (int i = 0; i < parameters.batch_size; ++i) { + const int q_len_i = cum_q_pinned.get()[i + 1] - cum_q_pinned.get()[i]; + if (q_len_i > max_query_len) { + max_query_len = q_len_i; + } + } + if (total_kv_tokens == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "PagedAttention MEA fallback: total_kv_tokens is zero for non-empty input."); + } + if (total_kv_tokens < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "PagedAttention MEA fallback: total_kv_tokens is negative (", total_kv_tokens, ")."); + } + + const size_t gather_elems = static_cast(total_kv_tokens) * + parameters.num_heads * parameters.head_size; + gathered_key_buffer = GetScratchBuffer(sizeof(T) * gather_elems, GetComputeStream(context)); + gathered_value_buffer = GetScratchBuffer(sizeof(T) * gather_elems, GetComputeStream(context)); + + if (MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + // MEA output accumulator is float32 regardless of input dtype (see GQA pattern at + // group_query_attention.cc:482); use sizeof(float), not sizeof(T). + const size_t fmha_elems = static_cast(parameters.token_count) * + parameters.num_heads * parameters.head_size; + fmha_buffer = GetScratchBuffer(sizeof(float) * fmha_elems, GetComputeStream(context)); + } + } +#endif + // Print debug info if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = use_flash_attention; + debug_info.use_efficient_attention = use_memory_efficient_attention; debug_info.Print("PagedAttention", this->Node().Name(), @@ -194,10 +303,11 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { data.value_cache = reinterpret_cast(const_cast(value_cache->Data())); data.cumulative_seqlens_q = reinterpret_cast(cumulative_seqlens_q->Data()); data.past_seqlens = reinterpret_cast(past_seqlens->Data()); - data.cumulative_seqlens_kv = reinterpret_cast(cumulative_seqlens_kv_buffer.get()); + data.cumulative_seqlens_kv = cumulative_seqlens_kv_ptr; data.block_table = reinterpret_cast(block_table->Data()); data.output = reinterpret_cast(output->MutableData()); data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -208,6 +318,15 @@ Status PagedAttention::ComputeInternal(OpKernelContext* context) const { data.cos_cache = reinterpret_cast(cos_cache->Data()); data.sin_cache = reinterpret_cast(sin_cache->Data()); } + if (use_memory_efficient_attention) { + data.gathered_key = reinterpret_cast(gathered_key_buffer.get()); + data.gathered_value = reinterpret_cast(gathered_value_buffer.get()); + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + data.total_kv_tokens = total_kv_tokens; + data.max_query_len = max_query_len; + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h index a3df144745f61..027141f02b9ae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h @@ -29,6 +29,7 @@ class PagedAttention final : public CudaKernel { float scale_; float softcap_; bool disable_flash_attention_; + bool disable_memory_efficient_attention_; const AttentionKernelOptions* kernel_options_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu index 06608ebed44cc..2241fa232a2c0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/attention_softmax.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/paged_attention_impl.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" @@ -237,6 +238,101 @@ Status LaunchReshapeAndCache(const T* key, const T* value, T* key_cache, T* valu return CUDA_CALL(cudaGetLastError()); } +// Gather paged KV into packed-varlen [total_kv_tokens, num_heads, head_size], expanding GQA heads. +// total_elems = total_kv_tokens * num_heads * head_size can exceed INT32_MAX for realistic +// large-context GQA configs (e.g., 2M tokens * 64 * 128 = 16.4B), so the linear index is int64_t +// and the kernel uses a grid-stride loop instead of a single (tid >= total_elems) early-exit. +template +__global__ void GatherAndExpandPagedKVCache(const T* __restrict__ key_cache, + const T* __restrict__ value_cache, + T* __restrict__ gathered_key, + T* __restrict__ gathered_value, + const int* __restrict__ block_table, + const int* __restrict__ cumulative_seqlens_kv, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int block_size, + const int max_num_blocks_per_seq, + const int64_t total_elems) { + const int64_t stride = static_cast(gridDim.x) * blockDim.x; + const int64_t num_heads_times_head = static_cast(num_heads) * head_size; + const int q_kv_head_ratio = num_heads / kv_num_heads; + const int64_t page_stride = static_cast(block_size) * kv_num_heads * head_size; + + for (int64_t tid = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; + tid < total_elems; + tid += stride) { + const int h = static_cast(tid % head_size); + const int head_id = static_cast((tid / head_size) % num_heads); + const int token_id = static_cast(tid / num_heads_times_head); + + // cumulative_seqlens_kv is a prefix sum of non-negative per-batch KV lengths + // (past_seqlens[i] + new_tokens[i]), so it is monotonically non-decreasing for + // any valid op input — the same assumption the previous linear scan made. + // Binary-search for the batch this token belongs to: log2(batch_size) is strictly + // better than the linear scan, which ran once per (token, head, h) element and + // multiplied its cost by num_heads * head_size. + int left = 0; + int right = batch_size; + while (left < right) { + const int mid = left + (right - left) / 2; + if (token_id < cumulative_seqlens_kv[mid + 1]) { + right = mid; + } else { + left = mid + 1; + } + } + const int batch_id = left; + + const int pos = token_id - cumulative_seqlens_kv[batch_id]; + const int block_idx_in_seq = pos / block_size; + const int block_offset = pos % block_size; + const int block_id = block_table[batch_id * max_num_blocks_per_seq + block_idx_in_seq]; + + // GQA expansion: each output head maps to kv_head_id = head_id / (num_heads / kv_num_heads). + // For MHA (num_heads == kv_num_heads) this is the identity. + const int kv_head_id = head_id / q_kv_head_ratio; + + const int64_t paged_idx = static_cast(block_id) * page_stride + + static_cast(block_offset) * kv_num_heads * head_size + + kv_head_id * head_size + + h; + + gathered_key[tid] = key_cache[paged_idx]; + gathered_value[tid] = value_cache[paged_idx]; + } +} + +template +Status LaunchGatherAndExpandPagedKVCache(const T* key_cache, const T* value_cache, + T* gathered_key, T* gathered_value, + const int* block_table, const int* cumulative_seqlens_kv, + const int batch_size, const int num_heads, + const int kv_num_heads, const int head_size, + const int block_size, const int max_num_blocks_per_seq, + const int total_kv_tokens, cudaStream_t stream, + const int max_threads_per_block) { + const int64_t total_elems = static_cast(total_kv_tokens) * num_heads * head_size; + if (total_elems == 0) { + return Status::OK(); + } + // With the op's batch_size <= 256 precondition (paged_attention.cc) and MEA's + // head_size <= 1024 cap, blocks_needed = ceil(total_elems / threads) stays comfortably + // within int range for any realistic input, so no explicit clamp is needed. The kernel + // uses a grid-stride loop so launching fewer blocks than total_elems / threads would + // also be correct — we don't need an artificial "keep SMs busy" cap. + const int threads = static_cast(std::min(max_threads_per_block, total_elems)); + const int blocks = static_cast((total_elems + threads - 1) / threads); + GatherAndExpandPagedKVCache<<>>( + key_cache, value_cache, gathered_key, gathered_value, + block_table, cumulative_seqlens_kv, + batch_size, num_heads, kv_num_heads, head_size, + block_size, max_num_blocks_per_seq, total_elems); + return CUDA_CALL(cudaGetLastError()); +} + ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -276,12 +372,11 @@ Status FlashAttention( value = reinterpret_cast(key) + static_cast(kv_num_heads * head_size); } - // Calculate cumulative present sequence length in cumulative_seqlens_kv + // cumulative_seqlens_kv is populated by the caller (paged_attention.cc) before QkvToContext; + // shared across FA and MEA dispatch paths so the host can also read total_kv_tokens. int* cumulative_seqlens_q = const_cast(data.cumulative_seqlens_q); int* past_seqlens = const_cast(data.past_seqlens); int* cumulative_seqlens_kv = data.cumulative_seqlens_kv; - ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens, - batch_size, stream)); if (parameters.do_rotary) { // Will unpack Q and K in case of packed_qkv @@ -335,6 +430,127 @@ Status FlashAttention( } #endif +#if USE_MEMORY_EFFICIENT_ATTENTION +// Fallback when FlashAttention is unavailable (SM<80 or ORT_DISABLE_FLASH_ATTENTION=1). +// Mirrors the FlashAttention preprocessing (rotary, unpack, ReshapeAndCache), then gathers +// the paged KV cache into a packed-varlen [total_kv_tokens, num_heads, head_size] buffer and +// dispatches to CUTLASS memory-efficient attention via its seqstart_q / seqstart_k varlen ABI. +// Caller must populate data.gathered_key / data.gathered_value / data.total_kv_tokens. +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int token_count = parameters.token_count; + const int q_hidden_size = parameters.hidden_size; + const int kv_hidden_size = parameters.kv_hidden_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const int block_size = parameters.block_size; + const int max_num_blocks_per_seq = parameters.max_num_blocks_per_seq; + const int local_window_size = parameters.local_window_size; + const int total_kv_tokens = data.total_kv_tokens; + // Use the caller-computed actual max of per-batch new-query lengths, not the + // `token_count - batch_size + 1` heuristic: the heuristic assumes >=1 new token per batch + // and underestimates otherwise, which would silently drop query tokens from the + // rotary grid and from MEA's `grid_x = ceil_div(sequence_length, kQueriesPerBlock)`. + const int max_query_len = data.max_query_len; + + T* query = const_cast(data.query); + T* key; + T* value; + if (!parameters.is_packed_qkv) { + key = const_cast(data.key); + value = const_cast(data.value); + } else { + key = reinterpret_cast(query) + static_cast(num_heads * head_size); + value = reinterpret_cast(key) + static_cast(kv_num_heads * head_size); + } + + // cumulative_seqlens_kv is populated by the caller (paged_attention.cc) before QkvToContext; + // shared across FA and MEA dispatch paths. + int* cumulative_seqlens_q = const_cast(data.cumulative_seqlens_q); + int* past_seqlens = const_cast(data.past_seqlens); + int* cumulative_seqlens_kv = data.cumulative_seqlens_kv; + + if (parameters.do_rotary) { + auto q_buffer = data.workspace_buffer; + auto k_buffer = data.workspace_buffer + token_count * num_heads * head_size; + const int packed_seq_stride = parameters.is_packed_qkv ? (num_heads + 2 * kv_num_heads) * head_size : -1; + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, q_buffer, query, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size, + max_query_len, num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride, + max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, k_buffer, key, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size, + max_query_len, kv_num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride, + max_threads_per_block)); + query = q_buffer; + key = k_buffer; + } else if (parameters.is_packed_qkv) { + auto q_buffer = data.workspace_buffer; + const int packed_seq_stride = q_hidden_size + 2 * kv_hidden_size; + ORT_RETURN_IF_ERROR(LaunchUnpackCumulative( + query, q_buffer, token_count, q_hidden_size, packed_seq_stride, stream, max_threads_per_block)); + query = q_buffer; + } + + int* block_table = const_cast(data.block_table); + const int key_stride = parameters.is_packed_qkv && !parameters.do_rotary ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size; + const int value_stride = parameters.is_packed_qkv ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size; + ORT_RETURN_IF_ERROR(LaunchReshapeAndCache(key, value, data.key_cache, data.value_cache, block_table, past_seqlens, + cumulative_seqlens_q, batch_size, max_num_blocks_per_seq, token_count, + kv_hidden_size, block_size, key_stride, value_stride, stream, + max_threads_per_block)); + + ORT_RETURN_IF_ERROR(LaunchGatherAndExpandPagedKVCache( + data.key_cache, data.value_cache, data.gathered_key, data.gathered_value, + block_table, cumulative_seqlens_kv, batch_size, num_heads, kv_num_heads, + head_size, block_size, max_num_blocks_per_seq, total_kv_tokens, stream, max_threads_per_block)); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_bf16 = std::is_same::value; + p.is_half = !p.is_bf16 && (sizeof(T) == 2); + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = max_query_len; + p.kv_sequence_length = total_kv_tokens; + p.max_sequence_length = total_kv_tokens; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = true; + p.scale = scale; + p.softcap = parameters.softcap; + p.local_window_size = local_window_size; + p.seqstart_q_ptr = cumulative_seqlens_q; + p.seqstart_k_ptr = cumulative_seqlens_kv; + p.seqlen_k_ptr = nullptr; + p.query = query; + p.key = data.gathered_key; + p.value = data.gathered_value; + p.attn_bias = nullptr; + p.is_kv_bsnh = true; + p.has_custom_right_padding = false; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("mea paged attention output", data.output, token_count, num_heads, head_size); + + return Status::OK(); +} +#endif + ////////// API Functions template @@ -353,7 +569,13 @@ Status QkvToContext( } #endif - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Paged Attention not implemented."); +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No PagedAttention kernel available for the current configuration."); } template struct PagedAttentionData; diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h index 7e27556a5c63f..22f9793be0af6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h @@ -27,6 +27,11 @@ Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_q, T* unpacked const int kv_num_heads, const int head_size, const int token_count, cudaStream_t stream, const int max_threads_per_block); +// Exposed so paged_attention.cc can populate cumulative_seqlens_kv on both the FA and MEA +// dispatch paths (producer hoisted out of FlashAttention/UnfusedAttention in impl.cu). +Status LaunchGetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q, + const int32_t* past_seqlens, const int batch_size, cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py index 66eb4a885620b..fda861c8125ff 100644 --- a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -262,6 +262,7 @@ def paged_attention_func( cos=None, sin=None, window_size=-1, + sdpa_kernel=0, ): num_tokens = cumulative_sequence_length[-1].item() num_blocks = key_cache.shape[0] @@ -282,7 +283,11 @@ def paged_attention_func( "block_table": block_table.detach().cpu().numpy(), } sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) + if sdpa_kernel != 0 and config.ep == "CUDAExecutionProvider": + providers = [(config.ep, {"sdpa_kernel": str(sdpa_kernel)})] + else: + providers = [config.ep] + ort_session = InferenceSession(onnx_model_str, sess_options, providers=providers) io_binding = ort_session.io_binding() if key is not None and value is not None: ort_inputs["key"] = key.detach().cpu().numpy() @@ -490,6 +495,7 @@ def parity_check_paged_attention( config: Config, rtol=1e-3, atol=1e-3, + sdpa_kernel=0, ): # Generate padded inputs q = torch.randn( @@ -620,6 +626,7 @@ def parity_check_paged_attention( cos, sin, left_window_size, + sdpa_kernel=sdpa_kernel, ) num_tokens = q_unpad.shape[0] out = torch.reshape(out, (num_tokens, config.num_heads, config.head_size)) @@ -672,6 +679,25 @@ def has_flash_attention(): ) +def has_memory_efficient_attention(): + # CUTLASS fMHA (MemoryEfficientAttention) gate — these tests are fp16-only, + # so sm>=53 is sufficient. bf16 MEA would require sm>=80 but is not covered here. + if not torch.cuda.is_available(): + return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False + major, minor = torch.cuda.get_device_capability() + return (major * 10 + minor) >= 53 + + +# Bit value matching AttentionBackend::EFFICIENT_ATTENTION in +# onnxruntime/contrib_ops/cpu/bert/attention_common.h. Passing this as the +# CUDA provider option `sdpa_kernel` forces the PagedAttention kernel to +# select the MemoryEfficientAttention (CUTLASS fMHA) fallback even on SM>=80 +# where FlashAttention would otherwise be preferred. +SDPA_KERNEL_EFFICIENT_ATTENTION = 2 + + def paged_attention_test_cases(): batches = [4] if pipeline_mode else [1, 3, 5] seqs = ( @@ -732,5 +758,25 @@ def test_paged_attention(self, _, config): parity_check_paged_attention(config, rtol=5e-3, atol=5e-3) +@unittest.skipIf( + not has_memory_efficient_attention(), + reason="MemoryEfficientAttention (fp16) requires sm>=53; skipping.", +) +class TestPagedAttentionMEA(unittest.TestCase): + """Runs the same parity matrix as TestPagedAttention but forces the CUTLASS + memory-efficient attention fallback via the `sdpa_kernel` CUDA provider option. + This is the only coverage for the SM<80 fallback path introduced for PagedAttention; + on SM>=80 the class still runs to exercise the MEA dispatch end-to-end.""" + + @parameterized.expand(paged_attention_test_cases()) + def test_paged_attention_mea(self, _, config): + parity_check_paged_attention( + config, + rtol=5e-3, + atol=5e-3, + sdpa_kernel=SDPA_KERNEL_EFFICIENT_ATTENTION, + ) + + if __name__ == "__main__": unittest.main(verbosity=2)