diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs index 098a18b7444cf..2467475b6b189 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs @@ -23,8 +23,8 @@ internal enum ErrorCode ModelLoaded = 8, NotImplemented = 9, InvalidGraph = 10, - ShapeInferenceNotRegistered = 11, - RequirementNotRegistered = 12, + ShapeInferenceNotRegistered = 11, // TODO: should be ORT_EP_FAIL + RequirementNotRegistered = 12, // TODO: should be ORT_MODEL_LOAD_CANCELED } /// diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index da9735aa4e418..8cf6420f2d0f7 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -46,6 +46,7 @@ enum StatusCode { EP_FAIL = 11, MODEL_LOAD_CANCELED = 12, MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "MODEL_LOAD_CANCELED"; case StatusCode::MODEL_REQUIRES_COMPILATION: return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_CANCELLED); case StatusCode::MODEL_REQUIRES_COMPILATION: return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d87e9e083185b..cf5ad29b03801 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -264,6 +264,7 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, ORT_MODEL_LOAD_CANCELED, ORT_MODEL_REQUIRES_COMPILATION, + ORT_NOT_FOUND, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -6031,6 +6032,11 @@ struct OrtApi { * Typical usage sets this to the result of Node_GetNumAttributes(). An error status is * returned if `num_attributes` is less than the number of node attributes. * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. + * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. @@ -6042,14 +6048,22 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute_name The name of the attribute - * \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr. + * \param[out] attribute Output parameter set to the OrtOpAttr instance if an attribute by the given name exists. + * For an unset optional attribute, `attribute` is set to NULL and a non-error status is + * returned. For an invalid attribute name, `attribute` is set to NULL and an error status with + * code ORT_NOT_FOUND is returned. + * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * diff --git a/java/src/main/java/ai/onnxruntime/OrtException.java b/java/src/main/java/ai/onnxruntime/OrtException.java index 5ec58ea137124..06c3d3cbc770c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtException.java +++ b/java/src/main/java/ai/onnxruntime/OrtException.java @@ -81,11 +81,17 @@ public enum OrtErrorCode { /** The ONNX graph is invalid. */ ORT_INVALID_GRAPH(10), /** The ORT execution provider failed. */ - ORT_EP_FAIL(11); + ORT_EP_FAIL(11), + /** Model load was canceled. */ + ORT_MODEL_LOAD_CANCELED(12), + /** Model requires compilation. */ + ORT_MODEL_REQUIRES_COMPILATION(13), + /** Item was not found. */ + ORT_NOT_FOUND(14); private final int value; - private static final OrtErrorCode[] values = new OrtErrorCode[12]; + private static final OrtErrorCode[] values = new OrtErrorCode[15]; static { for (OrtErrorCode ot : OrtErrorCode.values()) { diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index fe19015d642f0..5d8efd7b476cb 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1051,6 +1051,12 @@ jint convertErrorCode(OrtErrorCode code) { return 10; case ORT_EP_FAIL: return 11; + case ORT_MODEL_LOAD_CANCELED: + return 12; + case ORT_MODEL_REQUIRES_COMPILATION: + return 13; + case ORT_NOT_FOUND: + return 14; default: return -1; // Unknown error code } diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 4ceadb6191a9b..493fbff897af8 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -87,6 +87,24 @@ static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph, } } +#if !defined(ORT_MINIMAL_BUILD) +static bool IsOptionalAttribute(const Node& node, const std::string& attr_name) { + const ONNX_NAMESPACE::OpSchema* op_schema = node.Op(); + if (op_schema == nullptr) { + return false; + } + + auto attr_schema_iter = op_schema->attributes().find(attr_name); + if (attr_schema_iter == op_schema->attributes().end()) { + return false; // Not an attribute for this operator type. + } + + const ONNX_NAMESPACE::OpSchema::Attribute& attr_schema = attr_schema_iter->second; + + return !attr_schema.required; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // // EpNode // @@ -268,13 +286,20 @@ gsl::span EpNode::GetOutputsSpan() const { return outputs_; } -const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { +const OrtOpAttr* EpNode::GetAttribute(const std::string& name, bool& is_unset_optional_attr) const { auto iter = attributes_map_.find(name); - if (iter == attributes_map_.end()) { - return nullptr; - } else { + if (iter != attributes_map_.end()) { + is_unset_optional_attr = false; return reinterpret_cast(iter->second.get()); } + +#if !defined(ORT_MINIMAL_BUILD) + is_unset_optional_attr = IsOptionalAttribute(node_, name); +#else + // This is not properly set in a minimal build because it does not have access to the operator schema. + is_unset_optional_attr = false; +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } const std::string& EpNode::GetEpName() const { diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 243bdc2944ffb..e61bb4d62dba6 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -209,8 +209,9 @@ struct EpNode : public OrtNode { // Helper that returns this node's outputs as a span of EpValueInfo pointers. gsl::span GetOutputsSpan() const; - // Helper that gets the node's attributes by name. - const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the node's attributes by name. If the attribute is not set, returns NULL and sets the + // output parameter `is_unset_optional_attr` to true if this is an unset optional attribute. + const OrtOpAttr* GetAttribute(const std::string& name, bool& is_unset_optional_attr) const; // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 37f4fe7312bb4..ae9a86aa923fc 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2993,7 +2993,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_result_maybenull_ const OrtOpAttr** attribute) { API_IMPL_BEGIN if (attribute == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); @@ -3004,14 +3005,16 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); } - *attribute = ep_node->GetAttribute(attribute_name); + bool is_unset_optional_attr = false; + *attribute = ep_node->GetAttribute(attribute_name, is_unset_optional_attr); - if (*attribute) { + if (*attribute || is_unset_optional_attr) { return nullptr; } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + std::ostringstream oss; + oss << "Node attribute does not exist: " << attribute_name; + return OrtApis::CreateStatus(OrtErrorCode::ORT_NOT_FOUND, oss.str().c_str()); } - API_IMPL_END } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..9636c41938a2b 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -678,7 +678,7 @@ ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 8f3b97c8c7786..6b3062205b52e 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -37,6 +37,7 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "EPFail"); pybind11::register_exception(m, "ModelLoadCanceled"); pybind11::register_exception(m, "ModelRequiresCompilation"); + pybind11::register_exception(m, "NotFound"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -67,6 +68,8 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw ModelLoadCanceled(std::move(msg)); case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: throw ModelRequiresCompilation(std::move(msg)); + case onnxruntime::common::StatusCode::NOT_FOUND: + throw NotFound(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index 86bc4a5da8d46..7680c06c59d79 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -50,6 +50,9 @@ struct ModelLoadCanceled : std::runtime_error { struct ModelRequiresCompilation : std::runtime_error { explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} }; +struct NotFound : std::runtime_error { + explicit NotFound(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 45314f8f39eea..bdbc60c1a0c48 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,92 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, GetAttributeByName) { + // Load model with a single Conv that has no explicit attributes set. + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_default_attrs.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // + // Pre-check + // + + // Original Conv has no explicit attributes but Graph::Resolve() fills in default values for + // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not + // have statically computable default values, so will not be filled in by Graph::Resolve(). + const OrtGraph& ort_graph = test_graph->GetOrtGraph(); + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_nodes = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + ASSERT_EQ(num_nodes, 1); + + std::vector nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + const OrtNode* conv_node = nodes[0]; + const char* op_type = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type)); + ASSERT_STREQ(op_type, "Conv"); + + size_t num_attrs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs)); + ASSERT_EQ(num_attrs, 2); + + std::vector attrs(num_attrs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size())); + for (const OrtOpAttr* attr : attrs) { + const char* attr_name_cstr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr)); + std::string_view attr_name = attr_name_cstr; + ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set + } + + // + // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)}; + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)}; + ASSERT_FALSE(status.IsOK()); + ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 3: Get attribute that is known to be set. + // + { + const OrtOpAttr* attr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr)); + ASSERT_NE(attr, nullptr); + + OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type)); + ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING); + + std::string auto_pad_val; + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + size_t total_attr_bytes = 0; + Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)}; + auto_pad_val.resize(total_attr_bytes); + + ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes, + &total_attr_bytes)); + ASSERT_EQ(auto_pad_val, "NOTSET"); + } +} + // Check correctness of an OrtGraph that has external initializers. TEST(EpGraphTest, CheckModelExternalInitializers) { auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); diff --git a/onnxruntime/test/testdata/conv_default_attrs.onnx b/onnxruntime/test/testdata/conv_default_attrs.onnx new file mode 100644 index 0000000000000..fc7ee58dee15e Binary files /dev/null and b/onnxruntime/test/testdata/conv_default_attrs.onnx differ diff --git a/onnxruntime/test/testdata/make_conv_default_attrs.py b/onnxruntime/test/testdata/make_conv_default_attrs.py new file mode 100644 index 0000000000000..fc092bf8b25fb --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_default_attrs.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def main(): + inp_shape = (1, 2, 8, 8) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + + weight_data = [ + [[[-1.5, 0.0], [0.2, 1.5]], [[-1.5, 0.0], [0.2, 1.5]]], + [[[-1.0, 0.0], [0.1333, 1.0]], [[-1.0, 0.0], [0.1333, 1.0]]], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias = onnx.numpy_helper.from_array(np.array([0.0, 0.0], dtype=np.float32), "bias") + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + + onnx.checker.check_model(model, True) + onnx.save_model(model, "conv_default_attrs.onnx") + + +if __name__ == "__main__": + main()