diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 0d920ab7dac89..21aa797ce16eb 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -758,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr break; } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + + Ort::Value tensor(ort_value); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + size_t element_size = 0; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + element_size = sizeof(float); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + element_size = sizeof(uint8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + element_size = sizeof(int8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + element_size = sizeof(uint16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + element_size = sizeof(int16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + element_size = sizeof(int32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + element_size = sizeof(int64_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + element_size = sizeof(bool); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + element_size = sizeof(double); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + element_size = sizeof(uint32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + element_size = sizeof(uint64_t); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + auto shape = type_shape_info.GetShape(); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + size_t element_count = type_shape_info.GetElementCount(); + size_t data_bytes = element_count * element_size; + const void* data = tensor.GetTensorData(); + + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); + + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); return Ort::Status(err_msg.c_str(), ORT_FAIL); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cf5ad29b03801..2899a219bdda0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -276,6 +276,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, ORT_OP_ATTR_GRAPH, + ORT_OP_ATTR_TENSOR, } OrtOpAttrType; //! @} @@ -6065,6 +6066,20 @@ struct OrtApi { ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); + /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. + * + * \param[in] node The OrtNode instance. + * \param[in] attribute The OrtOpAttr instance. + * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. + Must be freed with OrtApi::ReleaseValue. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); + /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * * \param[in] attribute The OrtOpAttr instance. diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 6383d29d7a2bc..504b102e782fd 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -251,6 +251,16 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; + /// + /// Gets the node's 'TENSOR' attribute as an OrtValue. + /// + /// Node's 'TENSOR' attribute. + /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, + /// only if the attribute is of type 'TENSOR' + /// A status indicating success or an error. + virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, + OrtValue*& value) const = 0; + /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 493fbff897af8..eb7fb6937c29e 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -248,6 +248,32 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& graph_viewer = ep_graph_->GetGraphViewer(); + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + result = tensor_attribute_value.release(); + return Status::OK(); +} + Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index e61bb4d62dba6..be78d77360cb8 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,6 +183,9 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, + OrtValue*& attr_tensor) const override; + // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 5d84e48182bfe..d3795d911b22f 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,6 +137,11 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); + } + Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ae9a86aa923fc..4c7b4d7b29c2f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3018,6 +3018,20 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { + API_IMPL_BEGIN + if (attr_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) { API_IMPL_BEGIN const auto attr = attribute->attr_proto; @@ -3055,6 +3069,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + *type = OrtOpAttrType::ORT_OP_ATTR_TENSOR; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } @@ -4037,6 +4055,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, + &OrtApis::Node_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9636c41938a2b..3eee174ff81f4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -679,6 +679,8 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index bdbc60c1a0c48..188edad572182 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -306,6 +306,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& outpu output_data.assign(output_values, output_values + num_output_elems); } +static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {3}; + std::vector input_data = {2, 3, 4}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'x' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("x"); + + // Run session and get outputs + std::array output_names{"y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 24); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + // Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. // Checks that the outputs of the serialized and original models are identical. TEST(EpGraphTest, SerializeToProto_Mnist) { @@ -436,6 +469,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } } +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_ConstantOfShape) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "constant_of_shape_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + static_cast(value_info); + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConstantOfShapeModel(original_model_path, output_original); + RunConstantOfShapeModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -978,6 +1070,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."));