Skip to content

Commit

Permalink
based on Windows support PR #2465 by @andreigh
Browse files Browse the repository at this point in the history
 * based on #2465
 * manually applied, rebased, fix lint errors
 * use set_target_properties(), cleanup for windows
 * remove '/A' platform option to use windows ninja
 * remove unknown option '/m'
 * use sysconfig.get_config_var() to get the path of python*.lib
 * use os.name to check dll extension
 * clang fix for windows
 * remove '-fPIC' for windows clang

Original-author-by: Andrei Gheorghe <[email protected]>
Signed-off-by: Won-Kyu Park <[email protected]>
  • Loading branch information
wkpark committed Dec 1, 2023
1 parent 470da87 commit f067770
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 49 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
python/triton/_C/triton.dll

# Python caches
__pycache__/
Expand Down
50 changes: 35 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
if(NOT MSVC)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
else()
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1")
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand All @@ -47,7 +56,15 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})

set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if(NOT MSVC)
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated -fvisibility=hidden -fvisibility-inlines-hidden")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
endif()

if(APPLE)
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
Expand All @@ -59,7 +76,7 @@ endif()
if(NOT MLIR_DIR)
if(NOT LLVM_LIBRARY_DIR)
if(WIN32)
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
find_package(LLVM 17 REQUIRED COMPONENTS nvptx amdgpu)

include_directories(${LLVM_INCLUDE_DIRS})
separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS})
Expand Down Expand Up @@ -154,6 +171,8 @@ if(TRITON_BUILD_PYTHON_MODULE)

if(PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
link_directories(${PYTHON_LIB_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
Expand All @@ -163,16 +182,6 @@ if(TRITON_BUILD_PYTHON_MODULE)
endif()
endif()

# # Triton
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
# set_target_properties(triton PROPERTIES PREFIX "lib")
# else()
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
# endif()

# MLIR
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})

Expand All @@ -184,7 +193,11 @@ include(AddLLVM)
include(AddMLIR)

# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
endif()

include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
Expand Down Expand Up @@ -239,6 +252,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
${TRITON_LIBRARIES}
)
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
elseif(APPLE)
target_link_libraries(triton ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
Expand Down Expand Up @@ -275,6 +290,11 @@ if (${CODEGEN_BACKENDS_LEN} GREATER 0)
endforeach()
endif()

if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(test)

add_subdirectory(unittest)
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ llvm_update_compile_flags(triton-translate)
mlir_check_all_link_libraries(triton-translate)

add_llvm_executable(triton-llvm-opt
PARTIAL_SOURCES_INTENDED
triton-llvm-opt.cpp

DEPENDS
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
)

add_mlir_library(ASMBuilder
PARTIAL_SOURCES_INTENDED
GCNAsmFormat.cpp
PTXAsmFormat.cpp

Expand Down
56 changes: 31 additions & 25 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_thirdparty_packages(triton_cache_path):
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
version_file_path = os.path.join(package_dir, "version.txt")
if p.syspath_var_name not in os.environ and\
if p.syspath_var_name not in os.environ and p.url and\
(not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url):
try:
shutil.rmtree(package_root_dir)
Expand All @@ -128,6 +128,9 @@ def get_thirdparty_packages(triton_cache_path):
# write version url to package_dir
with open(os.path.join(package_dir, "version.txt"), "w") as f:
f.write(p.url)
elif p.syspath_var_name not in os.environ and not p.url:
raise RuntimeError(
f'{p.syspath_var_name} not set ! Please install {p.package} manually and set {p.syspath_var_name}.')
if p.include_flag:
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
if p.lib_flag:
Expand Down Expand Up @@ -262,6 +265,10 @@ def build_extension(self, ext):
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
]
if platform.system() == "Windows":
installed_base = sysconfig.get_config_var('installed_base')
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand All @@ -276,10 +283,8 @@ def build_extension(self, ext):
cmake_args += ["-DTRITON_CODEGEN_BACKENDS=" + all_codegen_backends]

if platform.system() == "Windows":
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
build_args += ["--", "/m"]
else:
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
Expand Down Expand Up @@ -321,27 +326,28 @@ def build_extension(self, ext):
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)


