Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3613,12 +3613,6 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
};

namespace detail {
/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`.
*
* Holds a single type constraint from an operator schema, providing access to
* the constraint's name, allowed data types, and associated input/output indices.
* This is a non-owning view — the lifetime is tied to the parent OrtOpSchema.
*/
template <typename T>
struct OpSchemaTypeConstraintImpl : Base<T> {
using B = Base<T>;
Expand All @@ -3639,15 +3633,11 @@ struct OpSchemaTypeConstraintImpl : Base<T> {
} // namespace detail

/// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`.
/// Holds a single type constraint from an operator schema, providing access to
/// the constraint's name, allowed data types, and associated input/output indices.
Comment thread
adrianlizarraga marked this conversation as resolved.
using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl<detail::Unowned<const OrtOpSchemaTypeConstraint>>;

namespace detail {
/** \brief Owning wrapper around an `OrtOpSchema*`.
*
* Provides access to operator schema metadata such as version, input/output names,
* and type constraints. The underlying OrtOpSchema is owned by this wrapper and
* released automatically on destruction.
*/
template <typename T>
struct OpSchemaImpl : Base<T> {
using B = Base<T>;
Expand Down Expand Up @@ -3685,6 +3675,9 @@ struct OpSchemaImpl : Base<T> {
} // namespace detail

/// Owning wrapper around an `OrtOpSchema*`.
/// Provides access to operator schema metadata such as version, input/output names,
/// and type constraints. The underlying OrtOpSchema is owned by this wrapper and
/// released automatically on destruction.
using OpSchema = detail::OpSchemaImpl<OrtOpSchema>;

/// \brief Get an operator schema from the global schema registry.
Expand Down
17 changes: 5 additions & 12 deletions onnxruntime/test/framework/ep_plugin_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,11 +876,8 @@ TEST(OpSchemaTypeConstraintTest, Add_SingleConstraint) {
// T should allow tensor(float) and tensor(double) among others
auto allowed_types = tc.GetAllowedTypes();
EXPECT_GT(allowed_types.size(), 1u);
auto has_type = [&](const char* t) {
return std::find(allowed_types.begin(), allowed_types.end(), t) != allowed_types.end();
};
EXPECT_TRUE(has_type("tensor(float)")) << "Expected T to allow tensor(float)";
EXPECT_TRUE(has_type("tensor(double)")) << "Expected T to allow tensor(double)";
EXPECT_THAT(allowed_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)";
EXPECT_THAT(allowed_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)";

// Both inputs use T
auto input_indices = tc.GetInputIndices();
Expand Down Expand Up @@ -921,22 +918,18 @@ TEST(OpSchemaTypeConstraintTest, LSTM_MultipleConstraints) {
ASSERT_NE(t_ptr, nullptr) << "Expected to find type constraint 'T'";
ASSERT_NE(t1_ptr, nullptr) << "Expected to find type constraint 'T1'";

auto has_type = [](gsl::span<const std::string> types, const char* t) {
return std::find(types.begin(), types.end(), t) != types.end();
};

// T should include tensor(float) and tensor(double)
auto t_types = t_tc.GetAllowedTypes();
EXPECT_GT(t_types.size(), 0u);
EXPECT_TRUE(has_type(t_types, "tensor(float)")) << "Expected T to allow tensor(float)";
EXPECT_TRUE(has_type(t_types, "tensor(double)")) << "Expected T to allow tensor(double)";
EXPECT_THAT(t_types, ::testing::Contains("tensor(float)")) << "Expected T to allow tensor(float)";
EXPECT_THAT(t_types, ::testing::Contains("tensor(double)")) << "Expected T to allow tensor(double)";

// T1 should include tensor(int32) (sequence_lens is int32)
auto t1_types = t1_tc.GetAllowedTypes();
EXPECT_GT(t1_types.size(), 0u);

// T1 is for sequence_lens which is int32
EXPECT_TRUE(has_type(t1_types, "tensor(int32)")) << "Expected T1 to allow tensor(int32)";
EXPECT_THAT(t1_types, ::testing::Contains("tensor(int32)")) << "Expected T1 to allow tensor(int32)";

// T should map to inputs X (0), W (1), R (2), B (3), initial_h (5), initial_c (6), P (7)
auto t_inputs = t_tc.GetInputIndices();
Expand Down
Loading