Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3050,11 +3050,12 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void*
std::unordered_map<std::string, DDSOutputAllocatorMap>& dds_output_allocator_maps = ep.GetDDSOutputAllocators();
auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name];

// Get default OrtMemoryInfo from factory
const OrtMemoryInfo* mem_info = nullptr;
if (ep.factory_.cuda_gpu_memory_infos.find(device_id) !=
ep.factory_.cuda_gpu_memory_infos.end()) {
mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get();
// Get default OrtMemoryInfo from factory's device cache
const OrtMemoryInfo* mem_info = ep.factory_.GetMemoryInfoByOrdinal(device_id, /* is pinned */false);
if (mem_info == nullptr) {
std::string err_msg = "TensorRT EP failed to get OrtMemoryInfo for device_id "
+ std::to_string(device_id) + " from provider factory.";
return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str());
}

// Get allocator from OrtKernelContext
Expand Down Expand Up @@ -3770,11 +3771,12 @@ OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_p
std::unordered_map<std::string, std::vector<int32_t>> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run
std::unordered_map<std::string, std::vector<int64_t>> shape_tensor_values_int64; // same as above but for int64 shape tensor input

// Get default OrtMemoryInfo from factory
const OrtMemoryInfo* mem_info = nullptr;
if (ep.factory_.cuda_gpu_memory_infos.find(device_id) !=
ep.factory_.cuda_gpu_memory_infos.end()) {
mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get();
// Get default OrtMemoryInfo from factory's device cache
const OrtMemoryInfo* mem_info = ep.factory_.GetMemoryInfoByOrdinal(device_id, /* is pinned */false);
if (mem_info == nullptr) {
std::string err_msg = "TensorRT EP failed to get OrtMemoryInfo for device_id "
+ std::to_string(device_id) + " from provider factory.";
return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str());
}

// Get allocator from OrtKernelContext
Expand Down
2 changes: 1 addition & 1 deletion src/tensorrt_execution_provider.def
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
LIBRARY "TensorRTEp.dll"
LIBRARY "ORTTensorRTEp.dll"
EXPORTS
CreateEpFactories @1
ReleaseEpFactory @2
Expand Down
14 changes: 3 additions & 11 deletions src/tensorrt_execution_provider_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this
auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device);
auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device);

// 0x10DE is the PCI vendor ID for NVIDIA
constexpr uint32_t nvidia_vendor_id = 0x10DE;

// Reject if GPU device is not NVIDIA
if ((src_type == OrtMemoryInfoDeviceType_GPU && src_vendor_id != nvidia_vendor_id) ||
(dst_type == OrtMemoryInfoDeviceType_GPU && dst_vendor_id != nvidia_vendor_id)) {
if ((src_type == OrtMemoryInfoDeviceType_GPU && src_vendor_id != kNvidiaVendorId) ||
(dst_type == OrtMemoryInfoDeviceType_GPU && dst_vendor_id != kNvidiaVendorId)) {
return false;
}

Expand Down Expand Up @@ -110,11 +107,6 @@ OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl*

/*static*/
void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept {
// In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore
// the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h)
//
// If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here
// delete static_cast<TRTEpDataTransfer*>(this_ptr);
;
delete static_cast<TRTEpDataTransfer*>(this_ptr);
}
} // namespace trt_ep
8 changes: 1 addition & 7 deletions src/tensorrt_execution_provider_data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
namespace trt_ep {

struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector<const OrtMemoryDevice*>& device_mem_infos,
std::vector<const OrtMemoryDevice*>& shared_mem_infos)
: ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} {
TRTEpDataTransfer(ApiPtrs api_ptrs) : OrtDataTransferImpl{}, ApiPtrs(api_ptrs) {
CanCopy = CanCopyImpl;
CopyTensors = CopyTensorsImpl;
Release = ReleaseImpl;
Expand All @@ -26,9 +24,5 @@ struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs {
OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr,
size_t num_tensors) noexcept;
static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept;

private:
std::vector<const OrtMemoryDevice*>& cuda_gpu_mem_devices_;
std::vector<const OrtMemoryDevice*>& cuda_pinned_mem_devices_;
};
} // namespace trt_ep
Loading
Loading