Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha

return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
WGSL_TEMPLATE_PARAMETER(multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_),
WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_),
WGSL_TEMPLATE_PARAMETER(use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0),
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
WGSL_TEMPLATE_VARIABLE(present_key, present_key),
Expand Down Expand Up @@ -594,10 +596,11 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t>(present_key->Shape()[2]);

const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset();

SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch);
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch, multi_rotary_cache_concat_offset);
program
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch)
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch, multi_rotary_cache_concat_offset)
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
.AddInputs({
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ using namespace onnxruntime::webgpu;

class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram> {
public:
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch)
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch, uint32_t multi_rotary_cache_concat_offset)
: Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"},
interleaved_(interleaved),
prepare_indirect_dispatch_(prepare_indirect_dispatch) {}
prepare_indirect_dispatch_(prepare_indirect_dispatch),
multi_rotary_cache_concat_offset_(multi_rotary_cache_concat_offset) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

Expand All @@ -39,6 +40,7 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program<S
private:
const bool interleaved_;
const bool prepare_indirect_dispatch_;
const uint32_t multi_rotary_cache_concat_offset_;
};

class CopyKVCacheProgram final : public Program<CopyKVCacheProgram> {
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Status SplitPackedQKVWithRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper

return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding.wgsl.template",
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
WGSL_TEMPLATE_PARAMETER(multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_),
WGSL_TEMPLATE_PARAMETER(use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0),
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
WGSL_TEMPLATE_VARIABLE(key, key),
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
Expand Down Expand Up @@ -74,9 +76,10 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext&
const auto work_per_head_vec = head_size_vec - half_rotary_embedding_dim_vec;
auto dispatch_size = static_cast<uint32_t>(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head_vec);

SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_);
const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset();
SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_, multi_rotary_cache_concat_offset);
program
.CacheHint(params.rotary_interleaved_)
.CacheHint(params.rotary_interleaved_, multi_rotary_cache_concat_offset)
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
.AddInputs({
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ using namespace onnxruntime::webgpu;

class SplitPackedQKVWithRotaryEmbeddingProgram final : public Program<SplitPackedQKVWithRotaryEmbeddingProgram> {
public:
SplitPackedQKVWithRotaryEmbeddingProgram(bool interleaved) : Program{"SplitPackedQKVWithRotaryEmbedding"}, interleaved_{interleaved} {}
SplitPackedQKVWithRotaryEmbeddingProgram(bool interleaved, uint32_t multi_rotary_cache_concat_offset)
: Program{"SplitPackedQKVWithRotaryEmbedding"},
interleaved_{interleaved},
multi_rotary_cache_concat_offset_{multi_rotary_cache_concat_offset} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

Expand All @@ -32,6 +35,7 @@ class SplitPackedQKVWithRotaryEmbeddingProgram final : public Program<SplitPacke

private:
const bool interleaved_;
const uint32_t multi_rotary_cache_concat_offset_;
};

class GroupQueryAttention final : public WebGpuKernel {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#param interleaved
#param multi_rotary_cache_concat_offset
#param use_multi_rotary_cache_concat

#use guardAgainstOutOfBoundsWorkgroupSizes
#use .setByIndices .getByIndices .getByOffset
Expand Down Expand Up @@ -30,9 +32,14 @@ $MAIN {
let total_seqlen = seqlen + 1u;
let past_seqlen = total_seqlen - uniforms.sequence_length;
let position_id = past_seqlen + seq_idx;
#if use_multi_rotary_cache_concat
let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset);
#else
let base_position = 0u;
#endif
// Process a rotary pair (i, j)
let cos_v = cos_cache.getByIndices(vec2<u32>(position_id, in_head_idx));
let sin_v = sin_cache.getByIndices(vec2<u32>(position_id, in_head_idx));
let cos_v = cos_cache.getByIndices(vec2<u32>(base_position + position_id, in_head_idx));
let sin_v = sin_cache.getByIndices(vec2<u32>(base_position + position_id, in_head_idx));

// Calculate actual indices in the head for i and j
#if interleaved
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#param interleaved
#param multi_rotary_cache_concat_offset
#param prepare_indirect_dispatch
#param use_multi_rotary_cache_concat

#use guardAgainstOutOfBoundsWorkgroupSizes
#use .setByIndices .getByIndices .getByOffset
Expand Down Expand Up @@ -32,6 +34,11 @@ $MAIN {
let past_seqlen = total_seqlen - uniforms.sequence_length;
// `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value
let position_id = past_seqlen + seq_idx;
#if use_multi_rotary_cache_concat
let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset);
#else
let base_position = 0u;
#endif

#if prepare_indirect_dispatch
// Prepare indirect dispatch buffer for thread 0
Expand All @@ -45,8 +52,8 @@ $MAIN {

if (in_head_idx < uniforms.half_rotary_dim) {
// Process a rotary pair (i, j)
let cos_v = cos_cache.getByIndices(vec2<u32>(position_id, in_head_idx));
let sin_v = sin_cache.getByIndices(vec2<u32>(position_id, in_head_idx));
let cos_v = cos_cache.getByIndices(vec2<u32>(base_position + position_id, in_head_idx));
let sin_v = sin_cache.getByIndices(vec2<u32>(base_position + position_id, in_head_idx));

// Calculate actual indices in the head for i and j
#if interleaved
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ class ComputeContextBase {
return ep_.IsGraphCaptureEnabled();
}

//
// Get the multi rotary cache concatenation offset (0 = disabled).
//
inline uint32_t MultiRotaryCacheConcatOffset() const {
return ep_.MultiRotaryCacheConcatOffset();
}

//
// Get the logger.
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
enable_graph_capture_{config.enable_graph_capture},
enable_int64_{config.enable_graph_capture || config.enable_int64},
multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset},
prepack_allocator_{std::make_shared<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), false)} {
// If graph capture is enabled, create a dedicated buffer manager for graph mode
if (enable_graph_capture_) {
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ struct CapturedCommandInfo;
} // namespace webgpu

struct WebGpuExecutionProviderConfig {
DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default
bool enable_graph_capture{false}; // graph capture feature is disabled by default
bool enable_pix_capture{false}; // PIX capture is disabled by default
bool enable_int64{false}; // int64 ops are not enabled by default
DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default
bool enable_graph_capture{false}; // graph capture feature is disabled by default
bool enable_pix_capture{false}; // PIX capture is disabled by default
bool enable_int64{false}; // int64 ops are not enabled by default
uint32_t multi_rotary_cache_concat_offset{0}; // offset for concatenated multi rotary cache (0 = disabled)
std::vector<std::string> force_cpu_node_names{};
};

Expand Down Expand Up @@ -82,6 +83,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
Status ReplayGraph(int graph_annotation_id) override;
webgpu::BufferManager& BufferManager() const;
AllocatorPtr PrepackAllocator() const { return prepack_allocator_; }
uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; }

private:
bool IsGraphCaptureAllowed() const;
Expand All @@ -94,6 +96,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
std::vector<std::string> force_cpu_node_names_;
bool enable_graph_capture_ = false;
bool enable_int64_ = false;
uint32_t multi_rotary_cache_concat_offset_ = 0;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options)
}
}

