|
25 | 25 | from tvm.relax.testing import get_relax_matmul_module |
26 | 26 | from tvm.script import relax as R |
27 | 27 |
|
| 28 | +try: |
| 29 | + import ml_dtypes |
| 30 | +except ImportError: |
| 31 | + ml_dtypes = None |
| 32 | + |
28 | 33 |
|
29 | 34 | @pytest.fixture(autouse=True) |
30 | 35 | def reset_seed(): |
@@ -226,6 +231,60 @@ def test_matmul_igemm_offload( |
226 | 231 | tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) |
227 | 232 |
|
228 | 233 |
|
| 234 | +@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") |
| 235 | +@pytest.mark.parametrize( |
| 236 | + "x_shape, y_shape, transpose_y, out_dtype", |
| 237 | + [ |
| 238 | + ((10, 32), (64, 32), True, "float32"), |
| 239 | + ((32, 16), (32, 16), True, "float16"), |
| 240 | + ((2, 10, 32), (2, 64, 32), True, "float32"), |
| 241 | + ], |
| 242 | +) |
| 243 | +def test_matmul_fp8_offload( |
| 244 | + x_shape, |
| 245 | + y_shape, |
| 246 | + transpose_y, |
| 247 | + out_dtype, |
| 248 | +): |
| 249 | + in_dtype = "e4m3_float8" |
| 250 | + mod = get_relax_matmul_module( |
| 251 | + x_shape, |
| 252 | + y_shape, |
| 253 | + in_dtype, |
| 254 | + out_dtype, |
| 255 | + bias_shape=None, |
| 256 | + transposed_y=transpose_y, |
| 257 | + activation=None, |
| 258 | + ) |
| 259 | + numpytype = "float8_e4m3fn" |
| 260 | + x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype) |
| 261 | + y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype) |
| 262 | + z = np.swapaxes(y, -2, -1) if transpose_y else y |
| 263 | + args = (x, y) |
| 264 | + |
| 265 | + out = get_result_with_relax_cublas_offload(mod, args) |
| 266 | + ref_out = np.matmul(x, z).astype(out_dtype) |
| 267 | + |
| 268 | + tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3) |
| 269 | + |
| 270 | + |
| 271 | +@pytest.mark.parametrize( |
| 272 | + "M, N, K, out_dtype, partition_done", |
| 273 | + [ |
| 274 | + (15, 64, 32, "float32", True), |
| 275 | + (15, 64, 32, "e4m3_float8", True), |
| 276 | + (15, 64, 32, "e5m2_float8", False), |
| 277 | + (16, 32, 60, "float32", False), |
| 278 | + (16, 30, 64, "float32", False), |
| 279 | + ], |
| 280 | +) |
| 281 | +def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done): |
| 282 | + mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=True) |
| 283 | + mod = partition_for_cublas(mod) |
| 284 | + func_name = "relax_matmul_cublas" if partition_done else "R.matmul" |
| 285 | + assert func_name in mod["main"].script() |
| 286 | + |
| 287 | + |
229 | 288 | def test_cublas_partition_matmul_without_bias(): |
230 | 289 | # cuBLAS does not handle 2D bias (residual input) |
231 | 290 | mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32)) |
|
0 commit comments