From 3a927136b5e35577e3fabcd648c5abe0ac48a6e1 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 28 Mar 2025 14:36:54 +0800 Subject: [PATCH] [webgpu] Use 1D Dispatch groups --- .../contrib_ops/webgpu/bert/attention.cc | 70 ++++++++++--------- .../contrib_ops/webgpu/bert/attention.h | 8 ++- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 0d4afc8c13f4b..b93df1ec1c172 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -99,23 +99,24 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "// x holds the N and y holds the M\n" - << "let m = workgroup_id.y * TILE_SIZE;\n" - << "let n = workgroup_id.x * TILE_SIZE;\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n" + << "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n" + << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n" + << "let batch_idx = batch_head_idx / uniforms.num_heads;\n" + << "let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let kOffset = (batch_head_idx / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = (batch_head_idx / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" " }\n" " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" @@ -123,7 +124,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" - << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" + << " let pastKeyOffset = (batch_head_idx / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" @@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " workgroupBarrier();\n" << "}\n"; - shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" - << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" - << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n" + << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; @@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; - program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, - (parameters.sequence_length_ + tile_size - 1) / tile_size, - parameters.batch_size_ * parameters.num_heads_) + const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile) .SetWorkgroupSize(tile_size, tile_size) .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, @@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_ ? 1 : 0)}}) + {static_cast(parameters.is_first_prompt_ ? 1 : 0)}, + {num_total_seq_length_tile}, + {num_seq_length_tile}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let sequence_length = uniforms.sequence_length;\n" + shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n" + << "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" @@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) .CacheHint(work_group_size) - .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, @@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var tileQ: array;\n" << "var tileK: array;\n"; - shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + shader.MainFunctionBody() << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n" + << "let head_idx = batch_head_idx % uniforms.num_heads;\n" + << "let batch_idx = batch_head_idx / uniforms.num_heads;\n" + << "let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n" + << "let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n" + << "let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + shader.MainFunctionBody() << "let vOffset = (batch_head_idx / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = (batch_head_idx / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; } shader.MainFunctionBody() << "var value = output_value_t(0);\n" @@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" - << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n" + << " let pastValueOffset = (batch_head_idx / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n" @@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank, components}); } - program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size, - (parameters.sequence_length_ + tile_size - 1) / tile_size, - parameters.batch_size_ * parameters.num_heads_) + const uint32_t num_head_size_tile = (parameters.v_head_size_ + tile_n_size - 1) / tile_n_size; + const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile) .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, @@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.is_first_prompt_)}, + {num_head_size_tile}, + {num_seq_length_tile}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 164ea72b07d9d..cef2ddf33f6dd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -50,7 +50,9 @@ class AttentionProbsProgram final : public Program { {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32}, + {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -106,7 +108,9 @@ class VxAttentionScoreProgram final : public Program { {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"num_head_size_tile", ProgramUniformVariableDataType::Uint32}, + {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});