diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 35407b578bcd..e8b757b31edf 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -155,6 +155,8 @@ def _call(a, b, dot_out_dtype, allow_tf32): if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\ b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]: c_dtype = torch.float16 + elif a.dtype in [torch.int8] or b.dtype in [torch.int8]: + c_dtype = torch.int32 else: c_dtype = get_higher_dtype(a.dtype, b.dtype) c = torch.empty((M, N), device=device, dtype=c_dtype) @@ -174,6 +176,8 @@ def _call(a, b, dot_out_dtype, allow_tf32): ab_dtype = True if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: ab_dtype = False + if a.dtype in [torch.int8] and b.dtype in [torch.int8]: + ab_dtype = False # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K,