diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 66e4443de06ad..eed2c7603777c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7242,6 +7242,28 @@ struct OrtApi { * \since Version 1.25. */ ORT_API2_STATUS(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options); + + /** \brief Fetch an array of strings stored as an attribute in the graph node + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array and a success status is returned. + * + * Otherwise, the strings and pointer array are allocated using `allocator`. + * The caller must free each string and the pointer array with `allocator`. + * If the attribute array is empty, `*out` is set to nullptr and `*size` is set to 0. + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[in] allocator allocator used to allocate the returned string array and strings + * \param[out] out pointer to the returned array of null-terminated UTF-8 strings + * \param[out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.25. + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 8dae24a3bffe7..0235abe0ab40b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2868,6 +2868,7 @@ void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); +void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector&); } // namespace attr_utils template @@ -2884,7 +2885,7 @@ struct KernelInfoImpl : Base { return val; } - template // R is only implemented for std::vector, std::vector + template // R is only implemented for float, int64_t, and string std::vector GetAttributes(const char* name) const { std::vector result; attr_utils::GetAttrs(this->p_, name, result); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index bce2aa97d47cd..8d42c5f3e99a0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -3063,6 +3063,39 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std:: Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size)); out.swap(result); } + +inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector& result) { + AllocatorWithDefaultOptions allocator; + char** out = nullptr; + size_t size = 0; + + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_string(p, name, allocator, nullptr, &size)); + if (size == 0) { + result.clear(); + return; + } + + Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_string(p, name, allocator, &out, &size)); + + auto deleter = detail::AllocatedFree(allocator); + std::unique_ptr array_guard(out, deleter); + auto strings_deleter = [&deleter, size](char** values) { + for (size_t i = 0; i < size; ++i) { + if (values[i] != nullptr) { + deleter(values[i]); + } + } + }; + std::unique_ptr strings_guard(out, strings_deleter); + + std::vector strings; + strings.reserve(size); + for (size_t i = 0; i < size; ++i) { + strings.emplace_back(out[i]); + } + + strings.swap(result); +} } // namespace detail inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl{info} {} diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index f61be7d1c51b3..2c8b81e4ffefe 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -11,6 +11,8 @@ #include #include +#include "core/common/safeint.h" +#include "core/common/string_helper.h" #include "core/common/logging/logging.h" #include "core/framework/data_types.h" #include "core/framework/error_code_helper.h" @@ -545,6 +547,65 @@ static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, s return Status::OK(); } +static char* DuplicateStringToAllocatorMemory(const std::string& value, OrtAllocator* allocator) { + SafeInt allocation_size(value.size()); + allocation_size += 1; + + char* duplicated_value = static_cast(allocator->Alloc(allocator, allocation_size)); + if (duplicated_value == nullptr) { + return nullptr; + } + + std::memcpy(duplicated_value, value.data(), value.size()); + duplicated_value[value.size()] = '\0'; + return duplicated_value; +} + +static Status CopyStringDataFromVectorToMemory(const std::vector& values, OrtAllocator* allocator, char*** out, size_t* size) { + *size = values.size(); + + if (out == nullptr) { + return Status::OK(); + } + + ORT_RETURN_IF_NOT(allocator != nullptr, "allocator must not be null when out is provided"); + *out = nullptr; + + if (values.empty()) { + return Status::OK(); + } + + auto free_with_allocator = [allocator](void* value) { + allocator->Free(allocator, value); + }; + SafeInt alloc_count(values.size()); + char** array = reinterpret_cast(allocator->Alloc(allocator, alloc_count * sizeof(char*))); + ORT_RETURN_IF_NOT(array != nullptr, "Failed to allocate string attribute pointer array"); + std::unique_ptr array_guard(array, free_with_allocator); + + size_t allocated_string_count = 0; + for (size_t i = 0; i < values.size(); ++i) { + char* duplicated_value = DuplicateStringToAllocatorMemory(values[i], allocator); + if (duplicated_value == nullptr) { + for (size_t j = 0; j < allocated_string_count; ++j) { + if (array[j] != nullptr) { + allocator->Free(allocator, array[j]); + } + } + + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, + "Failed to allocate string attribute array"); + } + + array[i] = duplicated_value; + ++allocated_string_count; + } + + *out = array; + array_guard.release(); + return Status::OK(); +} + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size) { return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { @@ -569,6 +630,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKe }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + std::vector values; + auto status = reinterpret_cast(info)->GetAttrs(name, values); + if (status.IsOK()) { + status = CopyStringDataFromVectorToMemory(values, allocator, out, size); + } + return onnxruntime::ToOrtStatus(status); + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7a027c8eafb81..37a74a5de22a6 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4807,6 +4807,8 @@ static constexpr OrtApi ort_api_1_to_25 = { &OrtApis::RunOptionsEnableProfiling, &OrtApis::RunOptionsDisableProfiling, + &OrtApis::KernelInfoGetAttributeArray_string, + // End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -4844,6 +4846,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); +static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_string) / sizeof(void*) == 417, "Size of version 25 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.25.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3d990909cfb41..290aa71b12239 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -125,6 +125,8 @@ ORT_API_STATUS_IMPL(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); ORT_API(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); ORT_API_STATUS_IMPL(RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); ORT_API_STATUS_IMPL(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options); +ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size); ORT_API_STATUS_IMPL(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, diff --git a/onnxruntime/test/framework/kernel_info_test.cc b/onnxruntime/test/framework/kernel_info_test.cc new file mode 100644 index 0000000000000..1d94b97c78ff9 --- /dev/null +++ b/onnxruntime/test/framework/kernel_info_test.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" +#include "core/graph/model.h" +#include "core/graph/op.h" +#include "core/graph/onnx_protobuf.h" +#include "core/framework/execution_providers.h" +#include "core/framework/op_kernel.h" +#include "core/framework/external_data_loader_manager.h" +#include "core/framework/session_state.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/onnxruntime_cxx_api.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +ONNX_OPERATOR_SCHEMA(KernelInfoStringArrayAttrOp) + .SetDoc("Test op for kernel info string-array attributes.") + .Attr("strings_attr", "Repeated string attribute for kernel info API tests.", + AttrType::AttributeProto_AttributeType_STRINGS, std::vector{}) + .Output(0, "output_1", "docstr for output_1.", "tensor(int32)"); + +static void VerifyKernelInfoStringArrayAttribute(const std::vector& attribute_values) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); + + onnxruntime::Model model("graph_kernel_info_string_attr", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ExecutionProviders execution_providers; + auto tmp_cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo(false)); + auto* cpu_execution_provider = tmp_cpu_execution_provider.get(); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider))); + + DataTransferManager dtm; + ExternalDataLoaderManager edlm; + profiling::Profiler profiler; + + SessionOptions sess_options; + sess_options.enable_mem_pattern = true; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = true; + + SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm, + DefaultLoggingManager().DefaultLogger(), profiler, sess_options); + + std::vector inputs; + std::vector outputs; + TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + onnxruntime::NodeArg output_arg("node_1_out_1", &output_type); + outputs.push_back(&output_arg); + + onnxruntime::Node& node = graph.AddNode("node_1", "KernelInfoStringArrayAttrOp", "node 1.", inputs, outputs); + node.AddAttribute("strings_attr", gsl::make_span(attribute_values)); + ASSERT_STATUS_OK(graph.Resolve()); + + auto kernel_def = KernelDefBuilder().SetName("KernelInfoStringArrayAttrOp").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); + + OpKernelInfo kernel_info(node, *kernel_def, *cpu_execution_provider, session_state.GetConstantInitializedTensors(), + session_state.GetOrtValueNameIdxMap(), session_state.GetDataTransferMgr(), session_state.GetAllocators(), + session_state.GetSessionOptions().config_options); + + const OrtApi& ort_api = Ort::GetApi(); + OrtAllocator* allocator = nullptr; + ASSERT_EQ(nullptr, ort_api.GetAllocatorWithDefaultOptions(&allocator)); + + size_t size = 0; + ASSERT_EQ(nullptr, ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast(&kernel_info), "strings_attr", + allocator, nullptr, &size)); + ASSERT_EQ(attribute_values.size(), size); + + char** out = nullptr; + ASSERT_EQ(nullptr, ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast(&kernel_info), "strings_attr", + allocator, &out, &size)); + ASSERT_EQ(attribute_values.size(), size); + + if (attribute_values.empty()) { + ASSERT_EQ(nullptr, out); + } else { + ASSERT_NE(nullptr, out); + for (size_t i = 0; i < size; ++i) { + EXPECT_STREQ(attribute_values[i].c_str(), out[i]); + allocator->Free(allocator, out[i]); + } + allocator->Free(allocator, out); + } + + Ort::ConstKernelInfo ort_kernel_info{reinterpret_cast(&kernel_info)}; + EXPECT_EQ(attribute_values, ort_kernel_info.GetAttributes("strings_attr")); + + OrtStatus* status = ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast(&kernel_info), "missing_attr", + allocator, nullptr, &size); + ASSERT_NE(nullptr, status); + ort_api.ReleaseStatus(status); +} + +TEST(KernelInfoTests, KernelInfoGetAttributeArrayString) { + VerifyKernelInfoStringArrayAttribute({"alpha", "beta", "gamma"}); +} + +TEST(KernelInfoTests, KernelInfoGetAttributeArrayStringEmpty) { + VerifyKernelInfoStringArrayAttribute({}); +} + +} // namespace test +} // namespace onnxruntime