Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions third_party/amd/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@
include_dir = [os.path.join(dirname, "include")]


def _find_already_mmapped_dylib_on_linux(lib_name):
import platform
if platform.system() != 'Linux':
return None

# Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
# See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.

import ctypes
from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER

class DlPhdrInfo(ctypes.Structure):
_fields_ = [
('dlpi_addr', c_void_p),
('dlpi_name', c_char_p),
# We don't care about the remaining fields.
]

# callback_t must use POINTER(c_char) to avoid copying.
callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))

# Load libc and get the dl_iterate_phdr symbol.
try:
dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
except:
return None
# argtypes must use c_char_p to accept create_string_buffer.
dl_iterate_phdr.argtypes = [callback_t, c_char_p]
dl_iterate_phdr.restype = c_int

max_path_length = 4096
path = ctypes.create_string_buffer(max_path_length + 1)

# Define callback to get the loaded dylib path.
def callback(info, size, data):
dlpi_name = info.contents.dlpi_name
p = Path(os.fsdecode(dlpi_name))
if lib_name in p.name:
# Found the dylib; get its path.
ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
return 1
return 0

if dl_iterate_phdr(callback_t(callback), path):
return os.fsdecode(ctypes.string_at(path))
return None


@functools.lru_cache()
def _get_path_to_hip_runtime_dylib():
lib_name = "libamdhip64.so"
Expand All @@ -24,6 +72,13 @@ def _get_path_to_hip_runtime_dylib():
return env_libhip_path
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")

# If the shared object is already mmapped to address space, use it.
mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
if mmapped_path:
if os.path.exists(mmapped_path):
return mmapped_path
raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")

paths = []

import site
Expand Down