diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6e09f494f4a8d..bca41b7851c28 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -18,8 +18,11 @@ #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" #include "core/framework/kernel_registry.h" +#include "core/framework/run_options.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/common/parse_string.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" @@ -692,7 +695,6 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -907,7 +909,7 @@ Status WebGpuExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); } -Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { if (context_.ValidationMode() >= ValidationMode::Basic) { context_.PushErrorScope(); } @@ -916,20 +918,32 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_ context_.StartProfiling(); } - if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + if (IsGraphCaptureEnabled()) { + auto graph_annotation_str = run_options.config_options.GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + int graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(onnxruntime::TryParseStringWithClassicLocale(*graph_annotation_str, graph_annotation_id), + "Failed to parse the graph annotation id: ", + *graph_annotation_str); + } + + if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + } + m_current_graph_annotation_id = graph_annotation_id; } return Status::OK(); } -Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) { context_.Flush(BufferManager()); - if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) { + if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) { context_.CaptureEnd(); is_graph_captured_ = true; + ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -952,12 +966,12 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { return enable_graph_capture_; } -bool WebGpuExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; +bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return is_graph_captured_ && graph_annotation_id != -1; } -Status WebGpuExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); +Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); context_.Replay(captured_commands_, *graph_buffer_mgr_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 2567be2a1eb18..3bbec164a0190 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -99,6 +99,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + int m_current_graph_annotation_id = 0; webgpu::GpuBufferAllocator* allocator_ = nullptr; // Buffer manager specifically for graph capture mode