diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 457867061d6a7..11f67d342da0d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -37,13 +37,13 @@ namespace onnxruntime { namespace webgpu { -void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) { - std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() { +void WebGpuContext::Initialize(const WebGpuContextConfig& config) { + std::call_once(init_flag_, [this, &config]() { if (device_ == nullptr) { // Create wgpu::Adapter wgpu::RequestAdapterOptions req_adapter_options = {}; - req_adapter_options.backendType = static_cast(backend_type); - req_adapter_options.powerPreference = static_cast(power_preference_); + req_adapter_options.backendType = static_cast(config.backend_type); + req_adapter_options.powerPreference = static_cast(config.power_preference); #if !defined(__wasm__) auto enabled_adapter_toggles = GetEnabledAdapterToggles(); @@ -134,9 +134,9 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // create buffer manager buffer_mgr_ = BufferManagerFactory::Create(*this, - buffer_cache_config.storage.mode, - buffer_cache_config.uniform.mode, - buffer_cache_config.query_resolve.mode); + config.buffer_cache_config.storage.mode, + config.buffer_cache_config.uniform.mode, + config.buffer_cache_config.query_resolve.mode); // create initializer buffer manager. cache is always disabled for initializer buffer manager initializer_buffer_mgr_ = BufferManagerFactory::Create(*this, @@ -161,7 +161,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi } else { query_type_ = TimestampQueryType::None; } - if (enable_pix_capture) { + if (config.enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator pix_frame_generator_ = std::make_unique(instance_, @@ -979,8 +979,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co device, config.validation_mode, config.preserve_device, - config.max_storage_buffer_binding_size, - config.power_preference)); + config.max_storage_buffer_binding_size)); it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first; } else if (context_id != 0) { ORT_ENFORCE(it->second.context->instance_.Get() == instance && @@ -988,6 +987,10 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device."); } it->second.ref_count++; + + // perform initialization + it->second.context->Initialize(config); + return *it->second.context; } @@ -1017,6 +1020,11 @@ void WebGpuContextFactory::Cleanup() { default_instance_ = nullptr; } +WebGpuContext& WebGpuContextFactory::DefaultContext() { + WebGpuContextConfig config{}; + return WebGpuContextFactory::CreateContext(config); +} + void CleanupWebGpuContexts() { WebGpuContextFactory::Cleanup(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 84dfb47ef4687..5a97ef662855e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -34,26 +34,51 @@ struct CapturedCommandInfo { WGPUBuffer indirect_buffer; // WGPUBuffer for indirect dispatch, nullptr if not using indirect dispatch }; -struct WebGpuContextConfig { - int context_id; - WGPUInstance instance; - WGPUDevice device; - const void* dawn_proc_table; - ValidationMode validation_mode; - bool preserve_device; - uint64_t max_storage_buffer_binding_size; - int power_preference; -}; - struct WebGpuBufferCacheConfig { struct ConfigEntry { BufferCacheMode mode; - std::string config_string; + std::string config_string; // preserved for customized configuration, eg. bucket sizes + }; + ConfigEntry storage{BufferCacheMode::Bucket, {}}; + ConfigEntry uniform{BufferCacheMode::Simple, {}}; + ConfigEntry query_resolve{BufferCacheMode::Disabled, {}}; + ConfigEntry default_entry{BufferCacheMode::Disabled, {}}; +}; + +/// +/// Represents the configuration options for creating a WebGpuContext. +/// +struct WebGpuContextConfig { + int context_id{0}; + WGPUInstance instance{nullptr}; + WGPUDevice device{nullptr}; + const void* dawn_proc_table{nullptr}; + ValidationMode validation_mode{ +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::Basic // for release build, enable basic validation by default +#endif // !NDEBUG }; - ConfigEntry storage; - ConfigEntry uniform; - ConfigEntry query_resolve; - ConfigEntry default_entry; + bool preserve_device{false}; + uint64_t max_storage_buffer_binding_size{0}; + WebGpuBufferCacheConfig buffer_cache_config{}; + int power_preference{static_cast(WGPUPowerPreference_HighPerformance)}; + int backend_type{ +#ifdef _WIN32 + // Setup Windows default backend type based on the build configuration +#if defined(DAWN_ENABLE_D3D12) + static_cast(WGPUBackendType_D3D12) +#elif defined(DAWN_ENABLE_VULKAN) + static_cast(WGPUBackendType_Vulkan) +#else + 0 +#endif +#else + 0 +#endif + }; + bool enable_pix_capture{false}; }; class WebGpuContextFactory { @@ -63,13 +88,28 @@ class WebGpuContextFactory { int ref_count; }; + /// + /// Create a new WebGPU context for the specified context ID if not present, or return the existing one. (ref-count based) + /// static WebGpuContext& CreateContext(const WebGpuContextConfig& config); + + /// + /// Get the WebGPU context for the specified context ID. Throw if not present. + /// static WebGpuContext& GetContext(int context_id); + /// + /// Release the WebGPU context. (ref-count based) + /// static void ReleaseContext(int context_id); static void Cleanup(); + /// + /// Return the default context. Create if not present. + /// + static WebGpuContext& DefaultContext(); + private: WebGpuContextFactory() {} @@ -82,8 +122,6 @@ class WebGpuContextFactory { // Class WebGpuContext includes all necessary resources for the context. class WebGpuContext final { public: - void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture); - Status Wait(wgpu::Future f); const wgpu::Device& Device() const { return device_; } @@ -190,20 +228,20 @@ class WebGpuContext final { WGPUDevice device, webgpu::ValidationMode validation_mode, bool preserve_device, - uint64_t max_storage_buffer_binding_size, - int power_preference = static_cast(wgpu::PowerPreference::HighPerformance)) + uint64_t max_storage_buffer_binding_size) : instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device}, - max_storage_buffer_binding_size_{max_storage_buffer_binding_size}, - power_preference_{power_preference} { + max_storage_buffer_binding_size_{max_storage_buffer_binding_size} { ORT_ENFORCE(max_storage_buffer_binding_size_ == 0 || max_storage_buffer_binding_size_ >= 134217728, "max_storage_buffer_binding_size must be 0 or at least 128MB"); } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + void Initialize(const WebGpuContextConfig& config); + void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder, const std::vector& bind_buffers, const std::vector& bind_buffers_segments, @@ -292,7 +330,6 @@ class WebGpuContext final { bool is_profiling_ = false; bool preserve_device_; uint64_t max_storage_buffer_binding_size_; - int power_preference_; GraphCaptureState graph_capture_state_{GraphCaptureState::Default}; // External vector to store captured commands, owned by EP diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index a9282a028c803..ad012423d0486 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -27,18 +27,9 @@ struct CapturedCommandInfo; } // namespace webgpu struct WebGpuExecutionProviderConfig { - WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture) - : data_layout{data_layout}, - enable_graph_capture{enable_graph_capture}, - enable_pix_capture{enable_pix_capture} {} - WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default; - WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default; - ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig); - - DataLayout data_layout; - bool enable_graph_capture; - bool enable_pix_capture; - std::vector force_cpu_node_names; + DataLayout data_layout{DataLayout::NHWC}; // preferred layout is NHWC by default + bool enable_graph_capture{false}; // graph capture feature is disabled by default + std::vector force_cpu_node_names{}; }; class WebGpuExecutionProvider : public IExecutionProvider { diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index fdd7caa1706f5..c92c3624678ea 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -14,51 +14,13 @@ #include "core/providers/webgpu/webgpu_provider_options.h" #include "core/providers/webgpu/data_transfer.h" + +using namespace onnxruntime::webgpu; using namespace onnxruntime::webgpu::options; namespace onnxruntime { -// Helper struct that holds configuration parameters for creating a WebGPU context with default settings. -// This is used during lazy initialization of the data transfer to create a context if one doesn't exist. -struct WebGpuContextParams { - webgpu::WebGpuContextConfig context_config; // WebGPU context configuration - webgpu::WebGpuBufferCacheConfig buffer_cache_config; // Buffer cache settings - int backend_type; // Dawn backend type (D3D12, Vulkan, etc.) - bool enable_pix_capture; // Enable PIX GPU capture for debugging -}; - -static WebGpuContextParams GetDefaultWebGpuContextParams() { - WebGpuContextParams params; - params.context_config.context_id = 0; - params.context_config.instance = nullptr; - params.context_config.device = nullptr; - params.context_config.dawn_proc_table = nullptr; - params.context_config.validation_mode = webgpu::ValidationMode::Disabled; - params.context_config.preserve_device = false; - params.context_config.max_storage_buffer_binding_size = 0; - params.context_config.power_preference = static_cast(WGPUPowerPreference_HighPerformance); - - params.buffer_cache_config.storage.mode = webgpu::BufferCacheMode::Bucket; - params.buffer_cache_config.uniform.mode = webgpu::BufferCacheMode::Simple; - params.buffer_cache_config.query_resolve.mode = webgpu::BufferCacheMode::Disabled; - params.buffer_cache_config.default_entry.mode = webgpu::BufferCacheMode::Disabled; - -#ifdef _WIN32 -#if defined(DAWN_ENABLE_D3D12) - params.backend_type = static_cast(WGPUBackendType_D3D12); -#elif defined(DAWN_ENABLE_VULKAN) - params.backend_type = static_cast(WGPUBackendType_Vulkan); -#else - params.backend_type = static_cast(WGPUBackendType_D3D12); -#endif -#else - params.backend_type = 0; -#endif - params.enable_pix_capture = false; - return params; -} - struct WebGpuProviderFactory : IExecutionProviderFactory { - WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) + WebGpuProviderFactory(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) : context_id_{context_id}, context_{context}, config_{std::move(webgpu_ep_config)} { } @@ -68,25 +30,17 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { private: int context_id_; - webgpu::WebGpuContext& context_; + WebGpuContext& context_; WebGpuExecutionProviderConfig config_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { - // - // STEP.1 - prepare WebGpuExecutionProviderConfig - // - WebGpuExecutionProviderConfig webgpu_ep_config{ - // preferred layout is NHWC by default - DataLayout::NHWC, - // graph capture feature is disabled by default - false, - // enable pix capture feature is diabled by default - false, - }; +namespace { - std::string preferred_layout_str; - if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { +WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) { + WebGpuExecutionProviderConfig webgpu_ep_config{}; + + if (std::string preferred_layout_str; + config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { if (preferred_layout_str == kPreferredLayout_NHWC) { webgpu_ep_config.data_layout = DataLayout::NHWC; } else if (preferred_layout_str == kPreferredLayout_NCHW) { @@ -95,11 +49,9 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( ORT_THROW("Invalid preferred layout: ", preferred_layout_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_config.data_layout) << " (parsed from \"" - << preferred_layout_str << "\")"; - std::string enable_graph_capture_str; - if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (std::string enable_graph_capture_str; + config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { if (enable_graph_capture_str == kEnableGraphCapture_ON) { webgpu_ep_config.enable_graph_capture = true; } else if (enable_graph_capture_str == kEnableGraphCapture_OFF) { @@ -108,13 +60,13 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture; // parse force CPU node names // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. // each line is a node name that will be forced to run on CPU. - std::string force_cpu_node_names_str; - if (config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { + + if (std::string force_cpu_node_names_str; + config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { // split the string by EOL (\n or \r\n) std::istringstream ss(force_cpu_node_names_str); std::string line; @@ -127,209 +79,182 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( webgpu_ep_config.force_cpu_node_names.push_back(line); } } + + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_config.data_layout); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture; LOGS_DEFAULT(VERBOSE) << "WebGPU EP force CPU node count: " << webgpu_ep_config.force_cpu_node_names.size(); - // - // STEP.2 - prepare WebGpuContextConfig - // - int context_id = 0; - std::string context_id_str; - if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + return webgpu_ep_config; +} + +WebGpuContextConfig ParseWebGpuContextConfig(const ConfigOptions& config_options) { + WebGpuContextConfig config{}; + + if (std::string context_id_str; + config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { ORT_ENFORCE(std::errc{} == - std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), config.context_id).ec); } - size_t webgpu_instance = 0; - std::string webgpu_instance_str; - if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + if (std::string webgpu_instance_str; + config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + size_t webgpu_instance = 0; ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + config.instance = reinterpret_cast(webgpu_instance); } - size_t webgpu_device = 0; - std::string webgpu_device_str; - if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + if (std::string webgpu_device_str; + config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + size_t webgpu_device = 0; ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + config.device = reinterpret_cast(webgpu_device); } - size_t dawn_proc_table = 0; - std::string dawn_proc_table_str; - if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + if (std::string dawn_proc_table_str; + config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + size_t dawn_proc_table = 0; ORT_ENFORCE(std::errc{} == std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + config.dawn_proc_table = reinterpret_cast(dawn_proc_table); } - webgpu::ValidationMode validation_mode = -#ifndef NDEBUG - webgpu::ValidationMode::Full // for debug build, enable full validation by default -#else - webgpu::ValidationMode::Basic // for release build, enable basic validation by default -#endif // !NDEBUG - ; - std::string validation_mode_str; - if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (std::string validation_mode_str; + config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { if (validation_mode_str == kValidationMode_Disabled) { - validation_mode = webgpu::ValidationMode::Disabled; + config.validation_mode = ValidationMode::Disabled; } else if (validation_mode_str == kValidationMode_wgpuOnly) { - validation_mode = webgpu::ValidationMode::WGPUOnly; + config.validation_mode = ValidationMode::WGPUOnly; } else if (validation_mode_str == kValidationMode_basic) { - validation_mode = webgpu::ValidationMode::Basic; + config.validation_mode = ValidationMode::Basic; } else if (validation_mode_str == kValidationMode_full) { - validation_mode = webgpu::ValidationMode::Full; + config.validation_mode = ValidationMode::Full; } else { ORT_THROW("Invalid validation mode: ", validation_mode_str); } } - std::string preserve_device_str; - bool preserve_device = false; - if (config_options.TryGetConfigEntry(kPreserveDevice, preserve_device_str)) { + if (std::string preserve_device_str; + config_options.TryGetConfigEntry(kPreserveDevice, preserve_device_str)) { if (preserve_device_str == kPreserveDevice_ON) { - preserve_device = true; + config.preserve_device = true; } else if (preserve_device_str == kPreserveDevice_OFF) { - preserve_device = false; + config.preserve_device = false; } else { ORT_THROW("Invalid preserve device: ", preserve_device_str); } } - uint64_t max_storage_buffer_binding_size = 0; std::string max_storage_buffer_binding_size_str; if (config_options.TryGetConfigEntry(kMaxStorageBufferBindingSize, max_storage_buffer_binding_size_str)) { ORT_ENFORCE( std::errc{} == std::from_chars( max_storage_buffer_binding_size_str.data(), max_storage_buffer_binding_size_str.data() + max_storage_buffer_binding_size_str.size(), - max_storage_buffer_binding_size) + config.max_storage_buffer_binding_size) .ec, "Invalid maxStorageBufferBindingSize value: ", max_storage_buffer_binding_size_str); } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP max storage buffer binding size: " << max_storage_buffer_binding_size; - // power preference - int power_preference = static_cast(WGPUPowerPreference_HighPerformance); // default - std::string power_preference_str; - if (config_options.TryGetConfigEntry(kPowerPreference, power_preference_str)) { - if (power_preference_str == kPowerPreference_HighPerformance) { - power_preference = static_cast(WGPUPowerPreference_HighPerformance); - } else if (power_preference_str == kPowerPreference_LowPower) { - power_preference = static_cast(WGPUPowerPreference_LowPower); - } else { - ORT_THROW("Invalid power preference: ", power_preference_str); - } - } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP power preference: " << power_preference; - - webgpu::WebGpuContextConfig context_config{ - context_id, - reinterpret_cast(webgpu_instance), - reinterpret_cast(webgpu_device), - reinterpret_cast(dawn_proc_table), - validation_mode, - preserve_device, - max_storage_buffer_binding_size, - power_preference, - }; - - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << webgpu_instance; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << webgpu_device; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << preserve_device; - LOGS_DEFAULT(VERBOSE) << "WebGPU EP PowerPreference: " << power_preference; - - // - // STEP.3 - prepare parameters for WebGPU context initialization. - // - - int backend_type = 0; -#ifdef _WIN32 - // Setup Windows default backend type based on the build configuration -#if defined(DAWN_ENABLE_D3D12) - backend_type = static_cast(WGPUBackendType_D3D12); -#elif defined(DAWN_ENABLE_VULKAN) - backend_type = static_cast(WGPUBackendType_Vulkan); -#endif -#endif - - std::string backend_type_str; - if (config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { - if (backend_type_str == kDawnBackendType_D3D12) { - backend_type = static_cast(WGPUBackendType_D3D12); - } else if (backend_type_str == kDawnBackendType_Vulkan) { - backend_type = static_cast(WGPUBackendType_Vulkan); - } else { - ORT_THROW("Invalid Dawn backend type: ", backend_type_str); - } - } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << backend_type; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << config.context_id; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUInstance: " << reinterpret_cast(config.instance); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP WGPUDevice: " << reinterpret_cast(config.device); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << reinterpret_cast(config.dawn_proc_table); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << config.validation_mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << config.preserve_device; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP max storage buffer binding size: " << config.max_storage_buffer_binding_size; // buffer cache modes auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, - webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { + BufferCacheMode& value) -> void { std::string buffer_cache_mode_str; if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { - return webgpu::BufferCacheMode::Disabled; + value = BufferCacheMode::Disabled; } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { - return webgpu::BufferCacheMode::LazyRelease; + value = BufferCacheMode::LazyRelease; } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { - return webgpu::BufferCacheMode::Simple; + value = BufferCacheMode::Simple; } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { - return webgpu::BufferCacheMode::Bucket; + value = BufferCacheMode::Bucket; } else { - ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + ORT_THROW("Invalid buffer cache mode: ", buffer_cache_mode_str); } - } else { - return default_value; } }; - webgpu::WebGpuBufferCacheConfig buffer_cache_config; + WebGpuBufferCacheConfig& buffer_cache_config = config.buffer_cache_config; + parse_buffer_cache_mode(kStorageBufferCacheMode, buffer_cache_config.storage.mode); + parse_buffer_cache_mode(kUniformBufferCacheMode, buffer_cache_config.uniform.mode); + parse_buffer_cache_mode(kQueryResolveBufferCacheMode, buffer_cache_config.query_resolve.mode); + parse_buffer_cache_mode(kDefaultBufferCacheMode, buffer_cache_config.default_entry.mode); - buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, - webgpu::BufferCacheMode::Bucket); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << buffer_cache_config.storage.mode; - - buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, - webgpu::BufferCacheMode::Simple); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << buffer_cache_config.uniform.mode; + // power preference + if (std::string power_preference_str; + config_options.TryGetConfigEntry(kPowerPreference, power_preference_str)) { + if (power_preference_str == kPowerPreference_HighPerformance) { + config.power_preference = static_cast(WGPUPowerPreference_HighPerformance); + } else if (power_preference_str == kPowerPreference_LowPower) { + config.power_preference = static_cast(WGPUPowerPreference_LowPower); + } else { + ORT_THROW("Invalid power preference: ", power_preference_str); + } + } - buffer_cache_config.query_resolve.mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << buffer_cache_config.query_resolve.mode; + // backend type + if (std::string backend_type_str; + config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { + if (backend_type_str == kDawnBackendType_D3D12) { + config.backend_type = static_cast(WGPUBackendType_D3D12); + } else if (backend_type_str == kDawnBackendType_Vulkan) { + config.backend_type = static_cast(WGPUBackendType_Vulkan); + } else { + ORT_THROW("Invalid Dawn backend type: ", backend_type_str); + } + } - buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode; + // enable pix capture - bool enable_pix_capture = false; - std::string enable_pix_capture_str; - if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) { + if (std::string enable_pix_capture_str; + config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) { if (enable_pix_capture_str == kEnablePIXCapture_ON) { - enable_pix_capture = true; + config.enable_pix_capture = true; } else if (enable_pix_capture_str == kEnablePIXCapture_OFF) { - enable_pix_capture = false; + config.enable_pix_capture = false; } else { ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << enable_pix_capture; - // - // STEP.4 - start initialization. - // + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << config.buffer_cache_config.storage.mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << config.buffer_cache_config.uniform.mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << config.buffer_cache_config.query_resolve.mode; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << config.buffer_cache_config.default_entry.mode; - // Load the Dawn library and create the WebGPU instance. - auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP power preference: " << config.power_preference; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << config.backend_type; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << config.enable_pix_capture; + + return config; +} - // Create WebGPU device and initialize the context. - context.Initialize(buffer_cache_config, backend_type, enable_pix_capture); +} // namespace + +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { + // prepare WebGpuExecutionProviderConfig + WebGpuExecutionProviderConfig webgpu_ep_config = ParseEpConfig(config_options); + + // prepare WebGpuContextConfig + WebGpuContextConfig config = ParseWebGpuContextConfig(config_options); + + // Load the Dawn library and create the WebGPU instance. + auto& context = WebGpuContextFactory::CreateContext(config); // Create WebGPU EP factory. - return std::make_shared(context_id, context, std::move(webgpu_ep_config)); + return std::make_shared(config.context_id, context, std::move(webgpu_ep_config)); } // WebGPU DataTransfer implementation wrapper for the C API with lazy initialization @@ -406,16 +331,17 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { std::lock_guard lock(impl.init_mutex_); if (impl.data_transfer_ == nullptr) { // Always create a new context with context_id 0 - WebGpuContextParams params = GetDefaultWebGpuContextParams(); - params.context_config.context_id = impl.context_id_; - auto* context_ptr = &webgpu::WebGpuContextFactory::CreateContext(params.context_config); - context_ptr->Initialize(params.buffer_cache_config, params.backend_type, params.enable_pix_capture); + if (impl.context_id_ != 0) { + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Shared data transfer can only be created for the default device (0)."); + } + + auto& context = WebGpuContextFactory::DefaultContext(); // Create the DataTransfer instance // Note: The DataTransfer holds a const reference to BufferManager. The BufferManager's lifecycle // is managed by the WebGpuContext, which is stored in a static WebGpuContextFactory and persists // for the lifetime of the application, ensuring the reference remains valid. - impl.data_transfer_ = std::make_unique(context_ptr->BufferManager()); + impl.data_transfer_ = std::make_unique(context.BufferManager()); } } @@ -441,15 +367,15 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { } delete p_impl; if (data_transfer_initialized) { - webgpu::WebGpuContextFactory::ReleaseContext(context_id); + WebGpuContextFactory::ReleaseContext(context_id); } } const OrtApi& ort_api; const OrtEpApi& ep_api; - std::unique_ptr data_transfer_; // Lazy-initialized - int context_id_; // Track which context we're using - std::mutex init_mutex_; // Protects lazy initialization + std::unique_ptr data_transfer_; // Lazy-initialized + int context_id_; // Track which context we're using + std::mutex init_mutex_; // Protects lazy initialization }; OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() {