Skip to content

Commit 2d64107

Browse files
authored
【Hackathon 9th No.20】add unit tests for masked_per_token_quant (#4111)
* test: add unit tests for masked_per_token_quant * apply review
1 parent 584d116 commit 2d64107

File tree

1 file changed

+250
-0
lines changed

1 file changed

+250
-0
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import os
2+
import unittest
3+
4+
import numpy as np
5+
import paddle
6+
7+
from fastdeploy.model_executor.ops.gpu import masked_per_token_quant
8+
9+
10+
def masked_per_token_quant_ref(input_tensor, recv_expert_count, block_size):
11+
"""
12+
Paddle API implementation of masked_per_token_quant
13+
14+
Args:
15+
input_tensor: Input tensor with shape [num_local_expert, num_max_tokens_per_expert, hidden_size]
16+
recv_expert_count: Expert token count tensor with shape [num_local_expert]
17+
block_size: Quantization block size
18+
19+
Returns:
20+
Tuple of (quantized_tensor, scale_tensor)
21+
"""
22+
MAX_VALUE = 448.0
23+
epsilon = 1e-10
24+
25+
# Get dimensions
26+
input_shape = input_tensor.shape
27+
num_local_expert = input_shape[0]
28+
num_max_tokens_per_expert = input_shape[1]
29+
hidden_size = input_shape[2]
30+
31+
# CUDA kernel uses: hidden_size_scale = hidden_size / block_size (integer division)
32+
# This assumes hidden_size is divisible by block_size
33+
hidden_size_scale = hidden_size // block_size
34+
35+
# Check environment variable for fine-grained range
36+
use_finegrained_range = False
37+
env_var = os.getenv("PER_TOKEN_QUANT_FP8_USE_FINEGRAINED_RANGE")
38+
if env_var:
39+
use_finegrained_range = bool(int(env_var))
40+
41+
# Create mask for valid tokens based on recv_expert_count
42+
token_indices = paddle.arange(num_max_tokens_per_expert, dtype="int32").unsqueeze(
43+
0
44+
) # [1, num_max_tokens_per_expert]
45+
expert_counts = recv_expert_count.unsqueeze(1) # [num_local_expert, 1]
46+
valid_mask = token_indices < expert_counts # [num_local_expert, num_max_tokens_per_expert]
47+
48+
# Reshape input for block-wise processing
49+
# [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, block_size]
50+
reshaped_input = paddle.reshape(
51+
input_tensor, [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, block_size]
52+
).astype("float32")
53+
54+
# Calculate max absolute values per block
55+
max_abs_val = paddle.max(
56+
paddle.abs(reshaped_input), axis=-1, keepdim=True
57+
) # [num_local_expert, num_max_tokens_per_expert, hidden_size_scale, 1]
58+
max_abs_val = paddle.clip(max_abs_val, min=epsilon)
59+
60+
# Apply valid mask - set invalid tokens' max values to epsilon
61+
valid_mask_expanded = valid_mask.unsqueeze(2).unsqueeze(3) # [num_local_expert, num_max_tokens_per_expert, 1, 1]
62+
max_abs_val = paddle.where(valid_mask_expanded, max_abs_val, paddle.to_tensor(epsilon))
63+
64+
# Apply fine-grained range if enabled
65+
if use_finegrained_range:
66+
max_abs_val *= 7.0
67+
68+
# Calculate scale
69+
scale = max_abs_val / MAX_VALUE
70+
71+
# Quantize
72+
quanted_value = reshaped_input / scale
73+
74+
# Convert to float8_e4m3fn and reshape back
75+
quanted_x_reshaped = quanted_value.astype("float8_e4m3fn")
76+
quanted_x = paddle.reshape(quanted_x_reshaped, [num_local_expert, num_max_tokens_per_expert, hidden_size])
77+
78+
# Apply valid mask to quantized output - convert to float32 first, then back to float8_e4m3fn
79+
valid_mask_full = valid_mask.unsqueeze(2) # [num_local_expert, num_max_tokens_per_expert, 1]
80+
quanted_x_float32 = quanted_x.astype("float32")
81+
quanted_x_masked_float32 = paddle.where(valid_mask_full, quanted_x_float32, paddle.zeros_like(quanted_x_float32))
82+
quanted_x = quanted_x_masked_float32.astype("float8_e4m3fn")
83+
84+
# Prepare scale output - squeeze the last dimension
85+
quanted_scale = paddle.squeeze(scale, axis=-1) # [num_local_expert, num_max_tokens_per_expert, hidden_size_scale]
86+
87+
# Apply valid mask to scale
88+
valid_mask_scale = valid_mask.unsqueeze(2) # [num_local_expert, num_max_tokens_per_expert, 1]
89+
quanted_scale = paddle.where(valid_mask_scale, quanted_scale, paddle.zeros_like(quanted_scale))
90+
91+
return quanted_x, quanted_scale
92+
93+
94+
class TestMaskedPerTokenQuant(unittest.TestCase):
95+
def setUp(self) -> None:
96+
paddle.seed(2024)
97+
self.num_local_expert = 2
98+
self.num_max_tokens_per_expert = 4
99+
self.hidden_size = 256
100+
self.block_size = 128
101+
self.dtype = paddle.bfloat16
102+
103+
self.input_tensor = paddle.randn(
104+
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
105+
)
106+
self.recv_expert_count = paddle.to_tensor([3, 2], dtype="int32")
107+
108+
# Get reference results from paddle implementation
109+
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
110+
self.input_tensor, self.recv_expert_count, self.block_size
111+
)
112+
113+
def _mask_invalid_tokens(self, quanted_x, quanted_scale, recv_expert_count):
114+
"""Apply mask to zero out invalid tokens"""
115+
token_indices = paddle.arange(self.num_max_tokens_per_expert, dtype="int32").unsqueeze(0)
116+
expert_counts = recv_expert_count.unsqueeze(1)
117+
valid_mask = token_indices < expert_counts
118+
119+
# Apply mask to quantized values - convert to float32 first
120+
valid_mask_full = valid_mask.unsqueeze(2)
121+
quanted_x_float32 = quanted_x.astype("float32")
122+
quanted_x_masked_float32 = paddle.where(
123+
valid_mask_full, quanted_x_float32, paddle.zeros_like(quanted_x_float32)
124+
)
125+
quanted_x_masked = quanted_x_masked_float32.astype("float8_e4m3fn")
126+
127+
# Apply mask to scale values
128+
valid_mask_scale = valid_mask.unsqueeze(2)
129+
quanted_scale_masked = paddle.where(valid_mask_scale, quanted_scale, paddle.zeros_like(quanted_scale))
130+
131+
return quanted_x_masked, quanted_scale_masked
132+
133+
def test_masked_per_token_quant_basic(self):
134+
"""Test basic functionality against CUDA kernel"""
135+
quanted_x_cuda, quanted_scale_cuda = masked_per_token_quant(
136+
self.input_tensor, self.recv_expert_count, self.block_size
137+
)
138+
139+
quanted_x_cuda_masked, quanted_scale_cuda_masked = self._mask_invalid_tokens(
140+
quanted_x_cuda, quanted_scale_cuda, self.recv_expert_count
141+
)
142+
143+
# Check output shapes
144+
self.assertEqual(quanted_x_cuda.shape, self.quanted_x_ref.shape)
145+
self.assertEqual(quanted_scale_cuda.shape, self.quanted_scale_ref.shape)
146+
147+
# Check dtypes
148+
self.assertEqual(quanted_x_cuda.dtype, paddle.float8_e4m3fn)
149+
self.assertEqual(quanted_scale_cuda.dtype, paddle.float32)
150+
151+
# Compare scale values (using masked versions)
152+
np.testing.assert_allclose(
153+
self.quanted_scale_ref.numpy(), quanted_scale_cuda_masked.numpy(), rtol=1e-5, atol=1e-6
154+
)
155+
156+
# Compare quantized values (convert to float32 for comparison, using masked versions)
157+
quant_diff = paddle.mean(
158+
paddle.abs(quanted_x_cuda_masked.astype("float32") - self.quanted_x_ref.astype("float32"))
159+
) / paddle.mean(paddle.abs(self.quanted_x_ref.astype("float32")) + 1e-9)
160+
diff_val = float(quant_diff.numpy().item())
161+
self.assertLess(diff_val, 0.01, msg="Quantized values should be close")
162+
163+
164+
class TestMaskedPerTokenQuantCase1(TestMaskedPerTokenQuant):
165+
"""Test with float16 input"""
166+
167+
def setUp(self) -> None:
168+
paddle.seed(2024)
169+
self.num_local_expert = 3
170+
self.num_max_tokens_per_expert = 6
171+
self.hidden_size = 512
172+
self.block_size = 128
173+
self.dtype = paddle.float16
174+
175+
self.input_tensor = paddle.randn(
176+
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
177+
)
178+
self.recv_expert_count = paddle.to_tensor([4, 2, 5], dtype="int32")
179+
180+
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
181+
self.input_tensor, self.recv_expert_count, self.block_size
182+
)
183+
184+
185+
class TestMaskedPerTokenQuantCase2(TestMaskedPerTokenQuant):
186+
"""Test with different hidden size"""
187+
188+
def setUp(self) -> None:
189+
paddle.seed(2024)
190+
self.num_local_expert = 4
191+
self.num_max_tokens_per_expert = 8
192+
self.hidden_size = 384 # 3 * 128
193+
self.block_size = 128
194+
self.dtype = paddle.bfloat16
195+
196+
self.input_tensor = paddle.randn(
197+
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
198+
)
199+
self.recv_expert_count = paddle.to_tensor([6, 3, 7, 1], dtype="int32")
200+
201+
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
202+
self.input_tensor, self.recv_expert_count, self.block_size
203+
)
204+
205+
206+
class TestMaskedPerTokenQuantCase3(TestMaskedPerTokenQuant):
207+
"""Test with all experts having max tokens"""
208+
209+
def setUp(self) -> None:
210+
paddle.seed(2024)
211+
self.num_local_expert = 2
212+
self.num_max_tokens_per_expert = 4
213+
self.hidden_size = 256
214+
self.block_size = 128
215+
self.dtype = paddle.bfloat16
216+
217+
self.input_tensor = paddle.randn(
218+
[self.num_local_expert, self.num_max_tokens_per_expert, self.hidden_size], dtype=self.dtype
219+
)
220+
# All experts use all tokens
221+
self.recv_expert_count = paddle.to_tensor([4, 4], dtype="int32")
222+
223+
self.quanted_x_ref, self.quanted_scale_ref = masked_per_token_quant_ref(
224+
self.input_tensor, self.recv_expert_count, self.block_size
225+
)
226+
227+
228+
class TestMaskedPerTokenQuantEdgeCases(unittest.TestCase):
229+
"""Test edge cases"""
230+
231+
def test_zero_tokens_expert(self):
232+
"""Test expert with zero tokens"""
233+
paddle.seed(2024)
234+
input_tensor = paddle.randn([2, 4, 256], dtype="bfloat16")
235+
recv_expert_count = paddle.to_tensor([0, 2], dtype="int32") # First expert has no tokens
236+
237+
quanted_x_ref, quanted_scale_ref = masked_per_token_quant_ref(input_tensor, recv_expert_count, 128)
238+
239+
# First expert should be all zeros - convert to float32 for comparison
240+
expert_0_quanted = quanted_x_ref[0].astype("float32")
241+
self.assertTrue(paddle.all(expert_0_quanted == 0), "Expert with zero tokens should be all zero")
242+
self.assertTrue(paddle.all(quanted_scale_ref[0] == 0), "Expert with zero tokens should have zero scales")
243+
244+
# Second expert should have valid values - convert to float32 for comparison
245+
expert_1_quanted = quanted_x_ref[1, :2].astype("float32")
246+
self.assertTrue(paddle.any(expert_1_quanted != 0), "Expert with tokens should have non-zero values")
247+
248+
249+
if __name__ == "__main__":
250+
unittest.main()

0 commit comments

Comments
 (0)