From 2cc6e7ec9af6bb5f33ec59e28a368a0f65a5aee7 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 27 Aug 2023 16:38:16 -0700 Subject: [PATCH 01/41] poking around a bit --- python/setup.py | 4 +- python/src/triton.cc | 32 +++++++ python/triton/interpreter/new_interpreter.py | 88 ++++++++++++++++++++ python/triton/runtime/jit.py | 26 +++--- 4 files changed, 133 insertions(+), 17 deletions(-) create mode 100644 python/triton/interpreter/new_interpreter.py diff --git a/python/setup.py b/python/setup.py index 18764ec13165..6269a8b2c231 100644 --- a/python/setup.py +++ b/python/setup.py @@ -58,8 +58,8 @@ class Package(NamedTuple): def get_pybind11_package_info(): - name = "pybind11-2.10.0" - url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz" + name = "pybind11-2.11.1" + url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz" return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") # llvm diff --git a/python/src/triton.cc b/python/src/triton.cc index 1a947a8a7df4..9768d6395c9a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -65,6 +65,7 @@ #include #include +#include namespace py = pybind11; PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy); @@ -1978,11 +1979,42 @@ void init_triton_translation(py::module &m) { ret::take_ownership); } +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + m.def("load_ptrs", + [](py::array_t ptrs, py::dtype ret_dtype) -> py::array { + py::gil_scoped_release allow_threads; + int numel = ptrs.size(); + auto shape = + std::vector(ptrs.shape(), ptrs.shape() + ptrs.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptrs = ptrs.reshape({numel}); + for (size_t i = 0; i < ptrs.size(); ++i) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptrs.at(i)), + ret_dtype.itemsize()); + return ret.reshape(shape); + }); + + m.def("store_ptrs", [](py::array_t ptrs, py::array values) { + py::gil_scoped_release allow_threads; + int numel = ptrs.size(); + py::array_t reshaped_ptrs = ptrs.reshape({numel}); + auto reshaped_values = values.reshape({numel}); + for (size_t i = 0; i < ptrs.size(); ++i) { + memcpy(reinterpret_cast(reshaped_ptrs.at(i)), + reshaped_values.mutable_data(i), values.dtype().itemsize()); + } + }); +} + void init_triton(py::module &m) { py::module subm = m.def_submodule("triton"); init_triton_env_vars(subm); // init_triton_codegen(subm.def_submodule("code_gen")); init_triton_runtime(subm.def_submodule("runtime")); init_triton_ir(subm.def_submodule("ir")); + init_triton_interpreter(subm.def_submodule("interpreter")); init_triton_translation(subm); } diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py new file mode 100644 index 000000000000..a6b80520a78a --- /dev/null +++ b/python/triton/interpreter/new_interpreter.py @@ -0,0 +1,88 @@ +import inspect + +import numpy as np + +import triton +import triton.language as tl + + +# TODO: duplicate +def str_to_ty(name): + language = tl + if name[0] == "*": + ty = str_to_ty(name[1:]) + return language.pointer_type(ty) + tys = { + "fp8e4nv": language.float8e4nv, + "fp8e5": language.float8e5, + "fp8e4b15": language.float8e4b15, + "fp8e4b15x4": language.float8e4b15x4, + "fp16": language.float16, + "bf16": language.bfloat16, + "fp32": language.float32, + "fp64": language.float64, + "i1": language.int1, + "i8": language.int8, + "i16": language.int16, + "i32": language.int32, + "i64": language.int64, + "u8": language.uint8, + "u16": language.uint16, + "u32": language.uint32, + "u64": language.uint64, + "B": language.int1, + } + return tys[name] + + +def make_handle(arg, ty): + if ty.is_ptr(): + return np.array([arg.data_ptr()], dtype=np.uint64) + assert False + + +class Builder: + + def __init__(self) -> None: + pass + + +class Interpreter: + + def _make_wrapper(self, arg): + ty_str = triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)) + ty = str_to_ty(ty_str) + handle = make_handle(arg, ty) + return tl.tensor(handle, ty) + + def _patch_triton_functions(self, fn): + builder = Builder() + for key, obj in fn.__globals__.items(): + if obj is not tl: + continue + for name, member in inspect.getmembers(obj): + if tl.core.is_builtin(member): + new_member = lambda *args, member=member, **kwargs: (member(*args, **kwargs, _builder=builder)) + setattr(obj, name, new_member) + fn.__globals__[key] = obj + return fn + + def __init__(self, fn, grid) -> None: + self.fn = self._patch_triton_functions(fn) + self.grid = grid + + def __call__(self, *args, **kwargs): + args = [self._make_wrapper(arg) for arg in args] + self.fn(*args, **kwargs) + + +class InterpretedFunction: + + def __init__(self, fn) -> None: + self.fn = fn + + def __getitem__(self, grid): + return Interpreter(self.fn, grid) + + def __call__(self, *args, **kwargs): + return Interpreter(self.fn, None)(*args, **kwargs) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index aad0f57a66f0..e4e3da8fe51e 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -268,10 +268,6 @@ def _type_of(key): tys[v] = v return key if isinstance(key, str) else f"*{tys[dtype_str]}" - def _make_signature(self, sig_key): - signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) - return signature - def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) return constants @@ -588,17 +584,17 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - if interpret: - from ..interpreter.interpreter import GridSelector - return GridSelector(fn) - else: - return JITFunction( - fn, - version=version, - do_not_specialize=do_not_specialize, - debug=debug, - noinline=noinline, - ) + # if interpret: + from ..interpreter.new_interpreter import InterpretedFunction + return InterpretedFunction(fn) + # else: + # return JITFunction( + # fn, + # version=version, + # do_not_specialize=do_not_specialize, + # debug=debug, + # noinline=noinline, + # ) if fn is not None: return decorator(fn) From a0b3e6126581410d2062b7fe055fcae41a2ca8d3 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 27 Aug 2023 19:04:00 -0700 Subject: [PATCH 02/41] very very basic POC --- python/src/triton.cc | 2 - python/triton/interpreter/new_interpreter.py | 39 ++++++++++++++++++-- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 9768d6395c9a..9d6f207f7eb3 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1984,7 +1984,6 @@ void init_triton_interpreter(py::module &&m) { m.def("load_ptrs", [](py::array_t ptrs, py::dtype ret_dtype) -> py::array { - py::gil_scoped_release allow_threads; int numel = ptrs.size(); auto shape = std::vector(ptrs.shape(), ptrs.shape() + ptrs.ndim()); @@ -1998,7 +1997,6 @@ void init_triton_interpreter(py::module &&m) { }); m.def("store_ptrs", [](py::array_t ptrs, py::array values) { - py::gil_scoped_release allow_threads; int numel = ptrs.size(); py::array_t reshaped_ptrs = ptrs.reshape({numel}); auto reshaped_values = values.reshape({numel}); diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index a6b80520a78a..a94c0dd70139 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from .._C.libtriton.triton import interpreter as _interpreter # TODO: duplicate @@ -35,6 +36,18 @@ def str_to_ty(name): return tys[name] +def to_numpy_ty(ty): + typemap = { + tl.float16: np.float16, + tl.bfloat16: np.float16, + tl.float32: np.float32, + tl.float64: np.float64, + tl.int8: np.int8, + tl.uint8: np.uint8, + } + return typemap[ty] + + def make_handle(arg, ty): if ty.is_ptr(): return np.array([arg.data_ptr()], dtype=np.uint64) @@ -46,6 +59,15 @@ class Builder: def __init__(self) -> None: pass + def create_load(self, ptr, _0, _1, isVolatile): + return _interpreter.load_ptrs(ptr, np.dtype(np.float32)) + + def create_store(self, ptr, val, _0, _1): + return _interpreter.store_ptrs(ptr, val) + + def create_fadd(self, lhs, rhs): + return lhs + rhs + class Interpreter: @@ -55,15 +77,21 @@ def _make_wrapper(self, arg): handle = make_handle(arg, ty) return tl.tensor(handle, ty) + def patch_member(sef, obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, **kwargs, _builder=builder)) + setattr(obj, name, new_member) + def _patch_triton_functions(self, fn): builder = Builder() for key, obj in fn.__globals__.items(): if obj is not tl: continue + for name, member in inspect.getmembers(getattr(obj, 'tensor')): + if tl.core.is_builtin(member): + self.patch_member(getattr(obj, 'tensor'), name, member, builder) for name, member in inspect.getmembers(obj): if tl.core.is_builtin(member): - new_member = lambda *args, member=member, **kwargs: (member(*args, **kwargs, _builder=builder)) - setattr(obj, name, new_member) + self.patch_member(obj, name, member, builder) fn.__globals__[key] = obj return fn @@ -72,8 +100,11 @@ def __init__(self, fn, grid) -> None: self.grid = grid def __call__(self, *args, **kwargs): - args = [self._make_wrapper(arg) for arg in args] - self.fn(*args, **kwargs) + cpu_args = [arg.cpu() for arg in args] + wrapped_args = [self._make_wrapper(arg) for arg in cpu_args] + self.fn(*wrapped_args, **kwargs) + for arg, new_arg in zip(args, cpu_args): + arg.copy_(new_arg.to(arg.device)) class InterpretedFunction: From 692bf30fbace8aac6bb295027410a30b6091c399 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 27 Aug 2023 20:59:56 -0700 Subject: [PATCH 03/41] cleanup --- python/triton/interpreter/new_interpreter.py | 71 +++++++++++--------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index a94c0dd70139..e602185e6ced 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -48,12 +48,6 @@ def to_numpy_ty(ty): return typemap[ty] -def make_handle(arg, ty): - if ty.is_ptr(): - return np.array([arg.data_ptr()], dtype=np.uint64) - assert False - - class Builder: def __init__(self) -> None: @@ -71,40 +65,53 @@ def create_fadd(self, lhs, rhs): class Interpreter: - def _make_wrapper(self, arg): - ty_str = triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)) - ty = str_to_ty(ty_str) - handle = make_handle(arg, ty) - return tl.tensor(handle, ty) - - def patch_member(sef, obj, name, member, builder): + @staticmethod + def patch_attr(obj, name, member, builder): new_member = lambda *args, member=member, **kwargs: (member(*args, **kwargs, _builder=builder)) setattr(obj, name, new_member) - def _patch_triton_functions(self, fn): + @staticmethod + def _patch_lang_tensor(tensor, builder): + for name, member in inspect.getmembers(tensor): + if tl.core.is_builtin(member): + Interpreter.patch_attr(tensor, name, member, builder) + + @staticmethod + def _patch_lang_core(lang, builder): + for name, member in inspect.getmembers(lang): + if tl.core.is_builtin(member): + Interpreter.patch_attr(lang, name, member, builder) + + @staticmethod + def _patch_lang(fn): builder = Builder() - for key, obj in fn.__globals__.items(): - if obj is not tl: - continue - for name, member in inspect.getmembers(getattr(obj, 'tensor')): - if tl.core.is_builtin(member): - self.patch_member(getattr(obj, 'tensor'), name, member, builder) - for name, member in inspect.getmembers(obj): - if tl.core.is_builtin(member): - self.patch_member(obj, name, member, builder) - fn.__globals__[key] = obj - return fn + lang = [value for key, value in fn.__globals__.items() if value is tl] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + Interpreter._patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + Interpreter._patch_lang_core(lang[0], builder) def __init__(self, fn, grid) -> None: - self.fn = self._patch_triton_functions(fn) self.grid = grid - - def __call__(self, *args, **kwargs): - cpu_args = [arg.cpu() for arg in args] - wrapped_args = [self._make_wrapper(arg) for arg in cpu_args] + self.fn = fn + Interpreter._patch_lang(fn) + + @staticmethod + def _implicit_cvt(arg): + if hasattr(arg, 'data_ptr'): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + return tl.tensor(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return arg + + def __call__(self, *args_dev, **kwargs): + # we need to copy arguments to the host for the interpreter + args_hst = [arg.cpu() for arg in args_dev] + # implicitly convert tensor arguments to their base pointers + wrapped_args = [self._implicit_cvt(arg) for arg in args_hst] + # run function self.fn(*wrapped_args, **kwargs) - for arg, new_arg in zip(args, cpu_args): - arg.copy_(new_arg.to(arg.device)) + # copy arguments back to propagate side-effects + for arg_dev, arg_hst in zip(args_dev, args_hst): + arg_dev.copy_(arg_hst.to(arg_dev.device)) class InterpretedFunction: From 9c9cf0b746a0c4b63e650ed6e43227a7e966dcf4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 27 Aug 2023 22:26:20 -0700 Subject: [PATCH 04/41] more work --- python/triton/interpreter/new_interpreter.py | 343 ++++++++++++++++++- 1 file changed, 327 insertions(+), 16 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index e602185e6ced..054e5e02ef90 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -36,31 +36,341 @@ def str_to_ty(name): return tys[name] -def to_numpy_ty(ty): - typemap = { - tl.float16: np.float16, - tl.bfloat16: np.float16, - tl.float32: np.float32, - tl.float64: np.float64, - tl.int8: np.int8, - tl.uint8: np.uint8, - } - return typemap[ty] +class TensorHandle: + + def __init__(self, data, dtype): + self.data = data + self.dtype = dtype + + +def wrap_ret(compute_ret_ty): + def wrapper(fn): + def wrapped(*args, **kwargs): + ret = fn(*args, **kwargs) + return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs)) + return wrapped + return wrapper class Builder: + def np_dtype(self, tt_dtype): + np_types = { + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + } + return np_types[tt_dtype] + def __init__(self) -> None: pass - def create_load(self, ptr, _0, _1, isVolatile): - return _interpreter.load_ptrs(ptr, np.dtype(np.float32)) + def create_load(self, ptr, _0, _1, volatile): + dtype_tt = ptr.dtype.element_ty + dtype_np = self.np_dtype(dtype_tt) + ret = _interpreter.load_ptrs(ptr.data, dtype_np) + return TensorHandle(ret, dtype_tt) def create_store(self, ptr, val, _0, _1): - return _interpreter.store_ptrs(ptr, val) + return _interpreter.store_ptrs(ptr.data, val.data) def create_fadd(self, lhs, rhs): - return lhs + rhs + # assert lhs.dtype is not tl.bfloat16 + return TensorHandle(lhs.data + rhs.data, lhs.dtype) + + # casting ops + def cast_impl(self, src, dst_type): + return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type) + + def create_fp_to_fp(self, src, dst_type): + pass + + def create_si_to_fp(self, src, dst_type): + return self.cast_impl(src, dst_type) + + def create_ui_to_fp(self, src, dst_type): + pass + + def create_fp_to_si(self, src, dst_type): + pass + + def create_fp_to_ui(self, src, dst_type): + pass + + def create_fp_ext(self, src, dst_type): + pass + + def create_fp_trunc(self, src, dst_type): + pass + + def create_int_cast(self, src, dst_type, is_signed): + pass + + def create_bitcast(self, src, dst_type): + pass + + def create_to_index(self, input): + pass + + def create_index_to_si(self, input): + pass + + def create_fmul(self, lhs, rhs): + pass + + def create_fdiv(self, lhs, rhs): + pass + + def create_frem(self, lhs, rhs): + pass + + def create_fsub(self, lhs, rhs): + pass + + def create_mul(self, lhs, rhs): + pass + + def create_sdiv(self, lhs, rhs): + pass + + def create_udiv(self, lhs, rhs): + pass + + def create_srem(self, lhs, rhs): + pass + + def create_urem(self, lhs, rhs): + pass + + def create_add(self, lhs, rhs): + pass + + def create_sub(self, lhs, rhs): + pass + + def create_shl(self, lhs, rhs): + pass + + def create_lshr(self, lhs, rhs): + pass + + def create_ashr(self, lhs, rhs): + pass + + def create_minsi(self, lhs, rhs): + pass + + def create_minui(self, lhs, rhs): + pass + + def create_minf(self, lhs, rhs): + pass + + def create_maxsi(self, lhs, rhs): + pass + + def create_maxui(self, lhs, rhs): + pass + + def create_maxf(self, lhs, rhs): + pass + + def create_addptr(self, ptr, offset): + pass + + def create_icmpSLE(self, lhs, rhs): + pass + + def create_icmpSLT(self, lhs, rhs): + pass + + def create_icmpSGE(self, lhs, rhs): + pass + + def create_icmpSGT(self, lhs, rhs): + pass + + def create_icmpULE(self, lhs, rhs): + pass + + def create_icmpULT(self, lhs, rhs): + pass + + def create_icmpUGE(self, lhs, rhs): + pass + + def create_icmpUGT(self, lhs, rhs): + pass + + def create_icmpEQ(self, lhs, rhs): + pass + + def create_icmpNE(self, lhs, rhs): + pass + + def create_fcmpOLT(self, lhs, rhs): + pass + + def create_fcmpOGT(self, lhs, rhs): + pass + + def create_fcmpOLE(self, lhs, rhs): + pass + + def create_fcmpOGE(self, lhs, rhs): + pass + + def create_fcmpOEQ(self, lhs, rhs): + pass + + def create_fcmpONE(self, lhs, rhs): + pass + + def create_fcmpULT(self, lhs, rhs): + pass + + def create_fcmpUGT(self, lhs, rhs): + pass + + def create_fcmpULE(self, lhs, rhs): + pass + + def create_fcmpUGE(self, lhs, rhs): + pass + + def create_fcmpUEQ(self, lhs, rhs): + pass + + def create_fcmpUNE(self, lhs, rhs): + pass + + def create_and(self, lhs, rhs): + pass + + def create_xor(self, lhs, rhs): + pass + + def create_or(self, lhs, rhs): + pass + + def create_tensor_pointer_load(self, ptr, boundaryCheck, paddingOption, cacheModifier, evictionPolicy, isVolatile): + pass + + def create_tensor_pointer_store(self, ptr, value, boundaryCheck, cacheModifier, evictionPolicy): + pass + + def create_masked_load(self, ptrs, mask, other, cacheModifier, evictionPolicy, isVolatile): + pass + + def create_masked_store(self, ptrs, value, mask, cacheModifier, evictionPolicy): + pass + + def create_view(self, arg, shape): + pass + + def create_expand_dims(self, arg, axis): + pass + + def create_cat(self, lhs, rhs): + pass + + def create_trans(self, arg): + pass + + def create_broadcast(self, arg, shape): + pass + + def create_splat(self, arg, shape): + pass + + def create_atomic_cas(self, ptr, cmp, val, sem): + pass + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem): + pass + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + pass + + def create_get_program_id(self, axis): + pass + + def create_get_num_programs(self, axis): + pass + + def create_dot(self, a, b, c, allowTF32): + pass + + def create_exp(self, val): + pass + + def create_cos(self, val): + pass + + def create_sin(self, val): + pass + + def create_log(self, val): + pass + + def create_sqrt(self, val): + pass + + def create_fabs(self, val): + pass + + def create_iabs(self, val): + pass + + def create_reduce(self, operands, axis): + pass + + def create_reduce_ret(self, args): + pass + + def create_scan(self, operands, axis): + pass + + def create_scan_ret(self, args): + pass + + def create_ptr_to_int(self, val, type): + pass + + def create_int_to_ptr(self, val, type): + pass + + def create_select(self, condition, trueValue, falseValue): + pass + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + pass + + def create_print(self, prefix, values): + pass + + def create_assert(self, condition, message, fileName, funcName, lineNo): + pass + + def create_undef(self, type): + pass + + def create_barrier(self): + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensorShape, order): + pass + + def create_advance(self, ptr, offsets): + pass class Interpreter: @@ -85,7 +395,7 @@ def _patch_lang_core(lang, builder): @staticmethod def _patch_lang(fn): builder = Builder() - lang = [value for key, value in fn.__globals__.items() if value is tl] + lang = [value for _, value in fn.__globals__.items() if value is tl] assert len(lang) == 1, "triton.language must be visible from within jit'd function" Interpreter._patch_lang_tensor(getattr(lang[0], 'tensor'), builder) Interpreter._patch_lang_core(lang[0], builder) @@ -99,7 +409,8 @@ def __init__(self, fn, grid) -> None: def _implicit_cvt(arg): if hasattr(arg, 'data_ptr'): ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) - return tl.tensor(np.array([arg.data_ptr()], dtype=np.uint64), ty) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) return arg def __call__(self, *args_dev, **kwargs): From 24e60b57461514b7b36c8f7c5eaf016e3cb54a37 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 27 Aug 2023 22:35:59 -0700 Subject: [PATCH 05/41] . --- python/triton/interpreter/new_interpreter.py | 224 +++++-------------- 1 file changed, 60 insertions(+), 164 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 054e5e02ef90..6551c21ddfb4 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -89,178 +89,74 @@ def create_fadd(self, lhs, rhs): # casting ops def cast_impl(self, src, dst_type): return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type) + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) def create_fp_to_fp(self, src, dst_type): - pass - - def create_si_to_fp(self, src, dst_type): - return self.cast_impl(src, dst_type) - - def create_ui_to_fp(self, src, dst_type): - pass - - def create_fp_to_si(self, src, dst_type): - pass - - def create_fp_to_ui(self, src, dst_type): - pass - - def create_fp_ext(self, src, dst_type): - pass - - def create_fp_trunc(self, src, dst_type): - pass - - def create_int_cast(self, src, dst_type, is_signed): - pass + assert "float8 not NotImplemented yet" def create_bitcast(self, src, dst_type): - pass - - def create_to_index(self, input): - pass - - def create_index_to_si(self, input): - pass - - def create_fmul(self, lhs, rhs): - pass - - def create_fdiv(self, lhs, rhs): - pass - - def create_frem(self, lhs, rhs): - pass - - def create_fsub(self, lhs, rhs): - pass - - def create_mul(self, lhs, rhs): - pass - - def create_sdiv(self, lhs, rhs): - pass - - def create_udiv(self, lhs, rhs): - pass - - def create_srem(self, lhs, rhs): - pass - - def create_urem(self, lhs, rhs): - pass - - def create_add(self, lhs, rhs): - pass - - def create_sub(self, lhs, rhs): - pass - - def create_shl(self, lhs, rhs): - pass - - def create_lshr(self, lhs, rhs): - pass - - def create_ashr(self, lhs, rhs): - pass - - def create_minsi(self, lhs, rhs): - pass - - def create_minui(self, lhs, rhs): - pass - - def create_minf(self, lhs, rhs): - pass - - def create_maxsi(self, lhs, rhs): - pass - - def create_maxui(self, lhs, rhs): - pass - - def create_maxf(self, lhs, rhs): - pass + return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + # pointer arithmetic def create_addptr(self, ptr, offset): pass - def create_icmpSLE(self, lhs, rhs): - pass - - def create_icmpSLT(self, lhs, rhs): - pass - - def create_icmpSGE(self, lhs, rhs): - pass - - def create_icmpSGT(self, lhs, rhs): - pass - - def create_icmpULE(self, lhs, rhs): - pass - - def create_icmpULT(self, lhs, rhs): - pass - - def create_icmpUGE(self, lhs, rhs): - pass - - def create_icmpUGT(self, lhs, rhs): - pass - - def create_icmpEQ(self, lhs, rhs): - pass - - def create_icmpNE(self, lhs, rhs): - pass - - def create_fcmpOLT(self, lhs, rhs): - pass - - def create_fcmpOGT(self, lhs, rhs): - pass - - def create_fcmpOLE(self, lhs, rhs): - pass - - def create_fcmpOGE(self, lhs, rhs): - pass - - def create_fcmpOEQ(self, lhs, rhs): - pass - - def create_fcmpONE(self, lhs, rhs): - pass - - def create_fcmpULT(self, lhs, rhs): - pass - - def create_fcmpUGT(self, lhs, rhs): - pass - - def create_fcmpULE(self, lhs, rhs): - pass - - def create_fcmpUGE(self, lhs, rhs): - pass - - def create_fcmpUEQ(self, lhs, rhs): - pass - - def create_fcmpUNE(self, lhs, rhs): - pass - - def create_and(self, lhs, rhs): - pass - - def create_xor(self, lhs, rhs): - pass - - def create_or(self, lhs, rhs): - pass - def create_tensor_pointer_load(self, ptr, boundaryCheck, paddingOption, cacheModifier, evictionPolicy, isVolatile): pass From 2f6fe6562fab247b82554ab0a18dff78fe74837d Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 3 Sep 2023 23:53:05 -0700 Subject: [PATCH 06/41] progress --- python/triton/interpreter/new_interpreter.py | 186 ++++++++++--------- python/triton/language/core.py | 2 +- 2 files changed, 95 insertions(+), 93 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 6551c21ddfb4..4283e169d4bb 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -55,6 +55,8 @@ def wrapped(*args, **kwargs): class Builder: def np_dtype(self, tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) np_types = { tl.float16: np.dtype(np.float16), tl.float32: np.dtype(np.float32), @@ -73,6 +75,18 @@ def np_dtype(self, tt_dtype): def __init__(self) -> None: pass + # constants + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + # programming model + def create_get_program_id(self, axis): + pass + + def create_get_num_programs(self, axis): + pass + + # memory ops def create_load(self, ptr, _0, _1, volatile): dtype_tt = ptr.dtype.element_ty dtype_np = self.np_dtype(dtype_tt) @@ -82,10 +96,6 @@ def create_load(self, ptr, _0, _1, volatile): def create_store(self, ptr, val, _0, _1): return _interpreter.store_ptrs(ptr.data, val.data) - def create_fadd(self, lhs, rhs): - # assert lhs.dtype is not tl.bfloat16 - return TensorHandle(lhs.data + rhs.data, lhs.dtype) - # casting ops def cast_impl(self, src, dst_type): return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type) @@ -106,6 +116,8 @@ def create_bitcast(self, src, dst_type): # binary operators def binary_op(self, lhs, rhs, op): return TensorHandle(op(lhs.data, rhs.data), lhs.dtype) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) @@ -152,128 +164,117 @@ def binary_op(self, lhs, rhs, op): create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_fabs = lambda self, arg: self.unary_op(arg, np.abs) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + + # tensor operators + create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot) + create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype) + create_trans = lambda self, arg: self.unary_op(arg, np.transpose) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + # pointer arithmetic def create_addptr(self, ptr, offset): - pass + dtype_tt = ptr.dtype.element_ty + return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data, ptr.dtype) - def create_tensor_pointer_load(self, ptr, boundaryCheck, paddingOption, cacheModifier, evictionPolicy, isVolatile): - pass + # def create_tensor_pointer_load(self, ptr, boundaryCheck, paddingOption, cacheModifier, evictionPolicy, isVolatile): + # pass - def create_tensor_pointer_store(self, ptr, value, boundaryCheck, cacheModifier, evictionPolicy): - pass + # def create_tensor_pointer_store(self, ptr, value, boundaryCheck, cacheModifier, evictionPolicy): + # pass - def create_masked_load(self, ptrs, mask, other, cacheModifier, evictionPolicy, isVolatile): - pass + # def create_masked_load(self, ptrs, mask, other, cacheModifier, evictionPolicy, isVolatile): + # pass - def create_masked_store(self, ptrs, value, mask, cacheModifier, evictionPolicy): - pass - - def create_view(self, arg, shape): - pass + # def create_masked_store(self, ptrs, value, mask, cacheModifier, evictionPolicy): + # pass def create_expand_dims(self, arg, axis): - pass - - def create_cat(self, lhs, rhs): - pass - - def create_trans(self, arg): - pass + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype) def create_broadcast(self, arg, shape): - pass - - def create_splat(self, arg, shape): - pass - - def create_atomic_cas(self, ptr, cmp, val, sem): - pass + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype) - def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem): - pass + # def create_cat(self, lhs, rhs): + # pass - def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): - pass + # def create_broadcast(self, arg, shape): + # pass - def create_get_program_id(self, axis): - pass - - def create_get_num_programs(self, axis): - pass - - def create_dot(self, a, b, c, allowTF32): - pass - - def create_exp(self, val): - pass - - def create_cos(self, val): - pass - - def create_sin(self, val): - pass - - def create_log(self, val): - pass - - def create_sqrt(self, val): - pass + def create_splat(self, arg, shape): + return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype) - def create_fabs(self, val): - pass + # def create_atomic_cas(self, ptr, cmp, val, sem): + # pass - def create_iabs(self, val): - pass + # def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem): + # pass - def create_reduce(self, operands, axis): - pass + # def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + # pass - def create_reduce_ret(self, args): - pass + # def create_reduce(self, operands, axis): + # pass - def create_scan(self, operands, axis): - pass + # def create_reduce_ret(self, args): + # pass - def create_scan_ret(self, args): - pass + # def create_scan(self, operands, axis): + # pass - def create_ptr_to_int(self, val, type): - pass + # def create_scan_ret(self, args): + # pass - def create_int_to_ptr(self, val, type): - pass + # def create_ptr_to_int(self, val, type): + # pass - def create_select(self, condition, trueValue, falseValue): - pass + # def create_int_to_ptr(self, val, type): + # pass - def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): - pass + # def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + # pass - def create_print(self, prefix, values): - pass + # def create_print(self, prefix, values): + # pass - def create_assert(self, condition, message, fileName, funcName, lineNo): - pass + # def create_assert(self, condition, message, fileName, funcName, lineNo): + # pass - def create_undef(self, type): - pass + # def create_undef(self, type): + # pass - def create_barrier(self): - pass + # def create_barrier(self): + # pass - def create_make_block_ptr(self, base, shape, strides, offsets, tensorShape, order): - pass + # def create_make_block_ptr(self, base, shape, strides, offsets, tensorShape, order): + # pass - def create_advance(self, ptr, offsets): - pass + # def create_advance(self, ptr, offsets): + # pass class Interpreter: @staticmethod def patch_attr(obj, name, member, builder): - new_member = lambda *args, member=member, **kwargs: (member(*args, **kwargs, _builder=builder)) + new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) setattr(obj, name, new_member) @staticmethod @@ -311,14 +312,15 @@ def _implicit_cvt(arg): def __call__(self, *args_dev, **kwargs): # we need to copy arguments to the host for the interpreter - args_hst = [arg.cpu() for arg in args_dev] + args_hst = [arg.cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] # implicitly convert tensor arguments to their base pointers wrapped_args = [self._implicit_cvt(arg) for arg in args_hst] # run function self.fn(*wrapped_args, **kwargs) # copy arguments back to propagate side-effects for arg_dev, arg_hst in zip(args_dev, args_hst): - arg_dev.copy_(arg_hst.to(arg_dev.device)) + if hasattr(arg_dev, 'data_ptr'): + arg_dev.copy_(arg_hst.to(arg_dev.device)) class InterpretedFunction: diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3caf8ee3fea5..28317f717446 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -747,7 +747,7 @@ def __getitem__(self, slices, _builder=None): slices = [slices] ret = self for dim, sl in enumerate(slices): - if isinstance(sl, constexpr) and sl.value is None: + if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: pass From c405f28b98cb4342b8ea02e2f0131b4595e8c53d Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 4 Sep 2023 11:02:10 -0700 Subject: [PATCH 07/41] more cleaning --- python/triton/interpreter/new_interpreter.py | 87 +++++++++++++------- 1 file changed, 55 insertions(+), 32 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 4283e169d4bb..35ccf701dfe0 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -54,6 +54,17 @@ def wrapped(*args, **kwargs): class Builder: + def __init__(self, grid_dim) -> None: + assert len(grid_dim) == 3 + self.grid_idx = None + self.grid_dim = grid_dim + + def set_grid_idx(self, x, y, z): + assert x < self.grid_dim[0] + assert y < self.grid_dim[1] + assert z < self.grid_dim[2] + self.grid_idx = (x, y, z) + def np_dtype(self, tt_dtype): if isinstance(tt_dtype, tl.pointer_type): return np.dtype(np.uint64) @@ -72,19 +83,17 @@ def np_dtype(self, tt_dtype): } return np_types[tt_dtype] - def __init__(self) -> None: - pass - # constants def get_int32(self, value): return TensorHandle(np.array([value], dtype=np.int32), tl.int32) # programming model def create_get_program_id(self, axis): - pass + assert self.grid_idx is not None + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) def create_get_num_programs(self, axis): - pass + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) # memory ops def create_load(self, ptr, _0, _1, volatile): @@ -192,7 +201,7 @@ def create_make_range(self, start, stop): def create_addptr(self, ptr, offset): dtype_tt = ptr.dtype.element_ty - return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data, ptr.dtype) + return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) # def create_tensor_pointer_load(self, ptr, boundaryCheck, paddingOption, cacheModifier, evictionPolicy, isVolatile): # pass @@ -270,37 +279,28 @@ def create_splat(self, arg, shape): # pass -class Interpreter: +def patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) + setattr(new_member, '__triton_builtin__', True) + setattr(obj, name, new_member) - @staticmethod - def patch_attr(obj, name, member, builder): - new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) - setattr(obj, name, new_member) - @staticmethod - def _patch_lang_tensor(tensor, builder): - for name, member in inspect.getmembers(tensor): - if tl.core.is_builtin(member): - Interpreter.patch_attr(tensor, name, member, builder) +def _patch_lang_tensor(tensor, builder): + for name, member in inspect.getmembers(tensor): + if tl.core.is_builtin(member): + patch_attr(tensor, name, member, builder) - @staticmethod - def _patch_lang_core(lang, builder): - for name, member in inspect.getmembers(lang): - if tl.core.is_builtin(member): - Interpreter.patch_attr(lang, name, member, builder) - @staticmethod - def _patch_lang(fn): - builder = Builder() - lang = [value for _, value in fn.__globals__.items() if value is tl] - assert len(lang) == 1, "triton.language must be visible from within jit'd function" - Interpreter._patch_lang_tensor(getattr(lang[0], 'tensor'), builder) - Interpreter._patch_lang_core(lang[0], builder) +def _patch_lang_core(lang, builder): + for name, member in inspect.getmembers(lang): + if tl.core.is_builtin(member): + patch_attr(lang, name, member, builder) - def __init__(self, fn, grid) -> None: - self.grid = grid + +class Interpreter: + + def __init__(self, fn) -> None: self.fn = fn - Interpreter._patch_lang(fn) @staticmethod def _implicit_cvt(arg): @@ -323,13 +323,36 @@ def __call__(self, *args_dev, **kwargs): arg_dev.copy_(arg_hst.to(arg_dev.device)) +class GridExecutor: + + def __init__(self, fn, grid): + assert len(grid) <= 3 + self.fn = fn + self.grid = tuple(grid) + (1,) * (3 - len(grid)) + + def _patch_lang(self, builder): + lang = [value for _, value in self.fn.__globals__.items() if value is tl] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_core(lang[0], builder) + + def __call__(self, *args, **kwargs): + builder = Builder(self.grid) + self._patch_lang(builder) + for x in range(self.grid[0]): + for y in range(self.grid[1]): + for z in range(self.grid[2]): + builder.set_grid_idx(x, y, z) + Interpreter(self.fn)(*args, **kwargs) + + class InterpretedFunction: def __init__(self, fn) -> None: self.fn = fn def __getitem__(self, grid): - return Interpreter(self.fn, grid) + return GridExecutor(self.fn, grid) def __call__(self, *args, **kwargs): return Interpreter(self.fn, None)(*args, **kwargs) From e4257a1fb9314552e086ca7249b875fbdcc6d4c7 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 4 Sep 2023 11:06:24 -0700 Subject: [PATCH 08/41] more cleaning --- python/triton/interpreter/new_interpreter.py | 46 ++++++++------------ 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 35ccf701dfe0..1116cb211bf7 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -297,30 +297,12 @@ def _patch_lang_core(lang, builder): patch_attr(lang, name, member, builder) -class Interpreter: - - def __init__(self, fn) -> None: - self.fn = fn - - @staticmethod - def _implicit_cvt(arg): - if hasattr(arg, 'data_ptr'): - ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) - handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) - return tl.tensor(handle, ty) - return arg - - def __call__(self, *args_dev, **kwargs): - # we need to copy arguments to the host for the interpreter - args_hst = [arg.cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] - # implicitly convert tensor arguments to their base pointers - wrapped_args = [self._implicit_cvt(arg) for arg in args_hst] - # run function - self.fn(*wrapped_args, **kwargs) - # copy arguments back to propagate side-effects - for arg_dev, arg_hst in zip(args_dev, args_hst): - if hasattr(arg_dev, 'data_ptr'): - arg_dev.copy_(arg_hst.to(arg_dev.device)) +def _implicit_cvt(arg): + if hasattr(arg, 'data_ptr'): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg class GridExecutor: @@ -336,14 +318,24 @@ def _patch_lang(self, builder): _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) _patch_lang_core(lang[0], builder) - def __call__(self, *args, **kwargs): + def __call__(self, *args_dev, **kwargs): builder = Builder(self.grid) + # remaps core language functions to interpreted ones self._patch_lang(builder) + # we need to copy arguments to the host for the interpreter + args_hst = [arg.cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] + # implicitly convert tensor arguments to their base pointers + wrapped_args = [_implicit_cvt(arg) for arg in args_hst] + # iterate through grid for x in range(self.grid[0]): for y in range(self.grid[1]): for z in range(self.grid[2]): builder.set_grid_idx(x, y, z) - Interpreter(self.fn)(*args, **kwargs) + self.fn(*wrapped_args, **kwargs) + # copy arguments back to propagate side-effects + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, 'data_ptr'): + arg_dev.copy_(arg_hst.to(arg_dev.device)) class InterpretedFunction: @@ -355,4 +347,4 @@ def __getitem__(self, grid): return GridExecutor(self.fn, grid) def __call__(self, *args, **kwargs): - return Interpreter(self.fn, None)(*args, **kwargs) + return self.fn(*args, **kwargs) From 7457e0095bf46998975f6faeba86f3553cb2b5dc Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 4 Sep 2023 11:14:31 -0700 Subject: [PATCH 09/41] remove interpret flag --- python/triton/runtime/jit.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index e4e3da8fe51e..e5ebeb0a4c9c 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -562,7 +562,6 @@ def jit( do_not_specialize: Optional[Iterable[int]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, - interpret: Optional[bool] = None, ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. @@ -584,17 +583,17 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - # if interpret: - from ..interpreter.new_interpreter import InterpretedFunction - return InterpretedFunction(fn) - # else: - # return JITFunction( - # fn, - # version=version, - # do_not_specialize=do_not_specialize, - # debug=debug, - # noinline=noinline, - # ) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from ..interpreter.new_interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + ) if fn is not None: return decorator(fn) From b7cf36fb897872970716b7f94490bbfaf501505f Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 4 Sep 2023 15:15:25 -0700 Subject: [PATCH 10/41] . --- python/test/unit/language/test_core.py | 1 - python/triton/interpreter/new_interpreter.py | 80 +++++++++++++++----- python/triton/language/core.py | 9 +++ python/triton/language/semantic.py | 2 + python/triton/runtime/jit.py | 22 +++--- python/tutorials/03-matrix-multiplication.py | 49 ++++++------ 6 files changed, 112 insertions(+), 51 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index da329be6267b..570cadce37cf 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2295,7 +2295,6 @@ def kernel(X, stride_xm, stride_xk, if epilogue == 'chain-dot': z_ref = np.matmul(z_ref, w) # compare - # print(z_ref[:,0], z_tri[:,0]) if in_dtype == 'float32': # XXX: Somehow there's a larger difference when we use float32 np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 1116cb211bf7..c7ec2a235b24 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -42,6 +42,9 @@ def __init__(self, data, dtype): self.data = data self.dtype = dtype + def __bool__(self): + return bool(self.data.all()) + def wrap_ret(compute_ret_ty): def wrapper(fn): @@ -54,10 +57,9 @@ def wrapped(*args, **kwargs): class Builder: - def __init__(self, grid_dim) -> None: - assert len(grid_dim) == 3 - self.grid_idx = None - self.grid_dim = grid_dim + def __init__(self) -> None: + self.arch = None + # pass def set_grid_idx(self, x, y, z): assert x < self.grid_dim[0] @@ -65,6 +67,9 @@ def set_grid_idx(self, x, y, z): assert z < self.grid_dim[2] self.grid_idx = (x, y, z) + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + def np_dtype(self, tt_dtype): if isinstance(tt_dtype, tl.pointer_type): return np.dtype(np.uint64) @@ -84,9 +89,27 @@ def np_dtype(self, tt_dtype): return np_types[tt_dtype] # constants + def get_half_ty(self): + return tl.float16 + + def get_float_ty(self): + return tl.float32 + + def get_block_ty(self, dtype, shape): + return tl.tensor(shape, dtype) + def get_int32(self, value): return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type) + # programming model def create_get_program_id(self, axis): assert self.grid_idx is not None @@ -105,9 +128,18 @@ def create_load(self, ptr, _0, _1, volatile): def create_store(self, ptr, val, _0, _1): return _interpreter.store_ptrs(ptr.data, val.data) + def create_masked_load(self, ptrs, mask, other, cacheModifier, evictionPolicy, isVolatile): + return self.create_load(ptrs, None, None, None) + + def create_masked_store(self, ptrs, value, mask, cacheModifier, evictionPolicy): + return self.create_store(ptrs, value, None, None) + # casting ops def cast_impl(self, src, dst_type): + if isinstance(dst_type, tl.tensor): + dst_type = dst_type.dtype return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type) + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) @@ -132,8 +164,8 @@ def binary_op(self, lhs, rhs, op): create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) - create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) - create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide) + create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide) create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) @@ -194,6 +226,9 @@ def unary_op(self, arg, op): create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype) create_trans = lambda self, arg: self.unary_op(arg, np.transpose) + def create_dot(self, a, b, d, allow_tf32): + return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype) + def create_make_range(self, start, stop): return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) @@ -209,12 +244,6 @@ def create_addptr(self, ptr, offset): # def create_tensor_pointer_store(self, ptr, value, boundaryCheck, cacheModifier, evictionPolicy): # pass - # def create_masked_load(self, ptrs, mask, other, cacheModifier, evictionPolicy, isVolatile): - # pass - - # def create_masked_store(self, ptrs, value, mask, cacheModifier, evictionPolicy): - # pass - def create_expand_dims(self, arg, axis): return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype) @@ -281,7 +310,6 @@ def create_splat(self, arg, shape): def patch_attr(obj, name, member, builder): new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) - setattr(new_member, '__triton_builtin__', True) setattr(obj, name, new_member) @@ -305,6 +333,15 @@ def _implicit_cvt(arg): return arg +def _unwrap(tensor): + if isinstance(tensor, triton.TensorWrapper): + return tensor.base + return tensor + + +builder = Builder() + + class GridExecutor: def __init__(self, fn, grid): @@ -313,20 +350,22 @@ def __init__(self, fn, grid): self.grid = tuple(grid) + (1,) * (3 - len(grid)) def _patch_lang(self, builder): - lang = [value for _, value in self.fn.__globals__.items() if value is tl] + lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] assert len(lang) == 1, "triton.language must be visible from within jit'd function" _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) _patch_lang_core(lang[0], builder) def __call__(self, *args_dev, **kwargs): - builder = Builder(self.grid) + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in ['num_warps', 'num_stages', 'num_ctas']} # remaps core language functions to interpreted ones self._patch_lang(builder) # we need to copy arguments to the host for the interpreter - args_hst = [arg.cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] + args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] # implicitly convert tensor arguments to their base pointers wrapped_args = [_implicit_cvt(arg) for arg in args_hst] # iterate through grid + builder.set_grid_dim(*self.grid) for x in range(self.grid[0]): for y in range(self.grid[1]): for z in range(self.grid[2]): @@ -335,11 +374,17 @@ def __call__(self, *args_dev, **kwargs): # copy arguments back to propagate side-effects for arg_dev, arg_hst in zip(args_dev, args_hst): if hasattr(arg_dev, 'data_ptr'): - arg_dev.copy_(arg_hst.to(arg_dev.device)) + _unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device)) class InterpretedFunction: + def _patch_lang(self, builder): + lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_core(lang[0], builder) + def __init__(self, fn) -> None: self.fn = fn @@ -347,4 +392,5 @@ def __getitem__(self, grid): return GridExecutor(self.fn, grid) def __call__(self, *args, **kwargs): + self._patch_lang(builder) return self.fn(*args, **kwargs) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 28317f717446..d3c0faa0a4e0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -552,6 +552,7 @@ def __add__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.add(self, other, _builder) + @builtin def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @@ -560,6 +561,7 @@ def __sub__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.sub(self, other, _builder) + @builtin def __rsub__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.sub(other, self, _builder) @@ -569,6 +571,7 @@ def __mul__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.mul(self, other, _builder) + @builtin def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @@ -577,6 +580,7 @@ def __truediv__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.truediv(self, other, _builder) + @builtin def __rtruediv__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.truediv(other, self, _builder) @@ -741,6 +745,9 @@ def logical_or(self, other, _builder=None): def __not__(self, _builder=None): return semantic.not_(self, _builder) + def __bool__(self): + return bool(self.handle) + @builtin def __getitem__(self, slices, _builder=None): if isinstance(slices, slice): @@ -832,6 +839,8 @@ def arange(start, end, _builder=None): def _shape_check_impl(shape): shape = _constexpr_to_value(shape) for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) if not isinstance(d, constexpr): raise TypeError(f"Shape element {i} must have type `constexpr`") if not isinstance(d.value, int): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 8984cb4da4f8..649fab151053 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1504,6 +1504,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) if isinstance(elem, tl.constexpr): return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value) elif isinstance(elem, tl.tensor): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index e5ebeb0a4c9c..4bf32d8426fa 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -583,17 +583,17 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - if os.getenv("TRITON_INTERPRET", "0") == "1": - from ..interpreter.new_interpreter import InterpretedFunction - return InterpretedFunction(fn) - else: - return JITFunction( - fn, - version=version, - do_not_specialize=do_not_specialize, - debug=debug, - noinline=noinline, - ) + # if os.getenv("TRITON_INTERPRET", "0") == "1": + from ..interpreter.new_interpreter import InterpretedFunction + return InterpretedFunction(fn) + # else: + # return JITFunction( + # fn, + # version=version, + # do_not_specialize=do_not_specialize, + # debug=debug, + # noinline=noinline, + # ) if fn is not None: return decorator(fn) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 8bcae2007abd..e97eca457ed9 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -161,19 +161,19 @@ # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try # - An auto-tuning *key* whose change in values will trigger evaluation of all the # provided configs -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - ], - key=['M', 'N', 'K'], -) +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), +# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), +# ], +# key=['M', 'N', 'K'], +# ) @triton.jit def matmul_kernel( # Pointers to matrices @@ -233,7 +233,7 @@ def matmul_kernel( a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b) + accumulator += tl.dot(a, b, allow_tf32=False) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -241,7 +241,7 @@ def matmul_kernel( # while the accumulator is still in FP32! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float16) + c = accumulator.to(tl.float32) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. @@ -274,16 +274,20 @@ def matmul(a, b, activation=""): # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul_kernel[grid]( + # grid = lambda META: ( + # triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + # ) + matmul_kernel[(M // 128 * N // 128,)]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - ACTIVATION=activation + ACTIVATION=activation, + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=128, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, ) return c @@ -295,8 +299,8 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((512, 512), device='cuda', dtype=torch.float16) -b = torch.randn((512, 512), device='cuda', dtype=torch.float16) +a = torch.randn((256, 64), device='cuda', dtype=torch.float32) +b = torch.randn((64, 256), device='cuda', dtype=torch.float32) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") @@ -305,6 +309,7 @@ def matmul(a, b, activation=""): print("✅ Triton and Torch match") else: print("❌ Triton and Torch differ") +exit(1) # %% # Benchmark From 7595ad4ec2493102f5cb0b2eb6678ef40f522f66 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Mon, 4 Sep 2023 17:04:02 -0700 Subject: [PATCH 11/41] . --- python/src/triton.cc | 25 +++++++++++++------- python/triton/interpreter/new_interpreter.py | 25 ++++++++++++-------- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 9d6f207f7eb3..2acd86f6cc41 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1982,27 +1982,36 @@ void init_triton_translation(py::module &m) { void init_triton_interpreter(py::module &&m) { using ret = py::return_value_policy; - m.def("load_ptrs", - [](py::array_t ptrs, py::dtype ret_dtype) -> py::array { + m.def("load", + [](py::array_t ptrs, py::array_t masks, py::array other, + py::dtype ret_dtype) -> py::array { int numel = ptrs.size(); auto shape = std::vector(ptrs.shape(), ptrs.shape() + ptrs.ndim()); py::array ret(ret_dtype, py::array::ShapeContainer{numel}); py::array_t reshaped_ptrs = ptrs.reshape({numel}); + py::array_t reshaped_masks = masks.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); for (size_t i = 0; i < ptrs.size(); ++i) - memcpy(ret.mutable_data(i), - reinterpret_cast(reshaped_ptrs.at(i)), - ret_dtype.itemsize()); + if (reshaped_masks.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptrs.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), other.data(i), ret_dtype.itemsize()); return ret.reshape(shape); }); - m.def("store_ptrs", [](py::array_t ptrs, py::array values) { + m.def("store", [](py::array_t ptrs, py::array values, + py::array_t mask) { int numel = ptrs.size(); py::array_t reshaped_ptrs = ptrs.reshape({numel}); + py::array_t reshaped_masks = mask.reshape({numel}); auto reshaped_values = values.reshape({numel}); for (size_t i = 0; i < ptrs.size(); ++i) { - memcpy(reinterpret_cast(reshaped_ptrs.at(i)), - reshaped_values.mutable_data(i), values.dtype().itemsize()); + if (reshaped_masks.at(i)) + memcpy(reinterpret_cast(reshaped_ptrs.at(i)), + reshaped_values.mutable_data(i), values.dtype().itemsize()); } }); } diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index c7ec2a235b24..190856b141d2 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -119,20 +119,25 @@ def create_get_num_programs(self, axis): return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) # memory ops - def create_load(self, ptr, _0, _1, volatile): - dtype_tt = ptr.dtype.element_ty - dtype_np = self.np_dtype(dtype_tt) - ret = _interpreter.load_ptrs(ptr.data, dtype_np) - return TensorHandle(ret, dtype_tt) + def create_load(self, ptr, _0, _1, is_volatile): + mask = np.ones_like(ptr.data, dtype=np.bool) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) def create_store(self, ptr, val, _0, _1): - return _interpreter.store_ptrs(ptr.data, val.data) + mask = np.ones_like(ptr.data, dtype=np.bool) + return self.create_masked_store(ptr, val, mask, None, None) - def create_masked_load(self, ptrs, mask, other, cacheModifier, evictionPolicy, isVolatile): - return self.create_load(ptrs, None, None, None) + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.dtype.element_ty + dtype_np = self.np_dtype(dtype_tt) + if other is None: + other = np.ones_like(ptrs.data, dtype=dtype_np) + ret = _interpreter.load(ptrs.data, mask.data, other, dtype_np) + return TensorHandle(ret, dtype_tt) - def create_masked_store(self, ptrs, value, mask, cacheModifier, evictionPolicy): - return self.create_store(ptrs, value, None, None) + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) # casting ops def cast_impl(self, src, dst_type): From d1636a8e49004b664dc2e29609c2d92948eecf1e Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 10 Sep 2023 20:33:44 -0700 Subject: [PATCH 12/41] flash attention runs but produces incorrect result --- python/src/triton.cc | 6 +- python/triton/interpreter/new_interpreter.py | 115 +++++++++++++++++-- python/triton/language/standard.py | 2 +- python/tutorials/03-matrix-multiplication.py | 28 ++--- python/tutorials/06-fused-attention.py | 34 +++--- 5 files changed, 141 insertions(+), 44 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 2acd86f6cc41..14659e507746 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1992,13 +1992,15 @@ void init_triton_interpreter(py::module &&m) { py::array_t reshaped_ptrs = ptrs.reshape({numel}); py::array_t reshaped_masks = masks.reshape({numel}); py::array reshaped_others = other.reshape({numel}); - for (size_t i = 0; i < ptrs.size(); ++i) + for (size_t i = 0; i < ptrs.size(); ++i) { if (reshaped_masks.at(i)) memcpy(ret.mutable_data(i), reinterpret_cast(reshaped_ptrs.at(i)), ret_dtype.itemsize()); else - memcpy(ret.mutable_data(i), other.data(i), ret_dtype.itemsize()); + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } return ret.reshape(shape); }); diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 190856b141d2..02bf87942e84 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -46,6 +46,17 @@ def __bool__(self): return bool(self.data.all()) +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def wrap_ret(compute_ret_ty): def wrapper(fn): def wrapped(*args, **kwargs): @@ -104,6 +115,9 @@ def get_int32(self, value): def get_int64(self, value): return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + def get_fp32(self, value): return TensorHandle(np.array([value], dtype=np.float32), tl.float32) @@ -120,12 +134,12 @@ def create_get_num_programs(self, axis): # memory ops def create_load(self, ptr, _0, _1, is_volatile): - mask = np.ones_like(ptr.data, dtype=np.bool) + mask = np.ones_like(ptr.data, dtype=bool) other = None return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) def create_store(self, ptr, val, _0, _1): - mask = np.ones_like(ptr.data, dtype=np.bool) + mask = np.ones_like(ptr.data, dtype=bool) return self.create_masked_store(ptr, val, mask, None, None) def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): @@ -243,11 +257,33 @@ def create_addptr(self, ptr, offset): dtype_tt = ptr.dtype.element_ty return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) - # def create_tensor_pointer_load(self, ptr, boundaryCheck, paddingOption, cacheModifier, evictionPolicy, isVolatile): - # pass - - # def create_tensor_pointer_store(self, ptr, value, boundaryCheck, cacheModifier, evictionPolicy): - # pass + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): + ptrs = ptr.base.data + shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] + masks = np.ones(shapes, dtype=bool) + # padding_value = {None: 0., "zero": 0., "nan": float('nan')}[padding_option] + for dim in range(len(shapes)): + bcast_dims = [1] * len(shapes) + bcast_dims[dim] = shapes[dim] + off = (ptr.offsets[dim].data + np.arange(shapes[dim])).reshape(bcast_dims) + ptrs = ptrs + off * ptr.strides[dim].data + masks = np.logical_and(masks, off < ptr.shape[dim].data) + # other = np.full(shapes, padding_value, dtype=self.np_dtype(ptr.base.dtype.element_ty)) + ptrs = TensorHandle(ptrs, ptr.base.dtype) + return self.create_masked_load(ptrs, masks, None, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs = ptr.base.data + shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] + masks = np.ones(shapes, dtype=bool) + for dim in range(len(shapes)): + bcast_dims = [1] * len(shapes) + bcast_dims[dim] = shapes[dim] + off = (ptr.offsets[dim].data + np.arange(shapes[dim])).reshape(bcast_dims) + ptrs = ptrs + off * ptr.strides[dim].data + masks = np.logical_and(masks, off < ptr.shape[dim].data) + ptrs = TensorHandle(ptrs, ptr.base.dtype) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) def create_expand_dims(self, arg, axis): return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype) @@ -306,11 +342,15 @@ def create_splat(self, arg, shape): # def create_barrier(self): # pass - # def create_make_block_ptr(self, base, shape, strides, offsets, tensorShape, order): - # pass + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order) - # def create_advance(self, ptr, offsets): - # pass + def create_advance(self, ptr, offsets): + assert len(ptr.offsets) == len(offsets) + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret def patch_attr(obj, name, member, builder): @@ -328,6 +368,58 @@ def _patch_lang_core(lang, builder): for name, member in inspect.getmembers(lang): if tl.core.is_builtin(member): patch_attr(lang, name, member, builder) + # reduce is better off with a separate patch due to how + # the builder currently interfaces with custom functions + + def _new_reduce(input, axis, combine_fn): + fn = combine_fn.fn.__name__ + mapping = { + 'maximum': np.max, + '_sum_combine': np.sum, + } + ret = mapping[fn](input.handle.data, axis=axis) + ret_type = tl.block_type(input.dtype, ret.shape) + return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type) + + lang.reduce = _new_reduce + + +def _patch_lang_math(lang, builder): + math = lang.math + mapping = { + 'abs': 'abs', + 'acos': 'arccos', + 'asin': 'arcsin', + 'exp2': 'exp2', + 'log2': 'log2', + 'max': 'maximum', + } + + def make_numpy(name): + def impl(*args, **kwargs): + ret_type = args[0].type # TODO: incorrect + ret_dtype = args[0].dtype # TODO: incorrect + args = [arg.handle.data for arg in args] + kwargs = {k: v.handle.data for k, v in kwargs.items()} + ret = getattr(np, mapping[name])(*args, **kwargs) + ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type) + return ret + return impl + + def make_fallback(name): + def fallback(*args, **kwargs): + raise NotImplementedError(f""" +{name} not supported in interpreter mode: no known numpy implementation. +If you think that {name} in fact does have a numpy implementation, please add it +to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math. +""") + return fallback + + for name, member in inspect.getmembers(math): + if name in mapping: + setattr(math, name, make_numpy(name)) + else: + setattr(math, name, make_fallback(name)) def _implicit_cvt(arg): @@ -359,6 +451,7 @@ def _patch_lang(self, builder): assert len(lang) == 1, "triton.language must be visible from within jit'd function" _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) _patch_lang_core(lang[0], builder) + _patch_lang_math(lang[0], builder) def __call__(self, *args_dev, **kwargs): # removes reserved keywords from kwargs diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 8acc4261585f..8ef52cb9cfd6 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr else: return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast) else: - if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): if core.constexpr(input.dtype.is_floating()): input = input.to(core.float32) else: diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index e97eca457ed9..3c3de43c5162 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -161,19 +161,19 @@ # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try # - An auto-tuning *key* whose change in values will trigger evaluation of all the # provided configs -# @triton.autotune( -# configs=[ -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), -# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), -# ], -# key=['M', 'N', 'K'], -# ) +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + ], + key=['M', 'N', 'K'], +) @triton.jit def matmul_kernel( # Pointers to matrices @@ -233,7 +233,7 @@ def matmul_kernel( a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b, allow_tf32=False) + accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 71aa9651b7a8..73f17cd6fd7b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -11,7 +11,7 @@ """ -import pytest +# import pytest import torch import triton @@ -96,6 +96,7 @@ def _fwd_kernel( alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) # -- scale and update acc -- + # breakpoint() acc_scale = l_i * 0 + alpha # workaround some compiler bug acc *= acc_scale[:, None] acc += tl.dot(p.to(tl.float16), v) @@ -306,15 +307,15 @@ def backward(ctx, do): attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)]) -@pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', [(6, 9, 1024, 64, 128)]) +# @pytest.mark.parametrize('causal', [False, True]) def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() sm_scale = 0.5 - dout = torch.randn_like(q) + # dout = torch.randn_like(q) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale @@ -323,21 +324,21 @@ def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): p = torch.softmax(p.float(), dim=-1).half() # p = torch.exp(p) ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None + # ref_out.backward(dout) + # ref_dv, v.grad = v.grad.clone(), None + # ref_dk, k.grad = k.grad.clone(), None + # ref_dq, q.grad = q.grad.clone(), None # triton implementation tri_out = attention(q, k, v, causal, sm_scale).half() - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None + # tri_out.backward(dout) + # tri_dv, v.grad = v.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dq, q.grad = q.grad.clone(), None # compare assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) - assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) - assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) - assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) + # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) + # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) + # assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) try: @@ -410,4 +411,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype # only works on post-Ampere GPUs right now -bench_flash_attention.run(save_path='.', print_data=True) +# bench_flash_attention.run(save_path='.', print_data=True) +test_op(1, 1, 128, 64, 128, False, torch.float16) From 145d70f4b3d8e326f43abd1d1d32dddb92f1399e Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 10 Sep 2023 20:51:48 -0700 Subject: [PATCH 13/41] bugfix --- python/triton/interpreter/new_interpreter.py | 4 ++-- python/tutorials/06-fused-attention.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 02bf87942e84..fc42d6bf131b 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -258,8 +258,8 @@ def create_addptr(self, ptr, offset): return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): - ptrs = ptr.base.data shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] + ptrs = np.broadcast_to(ptr.base.data, shapes) masks = np.ones(shapes, dtype=bool) # padding_value = {None: 0., "zero": 0., "nan": float('nan')}[padding_option] for dim in range(len(shapes)): @@ -273,8 +273,8 @@ def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_ return self.create_masked_load(ptrs, masks, None, cache_modifier, eviction_policy, is_volatile) def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): - ptrs = ptr.base.data shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] + ptrs = np.broadcast_to(ptr.base.data, shapes) masks = np.ones(shapes, dtype=bool) for dim in range(len(shapes)): bcast_dims = [1] * len(shapes) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 73f17cd6fd7b..cfd4488624b4 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -311,9 +311,9 @@ def backward(ctx, do): # @pytest.mark.parametrize('causal', [False, True]) def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=1., std=0.).requires_grad_() + k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=1., std=0.).requires_grad_() + v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=1., std=0.).requires_grad_() sm_scale = 0.5 # dout = torch.randn_like(q) # reference implementation @@ -335,6 +335,8 @@ def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): # tri_dk, k.grad = k.grad.clone(), None # tri_dq, q.grad = q.grad.clone(), None # compare + print(ref_out) + print(tri_out) assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) From d304a201254075e046b1fa73504290ba51d55bd8 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 10 Sep 2023 21:05:36 -0700 Subject: [PATCH 14/41] progress --- python/triton/interpreter/new_interpreter.py | 8 ++++++-- python/tutorials/06-fused-attention.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index fc42d6bf131b..e17ad0e9103f 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -258,6 +258,8 @@ def create_addptr(self, ptr, offset): return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptr.base.dtype.element_ty + n_bytes = (dtype_tt.primitive_bitwidth // 8) shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] ptrs = np.broadcast_to(ptr.base.data, shapes) masks = np.ones(shapes, dtype=bool) @@ -266,13 +268,15 @@ def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_ bcast_dims = [1] * len(shapes) bcast_dims[dim] = shapes[dim] off = (ptr.offsets[dim].data + np.arange(shapes[dim])).reshape(bcast_dims) - ptrs = ptrs + off * ptr.strides[dim].data + ptrs = ptrs + (n_bytes * off * ptr.strides[dim].data).astype(np.uint64) masks = np.logical_and(masks, off < ptr.shape[dim].data) # other = np.full(shapes, padding_value, dtype=self.np_dtype(ptr.base.dtype.element_ty)) ptrs = TensorHandle(ptrs, ptr.base.dtype) return self.create_masked_load(ptrs, masks, None, cache_modifier, eviction_policy, is_volatile) def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + dtype_tt = ptr.base.dtype.element_ty + n_bytes = (dtype_tt.primitive_bitwidth // 8) shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] ptrs = np.broadcast_to(ptr.base.data, shapes) masks = np.ones(shapes, dtype=bool) @@ -280,7 +284,7 @@ def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier bcast_dims = [1] * len(shapes) bcast_dims[dim] = shapes[dim] off = (ptr.offsets[dim].data + np.arange(shapes[dim])).reshape(bcast_dims) - ptrs = ptrs + off * ptr.strides[dim].data + ptrs = ptrs + (n_bytes * off * ptr.strides[dim].data).astype(np.uint64) masks = np.logical_and(masks, off < ptr.shape[dim].data) ptrs = TensorHandle(ptrs, ptr.base.dtype) return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index cfd4488624b4..7a58945510ab 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -77,6 +77,7 @@ def _fwd_kernel( # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout + print("Q") q = tl.load(Q_block_ptr) q = (q * qk_scale).to(tl.float16) # loop over k, v and update accumulator From 5882703a1ee06edcb9ca7ab21a54a43e60d7a510 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 10 Sep 2023 21:08:27 -0700 Subject: [PATCH 15/41] flash attention fwd pass working --- python/tutorials/06-fused-attention.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 7a58945510ab..3ff6a59e3fa5 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -77,7 +77,6 @@ def _fwd_kernel( # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout - print("Q") q = tl.load(Q_block_ptr) q = (q * qk_scale).to(tl.float16) # loop over k, v and update accumulator @@ -312,9 +311,9 @@ def backward(ctx, do): # @pytest.mark.parametrize('causal', [False, True]) def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=1., std=0.).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=1., std=0.).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=1., std=0.).requires_grad_() + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() sm_scale = 0.5 # dout = torch.randn_like(q) # reference implementation From 2ef140e82c6f2895a118ef1f30b9a6b83032057c Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 10 Sep 2023 22:26:03 -0700 Subject: [PATCH 16/41] flash bwd also works --- python/src/triton.cc | 4 ++-- python/tutorials/06-fused-attention.py | 26 ++++++++++++-------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 14659e507746..9c4b9daa8208 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -2012,8 +2012,8 @@ void init_triton_interpreter(py::module &&m) { auto reshaped_values = values.reshape({numel}); for (size_t i = 0; i < ptrs.size(); ++i) { if (reshaped_masks.at(i)) - memcpy(reinterpret_cast(reshaped_ptrs.at(i)), - reshaped_values.mutable_data(i), values.dtype().itemsize()); + memcpy(reinterpret_cast(reshaped_ptrs.mutable_at(i)), + reshaped_values.data(i), values.dtype().itemsize()); } }); } diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 3ff6a59e3fa5..126cd4c08abb 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -315,7 +315,7 @@ def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() sm_scale = 0.5 - # dout = torch.randn_like(q) + dout = torch.randn_like(q) # reference implementation M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale @@ -324,23 +324,21 @@ def test_op(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): p = torch.softmax(p.float(), dim=-1).half() # p = torch.exp(p) ref_out = torch.matmul(p, v) - # ref_out.backward(dout) - # ref_dv, v.grad = v.grad.clone(), None - # ref_dk, k.grad = k.grad.clone(), None - # ref_dq, q.grad = q.grad.clone(), None + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None # triton implementation tri_out = attention(q, k, v, causal, sm_scale).half() - # tri_out.backward(dout) - # tri_dv, v.grad = v.grad.clone(), None - # tri_dk, k.grad = k.grad.clone(), None - # tri_dq, q.grad = q.grad.clone(), None + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None # compare - print(ref_out) - print(tri_out) assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) - # assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) - # assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) - # assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) + assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) + assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) + assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) try: From 5d3d9161e3da8f237dfc28db5e60387a085bea33 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 10 Sep 2023 23:56:47 -0700 Subject: [PATCH 17/41] cleanup --- python/triton/interpreter/new_interpreter.py | 48 +++++++++----------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index e17ad0e9103f..a55bee85e081 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -56,6 +56,22 @@ def __init__(self, base, shape, strides, offsets, tensor_shape, order): self.tensor_shape = tensor_shape self.order = order + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.dtype.element_ty + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype) + return ptrs, masks + def wrap_ret(compute_ret_ty): def wrapper(fn): @@ -258,35 +274,13 @@ def create_addptr(self, ptr, offset): return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): - dtype_tt = ptr.base.dtype.element_ty - n_bytes = (dtype_tt.primitive_bitwidth // 8) - shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] - ptrs = np.broadcast_to(ptr.base.data, shapes) - masks = np.ones(shapes, dtype=bool) - # padding_value = {None: 0., "zero": 0., "nan": float('nan')}[padding_option] - for dim in range(len(shapes)): - bcast_dims = [1] * len(shapes) - bcast_dims[dim] = shapes[dim] - off = (ptr.offsets[dim].data + np.arange(shapes[dim])).reshape(bcast_dims) - ptrs = ptrs + (n_bytes * off * ptr.strides[dim].data).astype(np.uint64) - masks = np.logical_and(masks, off < ptr.shape[dim].data) - # other = np.full(shapes, padding_value, dtype=self.np_dtype(ptr.base.dtype.element_ty)) - ptrs = TensorHandle(ptrs, ptr.base.dtype) - return self.create_masked_load(ptrs, masks, None, cache_modifier, eviction_policy, is_volatile) + ptrs, masks = ptr.materialize_pointers(boundary_check) + assert padding_option is None + other = None + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): - dtype_tt = ptr.base.dtype.element_ty - n_bytes = (dtype_tt.primitive_bitwidth // 8) - shapes = [int(ptr.tensor_shape[dim]) for dim in range(len(ptr.tensor_shape))] - ptrs = np.broadcast_to(ptr.base.data, shapes) - masks = np.ones(shapes, dtype=bool) - for dim in range(len(shapes)): - bcast_dims = [1] * len(shapes) - bcast_dims[dim] = shapes[dim] - off = (ptr.offsets[dim].data + np.arange(shapes[dim])).reshape(bcast_dims) - ptrs = ptrs + (n_bytes * off * ptr.strides[dim].data).astype(np.uint64) - masks = np.logical_and(masks, off < ptr.shape[dim].data) - ptrs = TensorHandle(ptrs, ptr.base.dtype) + ptrs, masks = ptr.materialize_pointers(boundary_check) return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) def create_expand_dims(self, arg, axis): From 2f1cab78e1ef3b7d96085baaeb023d2e2ae15ecc Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Fri, 15 Sep 2023 15:05:29 -0700 Subject: [PATCH 18/41] . --- python/triton/interpreter/new_interpreter.py | 9 ++++++++ python/triton/language/core.py | 2 -- python/triton/runtime/jit.py | 22 ++++++++++---------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index a55bee85e081..17a3511972e1 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -122,6 +122,12 @@ def get_half_ty(self): def get_float_ty(self): return tl.float32 + def get_int64_ty(self): + return tl.int64 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + def get_block_ty(self, dtype, shape): return tl.tensor(shape, dtype) @@ -289,6 +295,8 @@ def create_expand_dims(self, arg, axis): def create_broadcast(self, arg, shape): return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype) + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty) # def create_cat(self, lhs, rhs): # pass @@ -360,6 +368,7 @@ def _patch_lang_tensor(tensor, builder): for name, member in inspect.getmembers(tensor): if tl.core.is_builtin(member): patch_attr(tensor, name, member, builder) + tensor.__index__ = lambda self: int(self.handle.data) def _patch_lang_core(lang, builder): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index d3c0faa0a4e0..4bd90fa44dba 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -672,8 +672,6 @@ def __rrshift__(self, other, _builder=None): else: return semantic.lshr(other, self, _builder) - # comparison operators - # > @builtin def __gt__(self, other, _builder=None): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 4bf32d8426fa..e5ebeb0a4c9c 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -583,17 +583,17 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - # if os.getenv("TRITON_INTERPRET", "0") == "1": - from ..interpreter.new_interpreter import InterpretedFunction - return InterpretedFunction(fn) - # else: - # return JITFunction( - # fn, - # version=version, - # do_not_specialize=do_not_specialize, - # debug=debug, - # noinline=noinline, - # ) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from ..interpreter.new_interpreter import InterpretedFunction + return InterpretedFunction(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + ) if fn is not None: return decorator(fn) From 1af9397bef69a30448be3361821c6f26b88412f4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 19:48:04 -0700 Subject: [PATCH 19/41] . --- python/triton/interpreter/new_interpreter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/interpreter/new_interpreter.py index 17a3511972e1..0c146fd1aad8 100644 --- a/python/triton/interpreter/new_interpreter.py +++ b/python/triton/interpreter/new_interpreter.py @@ -492,6 +492,8 @@ def _patch_lang(self, builder): def __init__(self, fn) -> None: self.fn = fn + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] def __getitem__(self, grid): return GridExecutor(self.fn, grid) From 84f7a0c1cee28558e70202a816c67b4736dfe24a Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 19:52:33 -0700 Subject: [PATCH 20/41] . --- python/triton/interpreter/__init__.py | 0 python/triton/interpreter/core.py | 9 - python/triton/interpreter/interpreter.py | 171 ----- python/triton/interpreter/memory_map.py | 102 --- python/triton/interpreter/tl_lang.py | 641 ------------------ python/triton/interpreter/torch_wrapper.py | 18 - .../interpreter.py} | 0 python/triton/runtime/jit.py | 2 +- 8 files changed, 1 insertion(+), 942 deletions(-) delete mode 100644 python/triton/interpreter/__init__.py delete mode 100644 python/triton/interpreter/core.py delete mode 100644 python/triton/interpreter/interpreter.py delete mode 100644 python/triton/interpreter/memory_map.py delete mode 100644 python/triton/interpreter/tl_lang.py delete mode 100644 python/triton/interpreter/torch_wrapper.py rename python/triton/{interpreter/new_interpreter.py => runtime/interpreter.py} (100%) diff --git a/python/triton/interpreter/__init__.py b/python/triton/interpreter/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/triton/interpreter/core.py b/python/triton/interpreter/core.py deleted file mode 100644 index 82f3f43a25a0..000000000000 --- a/python/triton/interpreter/core.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Tuple - -import dataclasses - - -@dataclasses.dataclass -class ExecutionContext: - program_id: Tuple[int] - program_size: Tuple[int] diff --git a/python/triton/interpreter/interpreter.py b/python/triton/interpreter/interpreter.py deleted file mode 100644 index 001b80ec9855..000000000000 --- a/python/triton/interpreter/interpreter.py +++ /dev/null @@ -1,171 +0,0 @@ -import itertools -import random -from typing import Tuple - -from .. import language as tl -# import .language.core as lcore -from ..language import core as lcore -from . import torch_wrapper -from .core import ExecutionContext -from .memory_map import MemoryMap -from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor, - debugger_constexpr) - -torch = torch_wrapper.torch -tl_method_backup = {} - - -def get_proxy_method(proxy, name): - method = getattr(proxy, name) - - def fun(*args, **kwarg): - return method(*args, **kwarg) - - return fun - - -def attach_triton(module, proxy): - method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"] - for name in method_list: - if hasattr(module, name): - attr = getattr(module, name) - tl_method_backup[name] = attr - if callable(attr): - setattr(module, name, get_proxy_method(proxy, name)) - else: - setattr(module, name, getattr(proxy, name)) - - -def detach_triton(module): - for name, method in tl_method_backup.items(): - setattr(module, name, method) - - -def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]: - # reverse the grid dimensions and generate the range for each dimension - reversed_grid = reversed(grid) - ranges_for_each_dimension = [range(dim) for dim in reversed_grid] - - # gen all combinations - index_combinations = list(itertools.product(*ranges_for_each_dimension)) - random.shuffle(index_combinations) - - for index_combination in index_combinations: - yield index_combination - - -class DebuggerFunction: - def __init__(self, func, grid=(1,)): - self.func = func - self.grid = grid - - def _is_constexpr(self, name): - return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr - - def _get_constexpr(self): - result = [] - for name, annotation in self.func.__annotations__.items(): - if annotation is lcore.constexpr: - result.append(name) - return result - - def _assert_constexpr(self, **kwargs): - constexp = self._get_constexpr() - missing = [i for i in constexp if i not in kwargs.keys()] - assert len(missing) == 0, f"You must specify constexpr {missing}" - - def _get_grid(self, **kwargs): - if callable(self.grid): - return self.grid(kwargs) - else: - return self.grid - - def __call__(self, *args, **kwargs): - self._assert_constexpr(**kwargs) - - memory = MemoryMap() - - def convert_arg(v): - name, arg = v - if torch.is_tensor(arg): - ptr = memory.add_tensor(arg) - return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda")) - if self._is_constexpr(name): - return debugger_constexpr(arg) - return WrappedTensor(_primitive_to_tensor(arg)) - - new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args))) - new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]} - - grid = self._get_grid(**kwargs) - for program_id in program_ids_from_grid(grid): - proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid)) - attach_triton(tl, proxy) - self.func(*new_args, **new_kwargs) - detach_triton(tl) - - -class GridSelector: - """ - Entry point of the debugger - """ - - def __init__(self, func): - version = torch.__version__ - assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}" - self.func = func - - def __getitem__(self, grid): - return DebuggerFunction(self.func, grid) - - def __call__(self, *args, **kwargs): - return DebuggerFunction(self.func)(*args, **kwargs) - - -class AutotuneGridSelector: - def __init__(self, func, autotune_params): - self.func = func - self.autotune_params = autotune_params - - def __getitem__(self, grid): - return AutotuneRunner(self.func, self.autotune_params, grid) - - def __call__(self, *args, **kwargs): - return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs) - - -class AutotuneRunner: - def __init__(self, func, autotune_params, grid=None): - self.func = func - self.autotune_params = autotune_params - self.grid = grid - - def __call__(self, *args, **kwargs): - assert len(self.autotune_params["configs"]) >= 1 - - for config in self.autotune_params["configs"][1:]: - - def convert_arg(v): - if torch.is_tensor(v): - return torch.clone(v) - return v - - new_args = tuple(map(convert_arg, args)) - new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} - if self.grid: - self.func[self.grid](*new_args, **new_kwargs, **config.kwargs) - else: - self.func(*new_args, **new_kwargs, **config.kwargs) - - main_config = self.autotune_params["configs"][0] - if self.grid: - self.func[self.grid](*args, **kwargs, **main_config.kwargs) - else: - self.func(*args, **kwargs, **main_config.kwargs) - - -def triton_debug_autotune(**kwars): - def wrapper(func): - return AutotuneGridSelector(func, kwars) - - return wrapper diff --git a/python/triton/interpreter/memory_map.py b/python/triton/interpreter/memory_map.py deleted file mode 100644 index d0ff732a74b9..000000000000 --- a/python/triton/interpreter/memory_map.py +++ /dev/null @@ -1,102 +0,0 @@ -from __future__ import annotations - -import dataclasses - -from . import torch_wrapper - -torch = torch_wrapper.torch - - -@dataclasses.dataclass -class RegisteredStorage: - storage: torch.Storage - dtype: torch.dtype - size: int - ptr: int - - @property - def end_ptr(self) -> int: - return self.ptr + self.size - - @property - def access_tensor(self) -> torch.Tensor: - return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device) - - def ensure_immutable(self): - assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size - - -class MemoryMap: - storages: [RegisteredStorage] - - def __init__(self): - self.storages = [] - - def _get_registered_storage(self, pointer: torch.Tensor): - max_pointer = torch.max(pointer).item() - min_pointer = torch.min(pointer).item() - - registered_storage = next( - filter( - lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages - ), - None, - ) - if registered_storage is None: - raise Exception("Storage not found or pointers spanning multiple tensors") - registered_storage.ensure_immutable() - return registered_storage - - def add_tensor(self, t: torch.Tensor): - storage = t.untyped_storage() - self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr())) - return t.data_ptr() - - def load( - self, - pointer: torch.Tensor, - mask: torch.Tensor = None, - other=0.0, - ): - assert pointer.is_cuda - assert 0 < pointer.dim() < 3 - assert pointer.dtype == torch.int64 - - if mask is None: - mask = torch.ones_like(pointer).bool() - assert mask.is_cuda - assert 0 < mask.dim() < 3 - assert mask.dtype == torch.bool - mask = mask.expand(pointer.size()) - - if torch.all(~mask): - # Todo: The type is wrong here, we can't determine the correct type - return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda") - - registered_storage = self._get_registered_storage(pointer[mask]) - access_tensor = registered_storage.access_tensor - - index_tensor = pointer - registered_storage.ptr - - block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda") - block[mask] = access_tensor[index_tensor[mask]] - return block - - def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): - assert 0 < pointer.dim() < 3 - assert pointer.dtype == torch.int64 - - if mask is None: - mask = torch.ones_like(pointer).bool() - assert 0 < mask.dim() < 3 - assert mask.dtype == torch.bool - mask = mask.expand(pointer.size()) - - if torch.all(~mask): - return - - registered_storage = self._get_registered_storage(pointer[mask]) - access_tensor = registered_storage.access_tensor - - index_tensor = pointer - registered_storage.ptr - access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype) diff --git a/python/triton/interpreter/tl_lang.py b/python/triton/interpreter/tl_lang.py deleted file mode 100644 index e2a578fa580f..000000000000 --- a/python/triton/interpreter/tl_lang.py +++ /dev/null @@ -1,641 +0,0 @@ -from __future__ import annotations - -from ..language import core as lcore -from . import torch_wrapper -from .core import ExecutionContext -from .memory_map import MemoryMap - -torch = torch_wrapper.torch - - -def _primitive_to_tensor(x): - """ - Converts various Python primitive data types to PyTorch tensor. - """ - tensor_args = {"device": "cuda"} - if isinstance(x, bool): - return torch.tensor([x], dtype=torch.bool, **tensor_args) - elif isinstance(x, int): - if -(2**31) <= x < 2**31: - return torch.tensor([x], dtype=torch.int32, **tensor_args) - elif -(2**63) <= x < 2**63: - return torch.tensor([x], dtype=torch.int64, **tensor_args) - else: - raise RuntimeError(f"Nonrepresentable integer {x}.") - elif isinstance(x, float): - return torch.tensor([x], dtype=torch.float32, **tensor_args) - elif torch.is_tensor(x): - return x - elif isinstance(x, WrappedTensor): - return x - elif isinstance(x, debugger_constexpr): - if x.value is None: - return None - return _primitive_to_tensor(x.value) - elif x is None: - return None - assert False, f"cannot convert {x} of type {type(x)} to tensor" - - -def _infer_tensor(func): - """ - A decorator function to harmonize function args: - - converts primitives to PyTorch tensors - - wraps PyTorch tensors with WrappedTensors - """ - def wrapper(*args): - new_args = tuple(map(lambda v: _primitive_to_tensor(v), args)) - new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args)) - - return func(*new_args) - - return wrapper - - -def _tensor_operation(func): - """ - A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function. - Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor). - """ - def wrapper(*args, **kwargs): - for arg in args: - assert not torch.is_tensor(arg), "unexpected tensor argument" - - def unwrap_tensor(v): - if isinstance(v, WrappedTensor): - return v.tensor - if isinstance(v, debugger_constexpr): - return v.value - return v - - new_args = tuple(map(unwrap_tensor, args)) - new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()} - - result = func(args[0], *new_args[1:], **new_kwargs) - return WrappedTensor(result) if torch.is_tensor(result) else result - - return wrapper - - -class debugger_constexpr: - def __init__(self, value): - if isinstance(value, debugger_constexpr): - self.value = value.value - else: - self.value = value - - def __str__(self) -> str: - return "debugger_constexpr(" + str(self.value) + ")" - - def __index__(self) -> int: - return self.value - - def __bool__(self): - return bool(self.value) - - def __ge__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value >= other - - def __gt__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value > other - - def __le__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value <= other - - def __lt__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value < other - - def __eq__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value == other - - def __or__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value | other - - def __ror__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value | other - - def __and__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value & other - - def __rand__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value & other - - def to(self, dtype, bitcast=False, _builder=None): - if dtype in [torch.int64]: - ret_ty = int - elif dtype == torch.bool: - ret_ty = bool - elif dtype in [torch.float64]: - ret_ty = float - else: - raise ValueError("dtype not supported in debugger") - return debugger_constexpr(ret_ty(self.value)) - - -class WrappedTensor: - def __init__(self, tensor): - self.tensor = tensor - - def __index__(self) -> int: - return self.tensor.item() - - def __str__(self) -> str: - return "wrapped_" + str(self.tensor) - - def __bool__(self) -> bool: - return torch.all(self.tensor == True).item() # noqa: E712 - - @property - def dtype(self): - return self.tensor.dtype - - @_infer_tensor - @_tensor_operation - def __add__(self, other): - return torch.add(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __radd__(self, other): - return self.__add__(other) - - @_infer_tensor - @_tensor_operation - def __sub__(self, other): - return torch.sub(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rsub__(self, other): - return torch.sub(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __mul__(self, other): - return torch.mul(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rmul__(self, other): - return self.__mul__(other) - - @_infer_tensor - @_tensor_operation - def __truediv__(self, other): - return torch.div(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rtruediv__(self, other): - return torch.div(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __floordiv__(self, other): - return torch.floor_divide(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rfloordiv__(self, other): - return torch.floor_divide(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __mod__(self, other): - return torch.remainder(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rmod__(self, other): - return torch.remainder(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __neg__(self): - return -self.tensor - - @_infer_tensor - @_tensor_operation - def __invert__(self): - return ~self.tensor - - @_infer_tensor - @_tensor_operation - def __and__(self, other): - return torch.bitwise_and(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __or__(self, other): - return torch.bitwise_or(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __xor__(self, other): - return torch.bitwise_xor(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __lshift__(self, other): - return torch.bitwise_left_shift(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rshift__(self, other): - return torch.bitwise_right_shift(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __gt__(self, other): - return self.tensor > other - - @_infer_tensor - @_tensor_operation - def __rgt__(self, other): - return other > self.tensor - - @_infer_tensor - @_tensor_operation - def __ge__(self, other): - return self.tensor >= other - - @_infer_tensor - @_tensor_operation - def __rge__(self, other): - return other >= self.tensor - - @_infer_tensor - @_tensor_operation - def __lt__(self, other): - return self.tensor < other - - @_infer_tensor - @_tensor_operation - def __rlt__(self, other): - return other < self.tensor - - @_infer_tensor - @_tensor_operation - def __le__(self, other): - return self.tensor <= other - - @_infer_tensor - @_tensor_operation - def __rle__(self, other): - return other <= self.tensor - - @_infer_tensor - @_tensor_operation - def __eq__(self, other): - return torch.equal(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __ne__(self, other): - return not torch.equal(self.tensor, other) - - @_tensor_operation - def __getitem__(self, slices): - return self.tensor.__getitem__(slices) - # if isinstance(slices, slice): - # slices = [slices] - # src_shape = self.shape - # dst_shape = [] - # curr = 0 - # for sl in slices: - # if isinstance(sl, constexpr) and sl.value is None: - # dst_shape.append(1) - # elif sl == slice(None, None, None): - # dst_shape.append(src_shape[curr].value) - # curr += 1 - # ret = torch.reshape(self.tensor, dst_shape, ) - # return ret - - @_tensor_operation - def to(self, dtype, bitcast=False): - return self.tensor.to(dtype) - # if isinstance(bitcast, constexpr): - # bitcast = bitcast.value - # if bitcast: - # return semantic.bitcast(self, dtype, ) - # return semantic.cast(self, dtype, ) - - -def _constexpr_to_value(v): - if isinstance(v, debugger_constexpr): - return v.value - return v - - -class TritonLangProxy: - _memory_map: MemoryMap - _context: ExecutionContext - - def __init__(self, memory_map: MemoryMap, context: ExecutionContext): - self._memory_map = memory_map - self._context = context - - # Types - # Removed void, int1, float8, uint16, uint32, uint64, pi32_t - - # constexpr = debugger_constexpr - - # Program functions - - @_tensor_operation - def load( - self, - pointer: torch.Tensor, - mask: torch.Tensor = None, - other=0.0, - cache_modifier="", - eviction_policy="", - volatile=False, - ): - return self._memory_map.load(pointer, mask, other) - - @_tensor_operation - def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): - return self._memory_map.store(pointer, value, mask) - - @_tensor_operation - def program_id(self, axis): - assert axis < len(self._context.program_id) - return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda") - - @_tensor_operation - def num_programs(self, axis): - assert axis < len(self._context.program_size) - return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda") - - @_tensor_operation - def arange(self, start, end): - return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda") - - @_tensor_operation - def zeros(self, shape, dtype): - for i, d in enumerate(shape): - if not isinstance(d, debugger_constexpr): - raise TypeError(f"Shape element {i} must have type `constexpr`") - if not isinstance(d.value, int): - raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") - shape = [x.value for x in shape] - if isinstance(dtype, lcore.dtype): - if dtype.is_fp32(): - dtype = torch.float32 - elif dtype.is_fp16(): - dtype = torch.float16 - elif dtype.is_bf16(): - dtype = torch.bfloat16 - elif dtype.is_int32(): - dtype = torch.int32 - elif dtype.is_int16(): - dtype = torch.int16 - elif dtype.is_int8(): - dtype = torch.int8 - else: - raise TypeError(f"Unsupported dtype {dtype}") - return torch.zeros(size=shape, dtype=dtype, device="cuda") - - @_tensor_operation - def dequantize(self, input, scale, shift, nbit, dst_ty=None): - if dst_ty is None: - dst_ty = torch.float16 - raise NotImplementedError() - - @_tensor_operation - def broadcast(self, input, other): - raise NotImplementedError() - - @_tensor_operation - def broadcast_to(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def cat(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def reshape(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): - assert input.dtype == other.dtype - if trans_a: - input = input.T - if trans_b: - other = other.T - return torch.matmul(input=input, other=other) - - @_tensor_operation - def atomic_cas(self, pointer, cmp, val): - stored = self._memory_map.load(pointer, None, 0.0) - if not isinstance(cmp, torch.Tensor): - cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda") - if not isinstance(val, torch.Tensor): - val = torch.tensor([val], dtype=stored.dtype, device="cuda") - if stored == cmp: - self._memory_map.store(pointer, val, None) - return stored - - @_tensor_operation - def atomic_xchg(self, pointer, val, mask=None): - if isinstance(val, int): - val = torch.tensor([val], dtype=torch.int32, device="cuda") - stored = self._memory_map.load(pointer, mask, 0.0) - self._memory_map.store(pointer, val, mask) - return stored - - @_tensor_operation - def atomic_add(self, pointer, val, mask=None): - # arbitrary other value as it will masked during storing - stored = self._memory_map.load(pointer, mask, 0.0) - result = stored + val - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_max(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0.0) - result = torch.maximum(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_min(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0.0) - result = torch.minimum(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_and(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_and(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_or(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_or(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_xor(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_xor(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def where(self, condition, x, y): - condition = _primitive_to_tensor(condition) - x = _primitive_to_tensor(x) - y = _primitive_to_tensor(y) - return torch.where(condition, x, y) - - @_tensor_operation - def umulhi(self, x, y): - raise NotImplementedError() - - @_tensor_operation - def fdiv(self, x, y, ieee_rounding=False): - raise NotImplementedError() - - @_tensor_operation - def exp(self, x): - return torch.exp(x) - - @_tensor_operation - def log(self, x): - return torch.log(x) - - @_tensor_operation - def cos(self, x): - return torch.cos(x) - - @_tensor_operation - def sin(self, x): - return torch.sin(x) - - @_tensor_operation - def sqrt(self, x): - return torch.sqrt(x) - - @_tensor_operation - def globaltimer(self): - raise NotImplementedError() - - @_tensor_operation - def clock(self): - raise NotImplementedError() - - @_tensor_operation - def debug_barrier(self): - raise NotImplementedError() - - @_tensor_operation - def multiple_of(self, input, values): - return input - - @_tensor_operation - def max_contiguous(self, input, values): - return input - - @_tensor_operation - def max_constancy(self, input, values): - return input - - @_tensor_operation - def abs(self, x): - return torch.abs(x) - - @_tensor_operation - def cdiv(self, x, div): - return (x + div - 1) // div - - @_tensor_operation - def minimum(self, x, y): - if isinstance(x, int): - x = torch.tensor(x, device="cuda") - if isinstance(y, int): - y = torch.tensor(y, device="cuda") - return torch.minimum(x, y) - - @_tensor_operation - def maximum(self, x, y): - return torch.maximum(x, y) - - @_tensor_operation - def sigmoid(self, x): - raise NotImplementedError() - - @_tensor_operation - def softmax(self, x, ieee_rounding=False): - raise NotImplementedError() - - @_tensor_operation - def ravel(self, x): - raise NotImplementedError() - - @_tensor_operation - def swizzle2d(self, i, j, size_i, size_j, size_g): - raise NotImplementedError() - - @_tensor_operation - def zeros_like(self, input): - raise NotImplementedError() - - @_tensor_operation - def max(self, input, axis=None): - if axis is None: - return torch.max(input) - return torch.max(input, dim=axis).values - - @_tensor_operation - def argmax(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def min(self, input, axis=None): - if axis is None: - return torch.min(input) - return torch.min(input, dim=axis).values - - @_tensor_operation - def argmin(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def sum(self, input, axis=None): - if axis is None: - return torch.sum(input) - return torch.sum(input, dim=axis) - - @_tensor_operation - def xor_sum(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def cumsum(self, input, axis=None): - if axis is None: - return torch.cumsum(input) - return torch.cumsum(input, dim=axis) - - @_tensor_operation - def cumprod(self, input, axis=None): - if axis is None: - return torch.cumprod(input) - return torch.cumprod(input, dim=axis) diff --git a/python/triton/interpreter/torch_wrapper.py b/python/triton/interpreter/torch_wrapper.py deleted file mode 100644 index 44aa17eb1355..000000000000 --- a/python/triton/interpreter/torch_wrapper.py +++ /dev/null @@ -1,18 +0,0 @@ -try: - import torch as _torch -except ImportError: - _torch = None - - -class TorchWrapper: - """ - Helps in making torch an optional dependency - """ - - def __getattr__(self, name): - if _torch is None: - raise ImportError("Triton requires PyTorch to be installed") - return getattr(_torch, name) - - -torch = TorchWrapper() diff --git a/python/triton/interpreter/new_interpreter.py b/python/triton/runtime/interpreter.py similarity index 100% rename from python/triton/interpreter/new_interpreter.py rename to python/triton/runtime/interpreter.py diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 0d4f6963fe9b..be40eaf8888b 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -586,7 +586,7 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) if os.getenv("TRITON_INTERPRET", "0") == "1": - from ..interpreter.new_interpreter import InterpretedFunction + from .interpreter import InterpretedFunction return InterpretedFunction(fn) else: return JITFunction( From 9d2858a0e72aba3bd223ca1e89090a2cbca0705e Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 20:13:33 -0700 Subject: [PATCH 21/41] . --- python/tutorials/03-matrix-multiplication.py | 21 ++++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 3c3de43c5162..8bcae2007abd 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -241,7 +241,7 @@ def matmul_kernel( # while the accumulator is still in FP32! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) - c = accumulator.to(tl.float32) + c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. @@ -274,20 +274,16 @@ def matmul(a, b, activation=""): # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - # grid = lambda META: ( - # triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - # ) - matmul_kernel[(M // 128 * N // 128,)]( + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - ACTIVATION=activation, - BLOCK_SIZE_M=128, - BLOCK_SIZE_N=128, - BLOCK_SIZE_K=32, - GROUP_SIZE_M=8, + ACTIVATION=activation ) return c @@ -299,8 +295,8 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((256, 64), device='cuda', dtype=torch.float32) -b = torch.randn((64, 256), device='cuda', dtype=torch.float32) +a = torch.randn((512, 512), device='cuda', dtype=torch.float16) +b = torch.randn((512, 512), device='cuda', dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") @@ -309,7 +305,6 @@ def matmul(a, b, activation=""): print("✅ Triton and Torch match") else: print("❌ Triton and Torch differ") -exit(1) # %% # Benchmark From 9d60d6aa30a4ca76e13ca1e97173fa96a966dd76 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 21:57:13 -0700 Subject: [PATCH 22/41] more fixes --- python/src/triton.cc | 7 ++-- python/triton/runtime/autotuner.py | 2 + python/triton/runtime/interpreter.py | 41 +++++++++++++------- python/tutorials/03-matrix-multiplication.py | 5 ++- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 4962f94f1c31..0068a23f8006 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1991,12 +1991,13 @@ void init_triton_interpreter(py::module &&m) { py::array_t mask) { int numel = ptrs.size(); py::array_t reshaped_ptrs = ptrs.reshape({numel}); - py::array_t reshaped_masks = mask.reshape({numel}); - auto reshaped_values = values.reshape({numel}); + py::array_t reshaped_masks = mask.reshape({numel}); + py::array reshaped_values = values.reshape({numel}); for (size_t i = 0; i < ptrs.size(); ++i) { - if (reshaped_masks.at(i)) + if (reshaped_masks.at(i)) { memcpy(reinterpret_cast(reshaped_ptrs.mutable_at(i)), reshaped_values.data(i), values.dtype().itemsize()); + } } }); } diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index e3f2794f7d46..23192df0487a 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -36,6 +36,8 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] else: self.configs = configs + print(fn) + self.configs = configs[:1] self.key_idx = [arg_names.index(k) for k in key] self.cache = {} # hook to reset all required tensor to zeros before relaunching a kernel diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 0c146fd1aad8..b6bd6b072711 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -156,12 +156,12 @@ def create_get_num_programs(self, axis): # memory ops def create_load(self, ptr, _0, _1, is_volatile): - mask = np.ones_like(ptr.data, dtype=bool) + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) other = None return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) def create_store(self, ptr, val, _0, _1): - mask = np.ones_like(ptr.data, dtype=bool) + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) return self.create_masked_store(ptr, val, mask, None, None) def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): @@ -169,7 +169,7 @@ def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, dtype_np = self.np_dtype(dtype_tt) if other is None: other = np.ones_like(ptrs.data, dtype=dtype_np) - ret = _interpreter.load(ptrs.data, mask.data, other, dtype_np) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) return TensorHandle(ret, dtype_tt) def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): @@ -267,7 +267,7 @@ def unary_op(self, arg, op): create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype) create_trans = lambda self, arg: self.unary_op(arg, np.transpose) - def create_dot(self, a, b, d, allow_tf32): + def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc): return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype) def create_make_range(self, start, stop): @@ -369,6 +369,7 @@ def _patch_lang_tensor(tensor, builder): if tl.core.is_builtin(member): patch_attr(tensor, name, member, builder) tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: True def _patch_lang_core(lang, builder): @@ -445,13 +446,15 @@ def _unwrap(tensor): builder = Builder() +RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization'] + class GridExecutor: - def __init__(self, fn, grid): - assert len(grid) <= 3 + def __init__(self, fn, arg_names, grid): self.fn = fn - self.grid = tuple(grid) + (1,) * (3 - len(grid)) + self.arg_names = arg_names + self.grid = grid def _patch_lang(self, builder): lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] @@ -462,7 +465,7 @@ def _patch_lang(self, builder): def __call__(self, *args_dev, **kwargs): # removes reserved keywords from kwargs - kwargs = {k: v for k, v in kwargs.items() if k not in ['num_warps', 'num_stages', 'num_ctas']} + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} # remaps core language functions to interpreted ones self._patch_lang(builder) # we need to copy arguments to the host for the interpreter @@ -470,10 +473,15 @@ def __call__(self, *args_dev, **kwargs): # implicitly convert tensor arguments to their base pointers wrapped_args = [_implicit_cvt(arg) for arg in args_hst] # iterate through grid - builder.set_grid_dim(*self.grid) - for x in range(self.grid[0]): - for y in range(self.grid[1]): - for z in range(self.grid[2]): + grid_args = {name: val for name, val in zip(self.arg_names, args_dev)} + grid_args.update(kwargs) + grid = self.grid(grid_args) if callable(self.grid) else self.grid + assert len(grid) <= 3 + grid = grid + (1,) * (3 - len(grid)) + builder.set_grid_dim(*grid) + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): builder.set_grid_idx(x, y, z) self.fn(*wrapped_args, **kwargs) # copy arguments back to propagate side-effects @@ -492,11 +500,18 @@ def _patch_lang(self, builder): def __init__(self, fn) -> None: self.fn = fn + + def run(*args, **kwargs): + grid = kwargs['grid'] + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']} + + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + self.run = run signature = inspect.signature(fn) self.arg_names = [v.name for v in signature.parameters.values()] def __getitem__(self, grid): - return GridExecutor(self.fn, grid) + return GridExecutor(self.fn, self.arg_names, grid) def __call__(self, *args, **kwargs): self._patch_lang(builder) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 8bcae2007abd..867c6d4aa7ad 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -295,12 +295,13 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((512, 512), device='cuda', dtype=torch.float16) -b = torch.randn((512, 512), device='cuda', dtype=torch.float16) +a = torch.randn((32, 256), device='cuda', dtype=torch.float16) +b = torch.randn((256, 32), device='cuda', dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") +print((triton_output - torch_output).abs().max()) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): print("✅ Triton and Torch match") else: From dfa4d2206aa9afc946fb7924b825a1fccbc9c3e1 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 21:59:01 -0700 Subject: [PATCH 23/41] . --- python/tutorials/03-matrix-multiplication.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 867c6d4aa7ad..8bcae2007abd 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -295,13 +295,12 @@ def matmul(a, b, activation=""): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). torch.manual_seed(0) -a = torch.randn((32, 256), device='cuda', dtype=torch.float16) -b = torch.randn((256, 32), device='cuda', dtype=torch.float16) +a = torch.randn((512, 512), device='cuda', dtype=torch.float16) +b = torch.randn((512, 512), device='cuda', dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") -print((triton_output - torch_output).abs().max()) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): print("✅ Triton and Torch match") else: From 3d95cc76a31ae936e83012ccf349264691e6d1a5 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 22:00:30 -0700 Subject: [PATCH 24/41] . --- python/setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index da1d8a71ffce..a92812cfcb84 100644 --- a/python/setup.py +++ b/python/setup.py @@ -296,7 +296,6 @@ def build_extension(self, ext): "triton/_C", "triton/common", "triton/compiler", - "triton/interpreter", "triton/language", "triton/language/extra", "triton/ops", From 9d679f90c447fff1d70d43a6af9161a995fec6e4 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 22:03:28 -0700 Subject: [PATCH 25/41] . --- python/triton/runtime/jit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index be40eaf8888b..1809ce36cf50 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -14,6 +14,7 @@ from .._C.libtriton.triton import TMAInfos from ..common.backend import get_backend, path_to_ptxas from ..language.core import dtype +from .interpreter import InterpretedFunction TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TRITON_VERSION = "2.1.0" @@ -586,7 +587,6 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) if os.getenv("TRITON_INTERPRET", "0") == "1": - from .interpreter import InterpretedFunction return InterpretedFunction(fn) else: return JITFunction( From 86b49d6780d3ca80c332425e9c6c9ddd2ac7e41d Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 22:19:43 -0700 Subject: [PATCH 26/41] . --- .../test/unit/interpreter/test_interpreter.py | 69 ------------------- 1 file changed, 69 deletions(-) delete mode 100644 python/test/unit/interpreter/test_interpreter.py diff --git a/python/test/unit/interpreter/test_interpreter.py b/python/test/unit/interpreter/test_interpreter.py deleted file mode 100644 index b6bb6b79c206..000000000000 --- a/python/test/unit/interpreter/test_interpreter.py +++ /dev/null @@ -1,69 +0,0 @@ -import random - -import torch - -import triton -import triton.language as tl -from triton.interpreter.interpreter import program_ids_from_grid - - -def test_addition(): - - @triton.jit(interpret=True) - def add_kernel( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - a = torch.rand((128,), device="cuda") - b = torch.rand((128,), device="cuda") - expected = a + b - output = torch.empty((128,), device="cuda") - - def grid(meta): - return (triton.cdiv(128, meta["BLOCK_SIZE"]),) - - add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32) - - assert torch.allclose(expected, output, atol=1e-2, rtol=0) - - -def test_program_ids_from_grid(): - random.seed(123) - grid = (3, 4) - expected_combinations = 3 * 4 - unique_combinations = set(program_ids_from_grid(grid)) - assert len(unique_combinations) == expected_combinations - - first_run = list(program_ids_from_grid(grid)) - second_run = list(program_ids_from_grid(grid)) - assert first_run != second_run - - -def test_atomic(): - @triton.jit(interpret=True) - def atomic( - x_ptr, - ): - pid = tl.program_id(axis=0) - tl.atomic_add(x_ptr + pid, 1) - t = tl.atomic_xchg(x_ptr + pid, 3) - t += 1 # 2 - tl.atomic_cas(x_ptr + pid, 3, t) # match - tl.atomic_cas(x_ptr + pid, 40, 9) # no match - nb_dim = 16 - a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda") - - atomic[(nb_dim, )](a) - assert torch.allclose(a, torch.full_like(a, 2)) From 6bbd9fb3dcb4f18751e40f0ce76705b6c3365140 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 22:22:06 -0700 Subject: [PATCH 27/41] . --- python/triton/runtime/autotuner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 23192df0487a..e3f2794f7d46 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -36,8 +36,6 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] else: self.configs = configs - print(fn) - self.configs = configs[:1] self.key_idx = [arg_names.index(k) for k in key] self.cache = {} # hook to reset all required tensor to zeros before relaunching a kernel From 56888b4a10d3f6d6ed3c3dfce2520e378bd58efe Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 22:23:13 -0700 Subject: [PATCH 28/41] . --- python/triton/language/core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7334bb01df76..12bf676100a0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -741,9 +741,6 @@ def logical_or(self, other, _builder=None): def __not__(self, _builder=None): return semantic.not_(self, _builder) - def __bool__(self): - return bool(self.handle) - @builtin def __getitem__(self, slices, _builder=None): if isinstance(slices, (slice, constexpr)): From a977b51eb179eb9f329f961a7400f4f3cd7d07d2 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 22:27:36 -0700 Subject: [PATCH 29/41] . --- python/triton/language/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 12bf676100a0..150d3936018f 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -542,8 +542,8 @@ def __init__(self, handle, type: dtype): self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: - # ex. "float32[3,4]" - return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']' + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' @builtin def __add__(self, other, _builder=None): From 4cdfadd3a1d9d4fefe3e00f82e01cbf2694fe4d1 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 23:16:28 -0700 Subject: [PATCH 30/41] . --- python/test/unit/operators/test_flash_attention.py | 8 ++++---- python/triton/runtime/interpreter.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index b6f74f2fc33d..ba45fdab36b3 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -5,10 +5,10 @@ import triton.ops -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16), - (4, 48, 1024, 32), - (4, 48, 1024, 64), - (4, 48, 1024, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128)]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index b6bd6b072711..c29827ea1b88 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -168,7 +168,7 @@ def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, dtype_tt = ptrs.dtype.element_ty dtype_np = self.np_dtype(dtype_tt) if other is None: - other = np.ones_like(ptrs.data, dtype=dtype_np) + other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt) ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) return TensorHandle(ret, dtype_tt) @@ -430,7 +430,12 @@ def fallback(*args, **kwargs): setattr(math, name, make_fallback(name)) +# TODO: wrap everything in triton tensors def _implicit_cvt(arg): + if isinstance(arg, int): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg], dtype=np.int32), ty) + return tl.tensor(handle, ty) if hasattr(arg, 'data_ptr'): ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) From 41d4c79159cc399762c5480d58dfc8363d598d58 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 23:33:59 -0700 Subject: [PATCH 31/41] interpreter test workflow --- .github/workflows/integration-tests.yml | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 74967580b4e3..289e4c5bb51c 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -33,6 +33,49 @@ jobs: echo '::set-output name=matrix-optional::["ubuntu-latest"]' fi + Integration-Tests-Interpreter: + runs-on: ubuntu-22.04-x64, + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + submodules: 'true' + + - name: Set CUDA ENV + if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} + run: | + echo "BACKEND=CUDA" >> "${GITHUB_ENV}" + echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" + echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}" + echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" + + - name: Update PATH + run: | + echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + + - name: Install Triton + if: ${{ env.BACKEND == 'CUDA'}} + run: | + cd python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + python3 -m pip install --no-build-isolation -vvv '.[tests]' + python3 -m pip install pytest-xdist + + - name: Run interpreter tests + env: + TRITON_INTERPRET: 1 + run: | + cd python/test/unit + python3 -m pytest -vs operators/test_flash_attention.py + + Integration-Tests-Nvidia: needs: Runner-Preparation From 27f3c14b3fb1950d12d76b24f64c0051885585e0 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 23:36:39 -0700 Subject: [PATCH 32/41] . --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 289e4c5bb51c..972a7ea676b5 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -34,7 +34,7 @@ jobs: fi Integration-Tests-Interpreter: - runs-on: ubuntu-22.04-x64, + runs-on: ubuntu-latest steps: - name: Checkout From fbdb451951b6d3b0f1af0b57be9102b863ef3244 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 23:42:39 -0700 Subject: [PATCH 33/41] . --- .github/workflows/integration-tests.yml | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 972a7ea676b5..04e593776ec5 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -42,18 +42,6 @@ jobs: with: submodules: 'true' - - name: Set CUDA ENV - if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} - run: | - echo "BACKEND=CUDA" >> "${GITHUB_ENV}" - echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" - echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}" - echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" - - - name: Update PATH - run: | - echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" - - name: Check pre-commit run: | python3 -m pip install --upgrade pre-commit @@ -67,6 +55,7 @@ jobs: python3 -m pip install cmake==3.24 python3 -m pip install --no-build-isolation -vvv '.[tests]' python3 -m pip install pytest-xdist + python3 -m pip instal pytest - name: Run interpreter tests env: From 8bec7e0de0e0d225372950c98b0545598be657ba Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 23:44:44 -0700 Subject: [PATCH 34/41] . --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 04e593776ec5..982e13fda8fa 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -55,7 +55,7 @@ jobs: python3 -m pip install cmake==3.24 python3 -m pip install --no-build-isolation -vvv '.[tests]' python3 -m pip install pytest-xdist - python3 -m pip instal pytest + python3 -m pip install pytest - name: Run interpreter tests env: From 4301e4622e293db84b36424b831c0344fc52e4d6 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sat, 16 Sep 2023 23:46:42 -0700 Subject: [PATCH 35/41] . --- .github/workflows/integration-tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 982e13fda8fa..1d06ad32479f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -48,13 +48,11 @@ jobs: python3 -m pre_commit run --all-files --verbose - name: Install Triton - if: ${{ env.BACKEND == 'CUDA'}} run: | cd python python3 -m pip install --upgrade pip python3 -m pip install cmake==3.24 python3 -m pip install --no-build-isolation -vvv '.[tests]' - python3 -m pip install pytest-xdist python3 -m pip install pytest - name: Run interpreter tests From 1853b8c22866554508bd1ffb4c76ef82712e48d6 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 17 Sep 2023 00:07:13 -0700 Subject: [PATCH 36/41] . --- .github/workflows/integration-tests.yml | 36 ++++--------------- .../unit/operators/test_flash_attention.py | 2 ++ 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 1d06ad32479f..8a2747749382 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -33,35 +33,6 @@ jobs: echo '::set-output name=matrix-optional::["ubuntu-latest"]' fi - Integration-Tests-Interpreter: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - submodules: 'true' - - - name: Check pre-commit - run: | - python3 -m pip install --upgrade pre-commit - python3 -m pre_commit run --all-files --verbose - - - name: Install Triton - run: | - cd python - python3 -m pip install --upgrade pip - python3 -m pip install cmake==3.24 - python3 -m pip install --no-build-isolation -vvv '.[tests]' - python3 -m pip install pytest - - - name: Run interpreter tests - env: - TRITON_INTERPRET: 1 - run: | - cd python/test/unit - python3 -m pytest -vs operators/test_flash_attention.py - Integration-Tests-Nvidia: needs: Runner-Preparation @@ -148,6 +119,13 @@ jobs: run: | rm -rf ~/.triton + - name: Run interpreter tests + env: + TRITON_INTERPRET: 1 + run: | + cd python/test/unit + python3 -m pytest -vs operators/test_flash_attention.py + - name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} run: | diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index ba45fdab36b3..76209e77300a 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -55,3 +55,5 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) + if os.getenv("TRITON_INTERPRET") == 1: + assert False From 78325e9984b89cba34f83438190efc9347b0daa1 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 17 Sep 2023 00:11:43 -0700 Subject: [PATCH 37/41] . --- .github/workflows/integration-tests.yml | 3 ++- python/test/unit/operators/test_flash_attention.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 8a2747749382..ce7fa8225d4e 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -121,7 +121,8 @@ jobs: - name: Run interpreter tests env: - TRITON_INTERPRET: 1 + # TRITON_INTERPRET: "1" + CUA_VISIBLE_DEVICES: "" run: | cd python/test/unit python3 -m pytest -vs operators/test_flash_attention.py diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 76209e77300a..ba45fdab36b3 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -55,5 +55,3 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) - if os.getenv("TRITON_INTERPRET") == 1: - assert False From 602597eba72aa5171ddd37508f05f684091373b8 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 17 Sep 2023 00:38:02 -0700 Subject: [PATCH 38/41] trying to clean --- python/triton/runtime/interpreter.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index c29827ea1b88..01f98110efbd 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -457,9 +457,12 @@ def _unwrap(tensor): class GridExecutor: def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize self.fn = fn self.arg_names = arg_names self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = {name for name in arg_names if __annotations__.get(name) == 'constexpr'} def _patch_lang(self, builder): lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] @@ -476,11 +479,10 @@ def __call__(self, *args_dev, **kwargs): # we need to copy arguments to the host for the interpreter args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] # implicitly convert tensor arguments to their base pointers - wrapped_args = [_implicit_cvt(arg) for arg in args_hst] + args = inspect.getcallargs(self.fn, *args_dev, **kwargs) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} # iterate through grid - grid_args = {name: val for name, val in zip(self.arg_names, args_dev)} - grid_args.update(kwargs) - grid = self.grid(grid_args) if callable(self.grid) else self.grid + grid = self.grid(args) if callable(self.grid) else self.grid assert len(grid) <= 3 grid = grid + (1,) * (3 - len(grid)) builder.set_grid_dim(*grid) @@ -488,7 +490,7 @@ def __call__(self, *args_dev, **kwargs): for y in range(grid[1]): for z in range(grid[2]): builder.set_grid_idx(x, y, z) - self.fn(*wrapped_args, **kwargs) + self.fn(**args) # copy arguments back to propagate side-effects for arg_dev, arg_hst in zip(args_dev, args_hst): if hasattr(arg_dev, 'data_ptr'): From 213815ed8ae7047d616e421641c5e3f6dc620067 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 17 Sep 2023 00:43:40 -0700 Subject: [PATCH 39/41] . --- python/triton/runtime/interpreter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 01f98110efbd..a1a72dfdceaf 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -477,9 +477,10 @@ def __call__(self, *args_dev, **kwargs): # remaps core language functions to interpreted ones self._patch_lang(builder) # we need to copy arguments to the host for the interpreter - args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] + # args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] # implicitly convert tensor arguments to their base pointers - args = inspect.getcallargs(self.fn, *args_dev, **kwargs) + orig_args = inspect.getcallargs(self.fn, *args_dev, **kwargs) + args = {name: _unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for name, arg in orig_args.items()} args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} # iterate through grid grid = self.grid(args) if callable(self.grid) else self.grid @@ -492,7 +493,7 @@ def __call__(self, *args_dev, **kwargs): builder.set_grid_idx(x, y, z) self.fn(**args) # copy arguments back to propagate side-effects - for arg_dev, arg_hst in zip(args_dev, args_hst): + for arg_dev, arg_hst in zip(orig_args.values(), args.values()): if hasattr(arg_dev, 'data_ptr'): _unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device)) From 2a091eca4ca925cff898e97d2beff761d244df3c Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 17 Sep 2023 00:49:29 -0700 Subject: [PATCH 40/41] . --- python/triton/runtime/interpreter.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index a1a72dfdceaf..218208d86aa7 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -462,7 +462,7 @@ def __init__(self, fn, arg_names, grid): self.arg_names = arg_names self.grid = grid __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} - self.constexprs = {name for name in arg_names if __annotations__.get(name) == 'constexpr'} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr'] def _patch_lang(self, builder): lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] @@ -472,15 +472,14 @@ def _patch_lang(self, builder): _patch_lang_math(lang[0], builder) def __call__(self, *args_dev, **kwargs): + args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] # removes reserved keywords from kwargs kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} # remaps core language functions to interpreted ones self._patch_lang(builder) # we need to copy arguments to the host for the interpreter - # args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] # implicitly convert tensor arguments to their base pointers - orig_args = inspect.getcallargs(self.fn, *args_dev, **kwargs) - args = {name: _unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for name, arg in orig_args.items()} + args = inspect.getcallargs(self.fn, *args_hst, **kwargs) args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} # iterate through grid grid = self.grid(args) if callable(self.grid) else self.grid @@ -493,7 +492,7 @@ def __call__(self, *args_dev, **kwargs): builder.set_grid_idx(x, y, z) self.fn(**args) # copy arguments back to propagate side-effects - for arg_dev, arg_hst in zip(orig_args.values(), args.values()): + for arg_dev, arg_hst in zip(args_dev, args_hst): if hasattr(arg_dev, 'data_ptr'): _unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device)) From 9a23a1be9b6c4020e006a01d38396e0356c7ffeb Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Sun, 17 Sep 2023 01:03:43 -0700 Subject: [PATCH 41/41] . --- python/test/unit/operators/test_flash_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index ba45fdab36b3..75da98e5044a 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): pytest.skip('Segmentation fault') capability = torch.cuda.get_device_capability() - if capability[0] < 8: + interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"] + if not interpreter and capability[0] < 8: pytest.skip("Flash attention only supported for compute capability < 80") torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()