diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index b7997ce86737a..93b673f2df5bd 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -20,7 +20,6 @@ #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_graph.h" -#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include "core/common/parse_string.h" @@ -85,40 +84,6 @@ struct ShutdownProtobuf { namespace onnxruntime { -namespace cuda { -template <> -void Impl_Cast( - cudaStream_t stream, - const int64_t* input_data, int32_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const int32_t* input_data, int64_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const double* input_data, float* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const float* input_data, double* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} -} // namespace cuda - void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -372,51 +337,19 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ - skip_input_binding_allowed = false; \ - if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - data = scratch_buffers.back().get(); \ - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - data = scratch_buffers.back().get(); \ - } \ - break; \ - } - #define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ data_ptr = output_tensor_ptr; \ if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - buffers[output_name] = output_tensor_ptr; \ + buffer = output_tensor_ptr; \ } else { \ scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ + buffer = scratch_buffers.back().get(); \ } \ break; \ } -#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ - case DATA_TYPE: { \ - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ - data_ptr = output_tensor_ptr; \ - skip_output_binding_allowed = false; \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = static_cast(elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = 1; \ - } \ - break; \ - } - #define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ @@ -426,15 +359,6 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ - } \ - break; \ - } - /* * Set Nv executio context input. * @@ -557,7 +481,6 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -582,8 +505,6 @@ Status BindContextInput(Ort::KernelContext& ctx, * param output_type - Data type of the output * param i - Output iteration index * param output_tensors - Output iteration index to output's ORT value - * param output_dim_sizes - Output iteration index to the multiplocation of its shape's dimensions - * param dds_output_set - DDS output set * param dds_output_allocator_map - DDS output to its allocator * param scratch_buffer - The allocation buffer created by TRT EP * param allocator - ORT allocator @@ -595,16 +516,11 @@ Status BindContextOutput(Ort::KernelContext& ctx, const char* output_name, size_t output_index, size_t output_type, - size_t i, - std::unordered_map& output_tensors, - std::unordered_map& output_dim_sizes, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, - std::unordered_map& buffers, nvinfer1::Dims& dims, - void*& data_ptr, - bool& skip_output_binding_allowed) { + void*& data_ptr) { // Get output shape dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; @@ -634,10 +550,11 @@ Status BindContextOutput(Ort::KernelContext& ctx, data_ptr = nullptr; // Set data_ptr to nullptr for DDS output binding. } } else { - output_tensors[i] = ctx.GetOutput(output_index, dims.d, nb_dims); - auto& output_tensor = output_tensors[i]; + auto output_tensor = ctx.GetOutput(output_index, dims.d, nb_dims); const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + void* buffer = nullptr; + switch (output_type) { // below macros set data_ptr and skip_output_binding_allowed variables CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) @@ -648,13 +565,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - trt_context->setTensorAddress(output_name, buffers[output_name]); + trt_context->setTensorAddress(output_name, buffer); } return Status::OK(); @@ -711,7 +627,6 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -2837,7 +2752,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node.Name(), std::move(trt_parser)); engines_.emplace(fused_node.Name(), std::move(trt_engine)); contexts_.emplace(fused_node.Name(), std::move(trt_context)); networks_.emplace(fused_node.Name(), std::move(trt_network)); @@ -2853,7 +2767,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], + &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, @@ -2891,7 +2805,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; - int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; if (alloc_ == nullptr) { @@ -2966,16 +2879,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - if (require_io_binding) { - bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -2993,16 +2897,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); + + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } - - trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3082,14 +2985,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } } } @@ -3213,7 +3108,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); - int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -3283,16 +3177,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - if (require_io_binding) { - bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3311,16 +3196,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } - - trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3401,14 +3284,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 22b8314649757..9e5fd03756f02 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -195,7 +195,6 @@ struct TensorrtFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; nvinfer1::IBuilder* builder; - tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; std::unique_ptr* network = nullptr; @@ -386,7 +385,6 @@ class NvExecutionProvider : public IExecutionProvider { // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. - std::unordered_map> parsers_; std::unordered_map> engines_; std::unordered_map> contexts_; std::unordered_map> builders_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 24554560b4dde..9679da7cea2ff 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1782,7 +1782,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra type = OrtDevice::GPU; vendor = OrtDevice::VendorIds::MICROSOFT; } else if (type == OrtDevice::GPU) { -#if USE_CUDA +#if USE_CUDA || USE_NV || USE_NV_PROVIDER_INTERFACE || USE_CUDA_PROVIDER_INTERFACE vendor = OrtDevice::VendorIds::NVIDIA; #elif USE_ROCM || USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; diff --git a/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py new file mode 100644 index 0000000000000..d5c80a4a1f4ba --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py @@ -0,0 +1,468 @@ +# Copyright (c) NVIDIA Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import sys +import unittest +from collections.abc import Sequence + +import numpy as np +import torch +from autoep_helper import AutoEpTestCase +from helper import get_name +from numpy.testing import assert_almost_equal +from onnx import TensorProto, helper +from onnx.defs import onnx_opset_version + +import onnxruntime as onnxrt +from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice +from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue +from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding + + +class TestNvTensorRTRTXAutoEP(AutoEpTestCase): + """ + Test suite for the NvTensorRTRTX Execution Provider. + + This class contains tests for registering the NvTensorRTRTX EP, + selecting it using different policies, and running inference with various + I/O binding configurations. + """ + + ep_lib_path = "onnxruntime_providers_nv_tensorrt_rtx.dll" + ep_name = "NvTensorRTRTXExecutionProvider" + + def setUp(self): + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + self.register_execution_provider_library(self.ep_name, self.ep_lib_path) + + def tearDown(self): + self.unregister_execution_provider_library(self.ep_name) + + def _create_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 + ) + + def _create_ortvalue_alternate_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), + device, + 0, + ) + + def _create_uninitialized_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) + + def _create_numpy_input(self): + return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + + def _create_expected_output(self): + return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + + def _create_expected_output_alternate(self): + return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) + + def torch_to_onnx_type(self, torch_dtype): + if torch_dtype == torch.float32: + return TensorProto.FLOAT + elif torch_dtype == torch.float16: + return TensorProto.FLOAT16 + elif torch_dtype == torch.bfloat16: + return TensorProto.BFLOAT16 + elif torch_dtype == torch.int8: + return TensorProto.int8 + elif torch_dtype == torch.int32: + return TensorProto.INT32 + elif torch_dtype == torch.int64: + return TensorProto.INT64 + else: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + + def test_nv_tensorrt_rtx_ep_register_and_inference(self): + """ + Test registration of NvTensorRTRTX EP, adding its OrtDevice to the SessionOptions, and running inference. + """ + ep_devices = onnxrt.get_ep_devices() + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + self.assertEqual(nv_tensorrt_rtx_ep_device.ep_vendor, "NVIDIA") + + hw_device = nv_tensorrt_rtx_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_prefer_gpu_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the PREFER_GPU policy and running inference. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_selection_delegate_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the custom EP selection delegate function and then run inference. + """ + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(model_metadata), 0) + self.assertGreaterEqual(len(ep_devices), 1) + self.assertGreaterEqual(max_selections, 2) + + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + + # Select the NvTensorRTRTX device + return [nv_tensorrt_rtx_ep_device] + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_bind_input_only(self): + """ + Test I/O binding with input data only. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Bind output to CPU + io_binding.bind_output("Y") + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output = io_binding.copy_outputs_to_cpu()[0] + + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) + + def test_bind_input_and_bind_output_with_ortvalues(self): + """ + Test I/O binding with OrtValues for both input and output. + """ + # Set a policy to prefer GPU. NvTensorRTRTX EP should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind ortvalue as input + input_ortvalue = self._create_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue) + + # Bind ortvalue as output + output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_output("Y", output_ortvalue) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) + + # Bind another ortvalue as input + input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue_2) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) + + def test_bind_input_and_non_preallocated_output(self): + """ + Test I/O binding with non-preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + input = self._create_ortvalue_input_on_gpu("cuda") + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + io_binding.bind_output("Y", "cuda") + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + # We should be able to repeat the above process as many times as we want - try once more + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + input = self._create_ortvalue_alternate_input_on_gpu("cuda") + + # Change the bound input and validate the results in the same bound OrtValue + # Bind alternate input to the GPU + io_binding.bind_input( + "X", + "cuda", + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) + + def test_bind_input_and_preallocated_output(self): + """ + Test I/O binding with preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + output = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output_vals = io_binding.copy_outputs_to_cpu()[0] + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) + + # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer + # to the host and validating its contents + ort_output_vals_in_cpu = output.numpy() + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) + + def test_bind_input_types(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + + for dtype in [ + np.float32, + # np.float64, + np.int32, + # np.uint32, + np.int64, + # np.uint64, + # np.int16, + # np.uint16, + # np.int8, + np.uint8, + np.float16, + np.bool_, + ]: + with self.subTest(dtype=dtype, inner_device=str(device)): + x = np.arange(8).reshape((-1, 2)).astype(dtype) + proto_dtype = helper.np_dtype_to_tensor_dtype(x.dtype) + + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 + + # inference + node_add = helper.make_node("Identity", ["X"], ["Y"]) + + # graph + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + bind = SessionIOBinding(sess._sess) + ort_value = C_OrtValue.ortvalue_from_numpy(x, device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvaluevector = bind.get_outputs() + self.assertIsInstance(ortvaluevector, OrtValueVector) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + def test_bind_onnx_types_from_torch(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + + for dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + ]: + with self.subTest(dtype=dtype): + proto_dtype = self.torch_to_onnx_type(dtype) + + x_ = helper.make_tensor_value_info("X", proto_dtype, [None]) + y_ = helper.make_tensor_value_info("Y", proto_dtype, [None]) + node_add = helper.make_node("Identity", ["X"], ["Y"]) + graph_def = helper.make_graph([node_add], "lr", [x_], [y_], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=10, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + dev = "cuda" if torch.cuda.is_available() else "cpu" + device = ( + C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + if dev == "cuda" + else C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + ) + + x = torch.arange(8, dtype=dtype, device=dev) + y = torch.empty(8, dtype=dtype, device=dev) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, proto_dtype, x.shape, x.data_ptr()) + bind.bind_output("Y", device, proto_dtype, y.shape, y.data_ptr()) + sess._sess.run_with_iobinding(bind, None) + self.assertTrue(torch.equal(x, y)) + + +if __name__ == "__main__": + unittest.main(verbosity=1)