Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5397,24 +5397,20 @@ struct OrtApi {
/** \brief Get the device memory type from ::OrtMemoryInfo
*
* \param[in] ptr The OrtMemoryInfo instance to query.
* \param[out] out The device memory type.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \return The device memory type.
*
* \since Version 1.23
*/
ORT_API2_STATUS(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out);
ORT_API_T(OrtDeviceMemoryType, MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr);

/** \brief Get the vendor id from ::OrtMemoryInfo
*
* \param[in] ptr The OrtMemoryInfo instance to query.
* \param[out] out The vendor id.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \return The vendor id.
*
* \since Version 1.23
*/
ORT_API2_STATUS(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out);
ORT_API_T(uint32_t, MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr);

/// \name OrtValueInfo
/// @{
Expand Down Expand Up @@ -6081,15 +6077,14 @@ struct OrtApi {
*
* \param[in] options The OrtRunOptions instance.
* \param[in] config_key The configuration entry key. A null-terminated string.
* \param[out] config_value Output parameter set to the configuration entry value. Either a null-terminated string if
* a configuration entry exists or NULL otherwise.
* Do not free this value. It is owned by `options` and will be invalidated if another call
* to `AddRunConfigEntry()` overwrites it.
* \return The configuration entry value. Either a null-terminated string if the entry was found. nullptr otherwise.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(GetRunConfigEntry, _In_ const OrtRunOptions* options,
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value);
ORT_API_T(const char*, GetRunConfigEntry, _In_ const OrtRunOptions* options,
_In_z_ const char* config_key);

/// @}

Expand Down Expand Up @@ -6173,7 +6168,7 @@ struct OrtApi {
/** \brief Get a const pointer to the raw data inside a tensor
*
* Used to read the internal tensor data directly.
* \note The returned pointer is valid until the \p value is destroyed.
* \note The returned pointer is valid until the OrtValue is destroyed.
*
* \param[in] value A tensor type (string tensors are not supported)
* \param[out] out Filled in with a pointer to the internal storage
Expand Down
4 changes: 1 addition & 3 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,7 @@ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char
}

inline const char* RunOptions::GetConfigEntry(const char* config_key) {
const char* out{};
ThrowOnError(GetApi().GetRunConfigEntry(p_, config_key, &out));
return out;
return GetApi().GetRunConfigEntry(p_, config_key);
}

inline RunOptions& RunOptions::SetTerminate() {
Expand Down
5 changes: 2 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,11 @@ struct OrtEpApi {
/** \brief Get the OrtMemoryDevice from an OrtValue instance if it contains a Tensor.
*
* \param[in] value The OrtValue instance to get the memory device from.
* \param[out] device The OrtMemoryDevice associated with the OrtValue instance.
* \return Status Success if OrtValue contains a Tensor. Otherwise, an error status is returned.
* \return Memory device if OrtValue contains a Tensor, nullptr otherwise.
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ const OrtMemoryDevice** device);
ORT_API_T(const OrtMemoryDevice*, Value_GetMemoryDevice, _In_ const OrtValue* value);

/** \brief Compare two OrtMemoryDevice instances for equality.
*
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,10 @@ ORT_API(void, OrtApis::MemoryInfoGetDeviceType, _In_ const OrtMemoryInfo* info,
*out = static_cast<OrtMemoryInfoDeviceType>(info->device.Type());
}

ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out) {
*out = static_cast<OrtDeviceMemoryType>(ptr->device.MemType());
return nullptr;
ORT_API(OrtDeviceMemoryType, OrtApis::MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr) {
return static_cast<OrtDeviceMemoryType>(ptr->device.MemType());
}

ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out) {
*out = ptr->device.Vendor();
return nullptr;
ORT_API(uint32_t, OrtApis::MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr) {
return ptr->device.Vendor();
}
10 changes: 3 additions & 7 deletions onnxruntime/core/framework/run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,13 @@ ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options,
return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value));
}

ORT_API_STATUS_IMPL(OrtApis::GetRunConfigEntry, _In_ const OrtRunOptions* options,
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value_out) {
API_IMPL_BEGIN
ORT_API(const char*, OrtApis::GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key) {
const auto& config_options = options->config_options.GetConfigOptionsMap();
if (auto it = config_options.find(config_key); it != config_options.end()) {
*config_value_out = it->second.c_str();
return it->second.c_str();
} else {
*config_value_out = nullptr;
return nullptr;
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options,
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,8 @@ struct CudaDataTransferImpl : OrtDataTransferImpl {
OrtValue* dst_tensor = dst_tensors[idx];
OrtSyncStream* stream = streams ? streams[idx] : nullptr;

const OrtMemoryDevice *src_device = nullptr, *dst_device = nullptr;
RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensor, &src_device));
RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensor, &dst_device));
const OrtMemoryDevice* src_device = impl.ep_api.Value_GetMemoryDevice(src_tensor);
const OrtMemoryDevice* dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensor);

size_t bytes;
RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensor, &bytes));
Expand Down
9 changes: 3 additions & 6 deletions onnxruntime/core/session/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,13 @@ ORT_API(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemory
return static_cast<const OrtMemoryDevice*>(&memory_info->device);
}

ORT_API_STATUS_IMPL(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ const OrtMemoryDevice** device) {
*device = nullptr;
ORT_API(const OrtMemoryDevice*, Value_GetMemoryDevice, _In_ const OrtValue* value) {
if (value == nullptr || value->IsTensor() == false) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue does not contain an allocated tensor.");
return nullptr; // Tensor always has a device, so we don't need a more specific error here.
}

auto& tensor = value->Get<Tensor>();
*device = static_cast<const OrtMemoryDevice*>(&tensor.Location().device);

return nullptr;
return static_cast<const OrtMemoryDevice*>(&tensor.Location().device);
}

ORT_API(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ORT_API_STATUS_IMPL(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device,
_In_ const OrtMemoryInfo* allocator_memory_info);

ORT_API(const OrtMemoryDevice*, MemoryInfo_GetMemoryDevice, _In_ const OrtMemoryInfo* memory_info);
ORT_API_STATUS_IMPL(Value_GetMemoryDevice, _In_ const OrtValue* value, _Out_ const OrtMemoryDevice** device);
ORT_API(const OrtMemoryDevice*, Value_GetMemoryDevice, _In_ const OrtValue* value);

ORT_API(bool, MemoryDevice_AreEqual, _In_ const OrtMemoryDevice* a, _In_ const OrtMemoryDevice* b);
ORT_API(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemoryDevice* memory_device);
Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ ORT_API_STATUS_IMPL(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMem
_In_ size_t alignment, enum OrtAllocatorType allocator_type,
_Outptr_ OrtMemoryInfo** out);

ORT_API_STATUS_IMPL(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out);
ORT_API_STATUS_IMPL(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out);
ORT_API(OrtDeviceMemoryType, MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr);
ORT_API(uint32_t, MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr);

// OrtValueInfo
ORT_API_STATUS_IMPL(ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info,
Expand Down Expand Up @@ -685,8 +685,7 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node,
ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph);
ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out);

ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options,
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value);
ORT_API(const char*, GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key);

ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device,
_In_ OrtDeviceMemoryType memory_type);
Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/test/autoep/library/ep_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(OrtDataTransferImpl
for (size_t i = 0; i < num_tensors; ++i) {
// the implementation for a 'real' EP would be something along these lines.
// See CudaDataTransferImpl in onnxruntime\core\providers\cuda\cuda_provider_factory.cc
const OrtMemoryDevice* src_device = nullptr;
const OrtMemoryDevice* dst_device = nullptr;
RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device));
RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensors[i], &dst_device));
const OrtMemoryDevice* src_device = impl.ep_api.Value_GetMemoryDevice(src_tensors[i]);
const OrtMemoryDevice* dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensors[i]);

OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device);
OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device);
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/test/shared_lib/test_data_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,8 @@ TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) {

for (size_t idx = 0; idx < num_inputs; ++idx) {
const OrtMemoryInfo* mem_info = input_locations[idx];
OrtDeviceMemoryType mem_type;
OrtDeviceMemoryType mem_type = api->MemoryInfoGetDeviceMemType(mem_info);
OrtMemoryInfoDeviceType device_type;
ASSERT_ORTSTATUS_OK(api->MemoryInfoGetDeviceMemType(mem_info, &mem_type));
api->MemoryInfoGetDeviceType(mem_info, &device_type);

if (device_type == OrtMemoryInfoDeviceType_GPU && mem_type == OrtDeviceMemoryType_DEFAULT) {
Expand Down
Loading