Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 346 additions & 0 deletions CMakeLists.txt

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript

parser.add_argument("--use_dml", action="store_true", help="Whether to use DML. Default is to not use DML.")

parser.add_argument("--use_webgpu", action="store_true", help="Whether to use WebGPU. Default is to not use WebGPU.")

parser.add_argument("--use_guidance", action="store_true", help="Whether to add guidance support. Default is False.")

# The following options are mutually exclusive (cross compiling options such as android, ios, etc.)
Expand Down Expand Up @@ -527,6 +529,7 @@ def update(args: argparse.Namespace, env: dict[str, str]):
f"-DUSE_TRT_RTX={'ON' if args.use_trt_rtx else 'OFF'}",
f"-DUSE_ROCM={'ON' if args.use_rocm else 'OFF'}",
f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}",
f"-DUSE_WEBGPU={'ON' if args.use_webgpu else 'OFF'}",
f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}",
f"-DBUILD_WHEEL={build_wheel}",
f"-DUSE_GUIDANCE={'ON' if args.use_guidance else 'OFF'}",
Expand Down
1 change: 1 addition & 0 deletions cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ option(USE_CUDA "Build with CUDA support" ON)
option(USE_TRT_RTX "Build with TensorRT-RTX support" OFF)
option(USE_ROCM "Build with ROCm support" ON)
option(USE_DML "Build with DML support" OFF)
option(USE_WEBGPU "Build with WebGPU support" OFF)
option(USE_WINML "Build with WinML support" OFF)
option(USE_GUIDANCE "Build with guidance support" OFF)

Expand Down
7 changes: 7 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,13 @@ bool IsGraphCaptureEnabled(const Config::SessionOptions& session_options) {
}
} else if (provider_options->name == "DML") {
return true;
} else if (provider_options->name == "WebGPU") {
for (const auto& value : provider_options->options) {
if (value.first == "enableGraphCapture" && value.second == "1") {
return true;
}
}
return false;
} else if (provider_options->name == "NvTensorRtRtx") {
for (const auto& value : provider_options->options) {
if (value.first == "enable_cuda_graph" && value.second == "1") {
Expand Down
18 changes: 15 additions & 3 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "marian.h"
#include "decoder_only_pipeline.h"
#include "../dml/interface.h"
#include "../webgpu/interface.h"

#if defined(_WIN32)
#include <direct.h>
Expand Down Expand Up @@ -577,9 +578,18 @@ DeviceInterface* SetProviderSessionOptions(OrtSessionOptions& session_options,
opt_it != provider_options.options.end() && opt_it->second == "1") {
p_device = GetDeviceInterface(DeviceType::QNN);
}
} else if (provider_options.name == "WebGPU")
} else if (provider_options.name == "WebGPU") {
p_device = GetDeviceInterface(DeviceType::WEBGPU);
else if (provider_options.name == "OpenVINO")
// Convert provider options to unordered_map for SetWebGPUProvider
std::unordered_map<std::string, std::string> webgpu_options;
for (const auto& option : provider_options.options) {
webgpu_options[option.first] = option.second;
}

// Use the new SetWebGPUProvider function for enhanced provider setup
SetWebGPUProvider(session_options, webgpu_options);
continue; // Skip the generic AppendExecutionProvider below
} else if (provider_options.name == "OpenVINO")
p_device = GetDeviceInterface(DeviceType::OpenVINO);
else if (provider_options.name == "VitisAI") {
session_options.AddConfigEntry("session.inter_op.allow_spinning", "0");
Expand Down Expand Up @@ -871,7 +881,9 @@ Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
EnsureDeviceOrtInit(*p_device_, *config_);

// Only CUDA, TRT-RTX and DML does every input on the device
if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML || p_device_->GetType() == DeviceType::NvTensorRtRtx)
// For WebGPU, use device memory only if graph capture is enabled, otherwise use CPU
if (p_device_->GetType() == DeviceType::CUDA || p_device_->GetType() == DeviceType::DML || p_device_->GetType() == DeviceType::NvTensorRtRtx ||
(p_device_->GetType() == DeviceType::WEBGPU && IsGraphCaptureEnabled(config_->model.decoder.session_options)))
p_device_inputs_ = p_device_;
else
p_device_inputs_ = GetDeviceInterface(DeviceType::CPU);
Expand Down
Loading
Loading