Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,14 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (use_seqlen_k_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_head_sink_) {
shader.AddInput("head_sink", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform);

return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_),
WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_),
WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
Expand Down Expand Up @@ -300,11 +304,15 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad
if (use_indirect_dispatch_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_head_sink_) {
shader.AddInput("head_sink", ShaderUsage::UseUniform);
}
shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);

const uint32_t tile_size_k_vec = 8u;

return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_),
WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec),
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
Expand All @@ -324,10 +332,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
uint32_t num_present_sequence_length_tile,
uint32_t tile_size,
bool use_indirect_dispatch,
uint32_t present_sequence_length) {
uint32_t present_sequence_length,
const Tensor* head_sink) {
const int components = 4;
const bool has_head_sink = head_sink != nullptr;
int head_size_vec = parameters.v_head_size_ / components;
FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch};
FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink};
program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2},
{qk, ProgramTensorMetadataDependency::TypeAndRank},
{present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
Expand All @@ -339,14 +349,18 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
} else {
program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
}
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch)
if (has_head_sink) {
program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
}
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink)
.SetWorkgroupSize(64)
.AddUniformVariables({{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(head_size_vec)},
present_sequence_length,
{static_cast<uint32_t>(parameters.n_reps)},
num_present_sequence_length_tile,
{batch_heads}});
{batch_heads},
{static_cast<uint32_t>(parameters.num_heads_)}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -399,7 +413,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
const Tensor* cos_cache, const Tensor* sin_cache) {
const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) {
constexpr uint32_t tile_size = 64;

// Create present_key and present_value tensors if they are nullptr
Expand Down Expand Up @@ -467,6 +481,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
bool has_head_sink = head_sink != nullptr;
FlashAttentionProgram program{"FlashAttention",
has_attention_bias,
is_qualcomm,
Expand All @@ -476,7 +491,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.is_unidirectional_,
is_nvidia,
q_BNSH,
use_seqlen_k};
use_seqlen_k,
has_head_sink};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_key, ProgramTensorMetadataDependency::TypeAndRank, 4},
{present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}});
Expand All @@ -486,6 +502,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
if (use_seqlen_k) {
program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::None}});
}
if (has_head_sink) {
program.AddInputs({{head_sink, ProgramTensorMetadataDependency::Type}});
}
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}});
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
: parameters.scale_;
Expand All @@ -502,7 +521,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co

program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile)
.SetWorkgroupSize(tile_size)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k)
.CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k, has_head_sink)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(parameters.total_sequence_length_)},
{static_cast<uint32_t>(present_sequence_length)},
Expand Down Expand Up @@ -542,7 +561,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
seqlen_k, parameters, indirect_buffer_ptr,
num_total_seq_length_tile,
num_present_sequence_length_tile, tile_size,
use_indirect_dispatch, present_sequence_length));
use_indirect_dispatch, present_sequence_length,
head_sink));
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters,
num_total_seq_length_tile,
num_present_sequence_length_tile, tile_size, use_indirect_dispatch));
Expand Down
17 changes: 11 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_unidirectional,
bool is_nvidia,
bool q_BNSH,
bool use_seqlen_k = false)
bool use_seqlen_k = false,
bool has_head_sink = false)
: Program{kernel_name},
has_attention_bias_(has_attention_bias),
is_qualcomm_(is_qualcomm),
Expand All @@ -87,7 +88,8 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
is_unidirectional_(is_unidirectional),
is_nvidia_(is_nvidia),
q_BNSH_(q_BNSH),
use_seqlen_k_(use_seqlen_k) {
use_seqlen_k_(use_seqlen_k),
has_head_sink_(has_head_sink) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -112,6 +114,7 @@ class FlashAttentionProgram final : public Program<FlashAttentionProgram> {
bool is_nvidia_;
bool q_BNSH_;
bool use_seqlen_k_;
bool has_head_sink_;
};

class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
Expand Down Expand Up @@ -142,8 +145,8 @@ class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecode

class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDecodeSplitVxProgram> {
public:
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch) {
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -153,12 +156,14 @@ class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDe
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"batch_heads", ProgramUniformVariableDataType::Uint32});
{"batch_heads", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
int head_size_vec_;
bool use_indirect_dispatch_;
bool has_head_sink_;
};

class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
Expand All @@ -184,7 +189,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD
Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr,
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr);
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr);

bool CanApplyFlashAttention(const Tensor* bias,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

#param has_attention_bias
#param has_head_sink
#param is_fp16
#param is_qualcomm
#param is_unidirectional
Expand Down Expand Up @@ -174,8 +175,14 @@ $MAIN {
loadq(batch_idx, q_idx_global, head_idx, q_element_t(uniforms.alpha));
}

#if has_head_sink
let sink_value = q_element_t(head_sink[head_idx]);
var previous_max : q_element_t = sink_value;
var previous_denom : q_element_t = 1;
#else
var previous_max : q_element_t = min_value;
var previous_denom : q_element_t = 0;
#endif
let total_sequence_length = get_total_sequence_length();

#if is_unidirectional
Expand Down Expand Up @@ -295,7 +302,7 @@ $MAIN {
let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w;

// Compute lhs term of update di prime and the compute di prime.
let dleft = previous_denom * exp(previous_max - new_max);
let dleft = previous_denom * q_element_t(exp(f32(previous_max) - f32(new_max)));
var d = dleft + sum;
d = select(d, q_element_t(0.0000001), d == 0);
qk_1 = qk_1 / d;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param has_head_sink
#param tile_size
#param head_size_vec
#param tile_size_k_vec
Expand Down Expand Up @@ -56,6 +57,11 @@ $MAIN {

// Calculate the global max and sum in qk.
var g_max = f32(-3.4028234663852886e+38f);
#if has_head_sink
let head_idx = batch_head_idx % uniforms.num_heads;
let sink_value = f32(head_sink[head_idx]);
g_max = max(g_max, sink_value);
#endif
var g_sum = f32(0);
for (var i = 0u; i < num_total_seq_length_tile; i++)
{
Expand All @@ -68,6 +74,9 @@ $MAIN {
let m_value = metadata[meta_offset];
g_sum += exp(m_value.x - g_max) * m_value.y;
}
#if has_head_sink
g_sum += exp(sink_value - g_max);
#endif

if (total_seq_offset + local_idx < total_sequence_length) {
tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
// Use a sliding window if the total sequence exceeds the window's length.
bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.total_sequence_length_);
bool will_use_flash_attention = false;
if (head_sink == nullptr && !use_smooth_softmax_ && !use_sliding_window) {
if (!use_smooth_softmax_ && !use_sliding_window) {
// Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking
WebgpuAttentionParameters temp_params = parameters;
temp_params.is_packed_qkv_ = false;
Expand All @@ -266,7 +266,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
// Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled
// query points to packed QKV, K and V are nullptr since they're not needed
return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context, seqlen_k, cos_cache, sin_cache);
present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink);
}
// Fused: splitQKV + rotary QK
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
Expand Down Expand Up @@ -310,7 +310,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&

if (will_use_flash_attention) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context, seqlen_k);
present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink);
}

TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
Expand Down
Loading