Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,31 +99,32 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
<< "let m = workgroup_id.y * TILE_SIZE;\n"
<< "let n = workgroup_id.x * TILE_SIZE;\n"
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n"
<< "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n"
<< "let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n"
<< "let batch_idx = batch_head_idx / uniforms.num_heads;\n"
<< "let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n";
shader.MainFunctionBody() << "let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n";
}

shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
" if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
" }\n"
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";

if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
<< " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n"
<< " let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
Expand Down Expand Up @@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< "}\n";

shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n"
<< " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n"
<< " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";

shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
Expand Down Expand Up @@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
}

const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
Expand All @@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)}})
{static_cast<uint32_t>(parameters.is_first_prompt_ ? 1 : 0)},
{num_total_seq_length_tile},
{num_seq_length_tile}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand All @@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let sequence_length = uniforms.sequence_length;\n"
shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n"
<< "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
Expand Down Expand Up @@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(work_group_size)
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
{static_cast<uint32_t>(num_heads)},
Expand Down Expand Up @@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {

shader.AdditionalImplementation() << "var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << ">;\n";
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let m = global_id.y;\n"
<< "let n = global_id.x;\n"
<< "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
shader.MainFunctionBody() << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n"
<< "let head_idx = batch_head_idx % uniforms.num_heads;\n"
<< "let batch_idx = batch_head_idx / uniforms.num_heads;\n"
<< "let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n"
<< "let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n"
<< "let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.K;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n";
shader.MainFunctionBody() << "let vOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n";
if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n";
shader.MainFunctionBody() << "let presentValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n";
}

shader.MainFunctionBody() << "var value = output_value_t(0);\n"
Expand All @@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {

if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) {
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
<< " let pastValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n"
<< " let pastValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n"
Expand Down Expand Up @@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank, components});
}

program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
const uint32_t num_head_size_tile = (parameters.v_head_size_ + tile_n_size - 1) / tile_n_size;
const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size;
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
Expand All @@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)},
{static_cast<uint32_t>(parameters.is_first_prompt_)}})
{static_cast<uint32_t>(parameters.is_first_prompt_)},
{num_head_size_tile},
{num_seq_length_tile}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});

return context.RunProgram(program);
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

Expand Down Expand Up @@ -105,7 +107,9 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
{"num_head_size_tile", ProgramUniformVariableDataType::Uint32},
{"num_seq_length_tile", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

Expand Down
Loading