Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
89d414f
[Plugin EP] Add C APIs to add OpSchemas for custom ops, get opschema …
adrianlizarraga Mar 17, 2026
c57cc57
Add OpSchema APIs to get SinceVersion, NumInputs, and NumOutputs
adrianlizarraga Mar 17, 2026
df4058a
CXX apis
adrianlizarraga Mar 17, 2026
c79cad5
Add unit tests
adrianlizarraga Mar 17, 2026
a052871
null checks
adrianlizarraga Mar 17, 2026
183ea29
Update OpSchema_HasTypeConstraint to return a status
adrianlizarraga Mar 17, 2026
a05ade9
Update C++ API to return std::string
adrianlizarraga Mar 17, 2026
5b64fec
Remove unnecessary forward declaration
adrianlizarraga Mar 17, 2026
bfe6853
Review comments
adrianlizarraga Mar 17, 2026
b2b67f9
Merge branch 'main' into adrianl/PluginEp_Kernels_GetOpSchema
adrianlizarraga Mar 24, 2026
b0491bb
Merge branch 'main' into adrianl/PluginEp_Kernels_GetOpSchema
adrianlizarraga Mar 27, 2026
419735c
Add static_assert for ep_api
adrianlizarraga Mar 27, 2026
4f58b99
First attempt at getting more info about type constraints
adrianlizarraga Mar 27, 2026
29f1f7b
Remove HasTypeConstraint
adrianlizarraga Mar 28, 2026
0b71072
Refactor OrtOpSchema from reinterpret_cast to opaque owning struct
adrianlizarraga Mar 30, 2026
d5642a2
Refactor type constraints to per-entity OrtOpSchemaTypeConstraint API
adrianlizarraga Mar 30, 2026
09d5ecd
Rename OpSchemaTypeConstraint_GetName to GetTypeParamName
adrianlizarraga Mar 30, 2026
87dc453
Mention domains and handle 'ai.onnx' alias
adrianlizarraga Mar 30, 2026
50a6541
Update docs and move ReleaseOpSchema
adrianlizarraga Mar 30, 2026
742fd26
Add comment about there only being one type param per input or output
adrianlizarraga Mar 30, 2026
061b39c
Add doc comment about use of OpSchema_GetTypeConstraint
adrianlizarraga Mar 30, 2026
2cf04fc
Merge branch 'main' into adrianl/PluginEp_Kernels_GetOpSchema
adrianlizarraga Mar 30, 2026
8437870
Review comments
adrianlizarraga Mar 30, 2026
a195d9a
Make tests test what comments say
adrianlizarraga Mar 30, 2026
82d5322
Consolidate tests
adrianlizarraga Mar 30, 2026
a3a60fb
Address review comments
adrianlizarraga Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDef, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi);

// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
// but the struct has V2 in its name to indicate that it is the second version of the options.
Expand Down Expand Up @@ -3521,6 +3522,88 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
void* kernel_create_func_state);
};

namespace detail {
/** \brief Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`.
Comment thread
adrianlizarraga marked this conversation as resolved.
*
* Holds a single type constraint from an operator schema, providing access to
* the constraint's name, allowed data types, and associated input/output indices.
* This is a non-owning view — the lifetime is tied to the parent OrtOpSchema.
*/
template <typename T>
struct OpSchemaTypeConstraintImpl : Base<T> {
using B = Base<T>;
using B::B;

///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetTypeParamName
std::string GetTypeParamName() const;

///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetAllowedTypes
std::vector<std::string> GetAllowedTypes() const;

///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetInputIndices
std::vector<size_t> GetInputIndices() const;

///< Wraps OrtEpApi::OpSchemaTypeConstraint_GetOutputIndices
std::vector<size_t> GetOutputIndices() const;
};
} // namespace detail

/// Non-owning wrapper around a `const OrtOpSchemaTypeConstraint*`.
using ConstOpSchemaTypeConstraint = detail::OpSchemaTypeConstraintImpl<detail::Unowned<const OrtOpSchemaTypeConstraint>>;

namespace detail {
/** \brief Owning wrapper around an `OrtOpSchema*`.
*
* Provides access to operator schema metadata such as version, input/output names,
* and type constraints. The underlying OrtOpSchema is owned by this wrapper and
* released automatically on destruction.
*/
template <typename T>
struct OpSchemaImpl : Base<T> {
using B = Base<T>;
using B::B;

///< Wraps OrtEpApi::OpSchema_GetSinceVersion
int GetSinceVersion() const;

///< Wraps OrtEpApi::OpSchema_GetNumInputs
size_t GetNumInputs() const;

///< Wraps OrtEpApi::OpSchema_GetInputName
std::string GetInputName(size_t index) const;

///< Wraps OrtEpApi::OpSchema_GetInputTypeConstraint. Returns the type constraint for the given input,
///< or a wrapper around nullptr if the input has no type constraint.
ConstOpSchemaTypeConstraint GetInputTypeConstraint(size_t index) const;

///< Wraps OrtEpApi::OpSchema_GetNumOutputs
size_t GetNumOutputs() const;

///< Wraps OrtEpApi::OpSchema_GetOutputName
std::string GetOutputName(size_t index) const;

///< Wraps OrtEpApi::OpSchema_GetOutputTypeConstraint. Returns the type constraint for the given output,
///< or a wrapper around nullptr if the output has no type constraint.
ConstOpSchemaTypeConstraint GetOutputTypeConstraint(size_t index) const;

///< Wraps OrtEpApi::OpSchema_GetTypeConstraintCount
size_t GetTypeConstraintCount() const;

///< Wraps OrtEpApi::OpSchema_GetTypeConstraint. Returns the i-th type constraint.
ConstOpSchemaTypeConstraint GetTypeConstraint(size_t index) const;
};
} // namespace detail

