Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import pybind11

from build_helpers import get_base_dir, get_cmake_dir
import setup_helper as helper
from setup_tools import setup_helper as helper


@dataclass
Expand Down Expand Up @@ -400,7 +400,6 @@ def run(self):
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
if (cmake_major, cmake_minor) < (3, 18):
raise RuntimeError("CMake >= 3.18.0 is required")

for ext in self.extensions:
self.build_extension(ext)

Expand Down Expand Up @@ -432,7 +431,6 @@ def build_extension(self, ext):
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
thirdparty_cmake_args += self.get_pybind11_cmake_args()
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
ext_base_dir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# create build directories
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
Expand All @@ -449,6 +447,7 @@ def build_extension(self, ext):
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
]
cmake_args += helper.get_backend_cmake_args(build_ext=self)
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
Expand All @@ -472,7 +471,6 @@ def build_extension(self, ext):
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
f"-DCMAKE_INSTALL_PREFIX={ext_base_dir}",
]

# Note that asan doesn't work with binaries that use the GPU, so this is
Expand Down Expand Up @@ -515,6 +513,7 @@ def build_extension(self, ext):
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
subprocess.check_call(["cmake", "--install", "."], cwd=cmake_dir)
helper.install_extension(build_ext=self)


nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json")
Expand Down Expand Up @@ -652,27 +651,31 @@ class plugin_install(install):
def run(self):
add_links()
install.run(self)
helper.post_install()


class plugin_develop(develop):

def run(self):
add_links()
develop.run(self)
helper.post_install()


class plugin_bdist_wheel(bdist_wheel):

def run(self):
add_links()
bdist_wheel.run(self)
helper.post_install()


class plugin_egginfo(egg_info):

def run(self):
add_links()
egg_info.run(self)
helper.post_install()


# TODO: package_data_tools
Expand Down
4 changes: 4 additions & 0 deletions python/setup_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import setup_helper
from . import utils

__all__ = ["setup_helper", "utils"]
233 changes: 143 additions & 90 deletions python/setup_helper.py → python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,19 @@
import urllib.request
from pathlib import Path
import hashlib
from dataclasses import dataclass
from distutils.sysconfig import get_python_lib
from . import utils

flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
use_triton_shared = False
necessary_third_party = ["" if flagtree_backend == "tsingmicro" else "flir"]
default_backends = ["nvidia", "amd"]
extend_backends = []
default_backends = ["nvidia", "amd"]
plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro"]
ext_sourcedir = "triton/_C/"


@dataclass
class FlagTreeBackend:
name: str
url: str
tag: str


flagtree_backend_info = {
"flir":
FlagTreeBackend(name="flir", url="[email protected]:FlagTree/flir.git",
tag="e72b83ba46a5a9dd6466c7102f93fd600cde909e"),
"triton_shared":
FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
tag="5842469a16b261e45a2c67fbfc308057622b03ee"),
"cambricon":
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
}
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}
flagtree_backends = utils.flagtree_backends
backend_utils = utils.activate(flagtree_backend)

set_llvm_env = lambda path: set_env({
'LLVM_INCLUDE_DIRS': Path(path) / "include",
Expand All @@ -45,6 +29,101 @@ class FlagTreeBackend:
})


def install_extension(*args, **kargs):
try:
backend_utils.install_extension(*args, **kargs)
except Exception:
pass


def get_backend_cmake_args(*args, **kargs):
try:
return backend_utils.get_backend_cmake_args(*args, **kargs)
except Exception:
return []


def get_device_name():
return device_mapping[flagtree_backend]


def get_extra_packages():
packages = []
try:
packages = backend_utils.get_extra_install_packages()
except Exception:
packages = []
return packages


def get_package_data_tools():
package_data = ["compile.h", "compile.c"]
try:
package_data += backend_utils.get_package_data_tools()
except Exception:
package_data
return package_data


def git_clone(lib, lib_path):
import git
MAX_RETRY = 4
print(f"Clone {lib.name} into {lib_path} ...")
retry_count = MAX_RETRY
while (retry_count):
try:
repo = git.Repo.clone_from(lib.url, lib_path)
if lib.tag is not None:
repo.git.checkout(lib.tag)
sub_triton_path = Path(lib_path) / "triton"
if os.path.exists(sub_triton_path):
shutil.rmtree(sub_triton_path)
print(f"successfully clone {lib.name} into {lib_path} ...")
return True
except Exception:
retry_count -= 1
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")
return False


def dir_rollback(deep, base_path):
while (deep):
base_path = os.path.dirname(base_path)
deep -= 1
return Path(base_path)


def download_flagtree_third_party(name, condition, required=False, hock=None):
if not condition:
return
backend = None
for _backend in flagtree_backends:
if _backend.name in name:
backend = _backend
break
if backend is None:
return backend
base_dir = dir_rollback(3, __file__) / "third_party"
prelib_path = Path(base_dir) / name
lib_path = Path(base_dir) / _backend.name
if not os.path.exists(prelib_path) and not os.path.exists(lib_path):
succ = git_clone(lib=backend, lib_path=prelib_path)
if not succ and required:
raise RuntimeError("Bad network ! ")
else:
print(f'Found third_party {backend.name} at {lib_path}\n')

