Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 94 additions & 15 deletions python/test/unit/tools/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,25 @@

import triton
from triton.backends.compiler import GPUTarget
from triton.backends.nvidia.driver import include_dirs, library_dirs
from triton._internal_testing import is_cuda, is_hip

if is_cuda():
from triton.backends.nvidia.driver import include_dirs, library_dirs

def library_names():
return ["cuda"]

elif is_hip():
from triton.backends.amd.driver import include_dirs, _get_path_to_hip_runtime_dylib

def library_dirs():
hip_runtime_dylib = _get_path_to_hip_runtime_dylib()
return [os.path.dirname(hip_runtime_dylib)]

def library_names():
return ["amdhip64"]


kernel_utils_src = """
import triton

Expand Down Expand Up @@ -86,17 +102,26 @@ def kernel(
"""


test_utils_src = """
if is_cuda():
test_utils_src = """
#include <cuda.h>

// Forward declaration for backward compatibility with CUDA 12.x and 13.x
CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
"""
elif is_hip():
test_utils_src = """
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
"""

test_utils_src += """
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <assert.h>
#include "kernel.h"

// Forward declaration for backward compatibility with CUDA 12.x and 13.x
CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);

static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {
FILE *file = fopen(filename, "w");
if (file == NULL) {
Expand Down Expand Up @@ -142,7 +167,8 @@ def gen_kernel_library(dir, libname):


def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
test_src = f"""
if is_cuda():
test_src = f"""
int main(int argc, char **argv) {{
int M = {M}, N = {N}, K = {K};

Expand Down Expand Up @@ -195,6 +221,61 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
cuMemFree(C);
cuCtxDestroy(ctx);
}}
"""
elif is_hip():
test_src = f"""
int main(int argc, char **argv) {{
int M = {M}, N = {N}, K = {K};

// initialize hip handles
hipDevice_t dev;
// hipCtx_t ctx;
hipStream_t stream;
hipDeviceptr_t A, B, C;
hipError_t err = 0;
hipInit(0);
hipDeviceGet(&dev, 0);
// hipCtxCreate(&ctx, 0, dev);
hipMalloc(&A, M * K * 2);
hipMalloc(&B, K * N * 2);
hipMalloc(&C, M * N * 4);
hipStreamCreateWithFlags(&stream, 0);
load_matmul_fp16();

// initialize input data
int16_t hA[M*K];
int16_t hB[K*N];
memset(hA, 0, M*K*2);
memset(hB, 0, K*N*2);
read_csv_to_buffer(argv[1], hA, M*K);
read_csv_to_buffer(argv[2], hB, K*N);
hipMemcpyHtoD(A, hA, M*K*2);
hipMemcpyHtoD(B, hB, K*N*2);

// launch kernel
hipError_t ret;
int algo_id = {algo_id};
if (algo_id == 0) {{
ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1);
}} else {{
ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
}}
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
assert(ret == 0);

// read data
int32_t hC[M*N];
memset(hC, 0, M*N*4);
hipMemcpyDtoH(hC, C, M*N*4);
write_buffer_to_csv(argv[3], hC, M*N);

// free hip handles
unload_matmul_fp16();
hipFree(A);
hipFree(B);
hipFree(C);
// hipCtxDestroy(ctx);
}}
"""
src = test_utils_src + test_src
with open(os.path.join(dir, "test.c"), "w") as file:
Expand All @@ -205,7 +286,9 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
command.extend(["-I", inc_dir])
for lib_dir in library_dirs():
command.extend(["-L", lib_dir])
command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe])
for lib_name in library_names():
command.extend(["-l", lib_name])
command.extend(["-L", dir, "-l", "kernel", "-o", exe])
subprocess.run(command, check=True, cwd=dir)


Expand Down Expand Up @@ -294,12 +377,12 @@ def generate_matmul_test_data(dir, M, N, K):
def check_hasco_binary_str(tmp_dir: str, dtype: str):
# Linking is not yet enabled on HIP backend so just check compilation for now.
h_files = glob.glob(f"matmul_{dtype}.*.h", root_dir=tmp_dir)
cpp_files = glob.glob(f"matmul_{dtype}.*.cpp", root_dir=tmp_dir)
c_files = glob.glob(f"matmul_{dtype}.*.c", root_dir=tmp_dir)
assert len(h_files) == 1, "Expected one .h file"
assert len(cpp_files) == 1, "Expected one .cpp file"
assert len(c_files) == 1, "Expected one .c file"
pattern = re.compile(r'HSACO_NAME\[(\d+)\]')
with open(os.path.join(tmp_dir, cpp_files[0]), "r") as cpp_file:
content = cpp_file.read()
with open(os.path.join(tmp_dir, c_files[0]), "r") as c_file:
content = c_file.read()
matches = pattern.findall(content)
assert len(matches) == 1, "Expected one HSACO_NAME definition"
assert int(matches[0]) > 16, "Expected valid HSACO object binary string"
Expand All @@ -317,7 +400,6 @@ def test_compile_link_matmul_no_specialization():
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
if is_hip():
check_hasco_binary_str(tmp_dir, dtype)
return

link_aot_kernels(tmp_dir)

Expand Down Expand Up @@ -352,7 +434,6 @@ def test_compile_link_matmul():
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":16", ":16")])
if is_hip():
check_hasco_binary_str(tmp_dir, dtype)
return
link_aot_kernels(tmp_dir)

# compile test case
Expand Down Expand Up @@ -386,7 +467,6 @@ def test_launcher_has_no_available_kernel():
compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":1", ":1")])
if is_hip():
check_hasco_binary_str(tmp_dir, dtype)
return

link_aot_kernels(tmp_dir)

Expand Down Expand Up @@ -414,7 +494,6 @@ def test_launcher_has_no_available_kernel():
assert "kernel launch failed" in result.stderr


@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
def test_compile_link_autotune_matmul():
np.random.seed(3)

Expand Down
3 changes: 2 additions & 1 deletion python/triton/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def constexpr(s):
hex_ = str(binascii.hexlify(asm))[2:-1]

ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
backend_name = target.backend

params = {
"kernel_name": func_name,
Expand All @@ -192,9 +193,9 @@ def constexpr(s):
"gridZ": grid[2],
"_placeholder": "",
"warp_size": target.warp_size,
"backend_name": backend_name,
}
output_files = []
backend_name = target.backend
template_dir = Path(__file__).parent / "extra" / backend_name
for template_path in template_dir.glob('compile.*'):
ext = template_path.suffix
Expand Down
37 changes: 24 additions & 13 deletions python/triton/tools/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ def __init__(self) -> None:
self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
# [d|c]
self.arg_suffix = re.compile("[c,d]")
# [backend_name]
self.backend_name_re = re.compile("//[\\s]*tt-linker-backend:[\\s]*([\\w]+)")

self.kernels = defaultdict(list)
self.backend_name = None

def extract_linker_meta(self, header: str):
for ln in header.splitlines():
Expand All @@ -64,6 +67,14 @@ def extract_linker_meta(self, header: str):
num_specs=num_specs,
),
)
else:
m = self.backend_name_re.match(ln)
if _exists(m):
backend_name = m.group(1)
if self.backend_name is None:
self.backend_name = backend_name
elif self.backend_name != backend_name:
raise RuntimeError(f"differing backend {self.backend_name} vs. {backend_name}")

def _match_name(self, ker_name: str):
m = self.kernel_name.match(ker_name)
Expand Down Expand Up @@ -135,7 +146,7 @@ def gen_signature(m):
# generate declarations of kernels with meta-parameter and constant values
def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
return f"""
CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])});
TT_ResultTy {name}(TT_StreamTy stream, {gen_signature_with_full_args(metas[-1])});
void load_{name}();
void unload_{name}();
"""
Expand All @@ -144,16 +155,16 @@ def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
# generate declarations of kernels with meta-parameter and constant values
def make_global_decl(meta: KernelLinkerMeta) -> str:
return f"""
CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)});
CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id);
TT_ResultTy {meta.orig_kernel_name}_default(TT_StreamTy stream, {gen_signature_with_full_args(meta)});
TT_ResultTy {meta.orig_kernel_name}(TT_StreamTy stream, {gen_signature_with_full_args(meta)}, int algo_id);
void load_{meta.orig_kernel_name}();
void unload_{meta.orig_kernel_name}();
"""


# generate dispatcher function for kernels with different meta-parameter and constant values
def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
src = f"TT_ResultTy {meta.orig_kernel_name}_default(TT_StreamTy stream, {gen_signature_with_full_args(meta)}){{\n"
src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
src += "}\n"
return src
Expand All @@ -163,14 +174,14 @@ def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
src = f"// launcher for: {name}\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
src += f"TT_ResultTy {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(TT_StreamTy stream, {gen_signature(meta)});\n"
src += "\n"

src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
src += (f"TT_ResultTy {name}(TT_StreamTy stream, {gen_signature_with_full_args(metas[-1])}){{")
src += "\n"
for meta in sorted(metas, key=lambda m: -m.num_specs):
cond_fn = ( #
lambda val, hint: f"({val} % {hint} == 0)" #
lambda val, hint: f"((uintptr_t){val} % {hint} == 0)" #
if hint == 16 #
else f"({val} == {hint})" #
if hint == 1 #
Expand All @@ -185,7 +196,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -
arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
src += "\n"
src += " return CUDA_ERROR_INVALID_VALUE;\n"
src += " return TT_ERROR_INVALID_VALUE;\n"
src += "}\n"

for mode in ["load", "unload"]:
Expand All @@ -202,7 +213,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -

# generate dispatcher function for kernels with different meta-parameter and constant values
def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
src = f"TT_ResultTy {meta.orig_kernel_name}(TT_StreamTy stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n"
src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n"
src += "}\n"
Expand All @@ -212,7 +223,7 @@ def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
# generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values
def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str:
# the table of hint dispatchers
src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n"
src = f"typedef TT_ResultTy (*kernel_func_t)(TT_StreamTy stream, {gen_signature_with_full_args(meta)});\n"
src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n"
for name in names:
src += f" {name},\n"
Expand Down Expand Up @@ -287,8 +298,9 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
meta = meta_lists[0][0]
get_num_algos_decl = make_get_num_algos_decl(meta)
global_decl = make_global_decl(meta)
backend_prelude = (Path(__file__).parent / "extra" / parser.backend_name / "link.h").read_text()
with args.out.with_suffix(".h").open("w") as fp:
out = "#include <cuda.h>\n"
out = backend_prelude
out += "\n".join(algo_decls)
out += "\n"
out += get_num_algos_decl
Expand All @@ -305,8 +317,7 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
get_num_algos_def = make_get_num_algos_def(meta)
default_algo_kernel = make_default_algo_kernel(meta)
with args.out.with_suffix(".c").open("w") as fp:
out = ""
out += "#include <cuda.h>\n"
out = backend_prelude
out += "#include <stdint.h>\n"
out += "#include <assert.h>\n"
out += "\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <stdint.h>
#include <inttypes.h>
#include <string.h>
#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>

// helpers to check for hip errors
Expand All @@ -28,8 +29,8 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {{

// globals
#define HSACO_NAME {kernel_name}_hsaco
hipModule_t {kernel_name}_mod = nullptr;
hipFunction_t {kernel_name}_func = nullptr;
hipModule_t {kernel_name}_mod = NULL;
hipFunction_t {kernel_name}_func = NULL;
unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }};


Expand All @@ -50,7 +51,7 @@ void load_{kernel_name}() {{
{kernel_docstring}
*/
hipError_t {kernel_name}(hipStream_t stream, {signature}) {{
if ({kernel_name}_func == nullptr)
if ({kernel_name}_func == NULL)
load_{kernel_name}();
unsigned int gX = {gridX};
unsigned int gY = {gridY};
Expand All @@ -61,7 +62,7 @@ hipError_t {kernel_name}(hipStream_t stream, {signature}) {{

// TODO: shared memory
if(gX * gY * gZ > 0)
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * {warp_size}, 1, 1, {shared}, stream, args, nullptr);
return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * {warp_size}, 1, 1, {shared}, stream, args, NULL);
else
return hipErrorInvalidValue;
}}
5 changes: 5 additions & 0 deletions third_party/amd/tools/hip/compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@

#pragma once

#define __HIP_PLATFORM_AMD__

#include <hip/hip_runtime.h>
#include <inttypes.h>
#include <stdint.h>
#include <stdio.h>

// tt-linker-backend: {backend_name}

void unload_{kernel_name}(void);
void load_{kernel_name}(void);
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature});
14 changes: 14 additions & 0 deletions third_party/amd/tools/hip/link.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef TT_LINK_INCLUDES
#define TT_LINK_INCLUDES

#include <stdint.h>

#define __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>

typedef hipStream_t TT_StreamTy;
typedef hipError_t TT_ResultTy;

#define TT_ERROR_INVALID_VALUE hipErrorInvalidValue

#endif
Loading
Loading