/// Owning wrapper around an `OrtOpSchema*`.
using OpSchema = detail::OpSchemaImpl<OrtOpSchema>;

/// \brief Get an operator schema from the global schema registry.
///
/// Wraps OrtEpApi::GetOpSchema. Returns an OpSchema that may wrap nullptr if the schema is not found.
/// Available schemas include standard ONNX ops (domain "" or "ai.onnx"), ONNX ML ops ("ai.onnx.ml"),
/// and ORT contrib ops ("com.microsoft").
OpSchema GetOpSchema(const char* name, int max_inclusive_version, const char* domain);

namespace detail {
template <typename T>
struct SharedPrePackedWeightCacheImpl : Ort::detail::Base<T> {
Expand Down
109 changes: 109 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3946,4 +3946,113 @@ inline Ort::KeyValuePairs GetEnvConfigEntries() {

return Ort::KeyValuePairs{entries};
}

namespace detail {
template <typename T>
inline int OpSchemaImpl<T>::GetSinceVersion() const {
int version = 0;
ThrowOnError(GetEpApi().OpSchema_GetSinceVersion(this->p_, &version));
return version;
}

template <typename T>
inline size_t OpSchemaImpl<T>::GetNumInputs() const {
size_t num = 0;
ThrowOnError(GetEpApi().OpSchema_GetNumInputs(this->p_, &num));
return num;
}

template <typename T>
inline std::string OpSchemaImpl<T>::GetInputName(size_t index) const {
const char* name = nullptr;
ThrowOnError(GetEpApi().OpSchema_GetInputName(this->p_, index, &name));
return std::string(name);
}

template <typename T>
inline Ort::ConstOpSchemaTypeConstraint OpSchemaImpl<T>::GetInputTypeConstraint(size_t index) const {
const OrtOpSchemaTypeConstraint* tc = nullptr;
ThrowOnError(GetEpApi().OpSchema_GetInputTypeConstraint(this->p_, index, &tc));
return Ort::ConstOpSchemaTypeConstraint{tc};
}

template <typename T>
inline size_t OpSchemaImpl<T>::GetNumOutputs() const {
size_t num = 0;
ThrowOnError(GetEpApi().OpSchema_GetNumOutputs(this->p_, &num));
return num;
}

template <typename T>
inline std::string OpSchemaImpl<T>::GetOutputName(size_t index) const {
const char* name = nullptr;
ThrowOnError(GetEpApi().OpSchema_GetOutputName(this->p_, index, &name));
return std::string(name);
}

template <typename T>
inline Ort::ConstOpSchemaTypeConstraint OpSchemaImpl<T>::GetOutputTypeConstraint(size_t index) const {
const OrtOpSchemaTypeConstraint* tc = nullptr;
ThrowOnError(GetEpApi().OpSchema_GetOutputTypeConstraint(this->p_, index, &tc));
return Ort::ConstOpSchemaTypeConstraint{tc};
}

template <typename T>
inline size_t OpSchemaImpl<T>::GetTypeConstraintCount() const {
size_t count = 0;
ThrowOnError(GetEpApi().OpSchema_GetTypeConstraintCount(this->p_, &count));
return count;
}

template <typename T>
inline Ort::ConstOpSchemaTypeConstraint OpSchemaImpl<T>::GetTypeConstraint(size_t index) const {
const OrtOpSchemaTypeConstraint* tc = nullptr;
ThrowOnError(GetEpApi().OpSchema_GetTypeConstraint(this->p_, index, &tc));
return Ort::ConstOpSchemaTypeConstraint{tc};
}

template <typename T>
inline std::string OpSchemaTypeConstraintImpl<T>::GetTypeParamName() const {
const char* name = nullptr;
ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetTypeParamName(this->p_, &name));
return std::string(name);
}

template <typename T>
inline std::vector<std::string> OpSchemaTypeConstraintImpl<T>::GetAllowedTypes() const {
const char* const* types = nullptr;
size_t num_types = 0;
ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetAllowedTypes(this->p_, &types, &num_types));
std::vector<std::string> result;
result.reserve(num_types);
for (size_t i = 0; i < num_types; ++i) {
result.emplace_back(types[i]);
}
return result;
}

template <typename T>
inline std::vector<size_t> OpSchemaTypeConstraintImpl<T>::GetInputIndices() const {
const size_t* indices = nullptr;
size_t count = 0;
ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetInputIndices(this->p_, &indices, &count));
if (count == 0) return {};
return std::vector<size_t>(indices, indices + count);
Comment thread
adrianlizarraga marked this conversation as resolved.
}

template <typename T>
inline std::vector<size_t> OpSchemaTypeConstraintImpl<T>::GetOutputIndices() const {
const size_t* indices = nullptr;
size_t count = 0;
ThrowOnError(GetEpApi().OpSchemaTypeConstraint_GetOutputIndices(this->p_, &indices, &count));
if (count == 0) return {};
return std::vector<size_t>(indices, indices + count);
Comment thread
adrianlizarraga marked this conversation as resolved.
}
} // namespace detail

inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const char* domain) {
OrtOpSchema* schema = nullptr;
ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema));
return OpSchema{schema};
}
} // namespace Ort
Loading
Loading