Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 74 additions & 108 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,31 +235,6 @@ def _get_sys_extension() -> str:
raise RuntimeError(f"Unsupported operating system ({system})")


@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.

`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""

so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)

path_found = len(so_paths) > 0
ctypes_handles = []

if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))

return path_found, ctypes_handles


@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
Expand All @@ -279,101 +254,87 @@ def _nvidia_cudart_include_dir() -> str:


@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
def _load_cuda_library_from_python(lib_name: str):
"""
Attempts to load shared object file installed via python packages.

# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""

# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
ext = _get_sys_extension()
nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia")

# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# PyPI packages provided by nvidia libs exist
# in 3 possible direcories inside `nvidia`.
if os.path.isdir(os.path.join(nvidia_dir, "cu13")):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib*{ext}.*[0-9]"))
elif os.path.isdir(os.path.join(nvidia_dir, lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]"))
else:
so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]"))
Comment on lines +269 to +274
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: if-elif chain issue already reported; besides the missing elif on line 271, the wildcard lib*{ext} pattern (lines 270, 272, 274) will match ALL .so files in those directories, not just the requested library. Filter by lib_name after globbing or embed lib_name in the pattern. Should the glob patterns filter to lib{lib_name}*{ext}.*[0-9] or post-filter by basename?


# Attempt to locate libcudnn via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
path_found = len(so_paths) > 0
ctypes_handles = []

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))

return path_found, ctypes_handles


@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
def _load_cuda_library_from_system(lib_name: str):
"""
Attempts to load shared object file installed via system/cuda-toolkit.

`lib_name`: Name of library to load without extension or `lib` prefix.
"""

# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# Where to look for the shared lib in decreasing order of preference.
paths = (
os.environ.get(f"{lib_name.upper()}_HOME"),
os.environ.get(f"{lib_name.upper()}_PATH"),
os.environ.get("CUDA_HOME"),
os.environ.get("CUDA_PATH"),
"/usr/local/cuda",
)

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
for path in paths:
if path is None:
continue
libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
Comment on lines +306 to +309
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: glob uses lib{lib_name}* but does not filter out stub libraries; for nvrtc this previously filtered stubs explicitly (old code line ~336). Consider adding not ("stub" in x) filter.


# Search in LD_LIBRARY_PATH.
try:
_lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return True, _lib_handle
except OSError:
return False, None


@functools.lru_cache(maxsize=None)
def _load_curand():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
def _load_cuda_library(lib_name: str):
"""
Load given shared library.
Prioritize loading from system/toolkit
before checking python packages.
"""

# Attempt to locate library in system.
found, handle = _load_cuda_library_from_system(lib_name)
if found:
return handle

# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate library in Python dist-packages.
found, handle = _load_cuda_library_from_python(lib_name)
if found:
return handle

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
raise RuntimeError(f"{lib_name} shared object not found.")


@functools.lru_cache(maxsize=None)
Expand All @@ -384,11 +345,16 @@ def _load_core_library():

if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")

# `_load_cuda_library` is used for packages that must be loaded
# during runtime. Both system and pypi packages are searched
# and an error is thrown if not found.
_CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to handle the case where cuDNN is installed in CUDNN_PATH instead of CUDA_HOME?

_NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
_CURAND_LIB_CTYPES = _load_cuda_library("curand")
_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas")
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cuda_runtime")
Comment on lines +355 to +356
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: these calls skip system search and only look in python packages, but they don't error if not found (returns (False, [])); this will silently fail and assign False/[] to the ctypes vars instead of a handle

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only load those libraries from Python?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we prioritize system installs of the CUDA Runtime and cuBLAS?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! Looking at the code, there's an inconsistency in the loading logic. Lines 352-354 use _load_cuda_library() which properly prioritizes system installations over Python packages:

_CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
_NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
_CURAND_LIB_CTYPES = _load_cuda_library("curand")

But lines 355-356 skip the system search entirely and only look in Python packages:

_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas")
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cuda_runtime")

This inconsistency can lead to exactly the bug described in the PR - loading incorrect shared objects with mismatching versions from PyPI packages when a proper system installation exists.

For consistency with the other CUDA libraries and to prioritize system installations, these should be:

Suggested change
_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas")
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cuda_runtime")
_CUBLAS_LIB_CTYPES = _load_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_cuda_library("cuda_runtime")

This would ensure that cuBLAS and CUDA Runtime follow the same loading priority as the other CUDA libraries: system first, then Python packages as fallback.


_TE_LIB_CTYPES = _load_core_library()

# Needed to find the correct headers for NVRTC kernels.
Expand Down