Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,20 @@ struct PagedAttentionData {
// Fused op buffers
T* workspace_buffer = nullptr;

// Memory-efficient attention (CUTLASS fMHA) buffers for the unfused fallback path
// taken when FlashAttention is unavailable (SM<80 or ORT_DISABLE_FLASH_ATTENTION).
T* gathered_key = nullptr; // [total_kv_tokens, num_heads, head_size], packed varlen (GQA-expanded)
T* gathered_value = nullptr; // [total_kv_tokens, num_heads, head_size], packed varlen (GQA-expanded)
T* fmha_buffer = nullptr; // CUTLASS fMHA output-accumulator workspace
// Populated by the caller after a D->H sync on cumulative_seqlens_kv[batch_size].
int total_kv_tokens = 0;

// Output Tensors
T* output = nullptr;

// Kernel Flags
bool use_flash_attention = false;
bool use_memory_efficient_attention = false;
};

} // namespace cuda
Expand Down
96 changes: 85 additions & 11 deletions onnxruntime/contrib_ops/cuda/bert/paged_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "contrib_ops/cuda/bert/paged_attention.h"
#include "contrib_ops/cuda/bert/paged_attention_helper.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
Expand Down Expand Up @@ -50,6 +51,7 @@ PagedAttention<T>::PagedAttention(const OpKernelInfo& info)

kernel_options_ = this->GetAttentionKernelOptions();
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
disable_memory_efficient_attention_ = sizeof(T) != 2 || !kernel_options_->UseEfficientAttention();
}

template <typename T>
Expand Down Expand Up @@ -141,31 +143,51 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
"value_cache and value_cache_out must be the same buffer");
}

// Check flash kernel availability and allocate buffers
// Kernel backend selection — FlashAttention preferred, fall back to MemoryEfficientAttention.
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
onnxruntime::flash::is_supported<T>(device_prop,
parameters.head_size,
parameters.num_heads,
parameters.kv_num_heads);
size_t softmax_lse_bytes = 0;
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count,
parameters.num_heads);
}
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, GetComputeStream(context));
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_buffer = GetScratchBuffer<void>(0, GetComputeStream(context)); // nullptr
#endif

if (!use_flash_attention) {
#if USE_MEMORY_EFFICIENT_ATTENTION
const int sm = device_prop.major * 10 + device_prop.minor;
const bool is_half = std::is_same<T, MLFloat16>::value;
const bool is_bf16 = std::is_same<T, BFloat16>::value;
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
has_memory_efficient_attention(sm, is_half, is_bf16,
parameters.head_size, parameters.head_size);
#else
constexpr bool use_memory_efficient_attention = false;
#endif

if (!use_flash_attention && !use_memory_efficient_attention) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Currently PagedAttention is only supported through the FlashAttention kernel.");
"PagedAttention requires FlashAttention (sm>=80, fp16/bf16) or "
"MemoryEfficientAttention (fp16 sm>=53, bf16 sm>=80, head_size<=1024 and %8==0) "
"to be available. Check ORT_DISABLE_FLASH_ATTENTION / "
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION env vars and dtype/head_size.");
}

// Scratch buffers common to both backends.
size_t softmax_lse_bytes = 0;
#if USE_FLASH_ATTENTION
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count,
parameters.num_heads);
}
#endif
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, GetComputeStream(context));

size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1);
auto cumulative_seqlens_kv_buffer = GetScratchBuffer<void>(cumulative_seqlens_kv_bytes, GetComputeStream(context));
int* cumulative_seqlens_kv_ptr = reinterpret_cast<int*>(cumulative_seqlens_kv_buffer.get());

size_t workspace_buffer_bytes = 0;
if (do_rotary_) {
Expand All @@ -175,10 +197,53 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
auto workspace_buffer = GetScratchBuffer<void>(workspace_buffer_bytes, GetComputeStream(context));

// Populate cumulative_seqlens_kv for both backends. The MEA path additionally needs
// the last element on the host to size the tight gather buffers, so we D->H sync below.
cudaStream_t cuda_stream = static_cast<cudaStream_t>(ort_stream.get()->GetHandle());
ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV(
cumulative_seqlens_kv_ptr,
reinterpret_cast<const int*>(cumulative_seqlens_q->Data<int>()),
reinterpret_cast<const int*>(past_seqlens->Data<int>()),
parameters.batch_size, cuda_stream));
Comment thread
tianleiwu marked this conversation as resolved.

int total_kv_tokens = 0;
IAllocatorUniquePtr<void> gathered_key_buffer;
IAllocatorUniquePtr<void> gathered_value_buffer;
IAllocatorUniquePtr<void> fmha_buffer;

#if USE_MEMORY_EFFICIENT_ATTENTION
if (use_memory_efficient_attention) {
auto total_kv_pinned = this->AllocateBufferOnCPUPinned<int>(1);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(total_kv_pinned.get(),
cumulative_seqlens_kv_ptr + parameters.batch_size,
sizeof(int), cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
total_kv_tokens = total_kv_pinned.get()[0];
if (total_kv_tokens <= 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"PagedAttention MEA fallback: total_kv_tokens is non-positive (", total_kv_tokens, ").");
}
Comment thread
tianleiwu marked this conversation as resolved.
Outdated

const size_t gather_elems = static_cast<size_t>(total_kv_tokens) *
parameters.num_heads * parameters.head_size;
gathered_key_buffer = GetScratchBuffer<void>(sizeof(T) * gather_elems, GetComputeStream(context));
gathered_value_buffer = GetScratchBuffer<void>(sizeof(T) * gather_elems, GetComputeStream(context));

if (MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
// MEA output accumulator is float32 regardless of input dtype (see GQA pattern at
// group_query_attention.cc:482); use sizeof(float), not sizeof(T).
const size_t fmha_elems = static_cast<size_t>(parameters.token_count) *
parameters.num_heads * parameters.head_size;
fmha_buffer = GetScratchBuffer<void>(sizeof(float) * fmha_elems, GetComputeStream(context));
}
}
#endif

// Print debug info
if (kernel_options_->AllowDebugInfo()) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_efficient_attention = use_memory_efficient_attention;

debug_info.Print("PagedAttention",
this->Node().Name(),
Expand All @@ -194,10 +259,11 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.value_cache = reinterpret_cast<CudaT*>(const_cast<T*>(value_cache->Data<T>()));
data.cumulative_seqlens_q = reinterpret_cast<const int*>(cumulative_seqlens_q->Data<int>());
data.past_seqlens = reinterpret_cast<const int*>(past_seqlens->Data<int>());
data.cumulative_seqlens_kv = reinterpret_cast<int*>(cumulative_seqlens_kv_buffer.get());
data.cumulative_seqlens_kv = cumulative_seqlens_kv_ptr;
data.block_table = reinterpret_cast<const int*>(block_table->Data<int>());
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}
Expand All @@ -208,6 +274,14 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}
if (use_memory_efficient_attention) {
data.gathered_key = reinterpret_cast<CudaT*>(gathered_key_buffer.get());
data.gathered_value = reinterpret_cast<CudaT*>(gathered_value_buffer.get());
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
data.total_kv_tokens = total_kv_tokens;
}

cublasHandle_t cublas = GetCublasHandle(context);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/paged_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class PagedAttention final : public CudaKernel {
float scale_;
float softcap_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
const AttentionKernelOptions* kernel_options_;
};

Expand Down
Loading
Loading