Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,27 @@ struct PagedAttentionData {
// Fused op buffers
T* workspace_buffer = nullptr;

// Memory-efficient attention (CUTLASS fMHA) buffers for the unfused fallback path
// taken when FlashAttention is unavailable (SM<80 or ORT_DISABLE_FLASH_ATTENTION).
T* gathered_key = nullptr; // [total_kv_tokens, num_heads, head_size], packed varlen (GQA-expanded)
T* gathered_value = nullptr; // [total_kv_tokens, num_heads, head_size], packed varlen (GQA-expanded)
T* fmha_buffer = nullptr; // CUTLASS fMHA output-accumulator workspace
// Populated by the caller after a D->H sync on cumulative_seqlens_kv[batch_size].
int total_kv_tokens = 0;

// Actual max of per-batch new-query lengths (cumulative_seqlens_q[i+1] - cumulative_seqlens_q[i]).
// Populated by the caller via the same D->H sync so the MEA path's rotary grid and MEA's
// grid_x (ceil_div(sequence_length, kQueriesPerBlock)) cover every query token. The previous
// heuristic `token_count - batch_size + 1` underestimates when any batch has 0 new tokens,
// producing silent per-token dropout in MEA and rotary.
int max_query_len = 0;

// Output Tensors
T* output = nullptr;

// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
};

} // namespace cuda
Expand Down
141 changes: 130 additions & 11 deletions onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "contrib_ops/cuda/bert/paged_attention.h"
#include "contrib_ops/cuda/bert/paged_attention_helper.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
Expand Down Expand Up @@ -50,6 +51,7 @@ PagedAttention<T>::PagedAttention(const OpKernelInfo& info)

kernel_options_ = this->GetAttentionKernelOptions();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
disable_memory_efficient_attention_ = sizeof(T) != 2 || !kernel_options_->UseEfficientAttention();
}

template <typename T>
Expand Down Expand Up @@ -141,31 +143,57 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
"value_cache and value_cache_out must be the same buffer");
}

// Check flash kernel availability and allocate buffers
// Empty query input: output is already shaped [0, hidden_size], and the cache outputs
// alias the input caches (verified above), so no backend kernel or cache update is needed.
if (parameters.token_count == 0) {
return Status::OK();
}

// Kernel backend selection — FlashAttention preferred, fall back to MemoryEfficientAttention.
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported<T>(device_prop,
parameters.head_size,
parameters.num_heads,
parameters.kv_num_heads);
size_t softmax_lse_bytes = 0;
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count,
parameters.num_heads);
}
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, GetComputeStream(context));
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_buffer = GetScratchBuffer<void>(0, GetComputeStream(context)); // nullptr
#endif

if (!use_flash_attention) {
#if USE_MEMORY_EFFICIENT_ATTENTION
const int sm = device_prop.major * 10 + device_prop.minor;
const bool is_half = std::is_same<T, MLFloat16>::value;
const bool is_bf16 = std::is_same<T, BFloat16>::value;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
has_memory_efficient_attention(sm, is_half, is_bf16,
parameters.head_size, parameters.head_size);
#else
constexpr bool use_memory_efficient_attention = false;
#endif

if (!use_flash_attention && !use_memory_efficient_attention) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Currently PagedAttention is only supported through the FlashAttention kernel.");
"PagedAttention requires FlashAttention (sm>=80, fp16/bf16) or "
"MemoryEfficientAttention (fp16 sm>=53, bf16 sm>=80, head_size<=1024 and %8==0) "
"to be available. Check ORT_DISABLE_FLASH_ATTENTION / "
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION env vars and dtype/head_size.");
}

// Scratch buffers common to both backends.
size_t softmax_lse_bytes = 0;
#if USE_FLASH_ATTENTION
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count,
parameters.num_heads);
}
#endif
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, GetComputeStream(context));

size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1);
auto cumulative_seqlens_kv_buffer = GetScratchBuffer<void>(cumulative_seqlens_kv_bytes, GetComputeStream(context));
int* cumulative_seqlens_kv_ptr = reinterpret_cast<int*>(cumulative_seqlens_kv_buffer.get());

size_t workspace_buffer_bytes = 0;
if (do_rotary_) {
Expand All @@ -175,10 +203,91 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
auto workspace_buffer = GetScratchBuffer<void>(workspace_buffer_bytes, GetComputeStream(context));

// Populate cumulative_seqlens_kv for both backends. The MEA path additionally needs
// the last element on the host to size the tight gather buffers, so we D->H sync below.
//
// LaunchGetCumulativeSeqlensKV uses a per-block cub::BlockScan with a block size of 256
// and launches (batch_size + 255) / 256 blocks, so blocks scan independently. Enforce
// batch_size <= 256 so the cumulative sum is correct; a larger batch would silently
// produce wrong KV offsets. (A future grid-wide scan could lift this limit.)
constexpr int kMaxBatchSizeForCumulativeSeqlensKV = 256;
if (parameters.batch_size > kMaxBatchSizeForCumulativeSeqlensKV) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"PagedAttention currently supports batch_size <= ",
kMaxBatchSizeForCumulativeSeqlensKV,
" (LaunchGetCumulativeSeqlensKV limitation); got batch_size=",
parameters.batch_size, ".");
}

