From b255a1c09bc045d5eb04cb78066f2c23754a3c94 Mon Sep 17 00:00:00 2001 From: Pablo Zimmermann Date: Mon, 30 Oct 2023 18:23:25 +0100 Subject: [PATCH 1/2] Add a matmul test from int8, bf16 In this PR we are adding a matmul test from int8, bf16. I had a few issues in the test so I refactored the file a bit. - First I included two new params: - Dot_out_dtype: So users of the test class can specify the type used internally in the dot, and not the one set by default given the two types. There are several restrictions for these types anyway. - C_dtype: The return type of the matmul. I included a few tests in the case of making a dot with two float16. - I had to modify test_matmul to use small integers when testing with two float16 since torch used float32 internally in this case and we were having precision issues when comparing the results with triton in the case that dot_out_dtype was float16. - I also needed to include torch.int8 in the possible datatypes. Finally I tried to simplify a bit the logic of the matmul/test_matmul because after adding these two parameters it was a bit hard to follow why we needed every part of the code, so I included a type_preference_list for the allowed dot_out_dtype given the types of the operands a and b. --- python/test/unit/operators/test_matmul.py | 127 ++++++++++++---------- python/triton/ops/__init__.py | 3 +- python/triton/ops/matmul.py | 69 +++++++----- 3 files changed, 113 insertions(+), 86 deletions(-) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 642b0982b45a..76b06781e4b7 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -26,61 +26,61 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, DOT_OUT_DTYPE, C_DTYPE", itertools.chain( *[ [ # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] ], # mixed-precision *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), ("float8e4nv", "float8e4nv"), ("float8e5", "float8e4nv"), @@ -88,17 +88,26 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ("float8e4b15", "float8e4b15"), ("float8e4nv", "float16"), ("float16", "float8e5"), + ("int8", "bfloat16"), + ("float16", "int8"), ("float16", "float32"), ("float32", "float16"), ("bfloat16", "float32"), ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] ], + # Dot-out-dtype and ab_dtype + *[ + [ + (32, 32, 32, 1, 1, 2, None, None, None, False, False, DTYPE, DTYPE, True, True, DOT_OUT_DTYPE, C_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, DTYPE, DTYPE, True, True, DOT_OUT_DTYPE, C_DTYPE), + ] for DTYPE in ["float16"] for DOT_OUT_DTYPE in [None, "float16", "float32"] for C_DTYPE in [None, "float16", "float32"] + ], # mixed-precision block layout *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True, None, None), ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), ("float16", "float8e5"), ("float16", "float32"), @@ -108,7 +117,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, DOT_OUT_DTYPE, C_DTYPE): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -137,22 +146,34 @@ def maybe_upcast(x, dtype, is_float8): return f8_to_f16(x, dtype) return x - def init_input(m, n, dtype): + string_to_torch_type = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} + + def init_input(m, n, dtype, only_small_integers=False): if 'float8' in dtype: ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth return sign | val if dtype == "int8": + if only_small_integers: + return torch.randint(-2, 2, (m, n), device="cuda", dtype=torch.int8) return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) - dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] - exponents = torch.randint(-10, 0, size=(m, n)) - ret = (2. ** exponents).to(dtype).to("cuda") + if only_small_integers: + exponents = torch.randint(0, 3, size=(m, n)) + else: + exponents = torch.randint(-10, 0, size=(m, n)) + ret = (2. ** exponents).to(string_to_torch_type[dtype]).to("cuda") return ret # allocate/transpose inputs - a = init_input(M, K, ADTYPE) - b = init_input(K, N, BDTYPE) + use_small_integers = False + # We can't force torch to use float16 for internal computations (it will use float32), and + # therefore we need to use small integers if we want the comparation to be precise enough + # in case of a and b being float16. + if (ADTYPE == "float16" and BDTYPE == "float16"): + use_small_integers = True + a = init_input(M, K, ADTYPE, use_small_integers) + b = init_input(K, N, BDTYPE, use_small_integers) a = a if not AT else a.T.contiguous().T b = b if not BT else b.T.contiguous().T # run test @@ -164,19 +185,15 @@ def init_input(m, n, dtype): th_b = maybe_upcast(b, BDTYPE, b_fp8) if BT and b_fp8: th_b = th_b.view(th_b.shape[::-1]).T - if th_a.is_floating_point(): - ab_dtype = th_a.dtype if th_a.element_size() > th_b.element_size() else th_b.dtype - else: - ab_dtype = torch.float32 - th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype)) - if ADTYPE == "int8" or BDTYPE == "int8": - th_c = th_c.to(torch.int8) + dot_dtype = string_to_torch_type[DOT_OUT_DTYPE] if DOT_OUT_DTYPE else triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + c_dtype = string_to_torch_type[C_DTYPE] if C_DTYPE else dot_dtype + th_c = torch.matmul(th_a.to(c_dtype), th_b.to(c_dtype)) try: if a_fp8: a = triton.reinterpret(a, getattr(tl, ADTYPE)) if b_fp8: b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM) + tt_c = triton.ops.matmul(a, b, dot_dtype if DOT_OUT_DTYPE else None, ALLOW_TF32, F8_FASTACCUM, c_dtype) torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index 6ceec8b56a00..b353cb26b365 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -2,7 +2,7 @@ from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy from .flash_attention import attention -from .matmul import _matmul, matmul +from .matmul import _matmul, get_higher_dtype, matmul __all__ = [ "blocksparse", @@ -11,4 +11,5 @@ "_matmul", "matmul", "attention", + "get_higher_dtype" ] diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 9bbeb3650ac5..3e961553d6a9 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -4,7 +4,7 @@ from .. import language as tl from .matmul_perf_model import early_config_prune, estimate_matmul_time -_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] def get_higher_dtype(a, b): @@ -84,7 +84,8 @@ def _kernel(A, B, C, M, N, K, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, C_DTYPE: tl.constexpr ): # matrix multiplication pid = tl.program_id(0) @@ -116,16 +117,16 @@ def _kernel(A, B, C, M, N, K, _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - if AB_DTYPE: - a = a.to(C.dtype.element_ty) - b = b.to(C.dtype.element_ty) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) if fp8_fast_accum: acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) else: acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = acc.to(C.dtype.element_ty) + acc = acc.to(C_DTYPE) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -144,7 +145,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): + def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum, c_dtype): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -155,33 +156,41 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): assert a.shape[1] == b.shape[0], "incompatible dimensions" M, K = a.shape _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype if a.dtype not in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] else torch.float16, + b.dtype if b.dtype not in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] else torch.float16) + # allocates output - if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\ - b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]: - c_dtype = torch.float16 - elif a.dtype in [torch.int8] or b.dtype in [torch.int8]: - c_dtype = torch.int32 - else: - c_dtype = get_higher_dtype(a.dtype, b.dtype) + if (c_dtype is None): + c_dtype = ab_dtype + c = torch.empty((M, N), device=device, dtype=c_dtype) + + torch_tl_type_converter = {torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32, torch.int8: tl.int8, torch.int32: tl.int32} + + # Allowed types for dot_out_type given the types of a and b. + type_preference_list = {} + type_preference_list[torch.float16] = [torch.float32, torch.float16] + type_preference_list[torch.bfloat16] = [torch.float32, torch.bfloat16] + type_preference_list[torch.float32] = [torch.float32] + type_preference_list[torch.int8] = [torch.int32] if dot_out_dtype is None: - if c_dtype in [torch.float16, torch.float32, torch.bfloat16]: - dot_out_dtype = tl.float32 - else: - dot_out_dtype = tl.int32 + dot_out_dtype = type_preference_list[c_dtype][0] else: assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype" - if dot_out_dtype == torch.float16: - dot_out_dtype = tl.float16 - elif dot_out_dtype in [torch.float32, torch.bfloat16]: - dot_out_dtype = tl.float32 - else: - dot_out_dtype = tl.int32 - ab_dtype = True + assert dot_out_dtype in type_preference_list[a.dtype], "dot_out_dtype not compatible with the type of a" + assert dot_out_dtype in type_preference_list[b.dtype], "dot_out_dtype not compatible with the type of b" + + dot_out_dtype = torch_tl_type_converter[dot_out_dtype] + ab_dtype = torch_tl_type_converter[ab_dtype] + c_dtype = torch_tl_type_converter[c_dtype] + + # In this cases, there is no need to manually convert types to ab_dtype. if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: - ab_dtype = False + ab_dtype = None if a.dtype in [torch.int8] and b.dtype in [torch.int8]: - ab_dtype = False + ab_dtype = None # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, @@ -191,12 +200,12 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, - GROUP_M=8, AB_DTYPE=ab_dtype) + GROUP_M=8, AB_DTYPE=ab_dtype, C_DTYPE=c_dtype) return c @staticmethod - def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True): - return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum) + def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True, c_dtype=None): + return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, c_dtype=c_dtype) matmul = _matmul.apply From 8fce141c48ae3a159d8a210846b1a134dd063440 Mon Sep 17 00:00:00 2001 From: Pablo Zimmermann Date: Fri, 3 Nov 2023 15:34:33 +0100 Subject: [PATCH 2/2] Fixup: Addressing comments --- python/test/unit/operators/test_matmul.py | 103 ++++++++++------------ python/triton/ops/matmul.py | 76 ++++++++-------- 2 files changed, 89 insertions(+), 90 deletions(-) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 76b06781e4b7..d43bd96b81e2 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -8,25 +8,8 @@ import triton.ops -def f8_to_f16(x, dtype): - - @triton.jit - def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < N - x = tl.load(X + offs, mask=mask) - tl.store(Y + offs, x, mask=mask) - - ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) - grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) - dtype = getattr(tl, dtype) - kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) - return ret - - @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, DOT_OUT_DTYPE, C_DTYPE", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", itertools.chain( *[ [ @@ -95,12 +78,12 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ("bfloat16", "float32"), ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] ], - # Dot-out-dtype and ab_dtype + # acc-out-dtype and output_dtype *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, False, False, DTYPE, DTYPE, True, True, DOT_OUT_DTYPE, C_DTYPE), - (128, 256, 32, 1, 8, 2, None, None, None, False, False, DTYPE, DTYPE, True, True, DOT_OUT_DTYPE, C_DTYPE), - ] for DTYPE in ["float16"] for DOT_OUT_DTYPE in [None, "float16", "float32"] for C_DTYPE in [None, "float16", "float32"] + (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE, OUTPUT_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE, OUTPUT_DTYPE), + ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"] ], # mixed-precision block layout *[ @@ -117,7 +100,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, DOT_OUT_DTYPE, C_DTYPE): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -141,59 +124,69 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, N = BLOCK_N if N is None else N K = BLOCK_K * SPLIT_K if K is None else K - def maybe_upcast(x, dtype, is_float8): - if is_float8: - return f8_to_f16(x, dtype) - return x + def is_fp8(dtype): + return "float8" in dtype - string_to_torch_type = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} + def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + def upcast_if_fp8(x, dtype, is_transpose): + if is_fp8(dtype): + th = f8_to_f16(x, dtype) + if (is_transpose): + # TODO(karupayun): I do not understand why we are doing this. + return th.view(th.shape[::-1]).T + return th + return x - def init_input(m, n, dtype, only_small_integers=False): + def init_input(m, n, dtype, acc_dtype): if 'float8' in dtype: ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth return sign | val if dtype == "int8": - if only_small_integers: - return torch.randint(-2, 2, (m, n), device="cuda", dtype=torch.int8) return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) - if only_small_integers: + # The epsilon of "float16" is around ~1e-3, far away than the maximum absolute expected error (1e-5). + # If we want the comparation to be precise enough we need to use small integers + if (ACC_DTYPE == "float16"): exponents = torch.randint(0, 3, size=(m, n)) else: exponents = torch.randint(-10, 0, size=(m, n)) - ret = (2. ** exponents).to(string_to_torch_type[dtype]).to("cuda") + ret = (2. ** exponents).to(getattr(torch, dtype)).to("cuda") return ret # allocate/transpose inputs - use_small_integers = False - # We can't force torch to use float16 for internal computations (it will use float32), and - # therefore we need to use small integers if we want the comparation to be precise enough - # in case of a and b being float16. - if (ADTYPE == "float16" and BDTYPE == "float16"): - use_small_integers = True - a = init_input(M, K, ADTYPE, use_small_integers) - b = init_input(K, N, BDTYPE, use_small_integers) + a = init_input(M, K, ADTYPE, ACC_DTYPE) + b = init_input(K, N, BDTYPE, ACC_DTYPE) a = a if not AT else a.T.contiguous().T b = b if not BT else b.T.contiguous().T # run test - a_fp8 = "float8" in ADTYPE - b_fp8 = "float8" in BDTYPE - th_a = maybe_upcast(a, ADTYPE, a_fp8) - if AT and a_fp8: - th_a = th_a.view(th_a.shape[::-1]).T - th_b = maybe_upcast(b, BDTYPE, b_fp8) - if BT and b_fp8: - th_b = th_b.view(th_b.shape[::-1]).T - dot_dtype = string_to_torch_type[DOT_OUT_DTYPE] if DOT_OUT_DTYPE else triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) - c_dtype = string_to_torch_type[C_DTYPE] if C_DTYPE else dot_dtype - th_c = torch.matmul(th_a.to(c_dtype), th_b.to(c_dtype)) + th_a = upcast_if_fp8(a, ADTYPE, AT) + th_b = upcast_if_fp8(b, BDTYPE, BT) + ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype + output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype + th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) try: - if a_fp8: + if is_fp8(ADTYPE): a = triton.reinterpret(a, getattr(tl, ADTYPE)) - if b_fp8: + if is_fp8(BDTYPE): b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b, dot_dtype if DOT_OUT_DTYPE else None, ALLOW_TF32, F8_FASTACCUM, c_dtype) + tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, ALLOW_TF32, F8_FASTACCUM, output_dtype) torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 3e961553d6a9..7b9f2ce42519 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -7,7 +7,15 @@ _ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) if a is b: return a @@ -80,12 +88,11 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - dot_out_dtype: tl.constexpr, + acc_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, - AB_DTYPE: tl.constexpr, C_DTYPE: tl.constexpr + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr ): # matrix multiplication pid = tl.program_id(0) @@ -107,7 +114,7 @@ def _kernel(A, B, C, M, N, K, # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): if EVEN_K: a = tl.load(A) @@ -121,12 +128,12 @@ def _kernel(A, B, C, M, N, K, a = a.to(AB_DTYPE) b = b.to(AB_DTYPE) if fp8_fast_accum: - acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) else: - acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk - acc = acc.to(C_DTYPE) + acc = acc.to(C.dtype.element_ty) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -145,7 +152,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum, c_dtype): + def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -158,54 +165,53 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum, c_dtype): _, N = b.shape # common type between a and b - ab_dtype = get_higher_dtype(a.dtype if a.dtype not in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] else torch.float16, - b.dtype if b.dtype not in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] else torch.float16) + ab_dtype = get_higher_dtype(a.dtype, b.dtype) # allocates output - if (c_dtype is None): - c_dtype = ab_dtype + if (output_dtype is None): + output_dtype = ab_dtype - c = torch.empty((M, N), device=device, dtype=c_dtype) + c = torch.empty((M, N), device=device, dtype=output_dtype) - torch_tl_type_converter = {torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, torch.float32: tl.float32, torch.int8: tl.int8, torch.int32: tl.int32} + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,) + } - # Allowed types for dot_out_type given the types of a and b. - type_preference_list = {} - type_preference_list[torch.float16] = [torch.float32, torch.float16] - type_preference_list[torch.bfloat16] = [torch.float32, torch.bfloat16] - type_preference_list[torch.float32] = [torch.float32] - type_preference_list[torch.int8] = [torch.int32] - if dot_out_dtype is None: - dot_out_dtype = type_preference_list[c_dtype][0] + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] else: - assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype" - assert dot_out_dtype in type_preference_list[a.dtype], "dot_out_dtype not compatible with the type of a" - assert dot_out_dtype in type_preference_list[b.dtype], "dot_out_dtype not compatible with the type of b" + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a" + assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b" - dot_out_dtype = torch_tl_type_converter[dot_out_dtype] - ab_dtype = torch_tl_type_converter[ab_dtype] - c_dtype = torch_tl_type_converter[c_dtype] + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) - # In this cases, there is no need to manually convert types to ab_dtype. + # Tensor cores support input with mixed float8 types. if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: ab_dtype = None - if a.dtype in [torch.int8] and b.dtype in [torch.int8]: - ab_dtype = None # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - dot_out_dtype=dot_out_dtype, + acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, - GROUP_M=8, AB_DTYPE=ab_dtype, C_DTYPE=c_dtype) + GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod - def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True, c_dtype=None): - return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, c_dtype=c_dtype) + def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, output_dtype=output_dtype) matmul = _matmul.apply