-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Fix for #25145 #26994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for #25145 #26994
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -512,8 +512,25 @@ def __init__( | |
| def _create_inference_session(self, providers, provider_options, disabled_optimizers=None): | ||
| available_providers = C.get_available_providers() | ||
|
|
||
| # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. | ||
| if "TensorrtExecutionProvider" in available_providers: | ||
| # Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified | ||
| if providers: | ||
| has_tensorrt = any( | ||
| provider == "TensorrtExecutionProvider" | ||
| or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") | ||
| for provider in providers | ||
| ) | ||
| has_tensorrt_rtx = any( | ||
| provider == "NvTensorRTRTXExecutionProvider" | ||
| or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider") | ||
| for provider in providers | ||
| ) | ||
| if has_tensorrt and has_tensorrt_rtx: | ||
| raise ValueError( | ||
| "Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' " | ||
| "in the same session." | ||
| ) | ||
| # Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. | ||
| if "NvTensorRTRTXExecutionProvider" in available_providers: | ||
|
||
| if ( | ||
| providers | ||
| and any( | ||
|
|
@@ -522,15 +539,15 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi | |
| for provider in providers | ||
| ) | ||
| and any( | ||
| provider == "TensorrtExecutionProvider" | ||
| or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") | ||
| provider == "NvTensorRTRTXExecutionProvider" | ||
| or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider") | ||
| for provider in providers | ||
| ) | ||
| ): | ||
| self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | ||
| else: | ||
| self._fallback_providers = ["CPUExecutionProvider"] | ||
| if "NvTensorRTRTXExecutionProvider" in available_providers: | ||
| elif "TensorrtExecutionProvider" in available_providers: | ||
| if ( | ||
| providers | ||
| and any( | ||
|
|
@@ -539,8 +556,8 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi | |
| for provider in providers | ||
| ) | ||
| and any( | ||
| provider == "NvTensorRTRTXExecutionProvider" | ||
| or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider") | ||
| provider == "TensorrtExecutionProvider" | ||
| or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") | ||
| for provider in providers | ||
| ) | ||
| ): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description/issue mentions a typo impacting using
NvTensorRTRTXExecutionProviderwhen passed as a tuple. That typo still exists in_register_ep_custom_ops(currently checksproviders[i][0] == "NvTensorrtRTXExecutionProvider"), so passing( "NvTensorRTRTXExecutionProvider", {...})will skip plugin registration. Please align the tuple string check with the canonical provider name used elsewhere (e.g., tests useNvTensorRTRTXExecutionProvider).