Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows build fix #2738

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python/build/
python/triton.egg-info/
python/triton/_C/libtriton.pyd
python/triton/_C/libtriton.so
python/triton/_C/triton.dll

# Backends copied from submodules
python/triton/backends/
Expand Down
38 changes: 34 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,17 @@ set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
# used conditionally in this file and by lit tests

# Customized release build type with assertions: TritonRelBuildWithAsserts
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
if(NOT MSVC)
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
else()
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1")
set(CMAKE_CXX_FLAGS_TRITONRELBUILDWITHASSERTS "/Zi /Ob0 /Od /RTC1")
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_STATIC_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
endif()

# Default build type
if(NOT CMAKE_BUILD_TYPE)
Expand All @@ -45,7 +54,15 @@ endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
if(NOT MSVC)
if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -std=gnu++17 -Wno-deprecated")
endif()
else()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS /wd4244 /wd4624 /wd4715 /wd4530")
endif()

# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
Expand Down Expand Up @@ -103,7 +120,11 @@ endfunction()


# Disable warnings that show up in external code (gtest;pybind11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
if(NOT MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
endif()

include_directories(".")
include_directories(${MLIR_INCLUDE_DIRS})
Expand Down Expand Up @@ -139,6 +160,8 @@ if(TRITON_BUILD_PYTHON_MODULE)

if(PYTHON_INCLUDE_DIRS)
include_directories(${PYTHON_INCLUDE_DIRS})
message(STATUS "PYTHON_LIB_DIRS ${PYTHON_LIB_DIRS}")
link_directories(${PYTHON_LIB_DIRS})
else()
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
include_directories(${Python3_INCLUDE_DIRS})
Expand Down Expand Up @@ -238,6 +261,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
target_link_libraries(triton PUBLIC ${TRITON_LIBRARIES})
if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
set_target_properties(triton PROPERTIES PREFIX "lib")
else()
target_link_libraries(triton PRIVATE z)
endif()
Expand All @@ -255,6 +280,11 @@ if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
target_link_libraries(triton PRIVATE ${PYTHON_LDFLAGS})
endif()

if(WIN32)
option(CMAKE_USE_WIN32_THREADS_INIT "using WIN32 threads" ON)
option(gtest_disable_pthreads "Disable uses of pthreads in gtest." ON)
endif()

add_subdirectory(bin)
add_subdirectory(test)
add_subdirectory(unittest)
1 change: 1 addition & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ mlir_check_all_link_libraries(triton-lsp)


add_llvm_executable(triton-llvm-opt
PARTIAL_SOURCES_INTENDED
triton-llvm-opt.cpp

DEPENDS
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,7 @@ MfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
return {32, parentShapePerCTA[1]};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,8 +913,8 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
lastOp = op;
op = op->getBlock()->getParentOp();
}
return std::distance(lastOp->getBlock()->getParent()->begin(),
lastOp->getBlock()->getIterator());
return (long)std::distance(lastOp->getBlock()->getParent()->begin(),
lastOp->getBlock()->getIterator());
};
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
Expand Down
71 changes: 42 additions & 29 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_thirdparty_packages():
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
version_file_path = os.path.join(package_dir, "version.txt")
if p.syspath_var_name not in os.environ and\
if p.syspath_var_name not in os.environ and p.url and\
(not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url):
try:
shutil.rmtree(package_root_dir)
Expand All @@ -197,6 +197,9 @@ def get_thirdparty_packages():
# write version url to package_dir
with open(os.path.join(package_dir, "version.txt"), "w") as f:
f.write(p.url)
elif p.syspath_var_name not in os.environ and not p.url:
raise RuntimeError(
f'{p.syspath_var_name} not set ! Please install {p.package} manually and set {p.syspath_var_name}.')
if p.include_flag:
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
if p.lib_flag:
Expand All @@ -210,13 +213,17 @@ def download_and_copy(src_path, variable, version, url_func):
return
base_dir = os.path.dirname(__file__)
system = platform.system()
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
url = url_func(arch, version)
arch = {"x86_64": "64", "AMD64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
supported = {"Linux": "linux", "Windows": "win"}
is_supported = system in supported
if is_supported:
url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia") # path to cache the download
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", src_path) # final binary path
src_path = os.path.join(tmp_path, src_path)
src_path += ".exe" if os.name == "nt" else ""
download = not os.path.exists(src_path)
if os.path.exists(dst_path) and system == "Linux":
if os.path.exists(dst_path) and is_supported:
curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip()
curr_version = re.search(r"V([.|\d]+)", curr_version).group(1)
download = download or curr_version != version
Expand Down Expand Up @@ -316,6 +323,10 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
if platform.system() == "Windows":
installed_base = sysconfig.get_config_var('installed_base')
py_lib_dirs = os.getenv("PYTHON_LIB_DIRS", os.path.join(installed_base, "libs"))
cmake_args.append("-DPYTHON_LIB_DIRS=" + py_lib_dirs)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand All @@ -325,9 +336,8 @@ def build_extension(self, ext):
build_args = ["--config", cfg]

if platform.system() == "Windows":
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
else:
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
Expand Down Expand Up @@ -369,29 +379,32 @@ def build_extension(self, ext):
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)


