diff --git a/python/setup.py b/python/setup.py index bcdc5faa3107..e2c8d9ff96c6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -125,15 +125,15 @@ def get_thirdparty_packages(triton_cache_path): # ---- package data --- -def download_and_copy_ptxas(): - +def download_and_copy(src_path, version, url_func): base_dir = os.path.dirname(__file__) - src_path = "bin/ptxas" - version = "12.1.105" + # src_path = "bin/ptxas" + # version = "12.1.105" arch = platform.machine() if arch == "x86_64": arch = "64" - url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" + url = url_func(arch, version) + # url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" dst_prefix = os.path.join(base_dir, "triton") dst_suffix = os.path.join("third_party", "cuda", src_path) dst_path = os.path.join(dst_prefix, dst_suffix) @@ -156,9 +156,9 @@ def download_and_copy_ptxas(): shutil.copy(src_path, dst_path) return dst_suffix - # ---- cmake extension ---- + def get_base_dir(): return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) @@ -280,8 +280,9 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy_ptxas() - +download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") +download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2") +download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2") setup( name="triton", diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 2823cf9299b2..fc73f2bf374a 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -6,6 +6,7 @@ import triton import triton.language as tl +from triton.common.backend import path_to_nvdisasm @triton.jit @@ -50,10 +51,11 @@ def kernel_multi_files(X, Y, BLOCK: tl.constexpr): def extract_file_lines(asm): + nvdisasm, _ = path_to_nvdisasm() fd, path = tempfile.mkstemp() with open(fd, 'wb') as cubin: cubin.write(asm) - asm = subprocess.check_output(["nvdisasm", "-g", path]).decode("utf-8") + asm = subprocess.check_output([nvdisasm, "-g", path]).decode("utf-8") file_lines = [] lines = asm.splitlines() for line in lines: @@ -80,7 +82,7 @@ def check_file_lines(file_lines, file_name, lineno): @pytest.mark.parametrize("func", func_types) def test_line_info(func: str): try: - subprocess.check_output(["nvdisasm", "-h"]) + _, _ = path_to_nvdisasm() except BaseException: pytest.skip("nvdisasm is not available") @@ -99,20 +101,20 @@ def test_line_info(func: str): file_lines = extract_file_lines(kernel_info.asm["cubin"]) if func == "single": - assert (check_file_lines(file_lines, "test_line_info.py", 15)) assert (check_file_lines(file_lines, "test_line_info.py", 16)) + assert (check_file_lines(file_lines, "test_line_info.py", 17)) elif func == "call": - assert (check_file_lines(file_lines, "test_line_info.py", 28)) - assert (check_file_lines(file_lines, "test_line_info.py", 21)) - assert (check_file_lines(file_lines, "test_line_info.py", 30)) + assert (check_file_lines(file_lines, "test_line_info.py", 29)) + assert (check_file_lines(file_lines, "test_line_info.py", 22)) + assert (check_file_lines(file_lines, "test_line_info.py", 31)) elif func == "call_noinline": - assert (check_file_lines(file_lines, "test_line_info.py", 42)) - assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 43)) assert (check_file_lines(file_lines, "test_line_info.py", 36)) assert (check_file_lines(file_lines, "test_line_info.py", 37)) + assert (check_file_lines(file_lines, "test_line_info.py", 38)) elif func == "multi_files": - assert (check_file_lines(file_lines, "test_line_info.py", 47)) - assert (check_file_lines(file_lines, "test_line_info.py", 49)) + assert (check_file_lines(file_lines, "test_line_info.py", 48)) + assert (check_file_lines(file_lines, "test_line_info.py", 50)) assert (check_file_lines(file_lines, "standard.py", 33)) assert (check_file_lines(file_lines, "standard.py", 34)) assert (check_file_lines(file_lines, "standard.py", 36)) diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index 5b60c1377caa..edbcdde12c66 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -101,20 +101,34 @@ def get_backend(device_type: str): return _backends[device_type] if device_type in _backends else None -@functools.lru_cache() -def path_to_ptxas(): +def _path_to_binary(binary: str): base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get("TRITON_PTXAS_PATH", ""), - os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas") + os.path.join(base_dir, "third_party", "cuda", "bin", binary) ] - for ptxas in paths: - ptxas_bin = ptxas.split(" ")[0] - if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin): - result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT) + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: - return ptxas, version.group(1) - raise RuntimeError("Cannot find ptxas") + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def path_to_ptxas(): + return _path_to_binary("ptxas") + + +@functools.lru_cache() +def path_to_cuobjdump(): + return _path_to_binary("cuobjdump") + + +@functools.lru_cache() +def path_to_nvdisasm(): + return _path_to_binary("nvdisasm") diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index f5a5d941160b..211821bbb7ae 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -5,7 +5,6 @@ import json import os import re -import tempfile from collections import namedtuple from pathlib import Path from typing import Any @@ -24,7 +23,7 @@ from ..runtime.driver import driver from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability, version_key) -from ..tools.disasm import extract +from ..tools.disasm import get_sass from .code_generator import ast_to_ttir from .make_launcher import make_stub from .utils import (InfoFromBackendForTensorMap, TensorMapManager, @@ -500,7 +499,6 @@ def compile(fn, **kwargs): metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) else: metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - fn_cache_manager.put(next_module, ir_filename) fn_dump_manager.put(next_module, ir_filename) if (enable_override and fn_override_manager.has_file(ir_filename)): print(f"\nOverriding kernel with file {ir_filename}") @@ -517,6 +515,11 @@ def compile(fn, **kwargs): if ir_name == "cubin": asm[ir_name] = next_module + sass_ir = "sass" + sass_fname = f"{name}.{sass_ir}" + asm[sass_ir] = get_sass(next_module) + metadata_group[sass_fname] = fn_cache_manager.put(asm[sass_ir], sass_fname) + elif ir_name == "amdgcn": asm[ir_name] = str(next_module[0]) else: @@ -669,16 +672,3 @@ def runner(*args, stream=None): self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner - - def get_sass(self, fun=None): - if 'sass' in self.asm: - return self.asm['sass'] - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 24a0787c5c16..032b726682f5 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -20,8 +20,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import os import re import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') @@ -60,11 +64,25 @@ def processSassLines(fline, sline, labels): return (f'{ctrl}', f'{asm}') +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm if fun is None: - sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) else: - sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) sass_lines = sass_str.splitlines() line_idx = 0 while line_idx < len(sass_lines):