Skip to content

Commit 60af2f5

Browse files
Revert "Calculate output chunk size based on whether the kernel is GQA or not."
This reverts commit e448b1a.
1 parent e448b1a commit 60af2f5

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

+5-6
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
172172
<< "}\n";
173173

174174
shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
175-
<< " let headOffset = workgroup_id.z * uniforms.M * " << (is_gqa_ ? "uniforms.present_sequence_length" : "uniforms.N") << ";\n"
175+
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
176176
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
177177
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
178178

@@ -200,7 +200,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
200200
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
201201

202202
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
203-
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
203+
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
204204
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
205205
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
206206
if (feed_past_key) {
@@ -416,9 +416,8 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
416416

417417
shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
418418
<< "if (m < uniforms.M && n < uniforms.N) {\n"
419-
<< " let tmp = " << (is_gqa_ ? "uniforms.num_heads * uniforms.present_sequence_length" : "uniforms.v_hidden_size") << ";\n"
420-
<< " let outputIdx = batch_idx * uniforms.M * tmp + "
421-
<< " m * tmp + head_idx * uniforms.N + n;\n"
419+
<< " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + "
420+
<< " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n"
422421
<< " output[outputIdx] = value;\n"
423422
<< "}\n";
424423

@@ -439,7 +438,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
439438
const bool has_present_value = output_count > 1 && past_value != nullptr;
440439
const int tile_size = 12;
441440

442-
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_, parameters.is_gqa_};
441+
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
443442
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
444443
{V, ProgramTensorMetadataDependency::TypeAndRank}});
445444
if (feed_past_value) {

onnxruntime/contrib_ops/webgpu/bert/attention.h

+4-6
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
3434
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
3535
public:
3636
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
37-
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
38-
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
37+
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
38+
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
3939
}
4040

4141
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -63,7 +63,6 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
6363
const Tensor* seqlen_k_;
6464
bool past_present_share_buffer_;
6565
bool is_first_prompt_;
66-
bool is_gqa_;
6766
};
6867

6968
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
@@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
9089

9190
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
9291
public:
93-
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa)
94-
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
92+
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
93+
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
9594
}
9695

9796
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -117,7 +116,6 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
117116
const Tensor* seqlen_k_;
118117
bool past_present_share_buffer_;
119118
bool is_first_prompt_;
120-
bool is_gqa_;
121119
};
122120

123121
} // namespace webgpu

0 commit comments

Comments
 (0)