Skip to content

Commit

Permalink
simplified _build()
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Dec 2, 2023
1 parent 5f07810 commit 85916b4
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions python/triton/common/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def cuda_include_dir():
return os.path.join(cuda_path, "include")


def _cc_cmd(cc, src, out, include_dirs, library_dirs):
if cc == "cl":
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
cc_cmd += [f"/I{dir}" for dir in include_dirs]
cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
cc_cmd += ["/link", "cuda.lib", f"/OUT:{out}"]
else:
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"]
cc_cmd += [f"-I{dir}" for dir in include_dirs]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += ["-lcuda", "-o", out]

if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC"))

return cc_cmd


def _build(name, src, srcdir):
if is_hip():
hip_lib_dir = os.path.join(rocm_path_dir(), "lib")
Expand Down Expand Up @@ -91,35 +108,25 @@ def _build(name, src, srcdir):
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
py_lib_dirs = []
if os.name == "nt":
installed_base = sysconfig.get_config_var('installed_base')
py_libraries_dir = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
py_lib_dirs = [os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))]

if is_hip():
ret = subprocess.check_call([
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
])
else:
if cc == "cl":
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", f"/I{cu_include_dir}", f"/I{py_include_dir}", f"/I{srcdir}"]
cc_cmd += ["/link", "cuda.lib", f"/OUT:{so}"]
cc_cmd += [f"/LIBPATH:{dir}" for dir in cuda_lib_dirs]
cc_cmd += [f"/LIBPATH:{py_libraries_dir}"]
else:
cc_cmd = [
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
"-lcuda", "-o", so
]
if os.name == "nt": cc_cmd.pop(cc.cmd.index("-fPIC"))
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
cc_cmd = _cc_cmd(cc, src, so, [cu_include_dir, py_include_dir, srcdir], [*cuda_lib_dirs, *py_lib_dirs])
ret = subprocess.check_call(cc_cmd)

if ret == 0:
return so
# fallback on setuptools
extra_compile_args = []
library_dirs = cuda_lib_dirs
library_dirs = [*cuda_lib_dirs, *py_lib_dirs]
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
Expand Down

0 comments on commit 85916b4

Please sign in to comment.