if callable(hock):
hock(third_party_base_dir=base_dir, backend=backend, default_backends=default_backends)


def post_install():
try:
backend_utils.post_install()
except Exception:
pass


class FlagTreeCache:

def __init__(self):
Expand Down Expand Up @@ -211,8 +290,13 @@ class CommonUtils:

@staticmethod
def unlink():
cur_path = os.path.dirname(__file__)
backends_dir_path = Path(cur_path) / "triton" / "backends"
cur_path = dir_rollback(2, __file__)
if "editable_wheel" in sys.argv:
installation_dir = cur_path
else:
installation_dir = get_python_lib()
backends_dir_path = Path(installation_dir) / "triton" / "backends"
# raise RuntimeError(backends_dir_path)
if not os.path.exists(backends_dir_path):
return
for name in os.listdir(backends_dir_path):
Expand All @@ -230,15 +314,15 @@ def unlink():
def skip_package_dir(package):
if 'backends' in package or 'profiler' in package:
return True
if flagtree_backend in ['cambricon']:
if package not in ['triton', 'triton/_C']:
return True
return False
try:
return backend_utils.skip_package_dir(package)
except Exception:
return False

@staticmethod
def get_package_dir(packages):
package_dict = {}
if flagtree_backend and flagtree_backend not in ("cambricon", "aipu", "tsingmicro"):
if flagtree_backend and flagtree_backend not in plugin_backends:
connection = []
backend_triton_path = f"../third_party/{flagtree_backend}/python/"
for package in packages:
Expand All @@ -247,70 +331,20 @@ def get_package_dir(packages):
pair = (package, f"{backend_triton_path}{package}")
connection.append(pair)
package_dict.update(connection)
try:
package_dict.update(backend_utils.get_package_dir())
except Exception:
pass
return package_dict

@staticmethod
def download_third_party():
import git
MAX_RETRY = 4
global use_triton_shared, flagtree_backend
third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party"

def git_clone(lib, lib_path):
global use_triton_shared
print(f"Clone {lib.name} into {lib_path} ...")
retry_count = MAX_RETRY
while (retry_count):
try:
repo = git.Repo.clone_from(lib.url, lib_path)
repo.git.checkout(lib.tag)
if lib.name in flagtree_backend_info:
sub_triton_path = Path(lib_path) / "triton"
if os.path.exists(sub_triton_path):
shutil.rmtree(sub_triton_path)
print(f"successfully clone {lib.name} into {lib_path} ...")
return
except Exception:
retry_count -= 1
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")

print(f"Unable to clone third_party {lib.name}")
if lib.name in necessary_third_party:
use_triton_shared = False # TODO
print(f"\n\t{lib.name} is compiled by default, but for "
"some reason we couldn't download triton_shared\n"
"as third_party (most likely for network reasons), "
"so we couldn't compile triton_shared\n")

third_partys = []
if flagtree_backend != "tsingmicro":
third_partys.append(flagtree_backend_info["flir"])
if os.environ.get("USE_TRITON_SHARED", "ON") == "ON":
third_partys.append(flagtree_backend_info["triton_shared"])
else:
use_triton_shared = False
if flagtree_backend in flagtree_backend_info:
third_partys.append(flagtree_backend_info[flagtree_backend])

for lib in third_partys:
lib_path = Path(third_party_base_dir) / lib.name
if not os.path.exists(lib_path):
git_clone(lib=lib, lib_path=lib_path)
else:
print(f'Found third_party {lib.name} at {lib_path}\n')


def handle_flagtree_backend():
global ext_sourcedir
if flagtree_backend:
print(f"flagtree_backend is {flagtree_backend}")
print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m")
extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv and flagtree_backend not in ("aipu", "tsingmicro"):
if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends:
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
if flagtree_backend != "tsingmicro":
default_backends.append("flir")
if use_triton_shared:
default_backends.append("triton_shared")


def set_env(env_dict: dict):
Expand All @@ -322,8 +356,18 @@ def check_env(env_val):
return os.environ.get(env_val, '') != ''


CommonUtils.download_third_party()
download_flagtree_third_party("triton_shared", hock=utils.default.precompile_hock, condition=(not flagtree_backend))

download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"),
hock=utils.ascend.precompile_hock, required=True)

download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True)

download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock,
required=True)

handle_flagtree_backend()

cache = FlagTreeCache()

# iluvatar
Expand Down Expand Up @@ -375,6 +419,15 @@ def check_env(env_val):
post_hock=set_llvm_env,
)

# ascend
cache.store(
file="ascend-llvm-b5cc222d-ubuntu-arm64",
condition=("ascend" == flagtree_backend),
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz",
pre_hock=lambda: check_env('LLVM_SYSPATH'),
post_hock=set_llvm_env,
)

# aipu
cache.store(
file="aipu-llvm-a66376b0-ubuntu-x64",
Expand Down
Loading