Skip to content

Commit b3ffd97

Browse files
authored
[BYOC] Add layout check and update shape check for cublas FP8 BYOC (#16895)
1 parent 857fe61 commit b3ffd97

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

python/tvm/relax/backend/contrib/cublas.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from functools import reduce
2121

2222
import tvm
23+
from tvm import DataType
2324
from tvm.relax import transform
2425
from tvm.relax.transform import PatternCheckContext
2526

@@ -68,11 +69,30 @@ def _check_matmul(context: PatternCheckContext) -> bool:
6869
# Rows number must be multiples of 4 for IGEMM
6970
return False
7071
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
71-
# Matrix dimensions must be multiples of 16. This requirement is missing from the cuBLAS
72-
# docs, but it was observed during testing.
73-
if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 16 != 0:
72+
matmul_rhs_var = matmul_call.args[1]
73+
rhs_transposed = False
74+
if matmul_rhs_var in context.matched_bindings:
75+
matmul_rhs_call = context.matched_bindings[matmul_rhs_var]
76+
assert (
77+
isinstance(matmul_rhs_call, tvm.relax.Call)
78+
and matmul_rhs_call.op.name == "relax.permute_dims"
79+
)
80+
rhs_transposed = True
81+
82+
if not rhs_transposed:
83+
# cuBLAS FP8 operations require rhs being transposed
7484
return False
75-
if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or rhs_shape[-2] % 16 != 0:
85+
86+
# cuBLAS FP8 operations require all tensors being aligned to 16 bytes.
87+
if (
88+
not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int))
89+
or rhs_shape[-1] % (16 // DataType(lhs_dtype).itemsize()) != 0
90+
):
91+
return False
92+
if (
93+
not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int))
94+
or rhs_shape[-2] % (16 // DataType(out_dtype).itemsize()) != 0
95+
):
7696
return False
7797

7898
lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)

tests/python/relax/test_codegen_cublas.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,21 @@ def test_matmul_fp8_offload(
269269

270270

271271
@pytest.mark.parametrize(
272-
"M, N, K, out_dtype, partition_done",
272+
"M, N, K, out_dtype, transposed_y, partition_done",
273273
[
274-
(15, 64, 32, "float32", True),
275-
(15, 64, 32, "e4m3_float8", True),
276-
(15, 64, 32, "e5m2_float8", False),
277-
(16, 32, 60, "float32", False),
278-
(16, 30, 64, "float32", False),
274+
(15, 64, 32, "float32", True, True),
275+
(15, 64, 32, "e4m3_float8", True, True),
276+
(15, 64, 32, "e5m2_float8", True, False),
277+
(16, 32, 60, "float32", True, False),
278+
(16, 30, 64, "float32", True, False),
279+
(16, 8, 16, "float16", True, True),
280+
(16, 16, 16, "float16", False, False),
279281
],
280282
)
281-
def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done):
282-
mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=True)
283+
def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, transposed_y, partition_done):
284+
mod = get_relax_matmul_module(
285+
(M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=transposed_y
286+
)
283287
mod = partition_for_cublas(mod)
284288
func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
285289
assert func_name in mod["main"].script()

0 commit comments

Comments
 (0)