Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#include "core/framework/data_transfer_manager.h"
#include "core/framework/fallback_cpu_capability.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/run_options.h"
#include "core/graph/function_utils.h"
#include "core/graph/indexed_sub_graph.h"
#include "core/session/onnxruntime_run_options_config_keys.h"
#include "core/common/parse_string.h"

#include "core/providers/webgpu/webgpu_context.h"
#include "core/providers/webgpu/data_transfer.h"
Expand Down Expand Up @@ -692,7 +695,6 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Flatten)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile)>,

Expand Down Expand Up @@ -907,7 +909,7 @@ Status WebGpuExecutionProvider::OnSessionInitializationEnd() {
return Status::OK();
}

Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
if (context_.ValidationMode() >= ValidationMode::Basic) {
context_.PushErrorScope();
}
Expand All @@ -916,20 +918,32 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_
context_.StartProfiling();
}

if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) {
context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_);
if (IsGraphCaptureEnabled()) {
auto graph_annotation_str = run_options.config_options.GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation);
int graph_annotation_id = 0;
if (graph_annotation_str.has_value()) {
ORT_ENFORCE(onnxruntime::TryParseStringWithClassicLocale<int>(*graph_annotation_str, graph_annotation_id),
"Failed to parse the graph annotation id: ",
*graph_annotation_str);
}

if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_);
}
m_current_graph_annotation_id = graph_annotation_id;
}

return Status::OK();
}

Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) {
Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) {
context_.Flush(BufferManager());

if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) {
if (IsGraphCaptureAllowed()) {
if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) {
if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) {
context_.CaptureEnd();
is_graph_captured_ = true;
ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id));
} else {
IncrementRegularRunCountBeforeGraphCapture();
}
Expand All @@ -952,12 +966,12 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
return enable_graph_capture_;
}

bool WebGpuExecutionProvider::IsGraphCaptured(int) const {
return is_graph_captured_;
bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return is_graph_captured_ && graph_annotation_id != -1;
}

Status WebGpuExecutionProvider::ReplayGraph(int) {
ORT_ENFORCE(IsGraphCaptured(0));
Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
context_.Replay(captured_commands_, *graph_buffer_mgr_);
return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
int m_current_graph_annotation_id = 0;
webgpu::GpuBufferAllocator* allocator_ = nullptr;

// Buffer manager specifically for graph capture mode
Expand Down
Loading