diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 642b0982b45a..d43bd96b81e2 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -8,79 +8,62 @@ import triton.ops -def f8_to_f16(x, dtype): - - @triton.jit - def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offs < N - x = tl.load(X + offs, mask=mask) - tl.store(Y + offs, x, mask=mask) - - ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) - grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) - dtype = getattr(tl, dtype) - kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) - return ret - - @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", itertools.chain( *[ [ # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] ], # mixed-precision *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), ("float8e4nv", "float8e4nv"), ("float8e5", "float8e4nv"), @@ -88,17 +71,26 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ("float8e4b15", "float8e4b15"), ("float8e4nv", "float16"), ("float16", "float8e5"), + ("int8", "bfloat16"), + ("float16", "int8"), ("float16", "float32"), ("float32", "float16"), ("bfloat16", "float32"), ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] ], + # acc-out-dtype and output_dtype + *[ + [ + (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE, OUTPUT_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE, OUTPUT_DTYPE), + ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"] + ], # mixed-precision block layout *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True, None, None), ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), ("float16", "float8e5"), ("float16", "float32"), @@ -108,7 +100,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): ], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -132,12 +124,35 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, N = BLOCK_N if N is None else N K = BLOCK_K * SPLIT_K if K is None else K - def maybe_upcast(x, dtype, is_float8): - if is_float8: - return f8_to_f16(x, dtype) + def is_fp8(dtype): + return "float8" in dtype + + def f8_to_f16(x, dtype): + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + def upcast_if_fp8(x, dtype, is_transpose): + if is_fp8(dtype): + th = f8_to_f16(x, dtype) + if (is_transpose): + # TODO(karupayun): I do not understand why we are doing this. + return th.view(th.shape[::-1]).T + return th return x - def init_input(m, n, dtype): + def init_input(m, n, dtype, acc_dtype): if 'float8' in dtype: ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 @@ -145,38 +160,33 @@ def init_input(m, n, dtype): return sign | val if dtype == "int8": return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) - dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] - exponents = torch.randint(-10, 0, size=(m, n)) - ret = (2. ** exponents).to(dtype).to("cuda") + # The epsilon of "float16" is around ~1e-3, far away than the maximum absolute expected error (1e-5). + # If we want the comparation to be precise enough we need to use small integers + if (ACC_DTYPE == "float16"): + exponents = torch.randint(0, 3, size=(m, n)) + else: + exponents = torch.randint(-10, 0, size=(m, n)) + ret = (2. ** exponents).to(getattr(torch, dtype)).to("cuda") return ret # allocate/transpose inputs - a = init_input(M, K, ADTYPE) - b = init_input(K, N, BDTYPE) + a = init_input(M, K, ADTYPE, ACC_DTYPE) + b = init_input(K, N, BDTYPE, ACC_DTYPE) a = a if not AT else a.T.contiguous().T b = b if not BT else b.T.contiguous().T # run test - a_fp8 = "float8" in ADTYPE - b_fp8 = "float8" in BDTYPE - th_a = maybe_upcast(a, ADTYPE, a_fp8) - if AT and a_fp8: - th_a = th_a.view(th_a.shape[::-1]).T - th_b = maybe_upcast(b, BDTYPE, b_fp8) - if BT and b_fp8: - th_b = th_b.view(th_b.shape[::-1]).T - if th_a.is_floating_point(): - ab_dtype = th_a.dtype if th_a.element_size() > th_b.element_size() else th_b.dtype - else: - ab_dtype = torch.float32 - th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype)) - if ADTYPE == "int8" or BDTYPE == "int8": - th_c = th_c.to(torch.int8) + th_a = upcast_if_fp8(a, ADTYPE, AT) + th_b = upcast_if_fp8(b, BDTYPE, BT) + ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) + acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype + output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype + th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) try: - if a_fp8: + if is_fp8(ADTYPE): a = triton.reinterpret(a, getattr(tl, ADTYPE)) - if b_fp8: + if is_fp8(BDTYPE): b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM) + tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, ALLOW_TF32, F8_FASTACCUM, output_dtype) torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index 6ceec8b56a00..b353cb26b365 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -2,7 +2,7 @@ from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy from .flash_attention import attention -from .matmul import _matmul, matmul +from .matmul import _matmul, get_higher_dtype, matmul __all__ = [ "blocksparse", @@ -11,4 +11,5 @@ "_matmul", "matmul", "attention", + "get_higher_dtype" ] diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 9bbeb3650ac5..7b9f2ce42519 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -4,10 +4,18 @@ from .. import language as tl from .matmul_perf_model import early_config_prune, estimate_matmul_time -_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) if a is b: return a @@ -80,7 +88,7 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, - dot_out_dtype: tl.constexpr, + acc_dtype: tl.constexpr, allow_tf32: tl.constexpr, fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -106,7 +114,7 @@ def _kernel(A, B, C, M, N, K, # pointers A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): if EVEN_K: a = tl.load(A) @@ -116,13 +124,13 @@ def _kernel(A, B, C, M, N, K, _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - if AB_DTYPE: - a = a.to(C.dtype.element_ty) - b = b.to(C.dtype.element_ty) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) if fp8_fast_accum: - acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) else: - acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) @@ -144,7 +152,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): + def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -155,48 +163,55 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): assert a.shape[1] == b.shape[0], "incompatible dimensions" M, K = a.shape _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + # allocates output - if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\ - b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]: - c_dtype = torch.float16 - elif a.dtype in [torch.int8] or b.dtype in [torch.int8]: - c_dtype = torch.int32 - else: - c_dtype = get_higher_dtype(a.dtype, b.dtype) - c = torch.empty((M, N), device=device, dtype=c_dtype) - if dot_out_dtype is None: - if c_dtype in [torch.float16, torch.float32, torch.bfloat16]: - dot_out_dtype = tl.float32 - else: - dot_out_dtype = tl.int32 + if (output_dtype is None): + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,) + } + + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] else: - assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype" - if dot_out_dtype == torch.float16: - dot_out_dtype = tl.float16 - elif dot_out_dtype in [torch.float32, torch.bfloat16]: - dot_out_dtype = tl.float32 - else: - dot_out_dtype = tl.int32 - ab_dtype = True + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert acc_dtype in supported_acc_dtypes[a.dtype], "acc_dtype not compatible with the type of a" + assert acc_dtype in supported_acc_dtypes[b.dtype], "acc_dtype not compatible with the type of b" + + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: - ab_dtype = False - if a.dtype in [torch.int8] and b.dtype in [torch.int8]: - ab_dtype = False + ab_dtype = None # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - dot_out_dtype=dot_out_dtype, + acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod - def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True): - return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum) + def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, output_dtype=output_dtype) matmul = _matmul.apply