Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9013904
webgpu plugin pipeline
fs-eire Mar 24, 2026
23782b9
2
fs-eire Mar 24, 2026
0a8b240
linux
fs-eire Mar 24, 2026
f7c6c34
change npm feed
fs-eire Mar 25, 2026
4639345
fix linux build 2
fs-eire Mar 25, 2026
8cd1590
include version info
fs-eire Mar 25, 2026
794c6a3
fix build tool path
fs-eire Mar 25, 2026
1492f57
fix linux build 3
fs-eire Mar 25, 2026
ad6d885
exclude build tools from code sign
fs-eire Mar 25, 2026
781627e
dxc copy
fs-eire Mar 25, 2026
7334ea9
publish U package
fs-eire Mar 25, 2026
fe278cd
validate params
fs-eire Mar 25, 2026
e4f654d
use version string to detect ORT API version
fs-eire Mar 28, 2026
f6d38f5
update packaging pipeline
fs-eire Mar 28, 2026
113c823
fix build
fs-eire Mar 28, 2026
1b3cee6
resolve comments
fs-eire Mar 31, 2026
9037880
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-plugin…
fs-eire Mar 31, 2026
2a1ffff
fix build flag for arm64
fs-eire Apr 1, 2026
31c4f4c
apply fix according to comments
fs-eire Apr 1, 2026
4b97ea1
save current ort api version
fs-eire Apr 1, 2026
85eb070
resolve comments
fs-eire Apr 1, 2026
04278d9
fix build
fs-eire Apr 1, 2026
a81e76d
resolve comments
fs-eire Apr 1, 2026
4d8a762
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-plugin…
fs-eire Apr 2, 2026
ea3aa72
resolve comments
fs-eire Apr 2, 2026
92bfc30
fix build
fs-eire Apr 4, 2026
1214dd2
suppress MSVC ICE
fs-eire Apr 4, 2026
70174de
Address PR review feedback: null checks, binary verification, and nits
fs-eire Apr 4, 2026
1c90606
Add DXC hash verification and fail on RC versioning
fs-eire Apr 4, 2026
1b868f3
Update tools/ci_build/github/linux/build_webgpu_plugin_package.sh
fs-eire Apr 6, 2026
e0a4ce8
sign dxc dlls
fs-eire Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@
add_definitions("-DONNX_NAMESPACE=onnx")
add_definitions("-DONNX_USE_LITE_PROTO=1")

# Default plugin EP version to ORT_VERSION with "-dev" suffix if not explicitly provided.
if(NOT DEFINED onnxruntime_PLUGIN_EP_VERSION)
set(onnxruntime_PLUGIN_EP_VERSION "${ORT_VERSION}-dev")
endif()

# Set preprocessor definition for plugin EP version
target_compile_definitions(onnxruntime_providers_webgpu PRIVATE ORT_PLUGIN_EP_VERSION="${onnxruntime_PLUGIN_EP_VERSION}")
Comment thread
fs-eire marked this conversation as resolved.

# Set preprocessor definitions used in onnxruntime_providers_webgpu.rc
if(WIN32)
set(WEBGPU_DLL_FILE_DESCRIPTION "ONNX Runtime WebGPU Provider")
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/ep/adapter/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ struct OpKernelContext {
}
bool GetUseDeterministicCompute() const {
// TODO(fs-eire): Implement GetUseDeterministicCompute().
// if (CurrentOrtApiVersion() >= 25) {
// return /* TBD: wait for GetUseDeterministicCompute to be added in ORT API v25 */;
// } else {
return false;
// }
}
void* GetGPUComputeStream() const {
return context_.GetGPUComputeStream();
Expand Down
71 changes: 69 additions & 2 deletions include/onnxruntime/ep/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <charconv>
#include <cstring>
#include <mutex>
#include <optional>
#include <stdexcept>
Expand All @@ -26,8 +28,30 @@

namespace detail {
inline std::optional<ApiPtrs> g_api_ptrs;

inline bool TryGetAPIVersionFromVersionString(const char* version_str, uint32_t& api_version) {
// A valid version string should always be in the format of "1.{API_VERSION}.*".
if (!version_str || version_str[0] != '1' || version_str[1] != '.') {
return false;
}
const char* begin = version_str + 2;
const char* end = std::strchr(begin, '.');
if (!end) {
return false;
}
uint32_t version = 0;
auto [ptr, ec] = std::from_chars(begin, end, version);
if (ec != std::errc{} || ptr != end) {
return false;
}
api_version = version;
return true;
Comment thread
fs-eire marked this conversation as resolved.
}

inline uint32_t g_current_ort_api_version{};

} // namespace detail

/// <summary>
/// Get the global instance of ApiPtrs.
/// </summary>
Expand All @@ -45,16 +69,59 @@
inline void ApiInit(const OrtApiBase* ort_api_base) {
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Manual init for the C++ API
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
// The following initialization process is composed of 3 steps:
// 1) Get the ORT API version string
// 2) Try to parse the ORT API version from the version string. If parsing fails, we assume the version is 24.
// 3) Get the ORT API for the parsed version and initialize the global API instance with it.
constexpr uint32_t ORT_BASE_API_VERSION = 24;
const char* version_str = ort_api_base->GetVersionString();
if (!version_str) {
version_str = "unknown";
}
uint32_t current_ort_version = 0;
Comment thread
fs-eire marked this conversation as resolved.
if (!detail::TryGetAPIVersionFromVersionString(version_str, current_ort_version)) {
// If we fail to parse the version string, we can still try to get the API for the base version and hope it works.
current_ort_version = ORT_BASE_API_VERSION;
}
if (current_ort_version < ORT_BASE_API_VERSION) {
throw std::runtime_error("Failed to initialize EP API: the minimum required ORT API version is " + std::to_string(ORT_BASE_API_VERSION) +
", but the current version is \"" + version_str +
"\" (parsed API version: " + std::to_string(current_ort_version) + ").");
}

const OrtApi* ort_api = ort_api_base->GetApi(current_ort_version);
if (!ort_api) {
throw std::runtime_error("Failed to initialize EP API: the current ORT version is \"" + std::string(version_str) +

Check warning on line 94 in include/onnxruntime/ep/api.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/ep/api.h:94: Add #include <string> for string [build/include_what_you_use] [4]
"\" but it does not support the parsed API version " + std::to_string(current_ort_version) + ".");
}

detail::g_current_ort_api_version = current_ort_version;

const OrtEpApi* ep_api = ort_api->GetEpApi();
const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi();
Comment thread
fs-eire marked this conversation as resolved.
if (!ep_api || !model_editor_api) {
throw std::runtime_error("Failed to initialize EP API: GetEpApi or GetModelEditorApi returned null.");
}

// Manual init for the C++ API
Ort::InitApi(ort_api);

// Initialize the global API instance
detail::g_api_ptrs.emplace(*ort_api, *ep_api, *model_editor_api);
});
}

