Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
const Tensor* sin_cache = context->Input<Tensor>(8);
const Tensor* position_ids = context->Input<Tensor>(9);
const Tensor* attention_bias = context->Input<Tensor>(10);
const Tensor* head_sink = context->Input<Tensor>(11);

GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
Expand All @@ -73,6 +74,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {

ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
head_sink,
parameters));

const int batch_size = parameters.batch_size;
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ Status CheckInputs(const T* query,
template <typename T = Tensor>
Status CheckCustomAttentionInputs(const T* position_ids,
const T* attention_bias,
const T* head_sink,
const GroupQueryAttentionParameters& parameters) {
if (position_ids != nullptr) {
const auto& pos_ids_shape = position_ids->Shape();
Expand Down Expand Up @@ -377,6 +378,23 @@ Status CheckCustomAttentionInputs(const T* position_ids,
}
}

if (head_sink != nullptr) {
if (parameters.use_smooth_softmax) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_sink should not be provided when use_smooth_softmax is true.");
}

const auto& head_sink_shape = head_sink->Shape();
if (head_sink_shape.NumDimensions() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor");
}

if (head_sink_shape[0] != parameters.num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_sink dimension 0 must be equal to the num heads, got ", head_sink_shape[0]);
}
}

return Status::OK();
}

Expand Down
62 changes: 43 additions & 19 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
return context.RunProgram(program);
};

void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k) {
if (seqlen_k != nullptr) {
void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) {
if (has_seqlen_k) {
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
ss << "var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);\n";
} else {
Expand All @@ -87,7 +87,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
if (seqlen_k_ != nullptr) {
if (has_seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
Expand All @@ -107,7 +107,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
InitVarStub(oss, has_seqlen_k_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
Expand Down Expand Up @@ -182,7 +182,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);

AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_};
components, parameters.is_first_prompt_, seqlen_k != nullptr, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
Expand Down Expand Up @@ -224,30 +224,44 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
}

Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (seqlen_k_) {
if (has_seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
if (has_head_sink_) {
shader.AddInput("head_sink", ShaderUsage::UseUniform);
}
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n"
<< "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n"
<< "let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_);
InitVarStub(oss, has_seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
<< "}\n"
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
<< "workgroupBarrier();\n"
<< "var max_value = f32(-3.402823e+38f);\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< "workgroupBarrier();\n";

if (has_head_sink_) {
// Handle head sink
shader.MainFunctionBody() << "let sink_value: f32 = head_sink[head_idx];\n"
<< "var max_value = sink_value;\n";
} else if (use_smooth_softmax_) {
shader.MainFunctionBody() << "var max_value: f32 = 0.0;\n";
} else {
shader.MainFunctionBody() << "var max_value = f32(-3.402823e+38f);\n";
}

shader.MainFunctionBody() << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " max_value = max(thread_max[i], max_value);\n"
<< "}\n"
<< "var sum_vector = f32_val_t(0);\n"
Expand All @@ -259,8 +273,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "var sum: f32 = 0;\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " sum += thread_sum[i]\n;"
<< "}\n"
<< "if (sum == 0) {\n"
<< "}\n";

if (has_head_sink_) {
shader.MainFunctionBody() << "sum += exp(sink_value - max_value);\n";
} else if (use_smooth_softmax_) {
shader.MainFunctionBody() << "sum += exp(-max_value);\n";
}

shader.MainFunctionBody() << "if (sum == 0) {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n"
<< " }\n"
Expand All @@ -270,7 +291,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
<< " }\n"
<< "}\n";
if (seqlen_k_) {
if (has_seqlen_k_) {
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n"
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
<< "}\n";
Expand All @@ -280,7 +301,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
const Tensor* seqlen_k, bool is_first_prompt) {
const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) {
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
Expand All @@ -289,12 +310,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;

InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, seqlen_k};
InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
if (head_sink != nullptr) {
program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(work_group_size)
.CacheHint(work_group_size, use_smooth_softmax)
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
Expand Down Expand Up @@ -443,7 +467,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length =
Expand All @@ -457,7 +481,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
parameters, past_sequence_length, total_sequence_length, seqlen_k));

ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink));

ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k));
Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -62,15 +62,15 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
bool has_attention_bias_;
int tile_size_;
int components_;
const Tensor* seqlen_k_;
bool has_seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink)
: Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -86,7 +86,9 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
private:
int work_group_size_;
int components_;
const Tensor* seqlen_k_;
bool use_smooth_softmax_;
bool has_seqlen_k_;
bool has_head_sink_;
};

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr);

} // namespace webgpu
} // namespace contrib
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@
// sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi)
// o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i
//

// TODO: support smooth softmax and head_sink

Check warning on line 386 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc:386: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
shader.MainFunctionBody() << R"MAIN_FN(
var local_max_temp = max(qk_1, qk_2);
if (sg_size > 8)
Expand Down
Loading
Loading