diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index c60d6e11213d..5dd15d3455d0 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -9,50 +9,50 @@ @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE", itertools.chain( *[[ # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None), # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]], # n-stage *[[ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True, None, None), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True, None, None), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] @@ -60,9 +60,9 @@ for STAGES in [4]], # mixed-precision *[[ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None), ] for ADTYPE, BDTYPE in [ ("float8e4nv", "float8e5"), ("float8e4nv", "float8e4nv"), @@ -80,9 +80,9 @@ ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]], # mixed-precision block layout *[[ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True, None, None), ] for ADTYPE, BDTYPE in [ ("float8e4nv", "float16"), ("float16", "float8e5"), @@ -91,10 +91,17 @@ ("bfloat16", "float32"), ("float32", "bfloat16"), ] for AT in [False, True] for BT in [False, True]], + # acc-out-dtype and output_dtype + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE, + OUTPUT_DTYPE), + (128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE, + OUTPUT_DTYPE), + ] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]], ), ) def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, - F8_FASTACCUM): + F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -142,7 +149,7 @@ def upcast_if_fp8(x, dtype): return f8_to_f16(x, dtype) return x - def init_input(m, n, dtype): + def init_input(m, n, dtype, acc_dtype): if 'float8' in dtype: ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype] sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128 @@ -150,27 +157,30 @@ def init_input(m, n, dtype): return sign | val if dtype == "int8": return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) - dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] - exponents = torch.randint(-10, 0, size=(m, n)) - ret = (2.**exponents).to(dtype).to("cuda") + # Use small range of values to prevent numerical issues. + min_exp = -4 if acc_dtype == "float16" else -10 + exponents = torch.randint(min_exp, 0, size=(m, n)) + ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda") return ret # allocate/transpose inputs - a = init_input(M, K, ADTYPE) - b = init_input(K, N, BDTYPE) + a = init_input(M, K, ADTYPE, ACC_DTYPE) + b = init_input(K, N, BDTYPE, ACC_DTYPE) a = a if not AT else a.T.contiguous().T b = b if not BT else b.T.contiguous().T # run test th_a = upcast_if_fp8(a, ADTYPE) th_b = upcast_if_fp8(b, BDTYPE) ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype) - th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype)) + acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype + output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype + th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype)) try: if is_fp8(ADTYPE): a = triton.reinterpret(a, getattr(tl, ADTYPE)) if is_fp8(BDTYPE): b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM) + tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, ALLOW_TF32, F8_FASTACCUM, output_dtype) torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 585129f9cc8b..e170091d67e4 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -153,7 +153,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum): + def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -169,7 +169,10 @@ def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum): ab_dtype = get_higher_dtype(a.dtype, b.dtype) # allocates output - c = torch.empty((M, N), device=device, dtype=ab_dtype) + if (output_dtype is None): + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) # Allowed types for acc_type given the types of a and b. supported_acc_dtypes = { @@ -189,6 +192,7 @@ def to_tl_type(ty): acc_dtype = to_tl_type(acc_dtype) ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) # Tensor cores support input with mixed float8 types. if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: @@ -207,8 +211,9 @@ def to_tl_type(ty): return c @staticmethod - def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True): - return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum) + def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None): + return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, + output_dtype=output_dtype) matmul = _matmul.apply