/// <summary>
/// Get the current ORT API version that the EP API has been initialized with.
///
/// This function should be called after ApiInit() to get the actual API version.
/// </summary>
inline uint32_t CurrentOrtApiVersion() {
if (!detail::g_api_ptrs.has_value()) {
throw std::logic_error("onnxruntime::ep::CurrentOrtApiVersion() called before ApiInit().");
}
return detail::g_current_ort_api_version;
Comment thread
fs-eire marked this conversation as resolved.
}

} // namespace ep
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/ep/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ uint32_t ORT_API_CALL Factory::GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/)
}

const char* ORT_API_CALL Factory::GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
return "0.1.0";
return ORT_PLUGIN_EP_VERSION;
}

OrtStatus* ORT_API_CALL Factory::GetSupportedDevicesImpl(
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/ort_apis.h"
#include "core/session/ort_env.h"
#include "core/session/ort_version_check.h"
#include "core/session/utils.h"

#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)
Expand Down Expand Up @@ -4830,6 +4831,9 @@ ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
}

ORT_API(const char*, OrtApis::GetVersionString) {
static_assert(onnxruntime::version_check::IsOrtVersionValid(ORT_VERSION),
"ORT_VERSION must be in the format '1.Y.Z' where Y and Z are non-negative integers without leading "
"zeros, and Y must equal ORT_API_VERSION");
return ORT_VERSION;
}

Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/core/session/ort_version_check.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>
#include <string_view>

#include "core/session/onnxruntime_c_api.h"

namespace onnxruntime::version_check {

// A simple consteval-friendly result type for ParseUint.
// std::optional triggers an internal compiler error in MSVC 14.44 when used with consteval.
struct ParseUintResult {
uint32_t value;
bool has_value;

consteval bool operator==(uint32_t other) const { return has_value && value == other; }
consteval bool operator!=(uint32_t other) const { return !(*this == other); }
};

inline consteval ParseUintResult ParseUintNone() { return {0, false}; }

// Parse a non-negative integer from a string_view without leading zeros.
// Returns a result with has_value == false on failure (empty, leading zero, non-digit, or overflow).
consteval ParseUintResult ParseUint(std::string_view str) {
if (str.empty()) return ParseUintNone();
// Leading zeros are not allowed (except "0" itself).
if (str.size() > 1 && str[0] == '0') return ParseUintNone();
uint64_t result = 0;
for (char c : str) {
if (c < '0' || c > '9') return ParseUintNone();
result = result * 10 + static_cast<uint64_t>(c - '0');
if (result > UINT32_MAX) return ParseUintNone();
}
return {static_cast<uint32_t>(result), true};
}

// Validates a version string at compile time.
// It must be in the format "1.Y.Z" where:
// - Major version is 1
// - Y and Z are non-negative integers without leading zeros
// - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION)
consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) {
size_t first_dot = version.find('.');
if (first_dot == std::string_view::npos) return false;
size_t second_dot = version.find('.', first_dot + 1);
if (second_dot == std::string_view::npos) return false;
if (version.find('.', second_dot + 1) != std::string_view::npos) return false; // Exactly two dots
std::string_view major = version.substr(0, first_dot);
std::string_view minor = version.substr(first_dot + 1, second_dot - first_dot - 1);
std::string_view patch = version.substr(second_dot + 1);
if (major != "1") {
return false;
}
auto minor_val = ParseUint(minor);
auto patch_val = ParseUint(patch);
if (!minor_val.has_value || !patch_val.has_value) {
return false;
}
if (minor_val.value != expected_api_version) {
return false;
}
return true;
}

} // namespace onnxruntime::version_check
79 changes: 56 additions & 23 deletions onnxruntime/test/shared_lib/test_version.cc
Original file line number Diff line number Diff line change
@@ -1,31 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "onnxruntime_config.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/ort_version_check.h"