std::string multi_rotary_cache_concat_offset_str;
if (config_options.TryGetConfigEntry(kMultiRotaryCacheConcatOffset, multi_rotary_cache_concat_offset_str)) {
uint32_t offset_value = 0;
auto result = std::from_chars(multi_rotary_cache_concat_offset_str.data(),
multi_rotary_cache_concat_offset_str.data() + multi_rotary_cache_concat_offset_str.size(),
offset_value);
if (result.ec == std::errc{}) {
webgpu_ep_config.multi_rotary_cache_concat_offset = offset_value;
} else {
ORT_THROW("Invalid multiRotaryCacheConcatOffset value: ", multi_rotary_cache_concat_offset_str, ". Must be a non-negative integer.");
}
}

// parse force CPU node names
// The force CPU node names are separated by EOL (\n or \r\n) in the config entry.
// each line is a node name that will be forced to run on CPU.
Expand Down Expand Up @@ -108,6 +121,7 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options)
LOGS_DEFAULT(VERBOSE) << "WebGPU EP force CPU node count: " << webgpu_ep_config.force_cpu_node_names.size();
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP enable int64: " << webgpu_ep_config.enable_int64;
LOGS_DEFAULT(VERBOSE) << "WebGPU EP multi rotary cache concat offset: " << webgpu_ep_config.multi_rotary_cache_concat_offset;

return webgpu_ep_config;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace options {
constexpr const char* kPreferredLayout = "ep.webgpuexecutionprovider.preferredLayout";
constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGraphCapture";
constexpr const char* kEnableInt64 = "ep.webgpuexecutionprovider.enableInt64";
constexpr const char* kMultiRotaryCacheConcatOffset = "ep.webgpuexecutionprovider.multiRotaryCacheConcatOffset";

constexpr const char* kDawnProcTable = "ep.webgpuexecutionprovider.dawnProcTable";

Expand Down
Loading