cudaStream_t cuda_stream = static_cast<cudaStream_t>(ort_stream.get()->GetHandle());
ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV(
cumulative_seqlens_kv_ptr,
reinterpret_cast<const int*>(cumulative_seqlens_q->Data<int>()),
reinterpret_cast<const int*>(past_seqlens->Data<int>()),
parameters.batch_size, cuda_stream));
Comment thread
tianleiwu marked this conversation as resolved.

int total_kv_tokens = 0;
int max_query_len = 0;
IAllocatorUniquePtr<void> gathered_key_buffer;
IAllocatorUniquePtr<void> gathered_value_buffer;
IAllocatorUniquePtr<void> fmha_buffer;

#if USE_MEMORY_EFFICIENT_ATTENTION
if (use_memory_efficient_attention) {
// MEA needs two host-side quantities:
// - total_kv_tokens (= cumulative_seqlens_kv[batch_size]) to size tight gather buffers.
// - max_query_len (= max per-batch new-query length) to size the rotary and MEA grids
// correctly. The heuristic `token_count - batch_size + 1` underestimates when any
// batch has 0 new tokens (valid input), silently dropping query-tokens from those
// larger-than-average batches.
// Both come from cumulative_seqlens_q / cumulative_seqlens_kv, which are tiny (batch+1
// ints each), so one D->H copy of the full arrays is cheaper than issuing an extra
// reduction kernel and avoids a second sync.
const int kCumulativeCount = parameters.batch_size + 1;
auto cum_q_pinned = this->AllocateBufferOnCPUPinned<int>(kCumulativeCount);
auto cum_kv_pinned = this->AllocateBufferOnCPUPinned<int>(kCumulativeCount);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cum_q_pinned.get(),
reinterpret_cast<const int*>(cumulative_seqlens_q->Data<int>()),
sizeof(int) * kCumulativeCount, cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cum_kv_pinned.get(), cumulative_seqlens_kv_ptr,
sizeof(int) * kCumulativeCount, cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
total_kv_tokens = cum_kv_pinned.get()[parameters.batch_size];
for (int i = 0; i < parameters.batch_size; ++i) {
const int q_len_i = cum_q_pinned.get()[i + 1] - cum_q_pinned.get()[i];
if (q_len_i > max_query_len) {
max_query_len = q_len_i;
}
}
if (total_kv_tokens == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"PagedAttention MEA fallback: total_kv_tokens is zero for non-empty input.");
}
if (total_kv_tokens < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"PagedAttention MEA fallback: total_kv_tokens is negative (", total_kv_tokens, ").");
}

const size_t gather_elems = static_cast<size_t>(total_kv_tokens) *
parameters.num_heads * parameters.head_size;
gathered_key_buffer = GetScratchBuffer<void>(sizeof(T) * gather_elems, GetComputeStream(context));
gathered_value_buffer = GetScratchBuffer<void>(sizeof(T) * gather_elems, GetComputeStream(context));

if (MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
// MEA output accumulator is float32 regardless of input dtype (see GQA pattern at
// group_query_attention.cc:482); use sizeof(float), not sizeof(T).
const size_t fmha_elems = static_cast<size_t>(parameters.token_count) *
parameters.num_heads * parameters.head_size;
fmha_buffer = GetScratchBuffer<void>(sizeof(float) * fmha_elems, GetComputeStream(context));
}
}
#endif

// Print debug info
if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;

debug_info.Print("PagedAttention",
this->Node().Name(),
Expand All @@ -194,10 +303,11 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.value_cache = reinterpret_cast<CudaT*>(const_cast<T*>(value_cache->Data<T>()));
data.cumulative_seqlens_q = reinterpret_cast<const int*>(cumulative_seqlens_q->Data<int>());
data.past_seqlens = reinterpret_cast<const int*>(past_seqlens->Data<int>());
data.cumulative_seqlens_kv = reinterpret_cast<int*>(cumulative_seqlens_kv_buffer.get());
data.cumulative_seqlens_kv = cumulative_seqlens_kv_ptr;
data.block_table = reinterpret_cast<const int*>(block_table->Data<int>());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
Expand All @@ -208,6 +318,15 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}
if (use_memory_efficient_attention) {
data.gathered_key = reinterpret_cast<CudaT*>(gathered_key_buffer.get());
data.gathered_value = reinterpret_cast<CudaT*>(gathered_value_buffer.get());
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
data.total_kv_tokens = total_kv_tokens;
data.max_query_len = max_query_len;
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/paged_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class PagedAttention final : public CudaKernel {
float scale_;
float softcap_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
const AttentionKernelOptions* kernel_options_;
};

Expand Down
Loading
Loading