diff --git a/src/python/py/_dll_directory.py b/src/python/py/_dll_directory.py index 7b6aee39ea..558337a881 100644 --- a/src/python/py/_dll_directory.py +++ b/src/python/py/_dll_directory.py @@ -25,32 +25,35 @@ def add_onnxruntime_dependency(package_id: str): so that they can be found by the dynamic linker. """ if _is_windows(): - import importlib.util - - ort_package = importlib.util.find_spec("onnxruntime") - if not ort_package: - raise ImportError("Could not find the onnxruntime package.") - ort_package_path = ort_package.submodule_search_locations[0] - os.add_dll_directory(os.path.join(ort_package_path, "capi")) + import ctypes - # Load the DirectML.dll library to avoid loading it again in the native code. - # This avoids needing to know the exact path of the shared library from native code. - dml_path = os.path.join(ort_package_path, "capi", "DirectML.dll") - # The dependent onnxruntime package may have multiple execution providers. - # Check to see if DirectML.dll exists before trying to load it. - if os.path.exists(dml_path): - import ctypes - - _ = ctypes.CDLL(dml_path) - - # Workaround for onnxruntime.dll loading - ort_path = os.path.join(ort_package_path, "capi", "onnxruntime.dll") - # The dependent onnxruntime package may have multiple execution providers. - # Check to see if onnxruntime.dll exists before trying to load it. - if os.path.exists(ort_path): - import ctypes - - _ = ctypes.CDLL(ort_path) + kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + ort_handle = kernel32.GetModuleHandleW("onnxruntime.dll") + # Only manually load the dlls if onnxruntime.dll is not already loaded. + # This allows WinML to use its packed dlls. + if not ort_handle: + import importlib.util + + ort_package = importlib.util.find_spec("onnxruntime") + if not ort_package: + raise ImportError("Could not find the onnxruntime package.") + ort_package_path = ort_package.submodule_search_locations[0] + os.add_dll_directory(os.path.join(ort_package_path, "capi")) + + # Load the DirectML.dll library to avoid loading it again in the native code. + # This avoids needing to know the exact path of the shared library from native code. + dml_path = os.path.join(ort_package_path, "capi", "DirectML.dll") + # The dependent onnxruntime package may have multiple execution providers. + # Check to see if DirectML.dll exists before trying to load it. + if os.path.exists(dml_path): + _ = ctypes.CDLL(dml_path) + + # Workaround for onnxruntime.dll loading + ort_path = os.path.join(ort_package_path, "capi", "onnxruntime.dll") + # The dependent onnxruntime package may have multiple execution providers. + # Check to see if onnxruntime.dll exists before trying to load it. + if os.path.exists(ort_path): + _ = ctypes.CDLL(ort_path) elif _is_linux() or _is_macos(): import ctypes