-
Notifications
You must be signed in to change notification settings - Fork 4k
Add webgpu plugin EP pipeline #27841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
9013904
23782b9
0a8b240
f7c6c34
4639345
8cd1590
794c6a3
1492f57
ad6d885
781627e
7334ea9
fe278cd
e4f654d
f6d38f5
113c823
1b3cee6
9037880
2a1ffff
31c4f4c
4b97ea1
85eb070
04278d9
a81e76d
4d8a762
ea3aa72
92bfc30
1214dd2
70174de
1c90606
1b868f3
e0a4ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cstdint> | ||
| #include <optional> | ||
| #include <string_view> | ||
|
|
||
| #include "core/session/onnxruntime_c_api.h" | ||
|
|
||
| namespace onnxruntime::version_check { | ||
|
|
||
| // Parse a non-negative integer from a string_view without leading zeros. | ||
| // Returns std::nullopt on failure (empty, leading zero, non-digit, or overflow). | ||
| consteval std::optional<uint32_t> ParseUint(std::string_view str) { | ||
|
fs-eire marked this conversation as resolved.
Outdated
|
||
| if (str.empty()) return std::nullopt; | ||
| // Leading zeros are not allowed (except "0" itself). | ||
| if (str.size() > 1 && str[0] == '0') return std::nullopt; | ||
| uint64_t result = 0; | ||
| for (char c : str) { | ||
| if (c < '0' || c > '9') return std::nullopt; | ||
| result = result * 10 + static_cast<uint64_t>(c - '0'); | ||
| if (result > UINT32_MAX) return std::nullopt; | ||
| } | ||
| return static_cast<uint32_t>(result); | ||
| } | ||
|
|
||
| // Validates a version string at compile time. | ||
| // It must be in the format "1.Y.Z" where: | ||
| // - Major version is 1 | ||
| // - Y and Z are non-negative integers without leading zeros | ||
| // - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION) | ||
| consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) { | ||
| size_t first_dot = version.find('.'); | ||
| if (first_dot == std::string_view::npos) return false; | ||
| size_t second_dot = version.find('.', first_dot + 1); | ||
| if (second_dot == std::string_view::npos) return false; | ||
| if (version.find('.', second_dot + 1) != std::string_view::npos) return false; // Exactly two dots | ||
| std::string_view major = version.substr(0, first_dot); | ||
| std::string_view minor = version.substr(first_dot + 1, second_dot - first_dot - 1); | ||
| std::string_view patch = version.substr(second_dot + 1); | ||
| if (major != "1") { | ||
| return false; | ||
| } | ||
| auto minor_val = ParseUint(minor); | ||
| auto patch_val = ParseUint(patch); | ||
| if (!minor_val || !patch_val) { | ||
| return false; | ||
| } | ||
| if (*minor_val != expected_api_version) { | ||
| return false; | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| } // namespace onnxruntime::version_check | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,31 +1,64 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "onnxruntime_config.h" | ||
| #include "core/session/onnxruntime_cxx_api.h" | ||
| #include "core/session/ort_version_check.h" | ||
|
|
||
| #include <cstdint> | ||
| #include <charconv> | ||
| #include <optional> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "absl/strings/str_split.h" | ||
| #include "gtest/gtest.h" | ||
|
|
||
| TEST(CApiTest, VersionConsistencyWithApiVersion) { | ||
| const auto version_string = Ort::GetVersionString(); | ||
| const std::vector<std::string> version_string_components = absl::StrSplit(version_string, '.'); | ||
| ASSERT_EQ(version_string_components.size(), size_t{3}); | ||
|
|
||
| auto to_uint32_t = [](const std::string& s) -> std::optional<uint32_t> { | ||
| uint32_t result{}; | ||
| if (std::from_chars(s.data(), s.data() + s.size(), result).ec == std::errc{}) { | ||
| return result; | ||
| } | ||
| return std::nullopt; | ||
| }; | ||
|
|
||
| ASSERT_NE(to_uint32_t(version_string_components[0]), std::nullopt); | ||
| ASSERT_EQ(to_uint32_t(version_string_components[1]), uint32_t{ORT_API_VERSION}); | ||
| ASSERT_NE(to_uint32_t(version_string_components[0]), std::nullopt); | ||
| using onnxruntime::version_check::IsOrtVersionValid; | ||
| using onnxruntime::version_check::ParseUint; | ||
|
|
||
| // Compile-time tests for ParseUint | ||
| static_assert(ParseUint("0") == 0u); | ||
|
Check failure on line 14 in onnxruntime/test/shared_lib/test_version.cc
|
||
| static_assert(ParseUint("1") == 1u); | ||
|
Check failure on line 15 in onnxruntime/test/shared_lib/test_version.cc
|
||
| static_assert(ParseUint("25") == 25u); | ||
|
Check failure on line 16 in onnxruntime/test/shared_lib/test_version.cc
|
||
| static_assert(ParseUint("123") == 123u); | ||
|
Check failure on line 17 in onnxruntime/test/shared_lib/test_version.cc
|
||
| static_assert(ParseUint("4294967295") == 4294967295u); // UINT32_MAX | ||
|
Check failure on line 18 in onnxruntime/test/shared_lib/test_version.cc
|
||
| static_assert(ParseUint("4294967296") == std::nullopt); // UINT32_MAX + 1 overflows | ||
| static_assert(ParseUint("") == std::nullopt); // empty | ||
| static_assert(ParseUint("01") == std::nullopt); // leading zero | ||
| static_assert(ParseUint("00") == std::nullopt); // leading zero | ||
| static_assert(ParseUint("abc") == std::nullopt); // non-digit | ||
| static_assert(ParseUint("1a") == std::nullopt); // trailing non-digit | ||
| static_assert(ParseUint("-1") == std::nullopt); // negative sign | ||
| static_assert(ParseUint("1.0") == std::nullopt); // contains dot | ||
| static_assert(ParseUint("0").has_value()); | ||
| static_assert(!ParseUint("").has_value()); | ||
|
|
||
| // Compile-time tests for IsOrtVersionValid (default expected_api_version = ORT_API_VERSION) | ||
| static_assert(IsOrtVersionValid(ORT_VERSION)); // current version must be valid | ||
|
|
||
| // Invalid formats | ||
| static_assert(!IsOrtVersionValid("")); | ||
| static_assert(!IsOrtVersionValid("1")); | ||
| static_assert(!IsOrtVersionValid("1.0")); | ||
| static_assert(!IsOrtVersionValid("1.0.0.0")); // too many dots | ||
| static_assert(!IsOrtVersionValid("2.0.0")); // major != 1 | ||
| static_assert(!IsOrtVersionValid("1.02.0")); // leading zero in minor | ||
| static_assert(!IsOrtVersionValid("1.0.01")); // leading zero in patch | ||
| static_assert(!IsOrtVersionValid("1..0")); // empty minor | ||
| static_assert(!IsOrtVersionValid("1.0.")); // empty patch | ||
| static_assert(!IsOrtVersionValid(".1.0")); // empty major | ||
| static_assert(!IsOrtVersionValid("abc")); // non-numeric | ||
| static_assert(!IsOrtVersionValid("1.abc.0")); // non-numeric minor | ||
| static_assert(!IsOrtVersionValid("1.0.abc")); // non-numeric patch | ||
|
|
||
| // Compile-time tests for IsOrtVersionValid with explicit expected_api_version | ||
| static_assert(IsOrtVersionValid("1.0.0", 0)); | ||
| static_assert(IsOrtVersionValid("1.1.0", 1)); | ||
| static_assert(IsOrtVersionValid("1.25.0", 25)); | ||
| static_assert(IsOrtVersionValid("1.25.3", 25)); | ||
| static_assert(IsOrtVersionValid("1.100.0", 100)); | ||
| static_assert(!IsOrtVersionValid("1.25.0", 24)); // minor doesn't match expected | ||
| static_assert(!IsOrtVersionValid("1.25.0", 26)); // minor doesn't match expected | ||
| static_assert(!IsOrtVersionValid("1.0.0", 1)); // minor 0 != expected 1 | ||
| static_assert(!IsOrtVersionValid("2.0.0", 0)); // major != 1 | ||
| static_assert(!IsOrtVersionValid("1.02.0", 2)); // leading zero in minor | ||
| static_assert(!IsOrtVersionValid("1.0.01", 0)); // leading zero in patch | ||
|
|
||
| TEST(CApiTest, VersionIsValid) { | ||
| // Runtime sanity check — the version string returned by the API is the expected one. | ||
| EXPECT_STREQ(Ort::GetVersionString(), ORT_VERSION); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.