-
Notifications
You must be signed in to change notification settings - Fork 540
Fix runtime lib loading logic #2297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -235,31 +235,6 @@ def _get_sys_extension() -> str: | |||||||||
| raise RuntimeError(f"Unsupported operating system ({system})") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||
| def _load_nvidia_cuda_library(lib_name: str): | ||||||||||
| """ | ||||||||||
| Attempts to load shared object file installed via pip. | ||||||||||
|
|
||||||||||
| `lib_name`: Name of package as found in the `nvidia` dir in python environment. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| so_paths = glob.glob( | ||||||||||
| os.path.join( | ||||||||||
| sysconfig.get_path("purelib"), | ||||||||||
| f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]", | ||||||||||
| ) | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| path_found = len(so_paths) > 0 | ||||||||||
| ctypes_handles = [] | ||||||||||
|
|
||||||||||
| if path_found: | ||||||||||
| for so_path in so_paths: | ||||||||||
| ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)) | ||||||||||
|
|
||||||||||
| return path_found, ctypes_handles | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||
| def _nvidia_cudart_include_dir() -> str: | ||||||||||
| """Returns the include directory for cuda_runtime.h if exists in python environment.""" | ||||||||||
|
|
@@ -279,101 +254,87 @@ def _nvidia_cudart_include_dir() -> str: | |||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||
| def _load_cudnn(): | ||||||||||
| """Load CUDNN shared library.""" | ||||||||||
| def _load_cuda_library_from_python(lib_name: str): | ||||||||||
| """ | ||||||||||
| Attempts to load shared object file installed via python packages. | ||||||||||
|
|
||||||||||
| # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set | ||||||||||
| cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") | ||||||||||
| if cudnn_home: | ||||||||||
| libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) | ||||||||||
| libs.sort(reverse=True, key=os.path.basename) | ||||||||||
| if libs: | ||||||||||
| return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
| `lib_name`: Name of package as found in the `nvidia` dir in python environment. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda | ||||||||||
| cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" | ||||||||||
| libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) | ||||||||||
| libs.sort(reverse=True, key=os.path.basename) | ||||||||||
| if libs: | ||||||||||
| return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
| ext = _get_sys_extension() | ||||||||||
| nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia") | ||||||||||
|
|
||||||||||
| # Attempt to locate cuDNN in Python dist-packages | ||||||||||
| found, handle = _load_nvidia_cuda_library("cudnn") | ||||||||||
| if found: | ||||||||||
| return handle | ||||||||||
| # PyPI packages provided by nvidia libs exist | ||||||||||
| # in 3 possible direcories inside `nvidia`. | ||||||||||
| if os.path.isdir(os.path.join(nvidia_dir, "cu13")): | ||||||||||
| so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib*{ext}.*[0-9]")) | ||||||||||
| elif os.path.isdir(os.path.join(nvidia_dir, lib_name)): | ||||||||||
| so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]")) | ||||||||||
| else: | ||||||||||
| so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]")) | ||||||||||
|
|
||||||||||
| # Attempt to locate libcudnn via ldconfig | ||||||||||
| libs = subprocess.check_output(["ldconfig", "-p"]) | ||||||||||
| libs = libs.decode("utf-8").split("\n") | ||||||||||
| sos = [] | ||||||||||
| for lib in libs: | ||||||||||
| if "libcudnn" in lib and "=>" in lib: | ||||||||||
| sos.append(lib.split(">")[1].strip()) | ||||||||||
| if sos: | ||||||||||
| return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
| path_found = len(so_paths) > 0 | ||||||||||
| ctypes_handles = [] | ||||||||||
|
|
||||||||||
| # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise | ||||||||||
| return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) | ||||||||||
| if path_found: | ||||||||||
| for so_path in so_paths: | ||||||||||
| ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)) | ||||||||||
|
|
||||||||||
| return path_found, ctypes_handles | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||
| def _load_nvrtc(): | ||||||||||
| """Load NVRTC shared library.""" | ||||||||||
| # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda | ||||||||||
| cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" | ||||||||||
| libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True) | ||||||||||
| libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) | ||||||||||
| libs.sort(reverse=True, key=os.path.basename) | ||||||||||
| if libs: | ||||||||||
| return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
|
|
||||||||||
| # Attempt to locate NVRTC in Python dist-packages | ||||||||||
| found, handle = _load_nvidia_cuda_library("cuda_nvrtc") | ||||||||||
| if found: | ||||||||||
| return handle | ||||||||||
| def _load_cuda_library_from_system(lib_name: str): | ||||||||||
| """ | ||||||||||
| Attempts to load shared object file installed via system/cuda-toolkit. | ||||||||||
|
|
||||||||||
| `lib_name`: Name of library to load without extension or `lib` prefix. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| # Attempt to locate NVRTC via ldconfig | ||||||||||
| libs = subprocess.check_output(["ldconfig", "-p"]) | ||||||||||
| libs = libs.decode("utf-8").split("\n") | ||||||||||
| sos = [] | ||||||||||
| for lib in libs: | ||||||||||
| if "libnvrtc" in lib and "=>" in lib: | ||||||||||
| sos.append(lib.split(">")[1].strip()) | ||||||||||
| if sos: | ||||||||||
| return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
| # Where to look for the shared lib in decreasing order of preference. | ||||||||||
| paths = ( | ||||||||||
| os.environ.get(f"{lib_name.upper()}_HOME"), | ||||||||||
| os.environ.get(f"{lib_name.upper()}_PATH"), | ||||||||||
| os.environ.get("CUDA_HOME"), | ||||||||||
| os.environ.get("CUDA_PATH"), | ||||||||||
| "/usr/local/cuda", | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise | ||||||||||
| return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) | ||||||||||
| for path in paths: | ||||||||||
| if path is None: | ||||||||||
| continue | ||||||||||
| libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True) | ||||||||||
| libs.sort(reverse=True, key=os.path.basename) | ||||||||||
| if libs: | ||||||||||
| return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
|
Comment on lines
+306
to
+309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: glob uses |
||||||||||
|
|
||||||||||
| # Search in LD_LIBRARY_PATH. | ||||||||||
| try: | ||||||||||
| _lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) | ||||||||||
| return True, _lib_handle | ||||||||||
| except OSError: | ||||||||||
| return False, None | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||
| def _load_curand(): | ||||||||||
| """Load cuRAND shared library.""" | ||||||||||
| # Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda | ||||||||||
| cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" | ||||||||||
| libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) | ||||||||||
| libs = list(filter(lambda x: not ("stub" in x), libs)) | ||||||||||
| libs.sort(reverse=True, key=os.path.basename) | ||||||||||
| if libs: | ||||||||||
| return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
|
|
||||||||||
| # Attempt to locate cuRAND in Python dist-packages | ||||||||||
| found, handle = _load_nvidia_cuda_library("curand") | ||||||||||
| def _load_cuda_library(lib_name: str): | ||||||||||
| """ | ||||||||||
| Load given shared library. | ||||||||||
| Prioritize loading from system/toolkit | ||||||||||
| before checking python packages. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| # Attempt to locate library in system. | ||||||||||
| found, handle = _load_cuda_library_from_system(lib_name) | ||||||||||
| if found: | ||||||||||
| return handle | ||||||||||
|
|
||||||||||
| # Attempt to locate cuRAND via ldconfig | ||||||||||
| libs = subprocess.check_output(["ldconfig", "-p"]) | ||||||||||
| libs = libs.decode("utf-8").split("\n") | ||||||||||
| sos = [] | ||||||||||
| for lib in libs: | ||||||||||
| if "libcurand" in lib and "=>" in lib: | ||||||||||
| sos.append(lib.split(">")[1].strip()) | ||||||||||
| if sos: | ||||||||||
| return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) | ||||||||||
| # Attempt to locate library in Python dist-packages. | ||||||||||
| found, handle = _load_cuda_library_from_python(lib_name) | ||||||||||
| if found: | ||||||||||
| return handle | ||||||||||
timmoon10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
|
||||||||||
| # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise | ||||||||||
| return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) | ||||||||||
| raise RuntimeError(f"{lib_name} shared object not found.") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @functools.lru_cache(maxsize=None) | ||||||||||
|
|
@@ -384,11 +345,16 @@ def _load_core_library(): | |||||||||
|
|
||||||||||
| if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): | ||||||||||
| sanity_checks_for_pypi_installation() | ||||||||||
| _CUDNN_LIB_CTYPES = _load_cudnn() | ||||||||||
| _NVRTC_LIB_CTYPES = _load_nvrtc() | ||||||||||
| _CURAND_LIB_CTYPES = _load_curand() | ||||||||||
| _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") | ||||||||||
| _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") | ||||||||||
|
|
||||||||||
| # `_load_cuda_library` is used for packages that must be loaded | ||||||||||
| # during runtime. Both system and pypi packages are searched | ||||||||||
| # and an error is thrown if not found. | ||||||||||
| _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we also need to handle the case where cuDNN is installed in |
||||||||||
| _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") | ||||||||||
| _CURAND_LIB_CTYPES = _load_cuda_library("curand") | ||||||||||
| _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas") | ||||||||||
| _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cuda_runtime") | ||||||||||
|
Comment on lines
+355
to
+356
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: these calls skip system search and only look in python packages, but they don't error if not found (returns (False, [])); this will silently fail and assign False/[] to the ctypes vars instead of a handle There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we only load those libraries from Python? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we prioritize system installs of the CUDA Runtime and cuBLAS? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're absolutely right! Looking at the code, there's an inconsistency in the loading logic. Lines 352-354 use _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
_NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
_CURAND_LIB_CTYPES = _load_cuda_library("curand")But lines 355-356 skip the system search entirely and only look in Python packages: _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas")
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cuda_runtime")This inconsistency can lead to exactly the bug described in the PR - loading incorrect shared objects with mismatching versions from PyPI packages when a proper system installation exists. For consistency with the other CUDA libraries and to prioritize system installations, these should be:
Suggested change
This would ensure that cuBLAS and CUDA Runtime follow the same loading priority as the other CUDA libraries: system first, then Python packages as fallback. |
||||||||||
|
|
||||||||||
| _TE_LIB_CTYPES = _load_core_library() | ||||||||||
|
|
||||||||||
| # Needed to find the correct headers for NVRTC kernels. | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: if-elif chain issue already reported; besides the missing
elifon line 271, the wildcardlib*{ext}pattern (lines 270, 272, 274) will match ALL .so files in those directories, not just the requested library. Filter bylib_nameafter globbing or embedlib_namein the pattern. Should the glob patterns filter tolib{lib_name}*{ext}.*[0-9]or post-filter by basename?