diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index cf35e6b5bc..9c215e22e9 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2308,17 +2308,11 @@ __global__ void merge_multi_chunks_decoder_kernel( using LoadT = AlignedVector; LoadT load_vec; LoadT res_vec; - if constexpr (std::is_same::value) { -#pragma unroll - for (int i = 0; i < vec_size / 2; ++i) { - *((half2 *)(&res_vec) + i) = make_half2(0, 0); - } - } else { -#pragma unroll - for (int i = 0; i < vec_size / 2; ++i) { - *((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0); - } + + for (int i = 0; i < vec_size; ++i) { + res_vec[i] = T(0.f); } + float m; float d = 1.f; if constexpr (std::is_same::value) { @@ -2334,8 +2328,7 @@ __global__ void merge_multi_chunks_decoder_kernel( const float m_now = multi_m[offset]; const float d_now = multi_d[offset]; m = max(m_prev, m_now); - offset = (bid * num_chunks * num_heads + i * num_heads + hid) * head_dim + - vid * vec_size; + offset = offset * head_dim + vid * vec_size; Load(&multi_out[offset], &load_vec); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const T scale1_T = static_cast(scale1), diff --git a/tests/layers/test_ffn.py b/tests/layers/test_ffn.py new file mode 100644 index 0000000000..9e24630531 --- /dev/null +++ b/tests/layers/test_ffn.py @@ -0,0 +1,161 @@ +import json +import os +import shutil +import unittest + +import numpy as np +import paddle +import paddle.device.cuda.graphs as graphs + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) +from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import ( + BlockWiseFP8Config, +) +from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_MLP +from fastdeploy.scheduler import SchedulerConfig +from fastdeploy.worker.worker_process import init_distributed_environment + +paddle.set_default_dtype("bfloat16") + + +class FFNWrapper(paddle.nn.Layer): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + self.intermediate_size = 3584 + self.hidden_size = self.model_config.hidden_size + self.prefix = "hahahha" + self.fd_config = FDConfig( + model_config=self.model_config, + parallel_config=ParallelConfig( + { + "tensor_parallel_size": 1, + "expert_parallel_size": 1, + "expert_parallel_rank": 0, + "data_parallel_size": 1, + } + ), + quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), + # quant_config = WINT8Config({}), + scheduler_config=SchedulerConfig({}), + cache_config=CacheConfig({}), + graph_opt_config=GraphOptimizationConfig({}), + load_config=LoadConfig({}), + ips="0.0.0.0", + ) + self.fd_config.parallel_config.tp_group = None + self.fd_config.parallel_config.tensor_parallel_rank = 0 + self.fd_config.parallel_config.tensor_parallel_size = 1 + + self.ffn = Ernie4_5_MLP( + fd_config=self.fd_config, + intermediate_size=self.intermediate_size, + prefix=self.prefix, + ) + + up_gate_proj_weight_shape = [self.hidden_size, self.intermediate_size * 2] + down_proj_weight_shape = [self.intermediate_size, self.hidden_size] + + up_gate_proj_weight = paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16) + down_proj_weight = paddle.randn(down_proj_weight_shape, paddle.bfloat16) + + state_dict = { + f"{self.prefix}.up_gate_proj.weight": up_gate_proj_weight, + f"{self.prefix}.down_proj.weight": down_proj_weight, + } + + self.ffn.load_state_dict(state_dict) + + +class TestFusedMoE(unittest.TestCase): + def setUp(self) -> None: + self.architectures = ["Ernie4_5_MoeForCausalLM"] + self.hidden_size = 7168 + self.moe_intermediate_size = 1 + self.moe_num_experts = 1 + self.moe_k = 1 + self.hidden_act = "silu" + self.num_attention_heads = 64 + self.model_config = self.build_model_config() + + def build_model_config(self) -> ModelConfig: + model_name_or_path = self.build_config_json() + return ModelConfig( + { + "model": model_name_or_path, + "max_model_len": 2048, + } + ) + + def build_config_json(self) -> str: + config_dict = { + "architectures": self.architectures, + "hidden_size": self.hidden_size, + "moe_intermediate_size": self.moe_intermediate_size, + "moe_num_experts": self.moe_num_experts, + "moe_k": self.moe_k, + "hidden_act": self.hidden_act, + "num_attention_heads": self.num_attention_heads, + "dtype": "bfloat16", + } + + tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}" + os.makedirs(tmp_dir, exist_ok=True) + with open(f"./{tmp_dir}/config.json", "w") as f: + json.dump(config_dict, f) + self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir) + return self.model_name_or_path + + def test_ffn(self): + init_distributed_environment() + + ffn = FFNWrapper(self.model_config) + + # (ZKK): disable this test, + # CI machine does not support deepgemm blockwise_fp8, compilation error. + return + + moe_cuda_graphs = [None] * 100 + cache_hidden_states = [None] * 100 + for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]): + + cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16) + + moe_cuda_graphs[idx] = graphs.CUDAGraph() + moe_cuda_graphs[idx].capture_begin() + + num_layers = 80 + for _ in range(num_layers): + out = ffn.ffn(cache_hidden_states[idx]) + + moe_cuda_graphs[idx].capture_end() + + num_tests = 20 + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + for i in range(num_tests): + start_events[i].record() + + moe_cuda_graphs[idx].replay() + + end_events[i].record() + paddle.device.cuda.synchronize() + + times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] + print("num_tokens:", num_tokens) + print(times[-5:]) + + shutil.rmtree(self.model_name_or_path) + return out + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_fusedmoe.py new file mode 100644 index 0000000000..59d7d30f27 --- /dev/null +++ b/tests/layers/test_fusedmoe.py @@ -0,0 +1,220 @@ +import json +import os +import shutil +import unittest + +import numpy as np +import paddle +import paddle.device.cuda.graphs as graphs +from paddle.distributed import fleet + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + LoadConfig, + ModelConfig, + ParallelConfig, +) +from fastdeploy.model_executor.layers.moe.moe import FusedMoE +from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import ( + BlockWiseFP8Config, +) +from fastdeploy.scheduler import SchedulerConfig +from fastdeploy.worker.worker_process import init_distributed_environment + +paddle.set_default_dtype("bfloat16") + + +class FuseMoEWrapper(paddle.nn.Layer): + def __init__( + self, + model_config: ModelConfig, + tp_size: int = 1, + tp_rank: int = 0, + ep_size: int = 1, + ep_rank: int = 0, + prefix: str = "layer0", + nnodes: int = 1, + ): + super().__init__() + self.model_config = model_config + + self.tp_size = tp_size + self.ep_size = ep_size + self.ep_rank = ep_rank + + self.prefix = prefix + self.fd_config = FDConfig( + model_config=self.model_config, + parallel_config=ParallelConfig( + { + "tensor_parallel_size": self.tp_size, + "expert_parallel_size": self.ep_size, + "expert_parallel_rank": self.ep_rank, + "data_parallel_size": self.ep_size, + } + ), + quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]), + # quant_config=WINT8Config({}), + scheduler_config=SchedulerConfig({}), + cache_config=CacheConfig({}), + graph_opt_config=GraphOptimizationConfig({}), + load_config=LoadConfig({}), + ips=",".join(["0"] * nnodes), + ) + self.fd_config.parallel_config.tp_group = None + self.fd_config.parallel_config.tensor_parallel_rank = tp_rank + self.fd_config.parallel_config.expert_parallel_size = self.ep_size + if self.ep_size > 1: + self.fd_config.parallel_config.ep_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + self.fd_config.scheduler_config.splitwise_role = "decode" + self.fd_config.model_config.moe_phase.phase = "decode" + + weight_key_map = { + "gate_weight_key": f"{self.prefix}.gate.weight", + "gate_correction_bias_key": f"{self.prefix}.moe_statics.e_score_correction_bias", + "up_gate_proj_expert_weight_key": f"{self.prefix}.experts.{{}}.up_gate_proj.weight", + "down_proj_expert_weight_key": f"{self.prefix}.experts.{{}}.down_proj.weight", + } + + self.fused_moe = FusedMoE( + fd_config=self.fd_config, + moe_intermediate_size=self.fd_config.model_config.moe_intermediate_size, + num_experts=self.fd_config.model_config.moe_num_experts, + top_k=self.fd_config.model_config.moe_k, + layer_idx=0, + weight_key_map=weight_key_map, + ) + moe_layer = self.fused_moe + + up_gate_proj_weight_shape = [ + moe_layer.num_local_experts, + moe_layer.hidden_size, + moe_layer.moe_intermediate_size * 2, + ] + down_proj_weight_shape = [ + moe_layer.num_local_experts, + moe_layer.moe_intermediate_size, + moe_layer.hidden_size, + ] + + up_gate_proj_weight = paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16) + down_proj_weight = paddle.randn(down_proj_weight_shape, paddle.bfloat16) + + local_expert_ids = list( + range(moe_layer.expert_id_offset, moe_layer.expert_id_offset + moe_layer.num_local_experts) + ) + state_dict = {} + up_gate_proj_expert_weight_key = moe_layer.weight_key_map.get("up_gate_proj_expert_weight_key") + down_proj_expert_weight_key = moe_layer.weight_key_map.get("down_proj_expert_weight_key") + for expert_idx in local_expert_ids: + down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) + up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx) + state_dict[up_gate_proj_expert_weight_key_name] = up_gate_proj_weight[ + expert_idx - moe_layer.expert_id_offset + ] + state_dict[down_proj_expert_weight_key_name] = down_proj_weight[expert_idx - moe_layer.expert_id_offset] + + moe_layer.load_state_dict(state_dict) + + +class TestFusedMoE(unittest.TestCase): + def setUp(self) -> None: + self.architectures = ["Ernie4_5_MoeForCausalLM"] + self.hidden_size = 7168 + self.moe_intermediate_size = 3584 + self.moe_num_experts = 64 + self.moe_k = 8 + self.hidden_act = "silu" + self.num_attention_heads = 64 + self.model_config = self.build_model_config() + + def build_model_config(self) -> ModelConfig: + model_name_or_path = self.build_config_json() + return ModelConfig( + { + "model": model_name_or_path, + "max_model_len": 2048, + } + ) + + def build_config_json(self) -> str: + config_dict = { + "architectures": self.architectures, + "hidden_size": self.hidden_size, + "moe_intermediate_size": self.moe_intermediate_size, + "moe_num_experts": self.moe_num_experts, + "moe_k": self.moe_k, + "hidden_act": self.hidden_act, + "num_attention_heads": self.num_attention_heads, + "dtype": "bfloat16", + } + + tmp_dir = f"./tmpwedfewfef{paddle.distributed.get_rank()}" + os.makedirs(tmp_dir, exist_ok=True) + with open(f"./{tmp_dir}/config.json", "w") as f: + json.dump(config_dict, f) + self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir) + return self.model_name_or_path + + def test_fused_moe(self): + init_distributed_environment() + + gating = paddle.nn.Linear(self.model_config.hidden_size, self.model_config.moe_num_experts) + gating.to(dtype=paddle.float32) # it's dtype is bfloat16 default, but the forward input is float32 + gating.weight.set_value(paddle.rand(gating.weight.shape, dtype=paddle.float32)) + + os.environ["FD_USE_DEEP_GEMM"] = "0" + ep_size = paddle.distributed.get_world_size() + ep_rank = paddle.distributed.get_rank() + + tp_rank = 0 + tp_size = 1 + + nnodes = (ep_size + 7) // 8 + + fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes) + + # 这行代码必须保留,否则影响均匀性! + paddle.seed(ep_rank + 100) + + moe_cuda_graphs = [None] * 100 + cache_hidden_states = [None] * 100 + for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]): + + cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16) + + moe_cuda_graphs[idx] = graphs.CUDAGraph() + moe_cuda_graphs[idx].capture_begin() + + num_layers = 80 + for _ in range(num_layers): + out = fused_moe.fused_moe(cache_hidden_states[idx], gating) + + moe_cuda_graphs[idx].capture_end() + + num_tests = 20 + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)] + for i in range(num_tests): + start_events[i].record() + + moe_cuda_graphs[idx].replay() + + end_events[i].record() + paddle.device.cuda.synchronize() + + times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:] + print("num_token:", num_tokens) + print(times[-5:]) + GB = 1.0 * num_tokens * self.moe_k * self.hidden_size * 3.0 / (1e9) + times_s = (times[-1] / num_layers) / (1e3) + print(times[-1], round(GB / times_s, 1)) + + shutil.rmtree(self.model_name_or_path) + return out + + +if __name__ == "__main__": + unittest.main()