Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 59 additions & 49 deletions python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,60 +9,60 @@


@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM",
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE",
itertools.chain(
*[[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True, None, None),
# variable input
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True, None, None),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]],
# n-stage
*[[
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True),
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True, None, None),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True, None, None),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True, None, None),
]
for DTYPE in ["float16", "bfloat16", "float32"]
for AT in [False, True]
for BT in [False, True]
for STAGES in [4]],
# mixed-precision
*[[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM),
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM, None, None),
] for ADTYPE, BDTYPE in [
("float8e4nv", "float8e5"),
("float8e4nv", "float8e4nv"),
Expand All @@ -80,9 +80,9 @@
] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]],
# mixed-precision block layout
*[[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True),
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True, None, None),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True, None, None),
] for ADTYPE, BDTYPE in [
("float8e4nv", "float16"),
("float16", "float8e5"),
Expand All @@ -91,10 +91,17 @@
("bfloat16", "float32"),
("float32", "bfloat16"),
] for AT in [False, True] for BT in [False, True]],
# acc-out-dtype and output_dtype
*[[
(32, 32, 32, 1, 1, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE,
OUTPUT_DTYPE),
(128, 256, 32, 1, 8, 2, None, None, None, False, False, "float16", "float16", True, True, ACC_DTYPE,
OUTPUT_DTYPE),
] for ACC_DTYPE in [None, "float16", "float32"] for OUTPUT_DTYPE in [None, "float16", "float32"]],
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32,
F8_FASTACCUM):
F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
Expand Down Expand Up @@ -142,35 +149,38 @@ def upcast_if_fp8(x, dtype):
return f8_to_f16(x, dtype)
return x

def init_input(m, n, dtype):
def init_input(m, n, dtype, acc_dtype):
if 'float8' in dtype:
ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype]
sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128
val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth
return sign | val
if dtype == "int8":
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
exponents = torch.randint(-10, 0, size=(m, n))
ret = (2.**exponents).to(dtype).to("cuda")
# Use small range of values to prevent numerical issues.
min_exp = -4 if acc_dtype == "float16" else -10
exponents = torch.randint(min_exp, 0, size=(m, n))
ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda")
return ret

# allocate/transpose inputs
a = init_input(M, K, ADTYPE)
b = init_input(K, N, BDTYPE)
a = init_input(M, K, ADTYPE, ACC_DTYPE)
b = init_input(K, N, BDTYPE, ACC_DTYPE)
a = a if not AT else a.T.contiguous().T
b = b if not BT else b.T.contiguous().T
# run test
th_a = upcast_if_fp8(a, ADTYPE)
th_b = upcast_if_fp8(b, BDTYPE)
ab_dtype = triton.ops.get_higher_dtype(th_a.dtype, th_b.dtype)
th_c = torch.matmul(th_a.to(ab_dtype), th_b.to(ab_dtype))
acc_dtype = getattr(torch, ACC_DTYPE) if ACC_DTYPE else ab_dtype
output_dtype = getattr(torch, OUTPUT_DTYPE) if OUTPUT_DTYPE else ab_dtype
th_c = torch.matmul(th_a.to(output_dtype), th_b.to(output_dtype))
try:
if is_fp8(ADTYPE):
a = triton.reinterpret(a, getattr(tl, ADTYPE))
if is_fp8(BDTYPE):
b = triton.reinterpret(b, getattr(tl, BDTYPE))
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM)
tt_c = triton.ops.matmul(a, b, acc_dtype if ACC_DTYPE else None, ALLOW_TF32, F8_FASTACCUM, output_dtype)
torch.testing.assert_close(th_c, tt_c)
except triton.OutOfResources as e:
pytest.skip(str(e))
13 changes: 9 additions & 4 deletions python/triton/ops/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class _matmul(torch.autograd.Function):
_locks = {}

@staticmethod
def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum):
def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum, output_dtype):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
Expand All @@ -169,7 +169,10 @@ def _call(a, b, acc_dtype, allow_tf32, fp8_fast_accum):
ab_dtype = get_higher_dtype(a.dtype, b.dtype)

# allocates output
c = torch.empty((M, N), device=device, dtype=ab_dtype)
if (output_dtype is None):
output_dtype = ab_dtype

c = torch.empty((M, N), device=device, dtype=output_dtype)

# Allowed types for acc_type given the types of a and b.
supported_acc_dtypes = {
Expand All @@ -189,6 +192,7 @@ def to_tl_type(ty):

acc_dtype = to_tl_type(acc_dtype)
ab_dtype = to_tl_type(ab_dtype)
output_dtype = to_tl_type(output_dtype)

# Tensor cores support input with mixed float8 types.
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]:
Expand All @@ -207,8 +211,9 @@ def to_tl_type(ty):
return c

@staticmethod
def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True):
return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum)
def forward(ctx, a, b, acc_dtype=None, allow_tf32=True, fp8_fast_accum=True, output_dtype=None):
return _matmul._call(a, b, acc_dtype=acc_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum,
output_dtype=output_dtype)


matmul = _matmul.apply