Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7242,6 +7242,28 @@ struct OrtApi {
* \since Version 1.25.
*/
ORT_API2_STATUS(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options);

/** \brief Fetch an array of strings stored as an attribute in the graph node
*
* If `out` is nullptr, the value of `size` is set to the true size of the attribute
* array and a success status is returned.
*
* Otherwise, the strings and pointer array are allocated using `allocator`.
* The caller must free each string and the pointer array with `allocator`.
* If the attribute array is empty, `*out` is set to nullptr and `*size` is set to 0.
*
* \param[in] info instance
* \param[in] name name of the attribute to be parsed
* \param[in] allocator allocator used to allocate the returned string array and strings
Comment thread
tianleiwu marked this conversation as resolved.
* \param[out] out pointer to the returned array of null-terminated UTF-8 strings
* \param[out] size actual size of attribute array
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.25.
*/
ORT_API2_STATUS(KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size);
};

/*
Expand Down
3 changes: 2 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2868,6 +2868,7 @@ void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<std::string>&);
} // namespace attr_utils

template <typename T>
Expand All @@ -2884,7 +2885,7 @@ struct KernelInfoImpl : Base<T> {
return val;
}

template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
template <typename R> // R is only implemented for float, int64_t, and string
std::vector<R> GetAttributes(const char* name) const {
std::vector<R> result;
attr_utils::GetAttrs(this->p_, name, result);
Expand Down
33 changes: 33 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3063,6 +3063,39 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
out.swap(result);
}

inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<std::string>& result) {
AllocatorWithDefaultOptions allocator;
char** out = nullptr;
size_t size = 0;

Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_string(p, name, allocator, nullptr, &size));
if (size == 0) {
result.clear();
return;
}

Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_string(p, name, allocator, &out, &size));

auto deleter = detail::AllocatedFree(allocator);
std::unique_ptr<void, decltype(deleter)> array_guard(out, deleter);
auto strings_deleter = [&deleter, size](char** values) {
for (size_t i = 0; i < size; ++i) {
if (values[i] != nullptr) {
deleter(values[i]);
}
}
};
std::unique_ptr<char*, decltype(strings_deleter)> strings_guard(out, strings_deleter);

std::vector<std::string> strings;
strings.reserve(size);
for (size_t i = 0; i < size; ++i) {
strings.emplace_back(out[i]);
Comment thread
tianleiwu marked this conversation as resolved.
}

strings.swap(result);
}
} // namespace detail

inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
Expand Down
73 changes: 73 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <unordered_set>

#include <gsl/gsl>
#include "core/common/safeint.h"
#include "core/common/string_helper.h"
#include "core/common/logging/logging.h"
#include "core/framework/data_types.h"
#include "core/framework/error_code_helper.h"
Expand Down Expand Up @@ -545,6 +547,65 @@ static Status CopyDataFromVectorToMemory(const std::vector<T>& values, T* out, s
return Status::OK();
}

static char* DuplicateStringToAllocatorMemory(const std::string& value, OrtAllocator* allocator) {
SafeInt<size_t> allocation_size(value.size());
allocation_size += 1;

char* duplicated_value = static_cast<char*>(allocator->Alloc(allocator, allocation_size));
if (duplicated_value == nullptr) {
return nullptr;
}

std::memcpy(duplicated_value, value.data(), value.size());
duplicated_value[value.size()] = '\0';
return duplicated_value;
}

static Status CopyStringDataFromVectorToMemory(const std::vector<std::string>& values, OrtAllocator* allocator, char*** out, size_t* size) {
*size = values.size();

if (out == nullptr) {
return Status::OK();
}

ORT_RETURN_IF_NOT(allocator != nullptr, "allocator must not be null when out is provided");
*out = nullptr;

if (values.empty()) {
return Status::OK();
}

auto free_with_allocator = [allocator](void* value) {
allocator->Free(allocator, value);
};
SafeInt<size_t> alloc_count(values.size());
char** array = reinterpret_cast<char**>(allocator->Alloc(allocator, alloc_count * sizeof(char*)));
ORT_RETURN_IF_NOT(array != nullptr, "Failed to allocate string attribute pointer array");
std::unique_ptr<void, decltype(free_with_allocator)> array_guard(array, free_with_allocator);

size_t allocated_string_count = 0;
for (size_t i = 0; i < values.size(); ++i) {
char* duplicated_value = DuplicateStringToAllocatorMemory(values[i], allocator);
if (duplicated_value == nullptr) {
for (size_t j = 0; j < allocated_string_count; ++j) {
if (array[j] != nullptr) {
allocator->Free(allocator, array[j]);
}
}

return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL,
"Failed to allocate string attribute array");
Comment thread
tianleiwu marked this conversation as resolved.
}

array[i] = duplicated_value;
++allocated_string_count;
}

*out = array;
array_guard.release();
return Status::OK();
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ float* out, _Inout_ size_t* size) {
return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr {
Expand All @@ -569,6 +630,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKe
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size) {
return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr {
std::vector<std::string> values;
auto status = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAttrs<std::string>(name, values);
if (status.IsOK()) {
status = CopyStringDataFromVectorToMemory(values, allocator, out, size);
}
return onnxruntime::ToOrtStatus(status);
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) {
return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4807,6 +4807,8 @@ static constexpr OrtApi ort_api_1_to_25 = {

&OrtApis::RunOptionsEnableProfiling,
&OrtApis::RunOptionsDisableProfiling,
&OrtApis::KernelInfoGetAttributeArray_string,
// End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information)
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down Expand Up @@ -4844,6 +4846,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change");
static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change");
static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_string) / sizeof(void*) == 417, "Size of version 25 API cannot change");

// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.25.0",
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ ORT_API_STATUS_IMPL(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options);
ORT_API(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);
ORT_API_STATUS_IMPL(RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
ORT_API_STATUS_IMPL(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options);
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_string, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*size) char*** out, _Out_ size_t* size);

ORT_API_STATUS_IMPL(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator,
_In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type,
Expand Down
114 changes: 114 additions & 0 deletions onnxruntime/test/framework/kernel_info_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Comment thread Fixed
// Licensed under the MIT License.

#include "gtest/gtest.h"
#include "test/test_environment.h"
#include "test/util/include/asserts.h"
#include "core/graph/model.h"
#include "core/graph/op.h"
#include "core/graph/onnx_protobuf.h"
#include "core/framework/execution_providers.h"
#include "core/framework/op_kernel.h"
#include "core/framework/external_data_loader_manager.h"
#include "core/framework/session_state.h"
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/session/onnxruntime_cxx_api.h"

using namespace ONNX_NAMESPACE;

namespace onnxruntime {
namespace test {

ONNX_OPERATOR_SCHEMA(KernelInfoStringArrayAttrOp)
.SetDoc("Test op for kernel info string-array attributes.")
.Attr("strings_attr", "Repeated string attribute for kernel info API tests.",
AttrType::AttributeProto_AttributeType_STRINGS, std::vector<std::string>{})
.Output(0, "output_1", "docstr for output_1.", "tensor(int32)");

static void VerifyKernelInfoStringArrayAttribute(const std::vector<std::string>& attribute_values) {
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);

onnxruntime::Model model("graph_kernel_info_string_attr", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();

ExecutionProviders execution_providers;
auto tmp_cpu_execution_provider = std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo(false));
auto* cpu_execution_provider = tmp_cpu_execution_provider.get();
ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(tmp_cpu_execution_provider)));

DataTransferManager dtm;
ExternalDataLoaderManager edlm;
profiling::Profiler profiler;

SessionOptions sess_options;
sess_options.enable_mem_pattern = true;
sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
sess_options.use_deterministic_compute = false;
sess_options.enable_mem_reuse = true;

SessionState session_state(graph, execution_providers, tp.get(), nullptr, dtm, edlm,
DefaultLoggingManager().DefaultLogger(), profiler, sess_options);

std::vector<onnxruntime::NodeArg*> inputs;
std::vector<onnxruntime::NodeArg*> outputs;
TypeProto output_type;
output_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32);
output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
onnxruntime::NodeArg output_arg("node_1_out_1", &output_type);
outputs.push_back(&output_arg);

onnxruntime::Node& node = graph.AddNode("node_1", "KernelInfoStringArrayAttrOp", "node 1.", inputs, outputs);
node.AddAttribute("strings_attr", gsl::make_span(attribute_values));
ASSERT_STATUS_OK(graph.Resolve());

auto kernel_def = KernelDefBuilder().SetName("KernelInfoStringArrayAttrOp").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build();

OpKernelInfo kernel_info(node, *kernel_def, *cpu_execution_provider, session_state.GetConstantInitializedTensors(),
session_state.GetOrtValueNameIdxMap(), session_state.GetDataTransferMgr(), session_state.GetAllocators(),
session_state.GetSessionOptions().config_options);

const OrtApi& ort_api = Ort::GetApi();
OrtAllocator* allocator = nullptr;
ASSERT_EQ(nullptr, ort_api.GetAllocatorWithDefaultOptions(&allocator));

size_t size = 0;
ASSERT_EQ(nullptr, ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast<const OrtKernelInfo*>(&kernel_info), "strings_attr",
allocator, nullptr, &size));
ASSERT_EQ(attribute_values.size(), size);

char** out = nullptr;
ASSERT_EQ(nullptr, ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast<const OrtKernelInfo*>(&kernel_info), "strings_attr",
allocator, &out, &size));
ASSERT_EQ(attribute_values.size(), size);

if (attribute_values.empty()) {
ASSERT_EQ(nullptr, out);
} else {
ASSERT_NE(nullptr, out);
for (size_t i = 0; i < size; ++i) {
EXPECT_STREQ(attribute_values[i].c_str(), out[i]);
allocator->Free(allocator, out[i]);
}
allocator->Free(allocator, out);
}

Ort::ConstKernelInfo ort_kernel_info{reinterpret_cast<const OrtKernelInfo*>(&kernel_info)};
EXPECT_EQ(attribute_values, ort_kernel_info.GetAttributes<std::string>("strings_attr"));

OrtStatus* status = ort_api.KernelInfoGetAttributeArray_string(reinterpret_cast<const OrtKernelInfo*>(&kernel_info), "missing_attr",
allocator, nullptr, &size);
ASSERT_NE(nullptr, status);
ort_api.ReleaseStatus(status);
}

TEST(KernelInfoTests, KernelInfoGetAttributeArrayString) {
VerifyKernelInfoStringArrayAttribute({"alpha", "beta", "gamma"});
}

TEST(KernelInfoTests, KernelInfoGetAttributeArrayStringEmpty) {
VerifyKernelInfoStringArrayAttribute({});
}

} // namespace test
} // namespace onnxruntime
Loading