diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 286db9070766d..be860b45ee7ea 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -123,10 +123,11 @@ void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* / // even for empty tensors, so allocate a dummy byte. size = std::max(size, static_cast(1)); if (size > allocated_size) { - cudaFree(outputPtr); + alloc_->Free(alloc_, outputPtr); outputPtr = nullptr; allocated_size = 0; - if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + outputPtr = alloc_->Alloc(alloc_, size); + if (outputPtr) { allocated_size = size; } } @@ -800,7 +801,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. if (is_DDS || known_DDS) { if (!known_DDS) { - auto allocatorPtr = std::make_unique(); + auto allocatorPtr = std::make_unique(alloc); trt_context->setOutputAllocator(output_name, allocatorPtr.get()); dds_output_allocator_map[output_name] = std::move(allocatorPtr); } @@ -1081,7 +1082,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) engine_decryption_lib_path_ = info.engine_decryption_lib_path; } force_sequential_engine_build_ = info.force_sequential_engine_build; - context_memory_sharing_enable_ = info.context_memory_sharing_enable; sparsity_enable_ = info.sparsity_enable; auxiliary_streams_ = info.auxiliary_streams; profile_min_shapes = info.profile_min_shapes; @@ -1225,7 +1225,6 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_engine_decryption_enable: " << engine_decryption_enable_ << ", nv_engine_decryption_lib_path: " << engine_decryption_lib_path_ << ", nv_force_sequential_engine_build: " << force_sequential_engine_build_ - << ", nv_context_memory_sharing_enable: " << context_memory_sharing_enable_ << ", nv_sparsity_enable: " << sparsity_enable_ << ", nv_auxiliary_streams: " << auxiliary_streams_ << ", nv_cuda_graph_enable: " << cuda_graph_enable_ @@ -1298,9 +1297,15 @@ void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { } std::vector NvExecutionProvider::CreatePreferredAllocators() { + OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kSameAsRequested), + -1, -1, -1, -1); AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, - narrow(device_id_)); + narrow(device_id_), + true, + arena_cfg, + // make it stream aware + true); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { @@ -2349,21 +2354,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr ShapeRangesMap input_explicit_shape_ranges; ShapeRangesMap input_implicit_shape_ranges; - auto tensor_is_dynamic = [&](nvinfer1::ITensor* tensor) -> bool { - if (tensor->isShapeTensor()) { - return true; - } else { - nvinfer1::Dims dims = tensor->getDimensions(); - // Execution tensor - for (int j = 0, end = dims.nbDims; j < end; ++j) { - if (dims.d[j] == -1) { - return true; - } - } - } - return false; - }; - bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { has_explicit_profile = true; @@ -2375,7 +2365,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } else { for (unsigned int i = 0, end = num_inputs; i < end; ++i) { auto input = trt_network->getInput(i); - has_dynamic_shape |= tensor_is_dynamic(input); + has_dynamic_shape |= checkTrtTensorIsDynamic(input); } if (has_dynamic_shape) { LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] No explicit optimization profile was specified. " @@ -2574,31 +2564,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not build execution context for fused node: " + fused_node.Name()); } + bool is_dynamic_shape_context = false; // Create input to index map for (int i = 0; i < num_inputs; ++i) { auto input = trt_network->getInput(i); const std::string& input_name = input->getName(); + is_dynamic_shape_context |= checkTrtDimIsDynamic(trt_engine->getTensorShape(input_name.c_str())); const auto& iter = input_map.find(input_name); if (iter != input_map.end()) { input_indexes[input_name] = iter->second; @@ -2639,10 +2616,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr input_shape_ranges_[context->node_name], &tensorrt_mu_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], - context_memory_sharing_enable_, &max_ctx_mem_size_, engine_decryption_enable_, engine_decryption_, engine_encryption_, detailed_build_log_, sparsity_enable_, - auxiliary_streams_, cuda_graph_enable_, cache_prefix_, cache_suffix}; + auxiliary_streams_, cuda_graph_enable_, is_dynamic_shape_context, cache_prefix_, cache_suffix}; *state = p.release(); return 0; }; @@ -2676,7 +2652,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; @@ -2785,9 +2760,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (iter != input_indexes.end()) { input_index = iter->second; } - auto input_tensor = ctx.GetInput(input_index); - auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); - const auto tensor_shapes = tensor_info.GetShape(); auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, shape_tensor_values_int64, scratch_buffers, alloc, stream); if (status != Status::OK()) { @@ -2829,20 +2801,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; - } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + size_t mem_size = trt_engine->getDeviceMemorySizeV2(); + if (trt_state->is_dynamic_shape) { + mem_size = trt_context->updateDeviceMemorySizeForShapes(); + } + if (trt_state->context_memory_size != mem_size) { + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/); + trt_state->context_memory_size = mem_size; } + trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because @@ -2961,33 +2929,19 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra // // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_memory_sharing_enable_) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > max_ctx_mem_size_) { - max_ctx_mem_size_ = mem_size; - } - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); - - } else { - trt_context = std::unique_ptr(trt_engine->createExecutionContext()); - } + trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not build execution context for fused node: " + fused_node.Name()); } + bool is_dynamic_shape_context = false; // Create input/output to index maps for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { auto const& name = trt_engine->getIOTensorName(i); auto const& mode = trt_engine->getTensorIOMode(name); if (mode == nvinfer1::TensorIOMode::kINPUT) { + is_dynamic_shape_context |= checkTrtDimIsDynamic(trt_engine->getTensorShape(name)); const auto& iter = input_map.find(name); if (iter != input_map.end()) { input_indexes[name] = iter->second; @@ -3027,9 +2981,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra &contexts_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - context_memory_sharing_enable_, - &max_ctx_mem_size_, - &tensorrt_mu_}; + &tensorrt_mu_, + is_dynamic_shape_context}; *state = p.release(); return 0; }; @@ -3056,7 +3009,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); - auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -3144,20 +3096,17 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } // Set execution context memory - if (trt_state->context_memory_sharing_enable) { -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif - size_t mem_size = trt_engine->getDeviceMemorySizeV2(); -#if defined(_MSC_VER) -#pragma warning(pop) -#endif - if (mem_size > *max_context_mem_size_ptr) { - *max_context_mem_size_ptr = mem_size; - } - trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + size_t mem_size = trt_engine->getDeviceMemorySizeV2(); + if (trt_state->is_dynamic_shape) { + mem_size = trt_context->updateDeviceMemorySizeForShapes(); + } + if (trt_state->context_memory_size != mem_size) { + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/); + // trt_state->context_memory = IAllocator::MakeUniquePtr(alloc, mem_size, false /*use_reserve*/, stream); + trt_state->context_memory_size = mem_size; } + trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 7a0c47d28c81d..bc626a79b4256 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -78,6 +78,9 @@ using unique_pointer = std::unique_ptr; // class OutputAllocator : public nvinfer1::IOutputAllocator { public: + OutputAllocator() = delete; + OutputAllocator(OrtAllocator* allocator) : alloc_(allocator) {}; + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; @@ -95,10 +98,11 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { } ~OutputAllocator() override { - cudaFree(outputPtr); + alloc_->Free(alloc_, outputPtr); } private: + OrtAllocator* alloc_; void* outputPtr{nullptr}; uint64_t allocated_size = 0; std::vector output_shapes; @@ -130,8 +134,6 @@ struct TensorrtFuncState { std::string engine_cache_path; nvinfer1::IRuntime* runtime = nullptr; std::vector profiles; - bool context_memory_sharing_enable = false; - size_t* max_context_mem_size_ptr = nullptr; bool engine_decryption_enable = false; int (*engine_decryption)(const char*, char*, size_t*) = nullptr; int (*engine_encryption)(const char*, char*, size_t) = nullptr; @@ -139,8 +141,11 @@ struct TensorrtFuncState { bool sparsity_enable = false; int auxiliary_streams = -1; bool cuda_graph_enable = 0; + bool is_dynamic_shape = false; std::string cache_prefix; std::string cache_suffix; + IAllocatorUniquePtr context_memory = nullptr; + size_t context_memory_size = 0; }; // Minimum information to construct kernel function state for direct engine load code path @@ -153,9 +158,10 @@ struct TensorrtShortFuncState { std::unique_ptr* context = nullptr; std::vector> input_info; std::vector> output_info; - bool context_memory_sharing_enable = false; - size_t* max_context_mem_size_ptr = nullptr; std::mutex* tensorrt_mu_ptr = nullptr; + bool is_dynamic_shape = false; + IAllocatorUniquePtr context_memory = nullptr; + size_t context_memory_size = 0; }; // Holds important information for building valid ORT graph. @@ -251,9 +257,7 @@ class NvExecutionProvider : public IExecutionProvider { std::mutex tensorrt_mu_; int device_id_; std::string compute_capability_; - bool context_memory_sharing_enable_ = false; size_t max_ctx_mem_size_ = 0; - IAllocatorUniquePtr context_memory_ = nullptr; mutable char model_path_[4096] = {}; // Reserved for max path length bool engine_decryption_enable_ = false; int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index 2a67f3c3bec4d..4d6c6fe116076 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -34,7 +34,6 @@ struct NvExecutionProviderInfo { bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; - bool context_memory_sharing_enable{false}; std::string timing_cache_path{""}; bool detailed_build_log{false}; bool sparsity_enable{false}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h index 22e5eea6924de..ea586ba445ba2 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h @@ -683,4 +683,29 @@ std::string GetCacheSuffix(const std::string& fused_node_name, const std::string } return ""; } + +/* + * Checks if there is a an element with value `-1` in nvinfer1::Dims + */ +static bool checkTrtDimIsDynamic(nvinfer1::Dims dims) { + for (int j = 0, end = dims.nbDims; j < end; ++j) { + if (dims.d[j] == -1) { + return true; + } + } + return false; +} + +/* + * Checks if an nvinfer1::ITensor signales a dynamic shape, + * either due to dynamic shapes or due to it being a shape tensor + */ +static bool checkTrtTensorIsDynamic(nvinfer1::ITensor* tensor) { + if (tensor->isShapeTensor()) { + return true; + } else { + // Execution tensor + return checkTrtDimIsDynamic(tensor->getDimensions()); + } +} } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 0559699670c4a..19505da1bbe56 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -394,6 +394,7 @@ TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { } } +#if defined(WIN32) static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -409,11 +410,10 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { return has_ep; } -#if defined(WIN32) // Tests autoEP feature to automatically select an EP that supports the GPU. // Currently only works on Windows. TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { - PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); + PathString model_name = ORT_TSTR("nv_execution_provider_auto_ep.onnx"); std::string graph_name = "test"; std::vector dims = {1, 3, 2};