Skip to content

Commit b70ca35

Browse files
authored
【Hackathon 9th No.52】add test_dynamic_per_token_scaled_fp8_quant (#4015)
* add test_dynamic_per_token_scaled_fp8_quant * fix * add bfloat16 * ci
1 parent befe463 commit b70ca35

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import unittest
2+
3+
import numpy as np
4+
import paddle
5+
6+
from fastdeploy.model_executor.ops.gpu import dynamic_per_token_scaled_fp8_quant
7+
8+
9+
class TestDynamicPerTokenScaledFp8Quant(unittest.TestCase):
10+
def setUp(self):
11+
paddle.seed(42)
12+
np.random.seed(42)
13+
14+
def _run_dynamic_per_token_scaled_fp8_quant(self, input_data, scale_ub=0.0):
15+
"""
16+
Run the dynamic per-token scaled FP8 quantization operator
17+
18+
Args:
19+
input_data: Input data (numpy array)
20+
scale_ub: Scale upper bound value
21+
22+
Returns:
23+
Quantized output and scaling factors
24+
"""
25+
input_tensor = paddle.to_tensor(input_data)
26+
27+
# Determine the output shape
28+
num_tokens = input_tensor.shape[0] if len(input_tensor.shape) > 1 else 1
29+
30+
# Create the output tensor
31+
out_tensor = paddle.empty(input_tensor.shape, dtype=paddle.float8_e4m3fn)
32+
33+
# Create the scales tensor
34+
scales_tensor = paddle.empty([num_tokens], dtype="float32")
35+
36+
inputs = {"out": out_tensor, "input": input_tensor, "scale": scales_tensor}
37+
attrs = {"scale_ub": scale_ub}
38+
dynamic_per_token_scaled_fp8_quant(*inputs.values(), *attrs.values())
39+
40+
out_np = out_tensor.cpu().numpy()
41+
scales_np = scales_tensor.cpu().numpy()
42+
43+
return out_np, scales_np
44+
45+
def _verify_results(self, input_data, output_data, scales, scale_ub=0.0, tol=7e-2):
46+
"""
47+
Verify that the quantization results are correct
48+
49+
Args:
50+
input_data: Original input data
51+
output_data: Quantized output data
52+
scales: Scaling factors used
53+
scale_ub: Scale upper bound value
54+
tol: Allowed tolerance range
55+
"""
56+
# Check if the output data type is FP8
57+
self.assertEqual(output_data.dtype, "float8_e4m3fn") # FP8 is stored as float8_e4m3fn
58+
59+
# For each token, verify the quantization process
60+
num_tokens = input_data.shape[0] if len(input_data.shape) > 1 else 1
61+
62+
for i in range(num_tokens):
63+
# Get the current token's input and output
64+
if len(input_data.shape) > 1:
65+
token_input = input_data[i]
66+
token_output = output_data[i] if len(output_data.shape) > 1 else output_data
67+
else:
68+
token_input = input_data
69+
token_output = output_data
70+
71+
# Get the current token's scaling factor
72+
token_scale = scales[i] if num_tokens > 1 else scales[0]
73+
74+
# If there is a scale upper limit, check if it is respected
75+
if scale_ub > 0:
76+
max_val = np.max(np.abs(token_input))
77+
expected_scale = min(max_val, scale_ub) / 448.0
78+
self.assertAlmostEqual(token_scale, expected_scale, delta=tol)
79+
else:
80+
max_val = np.max(np.abs(token_input))
81+
expected_scale = max_val / 448.0
82+
self.assertAlmostEqual(token_scale, expected_scale, delta=tol)
83+
84+
# Verify that the quantized values are reasonable
85+
# The FP8 range is typically -1.0 to 1.0, quantized values should be within this range
86+
reconstructed = token_output.astype(np.float32) * token_scale
87+
diff = np.abs(reconstructed - token_input.astype(np.float32))
88+
self.assertTrue(np.all(diff <= tol * np.max(np.abs(token_input))))
89+
90+
def test_fp32_input(self):
91+
"""Test FP32 input"""
92+
input_data = np.array([0.1, -0.2, 0.3, -0.4], dtype=np.float32)
93+
94+
# Test the case without a scale upper limit
95+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
96+
self._verify_results(input_data, output_data, scales)
97+
98+
# Test the case with a scale upper limit
99+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data, scale_ub=1.5)
100+
print(output_data, scales)
101+
self._verify_results(input_data, output_data, scales, scale_ub=1.5)
102+
103+
# Test the single-token case
104+
single_token = input_data[0:1]
105+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(single_token)
106+
self._verify_results(single_token, output_data, scales)
107+
108+
def test_large_values(self):
109+
"""Test large value input"""
110+
input_data = np.array([100.0, -200.0, 300.0, -320.0], dtype=np.float32)
111+
112+
# Test no scale upper limit - should use max_value / 448 as the scaling factor
113+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
114+
self._verify_results(input_data, output_data, scales)
115+
116+
# Test with scale upper limit
117+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data, scale_ub=310.0)
118+
self._verify_results(input_data, output_data, scales, scale_ub=310.0)
119+
120+
def test_edge_cases(self):
121+
"""Test edge cases"""
122+
# Test all-zero input
123+
zero_input = np.zeros((2, 4), dtype=np.float32)
124+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(zero_input)
125+
self._verify_results(zero_input, output_data, scales)
126+
127+
# Test single-element input
128+
single_element = np.array([[5.0]], dtype=np.float32)
129+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(single_element)
130+
self._verify_results(single_element, output_data, scales)
131+
132+
# Test very large number of tokens
133+
large_input = np.random.randn(1024, 16).astype(np.float32)
134+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(large_input)
135+
self._verify_results(large_input, output_data, scales)
136+
137+
def test_dynamic_per_token_scaled_fp8_quant_fp16(self):
138+
# Test float16
139+
input_data = np.array([0.1, -0.2, 0.3, -0.4], dtype="float16")
140+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
141+
self._verify_results(input_data, output_data, scales)
142+
143+
def test_dynamic_per_token_scaled_fp8_quant_bf16(self):
144+
# Test bfloat16
145+
input_data = np.array([0.1, -0.2, 0.3, -0.4], dtype="bfloat16")
146+
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
147+
self._verify_results(input_data, output_data, scales)
148+
149+
150+
if __name__ == "__main__":
151+
unittest.main()

0 commit comments

Comments
 (0)