From 23e8f5b157f5743e211e3663918eb7561e392478 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Tue, 28 May 2024 13:07:11 -0700 Subject: [PATCH 1/2] [BACKEND][CPU] Make it buildable and runnable in a different environment --- include/triton/Conversion/CMakeLists.txt | 5 ++-- lib/Conversion/CMakeLists.txt | 5 ++-- python/src/passes.cc | 3 --- python/triton/runtime/build.py | 3 +++ third_party/cpu/backend/compiler.py | 2 +- third_party/cpu/backend/driver.py | 24 +++++++++++-------- .../include/TritonCPUToLLVM/CMakeLists.txt | 2 +- 7 files changed, 25 insertions(+), 19 deletions(-) diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt index ae31ac930b7e..3b8a95e1ecf7 100644 --- a/include/triton/Conversion/CMakeLists.txt +++ b/include/triton/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ -add_subdirectory(TritonCPUToLLVM) +# TODO(minjang): I will remove these scratches soon. +# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) -add_subdirectory(TritonToTritonCPU) +# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 83db4ae41607..426b22a42ef6 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ -#add_subdirectory(TritonToTritonCPU) +# TODO(minjang): I will remove these scratches soon. +# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -#add_subdirectory(TritonCPUToLLVM) +# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/python/src/passes.cc b/python/src/passes.cc index df7d9faa9052..b8e1b643dfe6 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -6,7 +6,6 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" -#include "triton/Conversion/TritonToTritonCPU/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -43,8 +42,6 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - // createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index d7baeb2868b0..2d565539176c 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -45,6 +45,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs] + # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. + if src.endswith(".cpp") or src.endswith(".cc"): + cc_cmd += ["-std=c++17"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 344cdd2f05ae..3daf83eaac28 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -20,7 +20,7 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False - allowed_dot_input_precisions: Tuple[str] = ("ieee",) + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) allow_fp8e4nv: bool = False enable_fp_fusion: bool = True diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3fe243fc262d..594e72c19f7e 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,11 +1,12 @@ -import os import hashlib +import os import tempfile from pathlib import Path + +from triton.backends.compiler import GPUTarget +from triton.backends.driver import DriverBase from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager -from triton.backends.driver import DriverBase -from triton.backends.compiler import GPUTarget dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") @@ -74,6 +75,7 @@ "LLVMSupport", "LLVMDemangle", "stdc++", + "z", ] @@ -90,6 +92,7 @@ def compile_module_from_src(src, name): with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -124,7 +127,7 @@ def get_device_properties(self, *args): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == "*": return "void*" return { "i1": "int32_t", @@ -148,10 +151,10 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): # Record the end of regular arguments; # subsequent arguments are architecture-specific descriptors. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == "*": return "PyObject*" return ty_to_cpp(ty) @@ -171,12 +174,13 @@ def format_of(ty): "uint64_t": "K", }[ty] - args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiOKOOOO" + args_format - arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + arg_ptrs_list = (", ".join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else "") kernel_fn_args = [i for i in signature.keys() if i not in constants] - kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' - kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" + kernel_fn_args_list = (", ".join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else "") + kernel_fn_arg_types = (", ".join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + + ", " if len(signature) > 0 else "") + "uint32_t, uint32_t, uint32_t" # generate glue code src = f""" diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt index 64b36523d35d..0936dff12d91 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUConversionPassIncGen) +add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen) From 0a73b485c233d2d79caad00a0c04ed1fbe98cb38 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Tue, 28 May 2024 14:23:55 -0700 Subject: [PATCH 2/2] Revert seemingly inconsistent python code formatting --- third_party/cpu/backend/driver.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 594e72c19f7e..5783a0342dbd 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,12 +1,11 @@ -import hashlib import os +import hashlib import tempfile from pathlib import Path - -from triton.backends.compiler import GPUTarget -from triton.backends.driver import DriverBase from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase +from triton.backends.compiler import GPUTarget dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") @@ -92,7 +91,6 @@ def compile_module_from_src(src, name): with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) import importlib.util - spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) @@ -127,7 +125,7 @@ def get_device_properties(self, *args): def ty_to_cpp(ty): - if ty[0] == "*": + if ty[0] == '*': return "void*" return { "i1": "int32_t", @@ -151,10 +149,10 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): # Record the end of regular arguments; # subsequent arguments are architecture-specific descriptors. - arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == "*": + if ty[0] == '*': return "PyObject*" return ty_to_cpp(ty) @@ -174,13 +172,13 @@ def format_of(ty): "uint64_t": "K", }[ty] - args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()]) + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiOKOOOO" + args_format - arg_ptrs_list = (", ".join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else "") + arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' kernel_fn_args = [i for i in signature.keys() if i not in constants] - kernel_fn_args_list = (", ".join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else "") - kernel_fn_arg_types = (", ".join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + - ", " if len(signature) > 0 else "") + "uint32_t, uint32_t, uint32_t" + kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' + kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" # generate glue code src = f"""