@@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
34
34
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
35
35
public:
36
36
AttentionProbsProgram (const std::string& kernel_name, bool feed_past_key, bool has_present_key,
37
- bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa )
38
- : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
37
+ bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1 , const Tensor* seqlen_k = nullptr , bool past_present_share_buffer = false )
38
+ : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
39
39
}
40
40
41
41
Status GenerateShaderCode (ShaderHelper& sh) const override ;
@@ -63,7 +63,6 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
63
63
const Tensor* seqlen_k_;
64
64
bool past_present_share_buffer_;
65
65
bool is_first_prompt_;
66
- bool is_gqa_;
67
66
};
68
67
69
68
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
@@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
90
89
91
90
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
92
91
public:
93
- VxAttentionScoreProgram (const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps, const Tensor* seqlen_k, bool past_present_share_buffer, bool is_gqa )
94
- : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_gqa_(is_gqa) {
92
+ VxAttentionScoreProgram (const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1 , const Tensor* seqlen_k = nullptr , bool past_present_share_buffer = false )
93
+ : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
95
94
}
96
95
97
96
Status GenerateShaderCode (ShaderHelper& sh) const override ;
@@ -117,7 +116,6 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
117
116
const Tensor* seqlen_k_;
118
117
bool past_present_share_buffer_;
119
118
bool is_first_prompt_;
120
- bool is_gqa_;
121
119
};
122
120
123
121
} // namespace webgpu
0 commit comments