@@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module(
134134 return tvm .IRModule ({"main" : func })
135135
136136
137+ def get_relax_matmul_multiply_module (
138+ x_shape ,
139+ y_shape ,
140+ z_shape ,
141+ in_dtype ,
142+ acc_dtype ,
143+ out_dtype ,
144+ transposed_y = False ,
145+ ):
146+ """Create a matmul op followd by multiply operations."""
147+ with IRBuilder () as builder :
148+ with relax_builder .function ():
149+ R .func_name ("main" )
150+ x = R .arg ("x" , R .Tensor (x_shape , in_dtype ))
151+ y = R .arg ("y" , R .Tensor (y_shape , in_dtype ))
152+ scaleA = R .arg ("scaleA" , R .Tensor (z_shape , acc_dtype ))
153+ scaleB = R .arg ("scaleB" , R .Tensor (z_shape , acc_dtype ))
154+
155+ with R .dataflow () as frame :
156+ if transposed_y :
157+ axes = list (range (len (y_shape ) - 2 )) + [- 1 , - 2 ]
158+ y = R .emit (R .permute_dims (y , axes = axes ))
159+ result = R .emit (R .matmul (x , y , out_dtype = acc_dtype ))
160+ z = R .emit (R .multiply (scaleA , scaleB ))
161+ result = R .emit (R .multiply (result , z ))
162+ if acc_dtype != out_dtype :
163+ result = R .emit (R .astype (result , out_dtype ))
164+ R .output (result )
165+ R .func_ret_value (frame .output_vars [0 ])
166+
167+ func = builder .get ()
168+ return tvm .IRModule ({"main" : func })
169+
170+
137171@pytest .mark .parametrize (
138172 "x_shape, y_shape, transpose_y, epilogue" ,
139173 [
@@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload():
327361 tvm .testing .assert_allclose (out , ref , rtol = 1e-3 , atol = 1e-3 )
328362
329363
364+ @tvm .testing .requires_cuda_compute_version (9 )
365+ @pytest .mark .skipif (ml_dtypes is None , reason = "requires ml_dtypes to be installed" )
366+ def test_matmul_fp8_multiply_offload ():
367+ x_shape = (10 , 32 )
368+ y_shape = (64 , 32 )
369+ z_shape = (1 ,)
370+ in_dtype , acc_dtype = ("e4m3_float8" , "float32" )
371+
372+ mod = get_relax_matmul_multiply_module (
373+ x_shape ,
374+ y_shape ,
375+ z_shape ,
376+ in_dtype ,
377+ acc_dtype ,
378+ "float16" ,
379+ transposed_y = True ,
380+ )
381+
382+ numpytype = "float8_e4m3fn"
383+ x = np .random .uniform (low = 0 , high = 5 , size = x_shape ).astype (numpytype )
384+ y = np .random .uniform (low = 0 , high = 5 , size = y_shape ).astype (numpytype )
385+ scaleA = np .random .uniform (low = 0 , high = 5 , size = z_shape ).astype (acc_dtype )
386+ scaleB = np .random .uniform (low = 0 , high = 5 , size = z_shape ).astype (acc_dtype )
387+ args = (x , y , scaleA , scaleB )
388+
389+ out = get_result_with_relax_cublas_offload (mod , args )
390+ ref = build_and_run (mod , args , "llvm" , legalize = True )
391+ tvm .testing .assert_allclose (out , ref , rtol = 1e-3 , atol = 1e-3 )
392+
393+
330394@pytest .mark .parametrize (
331395 "M, N, K, out_dtype, transposed_y, partition_done" ,
332396 [
@@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings
371435 assert len (mod ["main" ].body .blocks [0 ].bindings ) == num_bindings
372436
373437
438+ def test_cublas_partition_fp8_matmul_multiply ():
439+ M , N , K = (32 , 64 , 128 )
440+ mod = get_relax_matmul_multiply_module (
441+ (M , K ),
442+ (N , K ),
443+ (1 ,),
444+ "e4m3_float8" ,
445+ "float32" ,
446+ "float16" ,
447+ transposed_y = True ,
448+ )
449+ mod = partition_for_cublas (mod )
450+ assert len (mod ["main" ].body .blocks [0 ].bindings ) == 1
451+
452+
374453def test_cublas_partition_matmul_without_bias ():
375454 # cuBLAS does not handle 2D bias (residual input)
376455 mod = get_relax_matmul_module ((16 , 32 ), (32 , 32 ), "float16" , "float16" , bias_shape = (16 , 32 ))
0 commit comments