Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@
namespace onnxruntime {
namespace webgpu {

void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
void WebGpuContext::Initialize(const WebGpuContextConfig& config) {
std::call_once(init_flag_, [this, &config]() {
if (device_ == nullptr) {
// Create wgpu::Adapter
wgpu::RequestAdapterOptions req_adapter_options = {};
req_adapter_options.backendType = static_cast<wgpu::BackendType>(backend_type);
req_adapter_options.powerPreference = static_cast<wgpu::PowerPreference>(power_preference_);
req_adapter_options.backendType = static_cast<wgpu::BackendType>(config.backend_type);
req_adapter_options.powerPreference = static_cast<wgpu::PowerPreference>(config.power_preference);

#if !defined(__wasm__)
auto enabled_adapter_toggles = GetEnabledAdapterToggles();
Expand Down Expand Up @@ -134,9 +134,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi

// create buffer manager
buffer_mgr_ = BufferManagerFactory::Create(*this,
buffer_cache_config.storage.mode,
buffer_cache_config.uniform.mode,
buffer_cache_config.query_resolve.mode);
config.buffer_cache_config.storage.mode,
config.buffer_cache_config.uniform.mode,
config.buffer_cache_config.query_resolve.mode);

// create initializer buffer manager. cache is always disabled for initializer buffer manager
initializer_buffer_mgr_ = BufferManagerFactory::Create(*this,
Expand All @@ -161,7 +161,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
} else {
query_type_ = TimestampQueryType::None;
}
if (enable_pix_capture) {
if (config.enable_pix_capture) {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
Expand Down Expand Up @@ -979,15 +979,18 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
device,
config.validation_mode,
config.preserve_device,
config.max_storage_buffer_binding_size,
config.power_preference));
config.max_storage_buffer_binding_size));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
it->second.context->device_.Get() == device,
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
}
it->second.ref_count++;

// perform initialization
it->second.context->Initialize(config);

return *it->second.context;
}

Expand Down Expand Up @@ -1017,6 +1020,11 @@ void WebGpuContextFactory::Cleanup() {
default_instance_ = nullptr;
}

WebGpuContext& WebGpuContextFactory::DefaultContext() {
WebGpuContextConfig config{};
return WebGpuContextFactory::CreateContext(config);
}

void CleanupWebGpuContexts() {
WebGpuContextFactory::Cleanup();
}
Expand Down
83 changes: 60 additions & 23 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,51 @@ struct CapturedCommandInfo {
WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch
};

struct WebGpuContextConfig {
int context_id;
WGPUInstance instance;
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
bool preserve_device;
uint64_t max_storage_buffer_binding_size;
int power_preference;
};

struct WebGpuBufferCacheConfig {
struct ConfigEntry {
BufferCacheMode mode;
std::string config_string;
std::string config_string; // preserved for customized configuration, eg. bucket sizes
};
ConfigEntry storage{BufferCacheMode::Bucket, {}};
ConfigEntry uniform{BufferCacheMode::Simple, {}};
ConfigEntry query_resolve{BufferCacheMode::Disabled, {}};
ConfigEntry default_entry{BufferCacheMode::Disabled, {}};
};

/// <summary>
/// Represents the configuration options for creating a WebGpuContext.
/// </summary>
struct WebGpuContextConfig {
int context_id{0};
WGPUInstance instance{nullptr};
WGPUDevice device{nullptr};
const void* dawn_proc_table{nullptr};
ValidationMode validation_mode{
#ifndef NDEBUG
webgpu::ValidationMode::Full // for debug build, enable full validation by default
#else
webgpu::ValidationMode::Basic // for release build, enable basic validation by default
#endif // !NDEBUG
};
ConfigEntry storage;
ConfigEntry uniform;
ConfigEntry query_resolve;
ConfigEntry default_entry;
bool preserve_device{false};
uint64_t max_storage_buffer_binding_size{0};
WebGpuBufferCacheConfig buffer_cache_config{};
int power_preference{static_cast<int>(WGPUPowerPreference_HighPerformance)};
int backend_type{
#ifdef _WIN32
// Setup Windows default backend type based on the build configuration
#if defined(DAWN_ENABLE_D3D12)
static_cast<int>(WGPUBackendType_D3D12)
#elif defined(DAWN_ENABLE_VULKAN)
static_cast<int>(WGPUBackendType_Vulkan)
#else
0
#endif
#else
0
#endif
};
bool enable_pix_capture{false};
};

class WebGpuContextFactory {
Expand All @@ -63,13 +88,28 @@ class WebGpuContextFactory {
int ref_count;
};

/// <summary>
/// Create a new WebGPU context for the specified context ID if not present, or return the existing one. (ref-count based)
/// </summary>
static WebGpuContext& CreateContext(const WebGpuContextConfig& config);

/// <summary>
/// Get the WebGPU context for the specified context ID. Throw if not present.
/// </summary>
static WebGpuContext& GetContext(int context_id);

/// <summary>
/// Release the WebGPU context. (ref-count based)
/// </summary>
static void ReleaseContext(int context_id);

static void Cleanup();

/// <summary>
/// Return the default context. Create if not present.
/// </summary>
static WebGpuContext& DefaultContext();

private:
WebGpuContextFactory() {}

Expand All @@ -82,8 +122,6 @@ class WebGpuContextFactory {
// Class WebGpuContext includes all necessary resources for the context.
class WebGpuContext final {
public:
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture);

Status Wait(wgpu::Future f);

const wgpu::Device& Device() const { return device_; }
Expand Down Expand Up @@ -190,20 +228,20 @@ class WebGpuContext final {
WGPUDevice device,
webgpu::ValidationMode validation_mode,
bool preserve_device,
uint64_t max_storage_buffer_binding_size,
int power_preference = static_cast<int>(wgpu::PowerPreference::HighPerformance))
uint64_t max_storage_buffer_binding_size)
: instance_{instance},
device_{device},
validation_mode_{validation_mode},
query_type_{TimestampQueryType::None},
preserve_device_{preserve_device},
max_storage_buffer_binding_size_{max_storage_buffer_binding_size},
power_preference_{power_preference} {
max_storage_buffer_binding_size_{max_storage_buffer_binding_size} {
ORT_ENFORCE(max_storage_buffer_binding_size_ == 0 || max_storage_buffer_binding_size_ >= 134217728,
"max_storage_buffer_binding_size must be 0 or at least 128MB");
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);

void Initialize(const WebGpuContextConfig& config);

void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder,
const std::vector<WGPUBuffer>& bind_buffers,
const std::vector<uint32_t>& bind_buffers_segments,
Expand Down Expand Up @@ -292,7 +330,6 @@ class WebGpuContext final {
bool is_profiling_ = false;
bool preserve_device_;
uint64_t max_storage_buffer_binding_size_;
int power_preference_;
GraphCaptureState graph_capture_state_{GraphCaptureState::Default};

// External vector to store captured commands, owned by EP
Expand Down
15 changes: 3 additions & 12 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,9 @@ struct CapturedCommandInfo;
} // namespace webgpu

struct WebGpuExecutionProviderConfig {
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture},
enable_pix_capture{enable_pix_capture} {}
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);

DataLayout data_layout;
bool enable_graph_capture;
bool enable_pix_capture;
std::vector<std::string> force_cpu_node_names;
DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default
bool enable_graph_capture{false}; // graph capture feature is disabled by default
std::vector<std::string> force_cpu_node_names{};
};

class WebGpuExecutionProvider : public IExecutionProvider {
Expand Down
Loading
Loading