#include <cstdint>
#include <charconv>
#include <optional>
#include <string>
#include <vector>

#include "absl/strings/str_split.h"
#include "gtest/gtest.h"

TEST(CApiTest, VersionConsistencyWithApiVersion) {
const auto version_string = Ort::GetVersionString();
const std::vector<std::string> version_string_components = absl::StrSplit(version_string, '.');
ASSERT_EQ(version_string_components.size(), size_t{3});

auto to_uint32_t = [](const std::string& s) -> std::optional<uint32_t> {
uint32_t result{};
if (std::from_chars(s.data(), s.data() + s.size(), result).ec == std::errc{}) {
return result;
}
return std::nullopt;
};

ASSERT_NE(to_uint32_t(version_string_components[0]), std::nullopt);
ASSERT_EQ(to_uint32_t(version_string_components[1]), uint32_t{ORT_API_VERSION});
ASSERT_NE(to_uint32_t(version_string_components[2]), std::nullopt);
using onnxruntime::version_check::IsOrtVersionValid;
using onnxruntime::version_check::ParseUint;

// Compile-time tests for ParseUint
static_assert(ParseUint("0") == 0u);
static_assert(ParseUint("1") == 1u);
static_assert(ParseUint("25") == 25u);
static_assert(ParseUint("123") == 123u);
static_assert(ParseUint("4294967295") == 4294967295u); // UINT32_MAX
static_assert(!(ParseUint("4294967296").has_value)); // UINT32_MAX + 1 overflows
static_assert(!(ParseUint("").has_value)); // empty
static_assert(!(ParseUint("01").has_value)); // leading zero
static_assert(!(ParseUint("00").has_value)); // leading zero
static_assert(!(ParseUint("abc").has_value)); // non-digit
static_assert(!(ParseUint("1a").has_value)); // trailing non-digit
static_assert(!(ParseUint("-1").has_value)); // negative sign
static_assert(!(ParseUint("1.0").has_value)); // contains dot
static_assert(ParseUint("0").has_value);
static_assert(!ParseUint("").has_value);

// Compile-time tests for IsOrtVersionValid (default expected_api_version = ORT_API_VERSION)
static_assert(IsOrtVersionValid(ORT_VERSION)); // current version must be valid

// Invalid formats
static_assert(!IsOrtVersionValid(""));
static_assert(!IsOrtVersionValid("1"));
static_assert(!IsOrtVersionValid("1.0"));
static_assert(!IsOrtVersionValid("1.0.0.0")); // too many dots
static_assert(!IsOrtVersionValid("2.0.0")); // major != 1
static_assert(!IsOrtVersionValid("1.02.0")); // leading zero in minor
static_assert(!IsOrtVersionValid("1.0.01")); // leading zero in patch
static_assert(!IsOrtVersionValid("1..0")); // empty minor
static_assert(!IsOrtVersionValid("1.0.")); // empty patch
static_assert(!IsOrtVersionValid(".1.0")); // empty major
static_assert(!IsOrtVersionValid("abc")); // non-numeric
static_assert(!IsOrtVersionValid("1.abc.0")); // non-numeric minor
static_assert(!IsOrtVersionValid("1.0.abc")); // non-numeric patch

// Compile-time tests for IsOrtVersionValid with explicit expected_api_version
static_assert(IsOrtVersionValid("1.0.0", 0));
static_assert(IsOrtVersionValid("1.1.0", 1));
static_assert(IsOrtVersionValid("1.25.0", 25));
static_assert(IsOrtVersionValid("1.25.3", 25));
static_assert(IsOrtVersionValid("1.100.0", 100));
static_assert(!IsOrtVersionValid("1.25.0", 24)); // minor doesn't match expected
static_assert(!IsOrtVersionValid("1.25.0", 26)); // minor doesn't match expected
static_assert(!IsOrtVersionValid("1.0.0", 1)); // minor 0 != expected 1
static_assert(!IsOrtVersionValid("2.0.0", 0)); // major != 1
static_assert(!IsOrtVersionValid("1.02.0", 2)); // leading zero in minor
static_assert(!IsOrtVersionValid("1.0.01", 0)); // leading zero in patch

TEST(CApiTest, VersionIsValid) {
// Runtime sanity check — the version string returned by the API is the expected one.
EXPECT_STREQ(Ort::GetVersionString().c_str(), ORT_VERSION);
}
Loading
Loading