diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 400ba64b21ab3..2ab1cd7ec0ce2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -204,10 +204,14 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (has_head_sink_) { + shader.AddInput("head_sink", ShaderUsage::UseUniform); + } shader.AddOutput("output", ShaderUsage::UseUniform); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), + WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), @@ -300,11 +304,15 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (has_head_sink_) { + shader.AddInput("head_sink", ShaderUsage::UseUniform); + } shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); const uint32_t tile_size_k_vec = 8u; return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), @@ -324,10 +332,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, - uint32_t present_sequence_length) { + uint32_t present_sequence_length, + const Tensor* head_sink) { const int components = 4; + const bool has_head_sink = head_sink != nullptr; int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch}; + FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink}; program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}, {qk, ProgramTensorMetadataDependency::TypeAndRank}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); @@ -339,14 +349,18 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte } else { program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); } - program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch) + if (has_head_sink) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } + program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink) .SetWorkgroupSize(64) .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, {static_cast(head_size_vec)}, present_sequence_length, {static_cast(parameters.n_reps)}, num_present_sequence_length_tile, - {batch_heads}}); + {batch_heads}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -399,7 +413,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, - const Tensor* cos_cache, const Tensor* sin_cache) { + const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) { constexpr uint32_t tile_size = 64; // Create present_key and present_value tensors if they are nullptr @@ -467,6 +481,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; + bool has_head_sink = head_sink != nullptr; FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, @@ -476,7 +491,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co parameters.is_unidirectional_, is_nvidia, q_BNSH, - use_seqlen_k}; + use_seqlen_k, + has_head_sink}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); @@ -486,6 +502,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co if (use_seqlen_k) { program.AddInputs({{seqlen_k, ProgramTensorMetadataDependency::None}}); } + if (has_head_sink) { + program.AddInputs({{head_sink, ProgramTensorMetadataDependency::Type}}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -502,7 +521,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, q_BNSH, use_seqlen_k, has_head_sink) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, @@ -542,7 +561,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, - use_indirect_dispatch, present_sequence_length)); + use_indirect_dispatch, present_sequence_length, + head_sink)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 489a6673aae0f..fc2843f6ea908 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -77,7 +77,8 @@ class FlashAttentionProgram final : public Program { bool is_unidirectional, bool is_nvidia, bool q_BNSH, - bool use_seqlen_k = false) + bool use_seqlen_k = false, + bool has_head_sink = false) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), @@ -87,7 +88,8 @@ class FlashAttentionProgram final : public Program { is_unidirectional_(is_unidirectional), is_nvidia_(is_nvidia), q_BNSH_(q_BNSH), - use_seqlen_k_(use_seqlen_k) { + use_seqlen_k_(use_seqlen_k), + has_head_sink_(has_head_sink) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -112,6 +114,7 @@ class FlashAttentionProgram final : public Program { bool is_nvidia_; bool q_BNSH_; bool use_seqlen_k_; + bool has_head_sink_; }; class FlashAttentionDecodeQKTProgram final : public Program { @@ -142,8 +145,8 @@ class FlashAttentionDecodeQKTProgram final : public Program { public: - FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch) - : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false) + : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -153,12 +156,14 @@ class FlashAttentionDecodeSplitVxProgram final : public Program { @@ -184,7 +189,7 @@ class FlashAttentionDecodeVxReduceProgram final : public ProgramDataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); @@ -310,7 +310,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k); + present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink); } TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,