Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 119 additions & 14 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {

Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform);
const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform);

const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform);
const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform);
const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform);

if (prepare_indirect_dispatch_) {
sh.AddOutput("indirect_buffer", ShaderUsage::None);
}

return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_),
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
WGSL_TEMPLATE_VARIABLE(present_key, present_key),
WGSL_TEMPLATE_VARIABLE(present_value, present_value),
WGSL_TEMPLATE_VARIABLE(query, query),
WGSL_TEMPLATE_VARIABLE(seqlens, seqlens),
WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache));
}

Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Expectations are
// qkv have same number of heads and hidden dimension (head size).
Expand Down Expand Up @@ -351,7 +377,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, bool skip_copy_kv, Tensor* indirect_buffer) {
// Extract present_sequence_length directly from present_key tensor shape:
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);
Expand All @@ -361,7 +387,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
if (parameters.sequence_length_ > 1) {
const uint32_t tile_size = 64;
// For encode path, use the original CopyKVCache without indirect dispatch preparation
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
if (!skip_copy_kv) {
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
}
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
Expand Down Expand Up @@ -415,18 +443,22 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
seqlen_k != nullptr &&
context.IsGraphCaptureEnabled();

// Create indirect dispatch buffer if using indirect dispatch
Tensor* indirect_buffer_ptr = nullptr;
Tensor indirect_buffer;
if (use_indirect_dispatch) {
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
indirect_buffer_ptr = &indirect_buffer;
// Use the fused CopyKVCache that also prepares the indirect dispatch buffer
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
} else {
// Use the original CopyKVCache without indirect dispatch preparation
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
Tensor* indirect_buffer_ptr = indirect_buffer;
Tensor indirect_buffer_obj;

if (!skip_copy_kv) {
if (use_indirect_dispatch) {
if (indirect_buffer == nullptr) {
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
indirect_buffer_obj = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
indirect_buffer_ptr = &indirect_buffer_obj;
}
// Use the fused CopyKVCache that also prepares the indirect dispatch buffer
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
} else {
// Use the original CopyKVCache without indirect dispatch preparation
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, nullptr, nullptr));
}
Comment thread
xiaofeihan1 marked this conversation as resolved.
Outdated
}

// The metadata is used to store the max and sum of each tile.
Expand Down Expand Up @@ -467,6 +499,79 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
}

Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
const Tensor* packedQKV,
const Tensor* seqlen_k,
const Tensor* cos_cache,
const Tensor* sin_cache,
Tensor* query,
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer) {
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;

int components = 1;
// Currently we only support vectorization when RoPE is not interleaved
if (!params.rotary_interleaved_) {
if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) {
components = 4;
} else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) {
components = 2;
}
}
// Adjust dimensions for vectorization
const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components;
const auto head_size_vec = head_size / components;

// Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
// work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
// = head_size - half_rotary_dim
const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec;
auto dispatch_size = static_cast<uint32_t>(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);

// Extract present_sequence_length from present_key tensor shape
const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t>(present_key->Shape()[2]);

const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);

constexpr uint32_t tile_size = 64;
Comment thread
xiaofeihan1 marked this conversation as resolved.
Outdated

SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch);
program
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch)
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
.AddInputs({
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
{cos_cache, ProgramTensorMetadataDependency::Rank, components},
{sin_cache, ProgramTensorMetadataDependency::Rank, components},
});
program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components},
{present_key, ProgramTensorMetadataDependency::None, components},
{present_value, ProgramTensorMetadataDependency::None, components}});

if (prepare_indirect_dispatch) {
program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
}

program.AddUniformVariables({
{static_cast<uint32_t>(params.sequence_length_)},
{static_cast<uint32_t>(params.hidden_size_ / components)},
{static_cast<uint32_t>(params.kv_hidden_size_ / components)},
{static_cast<uint32_t>(params.num_heads_)},
{static_cast<uint32_t>(params.kv_num_heads_)},
{head_size_vec},
{half_rotary_embedding_dim_vec},
{present_sequence_length},
{tile_size},
{static_cast<uint32_t>(dispatch_size)},
});

program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
40 changes: 39 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,32 @@ namespace webgpu {

using namespace onnxruntime::webgpu;

class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram> {
public:
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch)
: Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"},
interleaved_(interleaved),
prepare_indirect_dispatch_(prepare_indirect_dispatch) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"kv_num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"half_rotary_dim", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"tile_size", ProgramUniformVariableDataType::Uint32},
{"dispatch_size", ProgramUniformVariableDataType::Uint32});

private:
const bool interleaved_;
const bool prepare_indirect_dispatch_;
};

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
Expand Down Expand Up @@ -145,10 +171,22 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, bool skip_copy_kv = false, Tensor* indirect_buffer = nullptr);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);

// Split packed QKV with Q/K rotary embedding and copy KV cache fusion
Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
const Tensor* packedQKV,
const Tensor* seqlen_k,
const Tensor* cos_cache,
const Tensor* sin_cache,
Tensor* query,
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer);
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading
Loading