diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index cc9d9f3da1d81..2da4d27b848ce 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1093,7 +1093,7 @@ void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { } std::vector NvExecutionProvider::CreatePreferredAllocators() { - OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kSameAsRequested), + OrtArenaCfg arena_cfg(0, static_cast(ArenaExtendStrategy::kNextPowerOfTwo), -1, -1, -1, -1); AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, @@ -2650,7 +2650,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } if (trt_state->context_memory_size != mem_size) { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; - trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/); + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, true /*use_reserve*/); trt_state->context_memory_size = mem_size; trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size); } @@ -2963,7 +2963,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra } if (trt_state->context_memory_size != mem_size) { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] A new context memory was allocated with size " << mem_size; - trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, false /*use_reserve*/); + trt_state->context_memory = IAllocator::MakeUniquePtrFromOrtAllocator(alloc, mem_size, true /*use_reserve*/); // trt_state->context_memory = IAllocator::MakeUniquePtr(alloc, mem_size, false /*use_reserve*/, stream); trt_state->context_memory_size = mem_size; trt_context->setDeviceMemoryV2(trt_state->context_memory.get(), mem_size);