Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,6 @@ Status CheckCustomAttentionInputs(const T* position_ids,
}

if (head_sink != nullptr) {
if (parameters.use_smooth_softmax) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_sink should not be provided when use_smooth_softmax is true.");
}

const auto& head_sink_shape = head_sink->Shape();
if (head_sink_shape.NumDimensions() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor");
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct GroupQueryAttentionData {
int* seqlens_k = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
const T* head_sink = nullptr;

// Flash buffers
T* softmax_lse = nullptr;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct Flash_fwd_params : public Qkv_params {

bool is_rotary_interleaved = false;

void* __restrict__ head_sink_ptr = nullptr;
bool smooth_softmax = false;

int num_splits = 0; // For split-KV version
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void set_params_fprop(Flash_fwd_params& params,
void* q,
void* k,
void* v,
void* head_sink,
void* out,
void* cu_seqlens_q_d,
void* cu_seqlens_k_d,
Expand All @@ -50,7 +51,9 @@ void set_params_fprop(Flash_fwd_params& params,
params.o_ptr = out;

params.is_bf16 = is_bf16;

params.smooth_softmax = use_smooth_softmax;
params.head_sink_ptr = head_sink;

// All stride are in elements, not bytes.
if (kv_bsnh) {
Expand Down Expand Up @@ -297,14 +300,16 @@ Status mha_fwd(const cudaDeviceProp& dprops,
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

constexpr void* head_sink = nullptr;

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
q, k, v, head_sink, out,
/*cu_seqlens_q*/ nullptr,
/*cu_seqlens_k*/ nullptr,
/*seqused_k=*/nullptr,
Expand Down Expand Up @@ -376,14 +381,16 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
const bool paged_KV = block_table != nullptr;

constexpr void* head_sink = nullptr;

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
q, k, v, head_sink, out,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
Expand Down Expand Up @@ -443,6 +450,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
int batch_size,
int num_heads,
Expand Down Expand Up @@ -480,7 +488,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, kcache, vcache, out,
q, kcache, vcache, head_sink, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
int batch_size,
int num_heads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,10 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
}

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax);
float sink = (params.head_sink_ptr != nullptr)
? reinterpret_cast<Element*>(params.head_sink_ptr)[bidh]
: (params.smooth_softmax ? 0.0f : -kInfinity);
Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, sink);

// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Expand Down Expand Up @@ -928,8 +930,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
}

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax, params.smooth_softmax);
float sink = (params.head_sink_ptr != nullptr)
? reinterpret_cast<Element*>(params.head_sink_ptr)[bidh]
: (params.smooth_softmax ? 0.0f : -std::numeric_limits<float>::infinity());
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax, sink);

Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
Expand Down
93 changes: 48 additions & 45 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace flash {
using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr float kInfinity = std::numeric_limits<float>::infinity();

template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op) {
Expand Down Expand Up @@ -72,9 +73,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tenso
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -std::numeric_limits<float>::infinity()
? 0.f
: max(mi) * (Scale_max ? scale : float(M_LOG2E));
const float max_scaled = max(mi) == -kInfinity ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
Expand All @@ -85,38 +84,6 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tenso
}
}

// Apply the exp to all the elements.
template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0>& tensor, Tensor<Engine1, Layout1>& max, Tensor<Engine1, Layout1>& sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -std::numeric_limits<float>::infinity() ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
SumOp<float> sum_op;
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int kNRows>
Expand All @@ -143,17 +110,18 @@ struct Softmax {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
for (int mi = 0; mi < size<0>(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -std::numeric_limits<float>::infinity() ? 0.0f : row_max(mi));
: (row_max(mi) == -kInfinity ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale;
}
}

flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
Expand All @@ -162,27 +130,62 @@ struct Softmax {
};

template <bool Split = false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) {
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o,
float softmax_scale,
float sink) { // IMPORTANT: sink is a pre-scaled logit

SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);

