diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 3825273265f0..c1ff6e1d65f3 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -13,6 +13,54 @@ include_dir = [os.path.join(dirname, "include")] +def _find_already_mmapped_dylib_on_linux(lib_name): + import platform + if platform.system() != 'Linux': + return None + + # Use dl_iterate_phdr to walk through the list of shared libraries at runtime. + # See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details. + + import ctypes + from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER + + class DlPhdrInfo(ctypes.Structure): + _fields_ = [ + ('dlpi_addr', c_void_p), + ('dlpi_name', c_char_p), + # We don't care about the remaining fields. + ] + + # callback_t must use POINTER(c_char) to avoid copying. + callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char)) + + # Load libc and get the dl_iterate_phdr symbol. + try: + dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr + except: + return None + # argtypes must use c_char_p to accept create_string_buffer. + dl_iterate_phdr.argtypes = [callback_t, c_char_p] + dl_iterate_phdr.restype = c_int + + max_path_length = 4096 + path = ctypes.create_string_buffer(max_path_length + 1) + + # Define callback to get the loaded dylib path. + def callback(info, size, data): + dlpi_name = info.contents.dlpi_name + p = Path(os.fsdecode(dlpi_name)) + if lib_name in p.name: + # Found the dylib; get its path. + ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name))) + return 1 + return 0 + + if dl_iterate_phdr(callback_t(callback), path): + return os.fsdecode(ctypes.string_at(path)) + return None + + @functools.lru_cache() def _get_path_to_hip_runtime_dylib(): lib_name = "libamdhip64.so" @@ -24,6 +72,13 @@ def _get_path_to_hip_runtime_dylib(): return env_libhip_path raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}") + # If the shared object is already mmapped to address space, use it. + mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name) + if mmapped_path: + if os.path.exists(mmapped_path): + return mmapped_path + raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}") + paths = [] import site