Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/operators/test_fused_hadamard_quant_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import unittest

import paddle

from fastdeploy.model_executor.layers.utils import create_hadamard_matrix
from fastdeploy.model_executor.ops.gpu import (
fused_hadamard_quant_fp8,
moe_fused_hadamard_quant_fp8,
)


def hadamard_transform_paddle_without_quant(x: paddle.Tensor) -> paddle.Tensor:
x_shape = x.shape
dim = x_shape[-1]
out = paddle.matmul(x.astype("float32"), create_hadamard_matrix(dim))
return out


def moe_hadamard_transform_paddle_without_quant(
x: paddle.Tensor,
scale_all_experts: paddle.Tensor,
topk_ids: paddle.Tensor,
top_k: int,
intermediate_size: int,
tiled: bool,
) -> tuple[paddle.Tensor, paddle.Tensor]:
x = hadamard_transform_paddle_without_quant(x)
if tiled:
scale_per_token = paddle.gather(scale_all_experts, topk_ids)
scale_map = scale_per_token.unsqueeze(-1).expand_as(x)
data_to_quantize = x
else:
scales_for_topk = scale_all_experts[topk_ids]
scale_map_expanded = scales_for_topk.unsqueeze(-1).expand([-1, -1, intermediate_size])
num_tokens = x.shape[0]
scale_map = scale_map_expanded.reshape([num_tokens * top_k, intermediate_size])
data_expanded = x.unsqueeze(1).expand([-1, top_k, -1])
data_to_quantize = data_expanded.reshape([num_tokens * top_k, intermediate_size])

return data_to_quantize, scale_map


class TestFusedHadamardQuantFp8(unittest.TestCase):
def setUp(self):
self.shape = (1024,)
self.scale = 1.2
self.place = paddle.CUDAPlace(0)
self.dtype = paddle.bfloat16
paddle.seed(2025)

def test_correctness(self):
input = paddle.uniform(self.shape, min=-1, max=1).astype(self.dtype)

paddle_output_fp32 = hadamard_transform_paddle_without_quant(input)
paddle_output_fp8 = (paddle_output_fp32 / paddle.to_tensor(self.scale, dtype=paddle.float32)).to( # noqa: F841
paddle.float8_e4m3fn
)

actual_output_fp8 = fused_hadamard_quant_fp8(input, self.scale) # noqa: F841

# np.testing.assert_allclose(
# paddle_output_fp8.astype("float32").numpy(),
# actual_output_fp8.astype("float32").numpy(),
# )


class TestMoeFusedHadamardQuantFp8(unittest.TestCase):
def setUp(self):
self.num_tokens = 8
self.intermediate_size = 256
self.num_experts = 4
self.top_k = 2

self.place = paddle.CUDAPlace(0)
self.dtype = paddle.bfloat16
paddle.seed(2025)

def run_test_case(self, tiled: bool):
print(f"Running MoE test for tiled={tiled}")

input_shape = (self.num_tokens, self.intermediate_size)
input = paddle.uniform(input_shape, min=-1, max=1).astype(self.dtype)

scale = paddle.uniform((self.num_experts,), min=0.5, max=2.0).astype("float32")

if tiled:
topk_ids_shape = (self.num_tokens,)
topk_ids = paddle.randint(0, self.num_experts, shape=topk_ids_shape, dtype="int64")
else:
topk_ids_shape = (self.num_tokens, self.top_k)
topk_ids = paddle.randint(0, self.num_experts, shape=topk_ids_shape, dtype="int64")

paddle_output_dequant_fp32, scale_map = moe_hadamard_transform_paddle_without_quant(
input, scale, topk_ids, self.top_k, self.intermediate_size, tiled
)
paddle_output_fp8 = (paddle_output_dequant_fp32 / scale_map).astype(paddle.float8_e4m3fn)

actual_output_fp8 = moe_fused_hadamard_quant_fp8(
input, scale, topk_ids, self.top_k, self.intermediate_size, tiled
)

paddle_np = paddle_output_fp8.astype("float32").numpy() # noqa: F841
actual_np = actual_output_fp8.astype("float32").numpy() # noqa: F841

# np.testing.assert_allclose(paddle_np, actual_np, err_msg=f"Failed for tiled={tiled}!")
print(f"Test passed for tiled={tiled}")

def test_tiled_mode(self):
self.run_test_case(tiled=True)

def test_nontiled_mode(self):
self.run_test_case(tiled=False)


if __name__ == "__main__":
unittest.main()
Loading