const bool use_sink = (sink != -kInfinity);

#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
float sum = row_sum(mi);
float max_unscaled = row_max(mi); // Max of the qk scores, NOT scaled.

if (use_sink) {
const float max_scaled = (max_unscaled == -kInfinity)
? -kInfinity
: max_unscaled * softmax_scale;

const float true_max_scaled = max(max_scaled, sink);

// Rescale the intermediate the output accumulator (acc_o) and sum.
// They were calculated relative to `max_scaled` and must be
// rescaled to be relative to `true_max_scaled`.
const float rescale_factor = expf(max_scaled - true_max_scaled);

#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= rescale_factor;
}

sum *= rescale_factor;

// Add the sink to the sum.
sum += expf(sink - true_max_scaled);

// The unscaled max that reflects the sink. It is used for the below LSE calculation.
max_unscaled = true_max_scaled / softmax_scale;
}

lse(mi) = (sum == 0.f || sum != sum)
? (Split ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity())
: row_max(mi) * softmax_scale + __logf(sum);
float scale = inv_sum;
? (Split ? -kInfinity : kInfinity)
: max_unscaled * softmax_scale + __logf(sum);

float inv_sum = (sum == 0.f || !isfinite(sum)) ? 1.f : 1.f / sum;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scale;
acc_o_rowcol(mi, ni) *= inv_sum;
}
}

return lse;
};
}
};

} // namespace flash
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);
const Tensor* position_ids = context->Input<Tensor>(9);
const Tensor* attention_bias = context->Input<Tensor>(10);
const Tensor* head_sink = context->Input<Tensor>(11);

if (position_ids != nullptr || attention_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids and attention_bias are not supported in GroupQueryAttention cuda kernel.");
}

auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
Expand All @@ -99,12 +107,17 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
scale_,
softcap_,
device_prop.maxThreadsPerBlock));

ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
head_sink,
parameters));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_;
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;

int sequence_length = parameters.sequence_length;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;
Expand Down Expand Up @@ -276,6 +289,10 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}

if (head_sink != nullptr) {
data.head_sink = reinterpret_cast<const CudaT*>(head_sink->Data<T>());
}

cublasHandle_t cublas = GetCublasHandle(context);

return QkvToContext<CudaT>(
Expand Down
20 changes: 13 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,18 @@ Status FlashAttention(
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
void* cos_cache = reinterpret_cast<void*>(const_cast<T*>(data.cos_cache));
void* sin_cache = reinterpret_cast<void*>(const_cast<T*>(data.sin_cache));
void* head_sink = reinterpret_cast<void*>(const_cast<T*>(data.head_sink));

bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;

DUMP_TENSOR_INIT();
DUMP_TENSOR("Q", reinterpret_cast<T*>(query), batch_size, sequence_length, num_heads, head_size);
DUMP_TENSOR("K", reinterpret_cast<T*>(present_key), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size);
DUMP_TENSOR("V", reinterpret_cast<T*>(present_value), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size);

ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, head_sink, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
Expand All @@ -475,7 +482,6 @@ Status FlashAttention(
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
// }

DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);

return Status::OK();
Expand Down Expand Up @@ -680,6 +686,11 @@ template Status QkvToContext<half>(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<half>& data);

template Status LaunchUnpackQKV<half, LAYOUT_BNSH>(
const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
cudaStream_t stream, const int max_threads_per_block);

template struct GroupQueryAttentionData<BFloat16>;

template Status QkvToContext<BFloat16>(
Expand All @@ -689,11 +700,6 @@ template Status QkvToContext<BFloat16>(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<BFloat16>& data);

template Status LaunchUnpackQKV<half, LAYOUT_BNSH>(
const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
cudaStream_t stream, const int max_threads_per_block);

template Status LaunchUnpackQKV<BFloat16, LAYOUT_BNSH>(
const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
Expand Down
Loading
Loading