Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,6 @@ Status CheckCustomAttentionInputs(const T* position_ids,
}

if (head_sink != nullptr) {
if (parameters.use_smooth_softmax) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_sink should not be provided when use_smooth_softmax is true.");
}

const auto& head_sink_shape = head_sink->Shape();
if (head_sink_shape.NumDimensions() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor");
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct GroupQueryAttentionData {
int* seqlens_k = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
const T* head_sink = nullptr;

// Flash buffers
T* softmax_lse = nullptr;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct Flash_fwd_params : public Qkv_params {

bool is_rotary_interleaved = false;

void* __restrict__ head_sink_ptr = nullptr;
bool smooth_softmax = false;

int num_splits = 0; // For split-KV version
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void set_params_fprop(Flash_fwd_params& params,
void* q,
void* k,
void* v,
void* head_sink,
void* out,
void* cu_seqlens_q_d,
void* cu_seqlens_k_d,
Expand All @@ -50,7 +51,9 @@ void set_params_fprop(Flash_fwd_params& params,
params.o_ptr = out;

params.is_bf16 = is_bf16;

params.smooth_softmax = use_smooth_softmax;
params.head_sink_ptr = head_sink;

// All stride are in elements, not bytes.
if (kv_bsnh) {
Expand Down Expand Up @@ -297,14 +300,16 @@ Status mha_fwd(const cudaDeviceProp& dprops,
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

constexpr void* head_sink = nullptr;

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
seqlen_q, seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
q, k, v, head_sink, out,
/*cu_seqlens_q*/ nullptr,
/*cu_seqlens_k*/ nullptr,
/*seqused_k=*/nullptr,
Expand Down Expand Up @@ -376,14 +381,16 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops,
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
const bool paged_KV = block_table != nullptr;

constexpr void* head_sink = nullptr;

Flash_fwd_params params;
set_params_fprop(params,
batch_size,
max_seqlen_q, max_seqlen_k,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, k, v, out,
q, k, v, head_sink, out,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
Expand Down Expand Up @@ -443,6 +450,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
int batch_size,
int num_heads,
Expand Down Expand Up @@ -480,7 +488,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
q, kcache, vcache, out,
q, kcache, vcache, head_sink, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* head_sink, // num_heads
int* block_table, // batch_size x max_num_blocks_per_seq
int batch_size,
int num_heads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q,
params.window_size_left, params.window_size_right, alibi_slope);

const float sink = (params.head_sink_ptr != nullptr)
? reinterpret_cast<Element*>(params.head_sink_ptr)[bidh]
: (params.smooth_softmax ? 0.0f : -std::numeric_limits<float>::infinity());

// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
Expand Down Expand Up @@ -314,8 +318,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi

// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2, sink)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2, sink);

// Convert acc_s from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(acc_s);
Expand Down Expand Up @@ -359,7 +363,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
mask.template apply_mask</*Causal_mask=*/false>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);

softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2, sink);

Tensor rP = flash::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
Expand All @@ -369,8 +373,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
}

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax);
Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax);

// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Expand Down Expand Up @@ -779,6 +782,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
: reinterpret_cast<float*>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);

const float sink = (params.head_sink_ptr != nullptr)
? reinterpret_cast<Element*>(params.head_sink_ptr)[bidh]
: (params.smooth_softmax ? 0.0f : -std::numeric_limits<float>::infinity());

// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
Expand Down Expand Up @@ -851,8 +858,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons

// We have key_padding_mask so we'll need to Check_inf
masking_step == 0
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2, sink)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2, sink);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }

// Convert acc_s from fp32 to fp16/bf16
Expand Down Expand Up @@ -917,7 +924,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons

mask.template apply_mask</*Causal_mask=*/false>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2, sink);

Tensor rP = flash::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
Expand All @@ -929,7 +936,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons

