diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3f058a38d9bfb..dd2736c4a7598 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7196,6 +7196,31 @@ struct OrtApi { */ ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); + /** \brief Get the element data type and shape for an OrtValue that represents a Tensor (scalar, dense, or sparse). + * + * \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for + * the shape data. The OrtValue instance's internal shape data is returned directly. + * + * \note Returns an error if the underlying OrtValue is not a Tensor. + * + * \param[in] value The OrtValue instance. + * \param[out] elem_type Output parameter set to the tensor element data type. + * \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array. + * For a scalar, `shape_data` is NULL and `shape_data_count` is 0. + * Must not be released as it is owned by the OrtValue instance. This pointer becomes invalid + * when the OrtValue is released or if the underlying shape data is updated or reallocated. + * \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`. + * `shape_data_count` is 0 for a scalar. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count); + /** \brief Enable profiling for this run * * \param[in] options diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 3e7ddf0075adb..2c1d52894e7f3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2233,6 +2233,19 @@ struct ConstValueImpl : Base { const R* GetSparseTensorValues() const; #endif + + /// + /// Returns the tensor's element type and a reference to the tensor's internal shape data. The shape data is owned + /// by the Ort::Value and becomes invalid when the Ort::Value is destroyed or if the underlying shape data is + /// updated or reallocated. + /// + /// For a scalar, shape.shape is nullptr and shape.shape_len is 0. + /// + /// Wraps OrtApi::GetTensorElementTypeAndShapeDataReference. + /// + /// Output parameter set to the element's data type. + /// Output parameter set to the OrtValue instance's shape data and number of elements. + void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, Shape& shape) const; }; template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e416e470f8144..745128fe6c7b4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2387,6 +2387,13 @@ inline const R* ConstValueImpl::GetSparseTensorValues() const { #endif +template +void ConstValueImpl::GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, + Shape& shape) const { + ThrowOnError(GetApi().GetTensorElementTypeAndShapeDataReference(this->p_, &elem_type, &shape.shape, + &shape.shape_len)); +} + template void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 0bac24a2c3aa0..16817ba1707bd 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -310,6 +310,64 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS return GetTensorShapeAndTypeHelper(type, shape, dim_params); } +ORT_API_STATUS_IMPL(OrtApis::GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count) { + API_IMPL_BEGIN + if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Input parameter `value` must contain a constructed tensor or sparse tensor"); + } + + if (elem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `elem_type` must not be NULL"); + } + + if (shape_data == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `shape_data` must not be NULL"); + } + + if (shape_data_count == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `shape_data_count` must not be NULL"); + } + + gsl::span shape_span; + onnxruntime::MLDataType ml_data_type = nullptr; + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + if (value->IsTensor()) { + const Tensor& tensor = value->Get(); + ml_data_type = tensor.DataType(); + shape_span = tensor.Shape().GetDims(); + } else { +#if !defined(DISABLE_SPARSE_TENSORS) + const SparseTensor& tensor = value->Get(); + ml_data_type = tensor.DataType(); + shape_span = tensor.DenseShape().GetDims(); +#else + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build."); +#endif + } + + if (ml_data_type != nullptr) { + type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type); + } + + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type"); + } + + *elem_type = type; + *shape_data = shape_span.empty() ? nullptr : shape_span.data(); + *shape_data_count = shape_span.size(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5d8aeb521be08..7a027c8eafb81 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4802,6 +4802,7 @@ static constexpr OrtApi ort_api_1_to_25 = { &OrtApis::EpAssignedNode_GetDomain, &OrtApis::EpAssignedNode_GetOperatorType, &OrtApis::RunOptionsSetSyncStream, + &OrtApis::GetTensorElementTypeAndShapeDataReference, // End of Version 24 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::RunOptionsEnableProfiling, @@ -4842,7 +4843,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); -static_assert(offsetof(OrtApi, RunOptionsSetSyncStream) / sizeof(void*) == 413, "Size of version 24 API cannot change"); +static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.25.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fad7e8e9c31bb..3d990909cfb41 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -810,4 +810,9 @@ ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgrap ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + +ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count); } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index a96a2c48b4ca6..4e991716dd108 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -480,6 +480,94 @@ TEST(CApiTest, dim_param) { ASSERT_EQ(strcmp(dim_param, ""), 0); } +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a dense OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_DenseTensor) { + Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); + + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(), + x_shape.data(), x_shape.size()); + Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); +} + +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a scalar OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_Scalar) { + Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); + + std::vector x_shape = {}; // Scalar (no shape) + std::array x_values = {1.0f}; + Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(), + x_shape.data(), x_shape.size()); + Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); + ASSERT_EQ(shape.shape, nullptr); + ASSERT_EQ(shape.shape_len, 0); +} + +#if !defined(DISABLE_SPARSE_TENSORS) +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a sparse OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_SparseTensor) { + std::vector common_shape{9, 9}; + std::vector A_values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, + 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, + 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, + 50.0, 51.0, 52.0, 53.0}; + + // 2 - D index + std::vector indices_shape{gsl::narrow(A_values.size()), 2}; + std::vector A_indices{0, 1, 0, 2, 0, 6, 0, 7, 0, 8, 1, 0, 1, + 1, 1, 2, 1, 6, 1, 7, 1, 8, 2, 0, 2, 1, + 2, 2, 2, 6, 2, 7, 2, 8, 3, 3, 3, 4, 3, + 5, 3, 6, 3, 7, 3, 8, 4, 3, 4, 4, 4, 5, + 4, 6, 4, 7, 4, 8, 5, 3, 5, 4, 5, 5, 5, + 6, 5, 7, 5, 8, 6, 0, 6, 1, 6, 2, 6, 3, + 6, 4, 6, 5, 7, 0, 7, 1, 7, 2, 7, 3, 7, + 4, 7, 5, 8, 0, 8, 1, 8, 2, 8, 3, 8, 4, + 8, 5}; + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + Ort::Value::Shape ort_dense_shape{common_shape.data(), common_shape.size()}; + Ort::Value::Shape ort_values_shape{&indices_shape[0], 1U}; + auto value_sparse = Ort::Value::CreateSparseTensor(info, A_values.data(), ort_dense_shape, ort_values_shape); + value_sparse.UseCooIndices(A_indices.data(), A_indices.size()); + + Ort::TensorTypeAndShapeInfo type_shape_info = value_sparse.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + value_sparse.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); +} +#endif // !defined(DISABLE_SPARSE_TENSORS) + static std::pair LoadAndGetInputShapePresent(const ORTCHAR_T* const model_url) { Ort::Session session(*ort_env, model_url, Ort::SessionOptions{}); const auto input_num = session.GetInputCount();