Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7196,6 +7196,31 @@ struct OrtApi {
*/
ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);

/** \brief Get the element data type and shape for an OrtValue that represents a Tensor (scalar, dense, or sparse).
*
* \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for
* the shape data. The OrtValue instance's internal shape data is returned directly.
*
* \note Returns an error if the underlying OrtValue is not a Tensor.
*
* \param[in] value The OrtValue instance.
* \param[out] elem_type Output parameter set to the tensor element data type.
* \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array.
* For a scalar, `shape_data` is NULL and `shape_data_count` is 0.
* Must not be released as it is owned by the OrtValue instance. This pointer becomes invalid
* when the OrtValue is released or if the underlying shape data is updated or reallocated.
* \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`.
* `shape_data_count` is 0 for a scalar.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);

/** \brief Enable profiling for this run
*
* \param[in] options
Expand Down
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,19 @@ struct ConstValueImpl : Base<T> {
const R* GetSparseTensorValues() const;

#endif

/// <summary>
/// Returns the tensor's element type and a reference to the tensor's internal shape data. The shape data is owned
/// by the Ort::Value and becomes invalid when the Ort::Value is destroyed or if the underlying shape data is
/// updated or reallocated.
///
/// For a scalar, shape.shape is nullptr and shape.shape_len is 0.
///
/// Wraps OrtApi::GetTensorElementTypeAndShapeDataReference.
/// </summary>
/// <param name="elem_type">Output parameter set to the element's data type.</param>
/// <param name="shape">Output parameter set to the OrtValue instance's shape data and number of elements.</param>
void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, Shape& shape) const;
};

template <typename T>
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2387,6 +2387,13 @@ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {

#endif

template <typename T>
void ConstValueImpl<T>::GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type,
Shape& shape) const {
ThrowOnError(GetApi().GetTensorElementTypeAndShapeDataReference(this->p_, &elem_type, &shape.shape,
&shape.shape_len));
}

template <typename T>
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
Expand Down
58 changes: 58 additions & 0 deletions onnxruntime/core/framework/tensor_type_and_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,64 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
return GetTensorShapeAndTypeHelper(type, shape, dim_params);
}

ORT_API_STATUS_IMPL(OrtApis::GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count) {
API_IMPL_BEGIN
if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Input parameter `value` must contain a constructed tensor or sparse tensor");
}

if (elem_type == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `elem_type` must not be NULL");
}

if (shape_data == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `shape_data` must not be NULL");
}

if (shape_data_count == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `shape_data_count` must not be NULL");
}

gsl::span<const int64_t> shape_span;
onnxruntime::MLDataType ml_data_type = nullptr;
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;

if (value->IsTensor()) {
const Tensor& tensor = value->Get<onnxruntime::Tensor>();
ml_data_type = tensor.DataType();
shape_span = tensor.Shape().GetDims();
} else {
#if !defined(DISABLE_SPARSE_TENSORS)
const SparseTensor& tensor = value->Get<onnxruntime::SparseTensor>();
ml_data_type = tensor.DataType();
shape_span = tensor.DenseShape().GetDims();
#else
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build.");
#endif
}

if (ml_data_type != nullptr) {
type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type);
}

if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type");
}

*elem_type = type;
*shape_data = shape_span.empty() ? nullptr : shape_span.data();
*shape_data_count = shape_span.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
_In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) {
API_IMPL_BEGIN
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4802,6 +4802,7 @@ static constexpr OrtApi ort_api_1_to_25 = {
&OrtApis::EpAssignedNode_GetDomain,
&OrtApis::EpAssignedNode_GetOperatorType,
&OrtApis::RunOptionsSetSyncStream,
&OrtApis::GetTensorElementTypeAndShapeDataReference,
// End of Version 24 - DO NOT MODIFY ABOVE (see above text for more information)

&OrtApis::RunOptionsEnableProfiling,
Expand Down Expand Up @@ -4842,7 +4843,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz

static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change");
static_assert(offsetof(OrtApi, RunOptionsSetSyncStream) / sizeof(void*) == 413, "Size of version 24 API cannot change");
static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change");

// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.25.0",
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,4 +810,9 @@ ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgrap
ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);

ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);
} // namespace OrtApis
88 changes: 88 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,94 @@ TEST(CApiTest, dim_param) {
ASSERT_EQ(strcmp(dim_param, ""), 0);
}

// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a dense OrtValue tensor.
TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_DenseTensor) {
Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);

const std::array<int64_t, 2> x_shape = {3, 2};
std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(),
x_shape.data(), x_shape.size());
Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo();

ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
Ort::Value::Shape shape{};
x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape);

ASSERT_EQ(elem_type, type_shape_info.GetElementType());

std::vector<int64_t> expected_shape = type_shape_info.GetShape();
gsl::span<const int64_t> actual_shape(shape.shape, shape.shape_len);
ASSERT_EQ(actual_shape, gsl::span<const int64_t>(expected_shape));
}

// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a scalar OrtValue tensor.
TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_Scalar) {
Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);

std::vector<int64_t> x_shape = {}; // Scalar (no shape)
std::array<float, 1> x_values = {1.0f};
Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(),
x_shape.data(), x_shape.size());
Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo();

ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
Ort::Value::Shape shape{};
x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape);

ASSERT_EQ(elem_type, type_shape_info.GetElementType());

std::vector<int64_t> expected_shape = type_shape_info.GetShape();
gsl::span<const int64_t> actual_shape(shape.shape, shape.shape_len);
ASSERT_EQ(actual_shape, gsl::span<const int64_t>(expected_shape));
ASSERT_EQ(shape.shape, nullptr);
ASSERT_EQ(shape.shape_len, 0);
}

#if !defined(DISABLE_SPARSE_TENSORS)
// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a sparse OrtValue tensor.
TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_SparseTensor) {
std::vector<int64_t> common_shape{9, 9};
std::vector<float> A_values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0,
18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0,
34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0,
42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0,
50.0, 51.0, 52.0, 53.0};

// 2 - D index
std::vector<int64_t> indices_shape{gsl::narrow<int64_t>(A_values.size()), 2};
std::vector<int64_t> A_indices{0, 1, 0, 2, 0, 6, 0, 7, 0, 8, 1, 0, 1,
1, 1, 2, 1, 6, 1, 7, 1, 8, 2, 0, 2, 1,
2, 2, 2, 6, 2, 7, 2, 8, 3, 3, 3, 4, 3,
5, 3, 6, 3, 7, 3, 8, 4, 3, 4, 4, 4, 5,
4, 6, 4, 7, 4, 8, 5, 3, 5, 4, 5, 5, 5,
6, 5, 7, 5, 8, 6, 0, 6, 1, 6, 2, 6, 3,
6, 4, 6, 5, 7, 0, 7, 1, 7, 2, 7, 3, 7,
4, 7, 5, 8, 0, 8, 1, 8, 2, 8, 3, 8, 4,
8, 5};

Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
Ort::Value::Shape ort_dense_shape{common_shape.data(), common_shape.size()};
Ort::Value::Shape ort_values_shape{&indices_shape[0], 1U};
auto value_sparse = Ort::Value::CreateSparseTensor(info, A_values.data(), ort_dense_shape, ort_values_shape);
value_sparse.UseCooIndices(A_indices.data(), A_indices.size());

Ort::TensorTypeAndShapeInfo type_shape_info = value_sparse.GetTensorTypeAndShapeInfo();

ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
Ort::Value::Shape shape{};
value_sparse.GetTensorElementTypeAndShapeDataReference(elem_type, shape);

ASSERT_EQ(elem_type, type_shape_info.GetElementType());

std::vector<int64_t> expected_shape = type_shape_info.GetShape();
gsl::span<const int64_t> actual_shape(shape.shape, shape.shape_len);
ASSERT_EQ(actual_shape, gsl::span<const int64_t>(expected_shape));
}
#endif // !defined(DISABLE_SPARSE_TENSORS)

static std::pair<bool, bool> LoadAndGetInputShapePresent(const ORTCHAR_T* const model_url) {
Ort::Session session(*ort_env, model_url, Ort::SessionOptions{});
const auto input_num = session.GetInputCount();
Expand Down
Loading