// Epilogue

Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax, params.smooth_softmax);
Tensor lse = softmax.template normalize_softmax_lse<Split>(acc_o, params.scale_softmax);

Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
Expand Down
42 changes: 31 additions & 11 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace flash {
using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr float kInfinity = std::numeric_limits<float>::infinity();

template <bool zero_init = true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& summary, Operator& op) {
Expand Down Expand Up @@ -72,9 +73,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0>& tenso
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -std::numeric_limits<float>::infinity()
? 0.f
: max(mi) * (Scale_max ? scale : float(M_LOG2E));
const float max_scaled = max(mi) == -kInfinity ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
Expand Down Expand Up @@ -102,7 +101,7 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0>& ten
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -std::numeric_limits<float>::infinity() ? 0.f : max(mi) * scale;
const float max_scaled = max(mi) == -kInfinity ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
Expand All @@ -127,53 +126,74 @@ struct Softmax {
__forceinline__ __device__ Softmax() {};

template <bool Is_first, bool Check_inf = false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2) {
__forceinline__ __device__ void softmax_rescale_o(Tensor0& acc_s, Tensor1& acc_o, float softmax_scale_log2, float sink) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);

const bool use_sink = (sink != -kInfinity);
if (use_sink) {
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
row_max(mi) = max(row_max(mi), sink);
}
}

flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);

if (use_sink) {
#pragma unroll
for (int mi = 0; mi < size(row_sum); ++mi) {
const float max_scaled = row_max(mi) == -kInfinity ? 0.f : row_max(mi) * softmax_scale_log2;
row_sum(mi) += exp2f(sink * softmax_scale_log2 - max_scaled);
}
}
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);

flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))

Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);

#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -std::numeric_limits<float>::infinity() ? 0.0f : row_max(mi));
: (row_max(mi) == -kInfinity ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;

#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale;
}
}

flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
}

template <bool Split = false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) {
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi);
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum)
? (Split ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity())
? (Split ? -kInfinity : kInfinity)
: row_max(mi) * softmax_scale + __logf(sum);
float scale = inv_sum;
#pragma unroll
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* total_seqlen = context->Input<Tensor>(6);
const Tensor* cos_cache = context->Input<Tensor>(7);
const Tensor* sin_cache = context->Input<Tensor>(8);
const Tensor* position_ids = context->Input<Tensor>(9);
const Tensor* attention_bias = context->Input<Tensor>(10);
const Tensor* head_sink = context->Input<Tensor>(11);

if (position_ids != nullptr || attention_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids and attention_bias are not supported in GroupQueryAttention cuda kernel.");
}

auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
Expand All @@ -99,12 +107,17 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
scale_,
softcap_,
device_prop.maxThreadsPerBlock));

ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
head_sink,
parameters));
parameters.local_window_size = local_window_size_;
parameters.is_unidirectional = is_unidirectional_;
parameters.use_smooth_softmax = use_smooth_softmax_;
parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr;
parameters.zeros_count = kZerosCount;
parameters.zero_ptr = zeros_.get();
// parameters.left_padding = left_padding_;

int sequence_length = parameters.sequence_length;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;
Expand Down Expand Up @@ -270,6 +283,10 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}

if (head_sink != nullptr) {
data.head_sink = reinterpret_cast<const CudaT*>(head_sink->Data<T>());
}

cublasHandle_t cublas = GetCublasHandle(context);

return QkvToContext<CudaT>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,12 @@ Status FlashAttention(
void* present_value = reinterpret_cast<void*>(const_cast<T*>(data.present_value));
void* cos_cache = reinterpret_cast<void*>(const_cast<T*>(data.cos_cache));
void* sin_cache = reinterpret_cast<void*>(const_cast<T*>(data.sin_cache));
void* head_sink = reinterpret_cast<void*>(const_cast<T*>(data.head_sink));

bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
device_prop, stream, query, present_key, present_value, key, value, data.output,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, head_sink, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
Expand Down
Loading
Loading