Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions tests/operators/test_dynamic_per_token_scaled_fp8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.gpu import dynamic_per_token_scaled_fp8_quant


class TestDynamicPerTokenScaledFp8Quant(unittest.TestCase):
def setUp(self):
paddle.seed(42)
np.random.seed(42)

def _run_dynamic_per_token_scaled_fp8_quant(self, input_data, scale_ub=0.0):
"""
运行动态逐token缩放FP8量化算子
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释麻烦切换到英文呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改切换英文


参数:
input_data: 输入数据(numpy数组)
scale_ub: 缩放上限值

返回:
量化后的输出和缩放因子
"""
input_tensor = paddle.to_tensor(input_data)

# 确定输出形状
num_tokens = input_tensor.shape[0] if len(input_tensor.shape) > 1 else 1

# 创建输出张量
out_tensor = paddle.empty(input_tensor.shape, dtype=paddle.float8_e4m3fn)

# 创建scales张量
scales_tensor = paddle.empty([num_tokens], dtype="float32")

inputs = {"out": out_tensor, "input": input_tensor, "scale": scales_tensor}
attrs = {"scale_ub": scale_ub}
dynamic_per_token_scaled_fp8_quant(*inputs.values(), *attrs.values())

out_np = out_tensor.cpu().numpy()
scales_np = scales_tensor.cpu().numpy()

return out_np, scales_np

def _verify_results(self, input_data, output_data, scales, scale_ub=0.0, tol=7e-2):
"""
验证量化结果是否正确

参数:
input_data: 原始输入数据
output_data: 量化后的输出数据
scales: 使用的缩放因子
scale_ub: 缩放上限值
tol: 允许的误差范围
"""
# 检查输出数据类型是否为FP8
self.assertEqual(output_data.dtype, "float8_e4m3fn") # FP8存储为float8_e4m3fn

# 对于每个token验证量化过程
num_tokens = input_data.shape[0] if len(input_data.shape) > 1 else 1

for i in range(num_tokens):
# 获取当前token的输入和输出
if len(input_data.shape) > 1:
token_input = input_data[i]
token_output = output_data[i] if len(output_data.shape) > 1 else output_data
else:
token_input = input_data
token_output = output_data

# 获取当前token的缩放因子
token_scale = scales[i] if num_tokens > 1 else scales[0]

# 如果有缩放上限,检查是否遵守
if scale_ub > 0:
max_val = np.max(np.abs(token_input))
expected_scale = min(max_val, scale_ub) / 448.0
self.assertAlmostEqual(token_scale, expected_scale, delta=tol)
else:
max_val = np.max(np.abs(token_input))
expected_scale = max_val / 448.0
self.assertAlmostEqual(token_scale, expected_scale, delta=tol)

# 验证量化后的值是否合理
# FP8的范围通常是-1.0到1.0,量化后应在这个范围内
reconstructed = token_output.astype(np.float32) * token_scale
diff = np.abs(reconstructed - token_input.astype(np.float32))
self.assertTrue(np.all(diff <= tol * np.max(np.abs(token_input))))

def test_fp32_input(self):
"""测试FP32输入"""
input_data = np.array([0.1, -0.2, 0.3, -0.4], dtype=np.float32)

# 测试无缩放上限的情况
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
self._verify_results(input_data, output_data, scales)

# 测试有缩放上限的情况
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data, scale_ub=1.5)
print(output_data, scales)
self._verify_results(input_data, output_data, scales, scale_ub=1.5)

# 测试单token情况
single_token = input_data[0:1]
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(single_token)
self._verify_results(single_token, output_data, scales)

def test_large_values(self):
"""测试大数值输入"""
input_data = np.array([100.0, -200.0, 300.0, -320.0], dtype=np.float32)

# 测试无缩放上限 - 应该使用最大值/448作为缩放因子
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
self._verify_results(input_data, output_data, scales)

# 测试有缩放上限
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data, scale_ub=310.0)
self._verify_results(input_data, output_data, scales, scale_ub=310.0)

def test_edge_cases(self):
"""测试边界情况"""
# 测试全零输入
zero_input = np.zeros((2, 4), dtype=np.float32)
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(zero_input)
self._verify_results(zero_input, output_data, scales)

# 测试单元素输入
single_element = np.array([[5.0]], dtype=np.float32)
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(single_element)
self._verify_results(single_element, output_data, scales)

# 测试非常大的token数量
large_input = np.random.randn(1024, 16).astype(np.float32)
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(large_input)
self._verify_results(large_input, output_data, scales)

def test_dynamic_per_token_scaled_fp8_quant_fp16(self):
# 测试float16
input_data = np.array([0.1, -0.2, 0.3, -0.4], dtype="float16")
output_data, scales = self._run_dynamic_per_token_scaled_fp8_quant(input_data)
self._verify_results(input_data, output_data, scales)


if __name__ == "__main__":
unittest.main()
Loading