download_and_copy(
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvcc/12.3.52/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/12.3.52/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
if platform.system() == "Linux":
download_and_copy(
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvcc/12.3.52/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/12.3.52/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)

setup(
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
Expand Down
8 changes: 6 additions & 2 deletions python/triton/common/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def _path_to_binary(binary: str):
]

for p in paths:
bin = p.split(" ")[0]
if os.name != "nt":
bin = p.split(" ")[0]
else:
bin = p
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
Expand Down Expand Up @@ -152,7 +155,8 @@ def compute_core_version_key():
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
ext = "so" if os.name != "nt" else "pyd"
with open(os.path.join(TRITON_PATH, "_C", "libtriton." + ext), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
Expand Down
23 changes: 18 additions & 5 deletions python/triton/common/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def libcuda_dirs():
if env_libcuda_path:
return [env_libcuda_path]

if os.name == "nt":
return [os.environ.get("CUDA_PATH") + "\\lib\\x64"]

libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
# each line looks like the following:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
Expand Down Expand Up @@ -88,18 +91,28 @@ def _build(name, src, srcdir):
if scheme == 'posix_local':
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
if os.name == "nt":
installed_base = sysconfig.get_config_var('installed_base')
py_libraries_dir = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))

if is_hip():
ret = subprocess.check_call([
cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
])
else:
cc_cmd = [
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
"-o", so
]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
if cc == "cl":
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", f"/I{cu_include_dir}", f"/I{py_include_dir}", f"/I{srcdir}"]
cc_cmd += ["/link", "cuda.lib", f"/OUT:{so}"]
cc_cmd += [f"/LIBPATH:{dir}" for dir in cuda_lib_dirs]
cc_cmd += [f"/LIBPATH:{py_libraries_dir}"]
else:
cc_cmd = [
cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
"-lcuda", "-o", so
]
if os.name == "nt": cc_cmd.pop(cc.cmd.index("-fPIC"))
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)

if ret == 0:
Expand Down
3 changes: 3 additions & 0 deletions python/triton/compiler/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..utils import get_ids_of_tensormaps, parse_tma_info
from ..make_launcher import make_stub
import hashlib
import os


def get_kernel_name(src: str, pattern: str) -> str:
Expand Down Expand Up @@ -196,6 +197,8 @@ def make_ptx(src, metadata, opt, capability):
def make_cubin(src, metadata, opt, capability):
metadata["name"] = get_kernel_name(src, pattern='// .globl')
ptxas, _ = path_to_ptxas()
if os.name == 'nt':
ptxas = f'"{ptxas}"'
return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion)

def add_stages(self, stages, compiler_options, linker_options):
Expand Down
7 changes: 7 additions & 0 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass
from .code_generator import ast_to_ttir
from pathlib import Path
import os
import re


Expand Down Expand Up @@ -216,6 +217,12 @@ def __init__(self, so_path, metadata_path):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
if spec is None:
if os.name == "nt":
import importlib.machinery
loader = importlib.machinery.ExtensionFileLoader("__triton_launcher", so_path)
spec = importlib.machinery.ModuleSpec(name="__triton_launcher", loader=loader, origin=so_path)
assert spec is not None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch")
Expand Down
2 changes: 1 addition & 1 deletion python/triton/compiler/make_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def make_stub(name, signature, constants, ids, **kwargs):
# name of files that are cached
so_cache_key = make_so_cache_key(get_cuda_version_key(), signature, constants, ids, **kwargs)
so_cache_manager = get_cache_manager(so_cache_key)
so_name = f"{name}.so"
so_name = f'{name}.{"so" if os.name != "nt" else "dll"}'
# retrieve stub from cache if it exists
cache_path = so_cache_manager.get_file(so_name)
if cache_path is None:
Expand Down
8 changes: 7 additions & 1 deletion python/triton/runtime/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self):
src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
fname = "cuda_utils.so"
fname = "cuda_utils." + ("so" if os.name != "nt" else "dll")
cache_path = cache.get_file(fname)
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -51,6 +51,12 @@ def __init__(self):
import importlib.util

spec = importlib.util.spec_from_file_location("cuda_utils", cache_path)
if spec is None:
if os.name == "nt":
import importlib.machinery
loader = importlib.machinery.ExtensionFileLoader("cuda_utils", cache_path)
spec = importlib.machinery.ModuleSpec(name="cuda_utils", loader=loader, origin=cache_path)
assert spec is not None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
Expand Down

0 comments on commit f067770

Please sign in to comment.