Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7217,6 +7217,29 @@ struct OrtApi {
* \since Version 1.25.
*/
ORT_API2_STATUS(RunOptionsDisableProfiling, _Inout_ OrtRunOptions* options);

/** \brief Get the element data type and shape for an OrtValue that represents a Tensor (sparse or dense).
*
* \note Returns an error if the underlying OrtValue is not a Tensor.
*
* \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for
* the shape data. The OrtValue instance's internal shape data is returned directly.
*
* \param[in] value The OrtValue instance.
* \param[out] elem_type Output parameter set to the tensor element data type.
* \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array.
* Must not be released as it is owned by the OrtValue instance. This data becomes invalid
* when the OrtValue is released.
* \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.25.
Comment thread
adrianlizarraga marked this conversation as resolved.
Outdated
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
*/
ORT_API2_STATUS(Value_GetTensorElementTypeAndShape, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);
};

/*
Expand Down
11 changes: 11 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,17 @@ struct ConstValueImpl : Base<T> {
const R* GetSparseTensorValues() const;

#endif

/// <summary>
/// Wraps OrtApi::Value_GetTensorElementTypeAndShape. Returns the tensor's type and shape without allocating a new
/// buffer for the shape.
/// </summary>
/// <param name="elem_type"></param>
/// <param name="shape_data"></param>
/// <param name="shape_data_count"></param>
/// <returns></returns>
Ort::Status GetTensorElementTypeAndShape(ONNXTensorElementDataType& elem_type,
const int64_t*& shape_data, size_t& shape_data_count) const;
};

template <typename T>
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2387,6 +2387,14 @@ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {

#endif

template <typename T>
Ort::Status ConstValueImpl<T>::GetTensorElementTypeAndShape(ONNXTensorElementDataType& elem_type,
const int64_t*& shape_data,
size_t& shape_data_count) const {
Ort::Status status{GetApi().Value_GetTensorElementTypeAndShape(this->p_, &elem_type, &shape_data, &shape_data_count)};
Comment thread
adrianlizarraga marked this conversation as resolved.
Outdated
return status;
}

template <typename T>
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
Expand Down
58 changes: 58 additions & 0 deletions onnxruntime/core/framework/tensor_type_and_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,64 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
return GetTensorShapeAndTypeHelper(type, shape, dim_params);
}

ORT_API_STATUS_IMPL(OrtApis::Value_GetTensorElementTypeAndShape, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_ const int64_t** shape_data,
_Out_ size_t* shape_data_count) {
API_IMPL_BEGIN
if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"the ort_value must contain a constructed tensor or sparse tensor");
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
}

if (elem_type == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `elem_type` must not be NULL");
}

if (shape_data == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `shape_data` must not be NULL");
}

if (shape_data_count == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `shape_data_count` must not be NULL");
}

gsl::span<const int64_t> shape_span;
onnxruntime::MLDataType ml_data_type = nullptr;
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;

if (value->IsTensor()) {
const Tensor& tensor = value->Get<onnxruntime::Tensor>();
ml_data_type = tensor.DataType();
shape_span = tensor.Shape().GetDims();
} else {
#if !defined(DISABLE_SPARSE_TENSORS)
const SparseTensor& tensor = value->Get<onnxruntime::SparseTensor>();
ml_data_type = tensor.DataType();
shape_span = tensor.DenseShape().GetDims();
#else
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build.");
#endif
}

if (ml_data_type != nullptr) {
type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type);
}

if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type");
}

*elem_type = type;
*shape_data = shape_span.data();
*shape_data_count = shape_span.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
_In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) {
API_IMPL_BEGIN
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4806,6 +4806,7 @@ static constexpr OrtApi ort_api_1_to_25 = {

&OrtApis::RunOptionsEnableProfiling,
&OrtApis::RunOptionsDisableProfiling,
&OrtApis::Value_GetTensorElementTypeAndShape,
Comment thread
adrianlizarraga marked this conversation as resolved.
Outdated
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -810,4 +810,9 @@ ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgrap
ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);

ORT_API_STATUS_IMPL(Value_GetTensorElementTypeAndShape, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);
} // namespace OrtApis
22 changes: 22 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "test/shared_lib/custom_op_utils.h"
#include "test/shared_lib/test_fixture.h"
#include "test/shared_lib/utils.h"
#include "test/util/include/api_asserts.h"
#include "test/util/include/providers.h"
#include "test/util/include/test_allocator.h"

Expand Down Expand Up @@ -480,6 +481,27 @@ TEST(CApiTest, dim_param) {
ASSERT_EQ(strcmp(dim_param, ""), 0);
}

TEST(CApiTest, Value_GetTensorElementTypeAndShape) {
Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault);

const std::array<int64_t, 2> x_shape = {3, 2};
std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(),
x_shape.data(), x_shape.size());
Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo();

ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
const int64_t* shape_data = nullptr;
size_t shape_data_count = 0;
ASSERT_ORTSTATUS_OK(x_value.GetTensorElementTypeAndShape(elem_type, shape_data, shape_data_count));

ASSERT_EQ(elem_type, type_shape_info.GetElementType());

std::vector<int64_t> expected_shape = type_shape_info.GetShape();
gsl::span<const int64_t> actual_shape(shape_data, shape_data_count);
ASSERT_EQ(actual_shape, gsl::span<const int64_t>(expected_shape));
}

static std::pair<bool, bool> LoadAndGetInputShapePresent(const ORTCHAR_T* const model_url) {
Ort::Session session(*ort_env, model_url, Ort::SessionOptions{});
const auto input_num = session.GetInputCount();
Expand Down
Loading