Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 139 additions & 24 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,32 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {

Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform);
const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform);

const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform);
const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform);
const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform);

if (prepare_indirect_dispatch_) {
sh.AddOutput("indirect_buffer", ShaderUsage::None);
}

return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_),
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
WGSL_TEMPLATE_VARIABLE(present_key, present_key),
WGSL_TEMPLATE_VARIABLE(present_value, present_value),
WGSL_TEMPLATE_VARIABLE(query, query),
WGSL_TEMPLATE_VARIABLE(seqlens, seqlens),
WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache));
}

Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Expectations are
// qkv have same number of heads and hidden dimension (head size).
Expand Down Expand Up @@ -351,17 +377,54 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k,
const Tensor* cos_cache, const Tensor* sin_cache) {
constexpr uint32_t tile_size = 64;

// Extract present_sequence_length directly from present_key tensor shape:
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
const uint32_t present_sequence_length = static_cast<uint32_t>(present_key->Shape()[2]);

const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled();

// Declare query_output at function scope to ensure it persists throughout the function
Tensor query_output;

// Create indirect dispatch buffer if using indirect dispatch
Tensor* indirect_buffer_ptr = nullptr;
Tensor indirect_buffer;

// Prepare indirect dispatch buffer for decode path with static KV cache
const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
parameters.past_present_share_buffer_ &&
seqlen_k != nullptr &&
context.IsGraphCaptureEnabled();
if (use_indirect_dispatch) {
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
indirect_buffer_ptr = &indirect_buffer;
}

const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr);

if (do_rotary) {
ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input.");
ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache.");

// Q points to the packed QKV tensor in this case, create query output tensor
query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));

ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters,
Q, seqlen_k,
cos_cache, sin_cache,
&query_output, present_key, present_value,
indirect_buffer_ptr, tile_size));
Q = &query_output;
} else {
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_indirect_dispatch ? seqlen_k : nullptr, indirect_buffer_ptr));
}

if (parameters.sequence_length_ > 1) {
const uint32_t tile_size = 64;
// For encode path, use the original CopyKVCache without indirect dispatch preparation
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, nullptr));
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
Expand Down Expand Up @@ -406,29 +469,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
parameters.sequence_length_, present_sequence_length});
const TensorShape qk_shape(qk_dims);
Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape);
constexpr uint32_t tile_size = 64;
const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size;
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;

// Determine if we should use indirect dispatch
const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
seqlen_k != nullptr &&
context.IsGraphCaptureEnabled();

// Create indirect dispatch buffer if using indirect dispatch
Tensor* indirect_buffer_ptr = nullptr;
Tensor indirect_buffer;
if (use_indirect_dispatch) {
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
indirect_buffer_ptr = &indirect_buffer;
// Use the fused CopyKVCache that also prepares the indirect dispatch buffer
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, indirect_buffer_ptr));
} else {
// Use the original CopyKVCache without indirect dispatch preparation
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, seqlen_k, nullptr));
}

// The metadata is used to store the max and sum of each tile.
const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_,
num_present_sequence_length_tile, 2});
Expand Down Expand Up @@ -467,6 +510,78 @@ bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
}

Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
const Tensor* packedQKV,
const Tensor* seqlen_k,
const Tensor* cos_cache,
const Tensor* sin_cache,
Tensor* query,
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer,
uint32_t tile_size) {
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;

int components = 1;
// Currently we only support vectorization when RoPE is not interleaved
if (!params.rotary_interleaved_) {
if ((params.head_size_ % 4 == 0) && (half_rotary_embedding_dim % 4 == 0)) {
components = 4;
} else if ((params.head_size_ % 2 == 0) && (half_rotary_embedding_dim % 2 == 0)) {
components = 2;
}
}
// Adjust dimensions for vectorization
const auto half_rotary_embedding_dim_vec = half_rotary_embedding_dim / components;
const auto head_size_vec = head_size / components;

// Dispatch: batch_size * sequence_length * num_heads * (half_rotary_dim + need_copy_dim)
// work_per_head = half_rotary_dim + (head_size - 2 * half_rotary_dim)
// = head_size - half_rotary_dim
const auto work_per_head = head_size_vec - half_rotary_embedding_dim_vec;
auto dispatch_size = static_cast<uint32_t>(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head);

// Extract present_sequence_length from present_key tensor shape
const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t>(present_key->Shape()[2]);

const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);

SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch);
program
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch)
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
.AddInputs({
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
{cos_cache, ProgramTensorMetadataDependency::Rank, components},
{sin_cache, ProgramTensorMetadataDependency::Rank, components},
});
program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components},
{present_key, ProgramTensorMetadataDependency::None, components},
{present_value, ProgramTensorMetadataDependency::None, components}});

if (prepare_indirect_dispatch) {
program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
}

program.AddUniformVariables({
{static_cast<uint32_t>(params.sequence_length_)},
{static_cast<uint32_t>(params.hidden_size_ / components)},
{static_cast<uint32_t>(params.kv_hidden_size_ / components)},
{static_cast<uint32_t>(params.num_heads_)},
{static_cast<uint32_t>(params.kv_num_heads_)},
{head_size_vec},
{half_rotary_embedding_dim_vec},
{present_sequence_length},
{tile_size},
{static_cast<uint32_t>(dispatch_size)},
});

program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
42 changes: 41 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,32 @@ namespace webgpu {

using namespace onnxruntime::webgpu;

class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram> {
public:
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch)
: Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"},
interleaved_(interleaved),
prepare_indirect_dispatch_(prepare_indirect_dispatch) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"kv_num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"half_rotary_dim", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"tile_size", ProgramUniformVariableDataType::Uint32},
{"dispatch_size", ProgramUniformVariableDataType::Uint32});

private:
const bool interleaved_;
const bool prepare_indirect_dispatch_;
};

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
public:
CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH,
Expand Down Expand Up @@ -145,10 +171,24 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr,
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);

// Split packed QKV with Q/K rotary embedding and copy KV cache fusion
Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
const Tensor* packedQKV,
const Tensor* seqlen_k,
const Tensor* cos_cache,
const Tensor* sin_cache,
Tensor* query,
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer,
uint32_t tile_size);
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading
Loading