diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 926f39f3cebc0..ecdac25af8003 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -661,6 +661,7 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. @@ -3521,6 +3522,88 @@ struct KernelRegistry : detail::Base { void* kernel_create_func_state); }; +namespace detail { +/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. + * + * Holds a single type constraint from an operator schema, providing access to + * the constraint's name, allowed data types, and associated input/output indices. + * This is a non-owning view — the lifetime is tied to the parent OrtOpSchema. + */ +template +struct OpSchemaTypeConstraintImpl : Base { + using B = Base; + using B::B; + + ///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetTypeParamName + std::string GetTypeParamName() const; + + ///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetAllowedTypes + std::vector GetAllowedTypes() const; + + ///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetInputIndices + std::vector GetInputIndices() const; + + ///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetOutputIndices + std::vector GetOutputIndices() const; +}; +} // namespace detail + +/// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`. +using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl>; + +namespace detail { +/** \brief Owning wrapper around an `OrtOpSchema*`. + * + * Provides access to operator schema metadata such as version, input/output names, + * and type constraints. The underlying OrtOpSchema is owned by this wrapper and + * released automatically on destruction. + */ +template +struct OpSchemaImpl : Base { + using B = Base; + using B::B; + + ///< Wraps OrtEpApi::OpSchema_GetSinceVersion + int GetSinceVersion() const; + + ///< Wraps OrtEpApi::OpSchema_GetNumInputs + size_t GetNumInputs() const; + + ///< Wraps OrtEpApi::OpSchema_GetInputName + std::string GetInputName(size_t index) const; + + ///< Wraps OrtEpApi::OpSchema_GetInputTypeConstraint. Returns the type constraint for the given input, + ///< or a wrapper around nullptr if the input has no type constraint. + ConstOpSchemaTypeConstraint GetInputTypeConstraint(size_t index) const; + + ///< Wraps OrtEpApi::OpSchema_GetNumOutputs + size_t GetNumOutputs() const; + + ///< Wraps OrtEpApi::OpSchema_GetOutputName + std::string GetOutputName(size_t index) const; + + ///< Wraps OrtEpApi::OpSchema_GetOutputTypeConstraint. Returns the type constraint for the given output, + ///< or a wrapper around nullptr if the output has no type constraint. + ConstOpSchemaTypeConstraint GetOutputTypeConstraint(size_t index) const; + + ///< Wraps OrtEpApi::OpSchema_GetTypeConstraintCount + size_t GetTypeConstraintCount() const; + + ///< Wraps OrtEpApi::OpSchema_GetTypeConstraint. Returns the i-th type constraint. + ConstOpSchemaTypeConstraint GetTypeConstraint(size_t index) const; +}; +} // namespace detail + +/// Owning wrapper around an `OrtOpSchema*`. +using OpSchema = detail::OpSchemaImpl; + +/// \brief Get an operator schema from the global schema registry. +/// +/// Wraps OrtEpApi::GetOpSchema. Returns an OpSchema that may wrap nullptr if the schema is not found. +/// Available schemas include standard ONNX ops (domain "" or "ai.onnx"), ONNX ML ops ("ai.onnx.ml"), +/// and ORT contrib ops ("com.microsoft"). +OpSchema GetOpSchema(const char* name, int max_inclusive_version, const char* domain); + namespace detail { template struct SharedPrePackedWeightCacheImpl : Ort::detail::Base { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 938d596cc56df..284304a12e4c8 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -3946,4 +3946,113 @@ inline Ort::KeyValuePairs GetEnvConfigEntries() { return Ort::KeyValuePairs{entries}; } + +namespace detail { +template +inline int OpSchemaImpl::GetSinceVersion() const { + int version = 0; + ThrowOnError(GetEpApi().OpSchema_GetSinceVersion(this->p_, &version)); + return version; +} + +template +inline size_t OpSchemaImpl::GetNumInputs() const { + size_t num = 0; + ThrowOnError(GetEpApi().OpSchema_GetNumInputs(this->p_, &num)); + return num; +} + +template +inline std::string OpSchemaImpl::GetInputName(size_t index) const { + const char* name = nullptr; + ThrowOnError(GetEpApi().OpSchema_GetInputName(this->p_, index, &name)); + return std::string(name); +} + +template +inline Ort::ConstOpSchemaTypeConstraint OpSchemaImpl::GetInputTypeConstraint(size_t index) const { + const OrtOpSchemaTypeConstraint* tc = nullptr; + ThrowOnError(GetEpApi().OpSchema_GetInputTypeConstraint(this->p_, index, &tc)); + return Ort::ConstOpSchemaTypeConstraint{tc}; +} + +template +inline size_t OpSchemaImpl::GetNumOutputs() const { + size_t num = 0; + ThrowOnError(GetEpApi().OpSchema_GetNumOutputs(this->p_, &num)); + return num; +} + +template +inline std::string OpSchemaImpl::GetOutputName(size_t index) const { + const char* name = nullptr; + ThrowOnError(GetEpApi().OpSchema_GetOutputName(this->p_, index, &name)); + return std::string(name); +} + +template +inline Ort::ConstOpSchemaTypeConstraint OpSchemaImpl::GetOutputTypeConstraint(size_t index) const { + const OrtOpSchemaTypeConstraint* tc = nullptr; + ThrowOnError(GetEpApi().OpSchema_GetOutputTypeConstraint(this->p_, index, &tc)); + return Ort::ConstOpSchemaTypeConstraint{tc}; +} + +template +inline size_t OpSchemaImpl::GetTypeConstraintCount() const { + size_t count = 0; + ThrowOnError(GetEpApi().OpSchema_GetTypeConstraintCount(this->p_, &count)); + return count; +} + +template +inline Ort::ConstOpSchemaTypeConstraint OpSchemaImpl::GetTypeConstraint(size_t index) const { + const OrtOpSchemaTypeConstraint* tc = nullptr; + ThrowOnError(GetEpApi().OpSchema_GetTypeConstraint(this->p_, index, &tc)); + return Ort::ConstOpSchemaTypeConstraint{tc}; +} + +template +inline std::string OpSchemaTypeConstraintImpl::GetTypeParamName() const { + const char* name = nullptr; + ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetTypeParamName(this->p_, &name)); + return std::string(name); +} + +template +inline std::vector OpSchemaTypeConstraintImpl::GetAllowedTypes() const { + const char* const* types = nullptr; + size_t num_types = 0; + ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetAllowedTypes(this->p_, &types, &num_types)); + std::vector result; + result.reserve(num_types); + for (size_t i = 0; i < num_types; ++i) { + result.emplace_back(types[i]); + } + return result; +} + +template +inline std::vector OpSchemaTypeConstraintImpl::GetInputIndices() const { + const size_t* indices = nullptr; + size_t count = 0; + ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetInputIndices(this->p_, &indices, &count)); + if (count == 0) return {}; + return std::vector(indices, indices + count); +} + +template +inline std::vector OpSchemaTypeConstraintImpl::GetOutputIndices() const { + const size_t* indices = nullptr; + size_t count = 0; + ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetOutputIndices(this->p_, &indices, &count)); + if (count == 0) return {}; + return std::vector(indices, indices + count); +} +} // namespace detail + +inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const char* domain) { + OrtOpSchema* schema = nullptr; + ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema)); + return OpSchema{schema}; +} } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 81cb32a2321e0..499e68e360889 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -26,6 +26,9 @@ ORT_RUNTIME_CLASS(SyncStreamImpl); ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); +ORT_RUNTIME_CLASS(OpSchema); +ORT_RUNTIME_CLASS(OpSchemaTypeConstraint); + /** \brief Base struct for imported external memory handles. * * EPs derive from this struct to add EP-specific fields (e.g., CUdeviceptr for CUDA). @@ -1444,6 +1447,238 @@ struct OrtEpApi { * \since Version 1.24 */ ORT_API2_STATUS(GetEnvConfigEntries, _Outptr_ OrtKeyValuePairs** config_entries); + + /** \brief Get an operator schema from the global schema registry. + * + * Looks up a schema by name, maximum inclusive version, and domain. + * The returned pointer is owned by the caller and must be released via ReleaseOpSchema. + * If the schema is not found, *out_schema is set to nullptr (no allocation occurs). + * + * Available schemas include standard ONNX operators (domain "" or "ai.onnx"), ONNX ML operators + * (domain "ai.onnx.ml"), and ORT contrib operators (domain "com.microsoft"). + * + * \param[in] name A null-terminated string for the operator name. + * \param[in] max_inclusive_version The maximum inclusive opset version. + * \param[in] domain A null-terminated string for the operator domain. + * \param[out] out_schema Output parameter set to the schema pointer, or nullptr if not found. + * Must be released via OrtEpApi::ReleaseOpSchema. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(GetOpSchema, _In_ const char* name, _In_ int max_inclusive_version, + _In_ const char* domain, _Outptr_result_maybenull_ OrtOpSchema** out_schema); + + ORT_CLASS_RELEASE(OpSchema); + + /** \brief Get the first ONNX opset version that introduced this operator schema. + * + * If an operator has had no changes that break backwards compatibility, the `since_version` is + * just the first opset version that introduced the operator. However, if the operator has had breaking changes, + * then `since_version` corresponds to the opset version that introduced the breaking change. + * + * For example, suppose operator "Foo" was added in version 3 and had a breaking change in version 6. + * Then, there will be an operator schema entry for "Foo" with a since_version of 3 and another updated + * operator schema entry for "Foo" with a since_version of 6. + * + * \param[in] schema The OrtOpSchema instance. + * \param[out] out Output parameter set to the ONNX opset version. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetSinceVersion, _In_ const OrtOpSchema* schema, _Out_ int* out); + + /** \brief Get the number of inputs defined by the operator schema. + * + * \param[in] schema The OrtOpSchema instance. + * \param[out] out Output parameter set to the number of inputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetNumInputs, _In_ const OrtOpSchema* schema, _Out_ size_t* out); + + /** \brief Get the name of the i-th input formal parameter from an operator schema. + * + * \param[in] schema The OrtOpSchema instance. + * \param[in] index Zero-based index of the input parameter. + * \param[out] out Output parameter set to the name of the input parameter (null-terminated UTF8 string). + * Valid as long as the OrtOpSchema exists. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetInputName, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const char** out); + + /** \brief Get the type constraint for the i-th input formal parameter from an operator schema. + * + * Returns a non-owning pointer to the OrtOpSchemaTypeConstraint associated with the given input. + * The returned pointer is valid as long as the parent OrtOpSchema is alive. + * If the input has no type constraint, *out is set to nullptr. + * + * Multiple inputs sharing the same type constraint (e.g., both using "T") return the same pointer. + * + * \param[in] schema The OrtOpSchema instance. + * \param[in] index Zero-based index of the input parameter. + * \param[out] out Output parameter set to the type constraint, or NULL if the input has no type constraint. + * Valid as long as the OrtOpSchema exists. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetInputTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_result_maybenull_ const OrtOpSchemaTypeConstraint** out); + + /** \brief Get the number of outputs defined by the operator schema. + * + * \param[in] schema The OrtOpSchema instance. + * \param[out] out Output parameter set to the number of outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetNumOutputs, _In_ const OrtOpSchema* schema, _Out_ size_t* out); + + /** \brief Get the name of the i-th output formal parameter from an operator schema. + * + * \param[in] schema The OrtOpSchema instance. + * \param[in] index Zero-based index of the output parameter. + * \param[out] out Output parameter set to the name of the output parameter (null-terminated UTF8 string). + * Valid as long as the OrtOpSchema exists. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetOutputName, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const char** out); + + /** \brief Get the type constraint for the i-th output formal parameter from an operator schema. + * + * Returns a non-owning pointer to the OrtOpSchemaTypeConstraint associated with the given output. + * The returned pointer is valid as long as the parent OrtOpSchema is alive. + * If the output has no type constraint, *out is set to nullptr. + * + * Multiple outputs sharing the same type constraint return the same pointer. + * Pointer equality can be used to check if two outputs share a type constraint. + * + * \param[in] schema The OrtOpSchema instance. + * \param[in] index Zero-based index of the output parameter. + * \param[out] out Output parameter set to the type constraint, or NULL if the output has no type constraint. + * Valid as long as the OrtOpSchema exists. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetOutputTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_result_maybenull_ const OrtOpSchemaTypeConstraint** out); + + /** \brief Get the number of unique type constraints in the operator schema. + * + * \param[in] schema The OrtOpSchema instance. + * \param[out] out Output set to the number of type constraints. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetTypeConstraintCount, _In_ const OrtOpSchema* schema, _Out_ size_t* out); + + /** \brief Get the i-th type constraint from the operator schema. + * + * Returns a non-owning pointer to the OrtOpSchemaTypeConstraint at the given index. + * The returned pointer is valid as long as the parent OrtOpSchema is alive. + * + * Constraints are returned in the order they are declared in the ONNX operator schema + * definition. The order is stable but has no semantic significance. + * + * Use this API to iterate all type constraints (e.g., to register allowed types for + * each constraint). Use OpSchema_GetInputTypeConstraint / OpSchema_GetOutputTypeConstraint + * to look up the constraint for a specific input or output. + * + * \param[in] schema The OrtOpSchema instance. + * \param[in] index Zero-based index of the type constraint. + * \param[out] out Output parameter set to the type constraint. + * Valid as long as the OrtOpSchema exists. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchema_GetTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const OrtOpSchemaTypeConstraint** out); + + /** \brief Get the type parameter name of a type constraint (e.g., "T", "T1"). + * + * \param[in] type_constraint The OrtOpSchemaTypeConstraint instance. + * \param[out] out Output parameter set to the type parameter name. + * Valid as long as the parent OrtOpSchema exists. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchemaTypeConstraint_GetTypeParamName, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const char** out); + + /** \brief Get the allowed type strings for a type constraint. + * + * Returns an array of null-terminated strings representing the allowed data types + * (e.g., "tensor(float)", "tensor(double)"). The array and its contents are valid + * as long as the parent OrtOpSchema exists. + * + * \param[in] type_constraint The OrtOpSchemaTypeConstraint instance. + * \param[out] out_types Output parameter set to the output array of type strings. + * Valid as long as the parent OrtOpSchema exists. + * \param[out] num_types Output parameter set to the number of elements in the output array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchemaTypeConstraint_GetAllowedTypes, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const char* const** out_types, _Out_ size_t* num_types); + + /** \brief Get the input indices that use a type constraint. + * + * Returns an array of zero-based input indices whose formal parameter type string + * matches this type constraint. The array is valid as long as the parent OrtOpSchema exists. + * + * \param[in] type_constraint The OrtOpSchemaTypeConstraint instance. + * \param[out] out_indices Output parameter set to the output array of input indices. + * \param[out] count Output parameter set to the number of elements in the output array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchemaTypeConstraint_GetInputIndices, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const size_t** out_indices, _Out_ size_t* count); + + /** \brief Get the output indices that use a type constraint. + * + * Returns an array of zero-based output indices whose formal parameter type string + * matches this type constraint. The array is valid as long as the parent OrtOpSchema exists. + * + * \param[in] type_constraint The OrtOpSchemaTypeConstraint instance. + * \param[out] out_indices Output parameter set to the output array of output indices. + * \param[out] count Output parameter set to the number of elements in the output array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(OpSchemaTypeConstraint_GetOutputIndices, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const size_t** out_indices, _Out_ size_t* count); }; /** diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index cb0a56756d8aa..490b779c0586c 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -5,6 +5,7 @@ #include "core/common/status.h" #include "core/common/exceptions.h" +#include "core/common/make_string.h" #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { @@ -39,6 +40,14 @@ Status ToStatusAndRelease(OrtStatus* ort_status, #define API_IMPL_END } #endif +// Check condition. If met, return an OrtStatus* error with the given OrtErrorCode. +#define ORT_API_RETURN_IF(condition, ort_error_code, ...) \ + do { \ + if (condition) { \ + return OrtApis::CreateStatus(ort_error_code, ::onnxruntime::MakeString(__VA_ARGS__).c_str()); \ + } \ + } while (false) + // Return the OrtStatus if it indicates an error #define ORT_API_RETURN_IF_ERROR(expr) \ do { \ diff --git a/onnxruntime/core/session/abi_opschema.h b/onnxruntime/core/session/abi_opschema.h new file mode 100644 index 0000000000000..8abfcf4c74edc --- /dev/null +++ b/onnxruntime/core/session/abi_opschema.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +namespace ONNX_NAMESPACE { +class OpSchema; +} // namespace ONNX_NAMESPACE + +/// A single type constraint entry (e.g., "T" or "T1") from an operator schema. +/// Holds the constraint name, allowed data types, and which input/output formal parameters use it. +/// Non-owning — lifetime is tied to the parent OrtOpSchema. +struct OrtOpSchemaTypeConstraint { + std::string type_param_str; // e.g., "T" + std::vector allowed_type_strs; // e.g., {"tensor(float)", "tensor(double)"} + std::vector allowed_type_ptrs; // C API view into allowed_type_strs + std::vector input_indices; // input indices using this constraint + std::vector output_indices; // output indices using this constraint +}; + +/// Opaque struct wrapping an ONNX operator schema pointer and its precomputed type constraints. +/// Allocated by GetOpSchema and released by ReleaseOpSchema. +struct OrtOpSchema { + const ONNX_NAMESPACE::OpSchema* onnx_schema; + std::vector constraints; + // O(1) lookup: input/output index → constraint pointer (nullptr if no constraint) + std::vector input_to_constraint; + std::vector output_to_constraint; +}; diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index cd14926425625..6663eb1f3b57c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -4,6 +4,7 @@ #include "core/session/plugin_ep/ep_api.h" #include +#include #include #include #include @@ -18,9 +19,12 @@ #include "core/framework/ortmemoryinfo.h" #include "core/framework/plugin_ep_stream.h" #include "core/framework/tensor.h" +#include "core/graph/constants.h" #include "core/graph/ep_api_types.h" +#include "core/graph/onnx_protobuf.h" #include "core/session/abi_devices.h" #include "core/session/abi_ep_types.h" +#include "core/session/abi_opschema.h" #include "core/session/environment.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" @@ -806,6 +810,248 @@ ORT_API_STATUS_IMPL(GetEnvConfigEntries, _Outptr_ OrtKeyValuePairs** config_entr API_IMPL_END } +ORT_API_STATUS_IMPL(GetOpSchema, _In_ const char* name, _In_ int max_inclusive_version, + _In_ const char* domain, _Outptr_result_maybenull_ OrtOpSchema** out_schema) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(name == nullptr, ORT_INVALID_ARGUMENT, "name must not be null"); + ORT_API_RETURN_IF(domain == nullptr, ORT_INVALID_ARGUMENT, "domain must not be null"); + ORT_API_RETURN_IF(out_schema == nullptr, ORT_INVALID_ARGUMENT, "out_schema must not be null"); + + // Normalize "ai.onnx" to "" (the canonical ONNX domain used by the schema registry). + const char* lookup_domain = (strcmp(domain, kOnnxDomainAlias) == 0) ? kOnnxDomain : domain; + + const auto* onnx_schema = ONNX_NAMESPACE::OpSchemaRegistry::Instance()->GetSchema( + name, max_inclusive_version, lookup_domain); + + if (onnx_schema == nullptr) { + *out_schema = nullptr; + return nullptr; + } + + auto result = std::make_unique(); + result->onnx_schema = onnx_schema; + + // Eagerly build type constraint data. + for (const auto& param : onnx_schema->typeConstraintParams()) { + OrtOpSchemaTypeConstraint constraint; + constraint.type_param_str = param.type_param_str; + constraint.allowed_type_strs = param.allowed_type_strs; + + const auto& inputs = onnx_schema->inputs(); + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].GetTypeStr() == param.type_param_str) { + constraint.input_indices.push_back(i); + } + } + + const auto& outputs = onnx_schema->outputs(); + for (size_t i = 0; i < outputs.size(); ++i) { + if (outputs[i].GetTypeStr() == param.type_param_str) { + constraint.output_indices.push_back(i); + } + } + + result->constraints.push_back(std::move(constraint)); + } + + // Build the C-compatible pointer arrays after all entries are in their final locations. + for (auto& constraint : result->constraints) { + constraint.allowed_type_ptrs.reserve(constraint.allowed_type_strs.size()); + for (const auto& s : constraint.allowed_type_strs) { + constraint.allowed_type_ptrs.push_back(s.c_str()); + } + } + + // Build input/output → constraint lookup tables. + // ONNX guarantees each input/output has at most one type parameter (FormalParameter::type_str_ is a single string). + const auto& inputs = onnx_schema->inputs(); + result->input_to_constraint.resize(inputs.size(), nullptr); + for (auto& constraint : result->constraints) { + for (size_t idx : constraint.input_indices) { + result->input_to_constraint[idx] = &constraint; + } + } + + const auto& outputs = onnx_schema->outputs(); + result->output_to_constraint.resize(outputs.size(), nullptr); + for (auto& constraint : result->constraints) { + for (size_t idx : constraint.output_indices) { + result->output_to_constraint[idx] = &constraint; + } + } + + *out_schema = result.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseOpSchema, _Frees_ptr_opt_ OrtOpSchema* schema) { + delete schema; +} + +ORT_API_STATUS_IMPL(OpSchema_GetSinceVersion, _In_ const OrtOpSchema* schema, _Out_ int* out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + *out = schema->onnx_schema->since_version(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetNumInputs, _In_ const OrtOpSchema* schema, _Out_ size_t* out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + *out = schema->onnx_schema->inputs().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetInputName, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const char** out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + const auto& inputs = schema->onnx_schema->inputs(); + ORT_API_RETURN_IF(index >= inputs.size(), ORT_INVALID_ARGUMENT, "Input index ", index, " out of range. Schema has ", + inputs.size(), " inputs."); + *out = inputs[index].GetName().c_str(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetInputTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_result_maybenull_ const OrtOpSchemaTypeConstraint** out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + ORT_API_RETURN_IF(index >= schema->input_to_constraint.size(), ORT_INVALID_ARGUMENT, + "Input index ", index, " out of range. Schema has ", + schema->input_to_constraint.size(), " inputs."); + + *out = schema->input_to_constraint[index]; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetNumOutputs, _In_ const OrtOpSchema* schema, _Out_ size_t* out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + *out = schema->onnx_schema->outputs().size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetOutputName, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const char** out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + const auto& outputs = schema->onnx_schema->outputs(); + ORT_API_RETURN_IF(index >= outputs.size(), ORT_INVALID_ARGUMENT, "Output index ", index, " out of range. Schema has ", + outputs.size(), " outputs."); + *out = outputs[index].GetName().c_str(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetOutputTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_result_maybenull_ const OrtOpSchemaTypeConstraint** out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + ORT_API_RETURN_IF(index >= schema->output_to_constraint.size(), ORT_INVALID_ARGUMENT, + "Output index ", index, " out of range. Schema has ", + schema->output_to_constraint.size(), " outputs."); + + *out = schema->output_to_constraint[index]; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetTypeConstraintCount, _In_ const OrtOpSchema* schema, _Out_ size_t* out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + *out = schema->constraints.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchema_GetTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const OrtOpSchemaTypeConstraint** out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(schema == nullptr, ORT_INVALID_ARGUMENT, "schema must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + ORT_API_RETURN_IF(index >= schema->constraints.size(), ORT_INVALID_ARGUMENT, + "Type constraint index ", index, " out of range. Schema has ", + schema->constraints.size(), " constraints."); + + *out = &schema->constraints[index]; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetTypeParamName, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const char** out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(type_constraint == nullptr, ORT_INVALID_ARGUMENT, "type_constraint must not be null"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "out must not be null"); + + *out = type_constraint->type_param_str.c_str(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetAllowedTypes, + _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const char* const** out_types, _Out_ size_t* num_types) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(type_constraint == nullptr, ORT_INVALID_ARGUMENT, "type_constraint must not be null"); + ORT_API_RETURN_IF(out_types == nullptr, ORT_INVALID_ARGUMENT, "out_types must not be null"); + ORT_API_RETURN_IF(num_types == nullptr, ORT_INVALID_ARGUMENT, "num_types must not be null"); + + *out_types = type_constraint->allowed_type_ptrs.data(); + *num_types = type_constraint->allowed_type_ptrs.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetInputIndices, + _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const size_t** out_indices, _Out_ size_t* count) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(type_constraint == nullptr, ORT_INVALID_ARGUMENT, "type_constraint must not be null"); + ORT_API_RETURN_IF(out_indices == nullptr, ORT_INVALID_ARGUMENT, "out_indices must not be null"); + ORT_API_RETURN_IF(count == nullptr, ORT_INVALID_ARGUMENT, "count must not be null"); + + *out_indices = type_constraint->input_indices.data(); + *count = type_constraint->input_indices.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetOutputIndices, + _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const size_t** out_indices, _Out_ size_t* count) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(type_constraint == nullptr, ORT_INVALID_ARGUMENT, "type_constraint must not be null"); + ORT_API_RETURN_IF(out_indices == nullptr, ORT_INVALID_ARGUMENT, "out_indices must not be null"); + ORT_API_RETURN_IF(count == nullptr, ORT_INVALID_ARGUMENT, "count must not be null"); + + *out_indices = type_constraint->output_indices.data(); + *count = type_constraint->output_indices.size(); + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -869,6 +1115,23 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::ReleaseKernelImpl, &OrtExecutionProviderApi::GetEnvConfigEntries, // End of Version 24 - DO NOT MODIFY ABOVE + + &OrtExecutionProviderApi::GetOpSchema, + &OrtExecutionProviderApi::ReleaseOpSchema, + &OrtExecutionProviderApi::OpSchema_GetSinceVersion, + &OrtExecutionProviderApi::OpSchema_GetNumInputs, + &OrtExecutionProviderApi::OpSchema_GetInputName, + &OrtExecutionProviderApi::OpSchema_GetInputTypeConstraint, + &OrtExecutionProviderApi::OpSchema_GetNumOutputs, + &OrtExecutionProviderApi::OpSchema_GetOutputName, + &OrtExecutionProviderApi::OpSchema_GetOutputTypeConstraint, + &OrtExecutionProviderApi::OpSchema_GetTypeConstraintCount, + &OrtExecutionProviderApi::OpSchema_GetTypeConstraint, + &OrtExecutionProviderApi::OpSchemaTypeConstraint_GetTypeParamName, + &OrtExecutionProviderApi::OpSchemaTypeConstraint_GetAllowedTypes, + &OrtExecutionProviderApi::OpSchemaTypeConstraint_GetInputIndices, + &OrtExecutionProviderApi::OpSchemaTypeConstraint_GetOutputIndices, + // End of Version 25 - DO NOT MODIFY ABOVE }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned @@ -878,6 +1141,8 @@ static_assert(offsetof(OrtEpApi, GetSyncIdForLastWaitOnSyncStream) / sizeof(void "Size of version 23 API cannot change"); static_assert(offsetof(OrtEpApi, GetEnvConfigEntries) / sizeof(void*) == 49, "Size of version 24 API cannot change"); +static_assert(offsetof(OrtEpApi, OpSchemaTypeConstraint_GetOutputIndices) / sizeof(void*) == 64, + "Size of version 25 API cannot change"); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 23bbe23026f9a..e937bf105fc4d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -119,4 +119,31 @@ ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl); // Env config entries ORT_API_STATUS_IMPL(GetEnvConfigEntries, _Outptr_ OrtKeyValuePairs** config_entries); + +// OpSchema APIs +ORT_API_STATUS_IMPL(GetOpSchema, _In_ const char* name, _In_ int max_inclusive_version, + _In_ const char* domain, _Outptr_result_maybenull_ OrtOpSchema** out_schema); +ORT_API(void, ReleaseOpSchema, _Frees_ptr_opt_ OrtOpSchema* schema); +ORT_API_STATUS_IMPL(OpSchema_GetSinceVersion, _In_ const OrtOpSchema* schema, _Out_ int* out); +ORT_API_STATUS_IMPL(OpSchema_GetNumInputs, _In_ const OrtOpSchema* schema, _Out_ size_t* out); +ORT_API_STATUS_IMPL(OpSchema_GetInputName, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const char** out); +ORT_API_STATUS_IMPL(OpSchema_GetInputTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_result_maybenull_ const OrtOpSchemaTypeConstraint** out); +ORT_API_STATUS_IMPL(OpSchema_GetNumOutputs, _In_ const OrtOpSchema* schema, _Out_ size_t* out); +ORT_API_STATUS_IMPL(OpSchema_GetOutputName, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const char** out); +ORT_API_STATUS_IMPL(OpSchema_GetOutputTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_result_maybenull_ const OrtOpSchemaTypeConstraint** out); +ORT_API_STATUS_IMPL(OpSchema_GetTypeConstraintCount, _In_ const OrtOpSchema* schema, _Out_ size_t* out); +ORT_API_STATUS_IMPL(OpSchema_GetTypeConstraint, _In_ const OrtOpSchema* schema, _In_ size_t index, + _Outptr_ const OrtOpSchemaTypeConstraint** out); +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetTypeParamName, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const char** out); +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetAllowedTypes, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const char* const** out_types, _Out_ size_t* num_types); +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetInputIndices, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const size_t** out_indices, _Out_ size_t* count); +ORT_API_STATUS_IMPL(OpSchemaTypeConstraint_GetOutputIndices, _In_ const OrtOpSchemaTypeConstraint* type_constraint, + _Outptr_ const size_t** out_indices, _Out_ size_t* count); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 9a740a5fa33d6..688f0ac19de82 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -3,6 +3,7 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include #include #include "gsl/gsl" @@ -779,4 +780,202 @@ TEST(PluginExecutionProviderTest, IsConcurrentRunSupported) { #endif // !defined(ORT_NO_EXCEPTIONS) } +// Tests for the Ort::OpSchema C++ wrapper API and Ort::GetOpSchema free function. +// These test the C++ layer over the OrtEpApi OpSchema functions using well-known ONNX operator schemas +// from the global ONNX schema registry. + +// Test that GetOpSchema returns null for various not-found cases. +TEST(OpSchemaCxxApiTest, GetOpSchema_NotFound) { + // Unknown op name + Ort::OpSchema schema_unknown = Ort::GetOpSchema("NonExistentOpXYZ_12345", 20, ""); + EXPECT_EQ(static_cast(schema_unknown), nullptr); + + // Relu was introduced in opset 1, so max_inclusive_version=0 should not find it. + Ort::OpSchema schema_v0 = Ort::GetOpSchema("Relu", 0, ""); + EXPECT_EQ(static_cast(schema_v0), nullptr); + + // Wrong domain + Ort::OpSchema schema_bad_domain = Ort::GetOpSchema("Relu", 20, "com.nonexistent.domain"); + EXPECT_EQ(static_cast(schema_bad_domain), nullptr); +} + +// Test version differentiation and "ai.onnx" domain alias normalization. +TEST(OpSchemaCxxApiTest, DifferentVersionsAndDomainAlias) { + // Relu was introduced in opset 1 and updated in opset 6, 13, and 14. + // Querying at version 5 should return the opset 1 schema. + Ort::OpSchema schema_v5 = Ort::GetOpSchema("Relu", 5, ""); + ASSERT_NE(static_cast(schema_v5), nullptr); + EXPECT_EQ(schema_v5.GetSinceVersion(), 1); + + // Querying at version 6 with "ai.onnx" domain alias should return the opset 6 schema. + Ort::OpSchema schema_v6 = Ort::GetOpSchema("Relu", 6, kOnnxDomainAlias); + ASSERT_NE(static_cast(schema_v6), nullptr); + EXPECT_EQ(schema_v6.GetSinceVersion(), 6); + + // "ai.onnx" and "" should resolve to the same schema at the same version. + Ort::OpSchema schema_canonical = Ort::GetOpSchema("Relu", 20, ""); + Ort::OpSchema schema_alias = Ort::GetOpSchema("Relu", 20, kOnnxDomainAlias); + ASSERT_NE(static_cast(schema_canonical), nullptr); + ASSERT_NE(static_cast(schema_alias), nullptr); + EXPECT_EQ(schema_canonical.GetSinceVersion(), schema_alias.GetSinceVersion()); +} + +// Test OpSchema methods on the "Add" operator (2 inputs, 1 output, shared constraint T). +// Also tests pointer identity: inputs/output sharing a constraint return the same pointer. +TEST(OpSchemaCxxApiTest, AddSchemaProperties) { + int opset_version = 20; + Ort::OpSchema schema = Ort::GetOpSchema("Add", opset_version, ""); + ASSERT_NE(static_cast(schema), nullptr); + + // The "since version" will be <= to the opset version used to retrieve the schema. + EXPECT_LT(schema.GetSinceVersion(), opset_version + 1); + EXPECT_GT(schema.GetSinceVersion(), 0); + + // Add has 2 inputs: A, B + ASSERT_EQ(schema.GetNumInputs(), 2u); + EXPECT_EQ(schema.GetInputName(0), "A"); + EXPECT_EQ(schema.GetInputName(1), "B"); + + // Both inputs should have a type constraint named "T" + Ort::ConstOpSchemaTypeConstraint tc_input0 = schema.GetInputTypeConstraint(0); + Ort::ConstOpSchemaTypeConstraint tc_input1 = schema.GetInputTypeConstraint(1); + ASSERT_NE(static_cast(tc_input0), nullptr); + ASSERT_NE(static_cast(tc_input1), nullptr); + EXPECT_EQ(tc_input0.GetTypeParamName(), "T"); + EXPECT_EQ(tc_input1.GetTypeParamName(), "T"); + + // Add has 1 output: C + ASSERT_EQ(schema.GetNumOutputs(), 1u); + EXPECT_EQ(schema.GetOutputName(0), "C"); + + Ort::ConstOpSchemaTypeConstraint tc_output0 = schema.GetOutputTypeConstraint(0); + ASSERT_NE(static_cast(tc_output0), nullptr); + EXPECT_EQ(tc_output0.GetTypeParamName(), "T"); + + // Both inputs and the output share constraint "T" — should return the same pointer. + EXPECT_EQ(static_cast(tc_input0), + static_cast(tc_input1)); + EXPECT_EQ(static_cast(tc_input0), + static_cast(tc_output0)); +} + +// Tests for the OrtOpSchemaTypeConstraint API (per-constraint entity). + +// Test type constraints for the Add operator (single constraint T on all inputs/outputs). +TEST(OpSchemaTypeConstraintTest, Add_SingleConstraint) { + Ort::OpSchema schema = Ort::GetOpSchema("Add", 20, ""); + ASSERT_NE(static_cast(schema), nullptr); + + ASSERT_EQ(schema.GetTypeConstraintCount(), 1u); + + // Constraint "T" + Ort::ConstOpSchemaTypeConstraint tc = schema.GetTypeConstraint(0); + EXPECT_EQ(tc.GetTypeParamName(), "T"); + + // T should allow tensor(float) and tensor(double) among others + auto allowed_types = tc.GetAllowedTypes(); + EXPECT_GT(allowed_types.size(), 1u); + auto has_type = [&](const char* t) { + return std::find(allowed_types.begin(), allowed_types.end(), t) != allowed_types.end(); + }; + EXPECT_TRUE(has_type("tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_TRUE(has_type("tensor(double)")) << "Expected T to allow tensor(double)"; + + // Both inputs use T + auto input_indices = tc.GetInputIndices(); + ASSERT_EQ(input_indices.size(), 2u); + EXPECT_EQ(input_indices[0], 0u); + EXPECT_EQ(input_indices[1], 1u); + + // Output uses T + auto output_indices = tc.GetOutputIndices(); + ASSERT_EQ(output_indices.size(), 1u); + EXPECT_EQ(output_indices[0], 0u); +} + +// Test type constraints for LSTM (multiple constraints: T and T1). +TEST(OpSchemaTypeConstraintTest, LSTM_MultipleConstraints) { + Ort::OpSchema schema = Ort::GetOpSchema("LSTM", 20, ""); + ASSERT_NE(static_cast(schema), nullptr); + + // LSTM has at least T and T1 + ASSERT_GE(schema.GetTypeConstraintCount(), 2u); + + // Find the T and T1 constraints by name + const OrtOpSchemaTypeConstraint* t_ptr = nullptr; + const OrtOpSchemaTypeConstraint* t1_ptr = nullptr; + Ort::ConstOpSchemaTypeConstraint t_tc{nullptr}; + Ort::ConstOpSchemaTypeConstraint t1_tc{nullptr}; + for (size_t i = 0; i < schema.GetTypeConstraintCount(); ++i) { + auto tc = schema.GetTypeConstraint(i); + if (tc.GetTypeParamName() == "T") { + t_ptr = static_cast(tc); + t_tc = tc; + } else if (tc.GetTypeParamName() == "T1") { + t1_ptr = static_cast(tc); + t1_tc = tc; + } + } + + ASSERT_NE(t_ptr, nullptr) << "Expected to find type constraint 'T'"; + ASSERT_NE(t1_ptr, nullptr) << "Expected to find type constraint 'T1'"; + + auto has_type = [](gsl::span types, const char* t) { + return std::find(types.begin(), types.end(), t) != types.end(); + }; + + // T should include tensor(float) and tensor(double) + auto t_types = t_tc.GetAllowedTypes(); + EXPECT_GT(t_types.size(), 0u); + EXPECT_TRUE(has_type(t_types, "tensor(float)")) << "Expected T to allow tensor(float)"; + EXPECT_TRUE(has_type(t_types, "tensor(double)")) << "Expected T to allow tensor(double)"; + + // T1 should include tensor(int32) (sequence_lens is int32) + auto t1_types = t1_tc.GetAllowedTypes(); + EXPECT_GT(t1_types.size(), 0u); + + // T1 is for sequence_lens which is int32 + EXPECT_TRUE(has_type(t1_types, "tensor(int32)")) << "Expected T1 to allow tensor(int32)"; + + // T should map to inputs X (0), W (1), R (2), B (3), initial_h (5), initial_c (6), P (7) + auto t_inputs = t_tc.GetInputIndices(); + EXPECT_EQ(t_inputs.size(), 7u); + EXPECT_EQ(t_inputs[0], 0u); // X + EXPECT_EQ(t_inputs[1], 1u); // W + EXPECT_EQ(t_inputs[2], 2u); // R + EXPECT_EQ(t_inputs[3], 3u); // B + EXPECT_EQ(t_inputs[4], 5u); // initial_h + EXPECT_EQ(t_inputs[5], 6u); // initial_c + EXPECT_EQ(t_inputs[6], 7u); // P + + // T should map to outputs Y (0), Y_h (1), Y_c (2) + auto t_outputs = t_tc.GetOutputIndices(); + ASSERT_EQ(t_outputs.size(), 3u); + EXPECT_EQ(t_outputs[0], 0u); // Y + EXPECT_EQ(t_outputs[1], 1u); // Y_h + EXPECT_EQ(t_outputs[2], 2u); // Y_c + + // T1 should map to the sequence_lens input (index 4) + auto t1_inputs = t1_tc.GetInputIndices(); + ASSERT_EQ(t1_inputs.size(), 1u); + EXPECT_EQ(t1_inputs[0], 4u); // sequence_lens is the 5th input (index 4) + + // T1 should not map to any outputs + auto t1_outputs = t1_tc.GetOutputIndices(); + EXPECT_EQ(t1_outputs.size(), 0u); +} + +#if !defined(ORT_NO_EXCEPTIONS) +// Test out-of-range index for type constraint accessors. +TEST(OpSchemaTypeConstraintTest, OutOfRangeIndex) { + Ort::OpSchema schema = Ort::GetOpSchema("Add", 20, ""); + ASSERT_NE(static_cast(schema), nullptr); + + size_t count = schema.GetTypeConstraintCount(); + + // Accessing beyond the count should throw + EXPECT_THROW(schema.GetTypeConstraint(count), Ort::Exception); +} +#endif // !defined(ORT_NO_EXCEPTIONS) + } // namespace onnxruntime::test