Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace webgpu {

void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
if (device_ == nullptr) {
// Create wgpu::Adapter
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
//
Expand Down Expand Up @@ -77,20 +77,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
req_adapter_options.nextInChain = &adapter_toggles_desc;
#endif

wgpu::Adapter adapter;
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(
&req_adapter_options,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) {
ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message});
*ptr = adapter;
*ptr = std::move(adapter);
},
&adapter_),
&adapter),
UINT64_MAX));
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
}
ORT_ENFORCE(adapter != nullptr, "Failed to get a WebGPU adapter.");

// Create wgpu::Device
if (device_ == nullptr) {
// Create wgpu::Device
wgpu::DeviceDescriptor device_desc = {};

#if !defined(__wasm__)
Expand All @@ -106,12 +105,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
device_toggles_desc.disabledToggles = disabled_device_toggles.data();
#endif

std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter_);
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter);
if (required_features.size() > 0) {
device_desc.requiredFeatures = required_features.data();
device_desc.requiredFeatureCount = required_features.size();
}
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_);
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter);
device_desc.requiredLimits = &required_limits;

// TODO: revise temporary error handling
Expand All @@ -123,20 +122,20 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
});

ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter.RequestDevice(
&device_desc,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) {
ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message});
*ptr = device;
*ptr = std::move(device);
},
&device_),
UINT64_MAX));
ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device.");
}

// cache adapter info
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_));
// cache device limits
wgpu::SupportedLimits device_supported_limits;
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
Expand Down Expand Up @@ -706,13 +705,12 @@ wgpu::Instance WebGpuContextFactory::default_instance_;
WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
const int context_id = config.context_id;
WGPUInstance instance = config.instance;
WGPUAdapter adapter = config.adapter;
WGPUDevice device = config.device;

if (context_id == 0) {
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");
ORT_ENFORCE(instance == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");

std::call_once(init_default_flag_, [
#if !defined(__wasm__)
Expand Down Expand Up @@ -750,23 +748,22 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
});
instance = default_instance_.Get();
} else {
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device.");
// for context ID > 0, user must provide custom WebGPU instance and device.
ORT_ENFORCE(instance != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device.");
}

std::lock_guard<std::mutex> lock(mutex_);

auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
it->second.context->adapter_.Get() == adapter &&
it->second.context->device_.Get() == device,
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
}
it->second.ref_count++;
return *it->second.context;
Expand Down
7 changes: 2 additions & 5 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ProgramBase;
struct WebGpuContextConfig {
int context_id;
WGPUInstance instance;
WGPUAdapter adapter;
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
Expand Down Expand Up @@ -76,7 +75,6 @@ class WebGpuContext final {

Status Wait(wgpu::Future f);

const wgpu::Adapter& Adapter() const { return adapter_; }
const wgpu::Device& Device() const { return device_; }

const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
Expand Down Expand Up @@ -149,8 +147,8 @@ class WebGpuContext final {
AtPasses
};

WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);

std::vector<const char*> GetEnabledAdapterToggles() const;
Expand Down Expand Up @@ -198,7 +196,6 @@ class WebGpuContext final {
LibraryHandles modules_;

wgpu::Instance instance_;
wgpu::Adapter adapter_;
wgpu::Device device_;

webgpu::ValidationMode validation_mode_;
Expand Down
11 changes: 1 addition & 10 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec);
}

size_t webgpu_adapter = 0;
std::string webgpu_adapter_str;
if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) {
static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch");
ORT_ENFORCE(std::errc{} ==
std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec);
}

size_t webgpu_device = 0;
std::string webgpu_device_str;
if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) {
Expand Down Expand Up @@ -154,7 +146,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
webgpu::WebGpuContextConfig context_config{
context_id,
reinterpret_cast<WGPUInstance>(webgpu_instance),
reinterpret_cast<WGPUAdapter>(webgpu_adapter),
reinterpret_cast<WGPUDevice>(webgpu_device),
reinterpret_cast<const void*>(dawn_proc_table),
validation_mode,
Expand Down Expand Up @@ -238,7 +229,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
// STEP.4 - start initialization.
//

// Load the Dawn library and create the WebGPU instance and adapter.
// Load the Dawn library and create the WebGPU instance.
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);

// Create WebGPU device and initialize the context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType";

constexpr const char* kDeviceId = "WebGPU:deviceId";
constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance";
constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter";
constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice";

constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode";
Expand Down
Loading