diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index cdd0f4ed57c15..400ba64b21ab3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -32,7 +32,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_), + WGSL_TEMPLATE_PARAMETER(multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_), WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0), WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache), WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv), WGSL_TEMPLATE_VARIABLE(present_key, present_key), @@ -594,10 +596,11 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput const uint32_t present_sequence_length = gsl::narrow_cast(present_key->Shape()[2]); const bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset(); - SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch); + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch, multi_rotary_cache_concat_offset); program - .CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch) + .CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch, multi_rotary_cache_concat_offset) .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) .AddInputs({ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index b98e18b71bc27..489a6673aae0f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -17,10 +17,11 @@ using namespace onnxruntime::webgpu; class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { public: - SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch) + SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram(bool interleaved, bool prepare_indirect_dispatch, uint32_t multi_rotary_cache_concat_offset) : Program{"SplitPackedQKVWithRotaryEmbeddingAndCopyKV"}, interleaved_(interleaved), - prepare_indirect_dispatch_(prepare_indirect_dispatch) {} + prepare_indirect_dispatch_(prepare_indirect_dispatch), + multi_rotary_cache_concat_offset_(multi_rotary_cache_concat_offset) {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -39,6 +40,7 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index e6b0189b6ee53..81bfe8a436c9c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -32,6 +32,8 @@ Status SplitPackedQKVWithRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding.wgsl.template", WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_), + WGSL_TEMPLATE_PARAMETER(multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_), + WGSL_TEMPLATE_PARAMETER(use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0), WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache), WGSL_TEMPLATE_VARIABLE(key, key), WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv), @@ -74,9 +76,10 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& const auto work_per_head_vec = head_size_vec - half_rotary_embedding_dim_vec; auto dispatch_size = static_cast(params.batch_size_ * params.sequence_length_ * params.num_heads_ * work_per_head_vec); - SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_); + const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset(); + SplitPackedQKVWithRotaryEmbeddingProgram program(params.rotary_interleaved_, multi_rotary_cache_concat_offset); program - .CacheHint(params.rotary_interleaved_) + .CacheHint(params.rotary_interleaved_, multi_rotary_cache_concat_offset) .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) .AddInputs({ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 077ec7768ea07..4127a8928f38e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -16,7 +16,10 @@ using namespace onnxruntime::webgpu; class SplitPackedQKVWithRotaryEmbeddingProgram final : public Program { public: - SplitPackedQKVWithRotaryEmbeddingProgram(bool interleaved) : Program{"SplitPackedQKVWithRotaryEmbedding"}, interleaved_{interleaved} {} + SplitPackedQKVWithRotaryEmbeddingProgram(bool interleaved, uint32_t multi_rotary_cache_concat_offset) + : Program{"SplitPackedQKVWithRotaryEmbedding"}, + interleaved_{interleaved}, + multi_rotary_cache_concat_offset_{multi_rotary_cache_concat_offset} {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -32,6 +35,7 @@ class SplitPackedQKVWithRotaryEmbeddingProgram final : public Program multi_rotary_cache_concat_offset); +#else + let base_position = 0u; +#endif // Process a rotary pair (i, j) - let cos_v = cos_cache.getByIndices(vec2(position_id, in_head_idx)); - let sin_v = sin_cache.getByIndices(vec2(position_id, in_head_idx)); + let cos_v = cos_cache.getByIndices(vec2(base_position + position_id, in_head_idx)); + let sin_v = sin_cache.getByIndices(vec2(base_position + position_id, in_head_idx)); // Calculate actual indices in the head for i and j #if interleaved diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index d6cb654afa756..c64bdf45cdcf8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -1,5 +1,7 @@ #param interleaved +#param multi_rotary_cache_concat_offset #param prepare_indirect_dispatch +#param use_multi_rotary_cache_concat #use guardAgainstOutOfBoundsWorkgroupSizes #use .setByIndices .getByIndices .getByOffset @@ -32,6 +34,11 @@ $MAIN { let past_seqlen = total_seqlen - uniforms.sequence_length; // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value let position_id = past_seqlen + seq_idx; +#if use_multi_rotary_cache_concat + let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset); +#else + let base_position = 0u; +#endif #if prepare_indirect_dispatch // Prepare indirect dispatch buffer for thread 0 @@ -45,8 +52,8 @@ $MAIN { if (in_head_idx < uniforms.half_rotary_dim) { // Process a rotary pair (i, j) - let cos_v = cos_cache.getByIndices(vec2(position_id, in_head_idx)); - let sin_v = sin_cache.getByIndices(vec2(position_id, in_head_idx)); + let cos_v = cos_cache.getByIndices(vec2(base_position + position_id, in_head_idx)); + let sin_v = sin_cache.getByIndices(vec2(base_position + position_id, in_head_idx)); // Calculate actual indices in the head for i and j #if interleaved diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 5b694a7a2e3f1..5277d64ad3611 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -96,6 +96,13 @@ class ComputeContextBase { return ep_.IsGraphCaptureEnabled(); } + // + // Get the multi rotary cache concatenation offset (0 = disabled). + // + inline uint32_t MultiRotaryCacheConcatOffset() const { + return ep_.MultiRotaryCacheConcatOffset(); + } + // // Get the logger. // diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 844591a930c0c..1891775c45057 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -829,6 +829,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, enable_graph_capture_{config.enable_graph_capture}, enable_int64_{config.enable_graph_capture || config.enable_int64}, + multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { // If graph capture is enabled, create a dedicated buffer manager for graph mode if (enable_graph_capture_) { diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 3cfd0536668c2..b5a6b5f167faf 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -31,10 +31,11 @@ struct CapturedCommandInfo; } // namespace webgpu struct WebGpuExecutionProviderConfig { - DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default - bool enable_graph_capture{false}; // graph capture feature is disabled by default - bool enable_pix_capture{false}; // PIX capture is disabled by default - bool enable_int64{false}; // int64 ops are not enabled by default + DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default + bool enable_graph_capture{false}; // graph capture feature is disabled by default + bool enable_pix_capture{false}; // PIX capture is disabled by default + bool enable_int64{false}; // int64 ops are not enabled by default + uint32_t multi_rotary_cache_concat_offset{0}; // offset for concatenated multi rotary cache (0 = disabled) std::vector force_cpu_node_names{}; }; @@ -82,6 +83,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { Status ReplayGraph(int graph_annotation_id) override; webgpu::BufferManager& BufferManager() const; AllocatorPtr PrepackAllocator() const { return prepack_allocator_; } + uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; } private: bool IsGraphCaptureAllowed() const; @@ -94,6 +96,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { std::vector force_cpu_node_names_; bool enable_graph_capture_ = false; bool enable_int64_ = false; + uint32_t multi_rotary_cache_concat_offset_ = 0; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index b3e1eb831de10..fc2496f0c7b68 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -72,6 +72,19 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) } } + std::string multi_rotary_cache_concat_offset_str; + if (config_options.TryGetConfigEntry(kMultiRotaryCacheConcatOffset, multi_rotary_cache_concat_offset_str)) { + uint32_t offset_value = 0; + auto result = std::from_chars(multi_rotary_cache_concat_offset_str.data(), + multi_rotary_cache_concat_offset_str.data() + multi_rotary_cache_concat_offset_str.size(), + offset_value); + if (result.ec == std::errc{}) { + webgpu_ep_config.multi_rotary_cache_concat_offset = offset_value; + } else { + ORT_THROW("Invalid multiRotaryCacheConcatOffset value: ", multi_rotary_cache_concat_offset_str, ". Must be a non-negative integer."); + } + } + // parse force CPU node names // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. // each line is a node name that will be forced to run on CPU. @@ -108,6 +121,7 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) LOGS_DEFAULT(VERBOSE) << "WebGPU EP force CPU node count: " << webgpu_ep_config.force_cpu_node_names.size(); LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << webgpu_ep_config.enable_pix_capture; LOGS_DEFAULT(VERBOSE) << "WebGPU EP enable int64: " << webgpu_ep_config.enable_int64; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP multi rotary cache concat offset: " << webgpu_ep_config.multi_rotary_cache_concat_offset; return webgpu_ep_config; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index c7651277ca85b..d2faccdb8c4a5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -12,6 +12,7 @@ namespace options { constexpr const char* kPreferredLayout = "ep.webgpuexecutionprovider.preferredLayout"; constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGraphCapture"; constexpr const char* kEnableInt64 = "ep.webgpuexecutionprovider.enableInt64"; +constexpr const char* kMultiRotaryCacheConcatOffset = "ep.webgpuexecutionprovider.multiRotaryCacheConcatOffset"; constexpr const char* kDawnProcTable = "ep.webgpuexecutionprovider.dawnProcTable";