download_and_copy(
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvcc/12.3.52/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/12.3.52/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.3.52",
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/12.3.52/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)

backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()]
if platform.system() in ["Linux", "Windows"]:
download_and_copy(
src_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version="12.3.52",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version="12.3.52",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
src_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version="12.3.52",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
backends = ["nvidia", "amd"]
if os.name == "nt":
backends = ["nvidia"]
backends = [*BackendInstaller.copy(backends), *BackendInstaller.copy_externals()]


def add_link_to_backends():
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_nested1_change():


def write_and_load_module(code, num_extra_lines):
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f:
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as f:
f.write(('# extra line\n' * num_extra_lines) + code)
f.flush()
spec = importlib.util.spec_from_file_location("module.name", f.name)
Expand Down
8 changes: 6 additions & 2 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ def kernel_sub(a, b, o, N: tl.constexpr):


def test_compile_in_subproc() -> None:
import os
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(4)), ())

multiprocessing.set_start_method('fork')
if os.name == "nt":
multiprocessing.set_start_method('spawn')
else:
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(target=compile_fn, args=(config, cc))
proc.start()
proc.join()
Expand All @@ -64,7 +68,7 @@ def test_compile_in_forked_subproc() -> None:
capability = major * 10 + minor
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())

assert multiprocessing.get_start_method() == 'fork'
assert multiprocessing.get_start_method() in ['fork', 'spawn']
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability))
proc.start()
proc.join()
Expand Down
6 changes: 5 additions & 1 deletion python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ def __init__(self, target: tuple) -> None:

@staticmethod
def _path_to_binary(binary: str):
binary += ".exe" if os.name == "nt" else ""
base_dir = os.path.join(os.path.dirname(__file__), os.pardir)
paths = [
os.environ.get(f"TRITON_{binary.upper()}_PATH", ""),
os.path.join(base_dir, "third_party", "cuda", "bin", binary),
]
for p in paths:
bin = p.split(" ")[0]
if os.name != "nt":
bin = p.split(" ")[0]
else:
bin = p
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
if result is not None:
Expand Down
3 changes: 2 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def triton_key():
contents += [hashlib.sha1(f.read()).hexdigest()]
# backend
libtriton_hash = hashlib.sha1()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
ext = "so" if os.name != "nt" else "pyd"
with open(os.path.join(TRITON_PATH, "_C", "libtriton." + ext), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
Expand Down
27 changes: 22 additions & 5 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@ def quiet():
sys.stdout, sys.stderr = old_stdout, old_stderr


def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
if cc in ["cl", "clang-cl"]:
cc_cmd = [cc, src, "/nologo", "/O2", "/LD"]
cc_cmd += [f"/I{dir}" for dir in include_dirs]
cc_cmd += ["/link"]
cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
cc_cmd += [f'{lib}.lib' for lib in libraries]
cc_cmd += [f"/OUT:{out}"]
else:
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC"]
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs]
cc_cmd += ["-o", out]

if os.name == "nt": cc_cmd.pop(cc_cmd.index("-fPIC"))

return cc_cmd


def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
Expand All @@ -41,10 +61,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
scheme = 'posix_prefix'
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
include_dirs = include_dirs + [srcdir, py_include_dir]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs]
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand All @@ -58,7 +75,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_compile_args=extra_compile_args + ['-O3' if "-O3" in cc_cmd else "/O2"],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
Expand Down
Loading