Skip to content

Commit 276f73c

Browse files
WanRui37luotao1
andauthored
【Hackathon 9th No.28】add test_cutlass_fp8_fp8_fp8_dual_gemm_fused (#3935)
* add test_cutlass_fp8_fp8_fp8_dual_gemm_fused * fix the version * fix code style --------- Co-authored-by: Tao Luo <[email protected]>
1 parent d3e4ae3 commit 276f73c

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import unittest
17+
from itertools import product
18+
19+
import numpy as np
20+
import paddle
21+
22+
from fastdeploy.model_executor.ops.gpu import cutlass_fp8_fp8_fp8_dual_gemm_fused
23+
24+
25+
class TestFp8Fp8Fp8DualGemm(unittest.TestCase):
26+
def setUp(self):
27+
"""
28+
Initialize the test environment,
29+
including setting random seeds and environment variables.
30+
"""
31+
paddle.seed(2024)
32+
self.prop = paddle.device.cuda.get_device_properties()
33+
self.sm_version = self.prop.major * 10 + self.prop.minor
34+
print(f"sm version: {self.sm_version}")
35+
self.E4M3_MAX_POS = 448.0
36+
os.environ["FLAGS_cuda_core_fp8_gemm"] = "1"
37+
print(paddle.device.cuda.get_device_properties())
38+
print(paddle.__git_commit__)
39+
40+
def test_dual_gemm_case(self):
41+
"""
42+
Check if the cutlass_fp8_fp8_fp8_dual_gemm_fused function works properly.
43+
"""
44+
if self.sm_version < 90:
45+
self.skipTest("cutlass_fp8_fp8_fp8_dual_gemm_fused only support sm90+")
46+
nks = [
47+
[2048, 2048],
48+
[2048, 5504],
49+
[6144, 2048],
50+
[4096, 4096],
51+
[4096, 12800],
52+
[6144, 4096],
53+
[5120, 5120],
54+
[5120, 13824],
55+
[15360, 5120],
56+
]
57+
m_values = [1, 2, 3, 4]
58+
transpose_combinations = [(False, True)]
59+
activation_types = [""]
60+
61+
combinations = product(m_values, nks, transpose_combinations, activation_types)
62+
for m, (n, k), (trans_x, trans_y), act_type in combinations:
63+
x = (
64+
paddle.rand([m, k] if not trans_x else [k, m])
65+
.clip(min=-self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
66+
.to(paddle.float8_e4m3fn)
67+
)
68+
69+
y0 = (
70+
paddle.rand([k, n] if not trans_y else [n, k])
71+
.clip(min=-self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
72+
.to(paddle.float8_e4m3fn)
73+
)
74+
75+
y1 = (
76+
paddle.rand([k, n] if not trans_y else [n, k])
77+
.clip(min=-self.E4M3_MAX_POS, max=self.E4M3_MAX_POS)
78+
.to(paddle.float8_e4m3fn)
79+
)
80+
81+
scale0 = 1.2
82+
scale1 = 0.8
83+
scale_out = 1.0
84+
85+
x_bf16 = x.astype("bfloat16")
86+
y0_bf16 = y0.astype("bfloat16")
87+
y1_bf16 = y1.astype("bfloat16")
88+
89+
gemm0 = paddle.matmul(x_bf16, y0_bf16, transpose_x=trans_x, transpose_y=trans_y)
90+
gemm1 = paddle.matmul(x_bf16, y1_bf16, transpose_x=trans_x, transpose_y=trans_y)
91+
92+
gemm0 = gemm0 * scale0
93+
gemm1 = gemm1 * scale1
94+
95+
if act_type == "" or act_type == "swiglu":
96+
ref_out = gemm0 * paddle.nn.functional.sigmoid(gemm1)
97+
98+
ref_out = ref_out.clip(min=-self.E4M3_MAX_POS, max=self.E4M3_MAX_POS).to(paddle.float8_e4m3fn)
99+
100+
result = cutlass_fp8_fp8_fp8_dual_gemm_fused(
101+
x,
102+
y0,
103+
y1,
104+
bias0=None,
105+
bias1=None,
106+
transpose_x=trans_x,
107+
transpose_y=trans_y,
108+
scale0=scale0,
109+
scale1=scale1,
110+
scale_out=scale_out,
111+
activation_type=act_type,
112+
)
113+
114+
np.testing.assert_allclose(
115+
ref_out.astype("float32").numpy(), result.astype("float32").numpy(), rtol=5e-3, atol=5e-3
116+
)
117+
118+
119+
if __name__ == "__main__":
120+
unittest.main()

0 commit comments

Comments
 (0)