Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,25 @@ def __init__(
def _create_inference_session(self, providers, provider_options, disabled_optimizers=None):
available_providers = C.get_available_providers()

# Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "TensorrtExecutionProvider" in available_providers:
# Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description/issue mentions a typo impacting using NvTensorRTRTXExecutionProvider when passed as a tuple. That typo still exists in _register_ep_custom_ops (currently checks providers[i][0] == "NvTensorrtRTXExecutionProvider"), so passing ( "NvTensorRTRTXExecutionProvider", {...}) will skip plugin registration. Please align the tuple string check with the canonical provider name used elsewhere (e.g., tests use NvTensorRTRTXExecutionProvider).

Copilot uses AI. Check for mistakes.
if providers:
has_tensorrt = any(
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
for provider in providers
)
has_tensorrt_rtx = any(
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
for provider in providers
)
if has_tensorrt and has_tensorrt_rtx:
raise ValueError(
"Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' "
"in the same session."
)
# Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU.
if "NvTensorRTRTXExecutionProvider" in available_providers:
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback-selection branching is driven by available_providers (line 533) rather than the provider the user actually requested. If both NvTensorRTRTXExecutionProvider and TensorrtExecutionProvider are available, the first if will always run and the elif (TensorRT) branch will never be evaluated, so a session requested with ["TensorrtExecutionProvider", "CUDAExecutionProvider"] will incorrectly get ['CPUExecutionProvider'] as fallback. Consider basing the branch on the requested providers (e.g., has_tensorrt/has_tensorrt_rtx flags) and only then deciding whether CUDA is explicitly present for the CUDA+CPU fallback.

Copilot uses AI. Check for mistakes.
if (
providers
and any(
Expand All @@ -522,15 +539,15 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
for provider in providers
)
and any(
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider")
for provider in providers
)
):
self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
self._fallback_providers = ["CPUExecutionProvider"]
if "NvTensorRTRTXExecutionProvider" in available_providers:
elif "TensorrtExecutionProvider" in available_providers:
if (
providers
and any(
Expand All @@ -539,8 +556,8 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
for provider in providers
)
and any(
provider == "NvTensorRTRTXExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider")
provider == "TensorrtExecutionProvider"
or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider")
for provider in providers
)
):
Expand Down
Loading