From 5bd6cf4a72162aeb5b1e25a678c90c7e600d0ce9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 31 Mar 2026 20:50:34 -0700 Subject: [PATCH 1/2] Cleanup for op schema API tests for plugin EPs --- .../core/session/onnxruntime_cxx_api.h | 17 +++++------------ .../test/framework/ep_plugin_provider_test.cc | 17 +++++------------ 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a938688fcfd5a..e457a2a57065e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3613,12 +3613,6 @@ struct KernelRegistry : detail::Base { }; namespace detail { -/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. - * - * Holds a single type constraint from an operator schema, providing access to - * the constraint's name, allowed data types, and associated input/output indices. - * This is a non-owning view — the lifetime is tied to the parent OrtOpSchema. - */ template struct OpSchemaTypeConstraintImpl : Base { using B = Base; @@ -3639,15 +3633,11 @@ struct OpSchemaTypeConstraintImpl : Base { } // namespace detail /// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. +/// Holds a single type constraint from an operator schema, providing access to +/// the constraint's name, allowed data types, and associated input/output indices. using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl>; namespace detail { -/** \brief Owning wrapper around an `OrtOpSchema*`. - * - * Provides access to operator schema metadata such as version, input/output names, - * and type constraints. The underlying OrtOpSchema is owned by this wrapper and - * released automatically on destruction. - */ template struct OpSchemaImpl : Base { using B = Base; @@ -3685,6 +3675,9 @@ struct OpSchemaImpl : Base { } // namespace detail /// Owning wrapper around an `OrtOpSchema*`. +/// Provides access to operator schema metadata such as version, input/output names, +/// and type constraints. The underlying OrtOpSchema is owned by this wrapper and +/// released automatically on destruction. using OpSchema = detail::OpSchemaImpl; /// \brief Get an operator schema from the global schema registry. diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index da958ba6fc970..d192a3a772cf5 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -876,11 +876,8 @@ TEST(OpSchemaTypeConstraintTest, Add_SingleConstraint) { // T should allow tensor(float) and tensor(double) among others auto allowed_types = tc.GetAllowedTypes(); EXPECT_GT(allowed_types.size(), 1u); - auto has_type = [&](const char* t) { - return std::find(allowed_types.begin(), allowed_types.end(), t) != allowed_types.end(); - }; - EXPECT_TRUE(has_type("tensor(float)")) << "Expected T to allow tensor(float)"; - EXPECT_TRUE(has_type("tensor(double)")) << "Expected T to allow tensor(double)"; + EXPECT_THAT(allowed_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_THAT(allowed_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)"; // Both inputs use T auto input_indices = tc.GetInputIndices(); @@ -921,22 +918,18 @@ TEST(OpSchemaTypeConstraintTest, LSTM_MultipleConstraints) { ASSERT_NE(t_ptr, nullptr) << "Expected to find type constraint 'T'"; ASSERT_NE(t1_ptr, nullptr) << "Expected to find type constraint 'T1'"; - auto has_type = [](gsl::span types, const char* t) { - return std::find(types.begin(), types.end(), t) != types.end(); - }; - // T should include tensor(float) and tensor(double) auto t_types = t_tc.GetAllowedTypes(); EXPECT_GT(t_types.size(), 0u); - EXPECT_TRUE(has_type(t_types, "tensor(float)")) << "Expected T to allow tensor(float)"; - EXPECT_TRUE(has_type(t_types, "tensor(double)")) << "Expected T to allow tensor(double)"; + EXPECT_THAT(t_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_THAT(t_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)"; // T1 should include tensor(int32) (sequence_lens is int32) auto t1_types = t1_tc.GetAllowedTypes(); EXPECT_GT(t1_types.size(), 0u); // T1 is for sequence_lens which is int32 - EXPECT_TRUE(has_type(t1_types, "tensor(int32)")) << "Expected T1 to allow tensor(int32)"; + EXPECT_TRUE(t1_types, ::testing::Contains("tensor(int32)")) << "Expected T1 to allow tensor(int32)"; // T should map to inputs X (0), W (1), R (2), B (3), initial_h (5), initial_c (6), P (7) auto t_inputs = t_tc.GetInputIndices(); From 92d1868da6dd7ba4eaf922af8f6fce6b593712e8 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 31 Mar 2026 20:58:05 -0700 Subject: [PATCH 2/2] Use EXPECT_THAT --- onnxruntime/test/framework/ep_plugin_provider_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index d192a3a772cf5..9640d94aebe58 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -929,7 +929,7 @@ TEST(OpSchemaTypeConstraintTest, LSTM_MultipleConstraints) { EXPECT_GT(t1_types.size(), 0u); // T1 is for sequence_lens which is int32 - EXPECT_TRUE(t1_types, ::testing::Contains("tensor(int32)")) << "Expected T1 to allow tensor(int32)"; + EXPECT_THAT(t1_types, ::testing::Contains("tensor(int32)")) << "Expected T1 to allow tensor(int32)"; // T should map to inputs X (0), W (1), R (2), B (3), initial_h (5), initial_c (6), P (7) auto t_inputs = t_tc.GetInputIndices();