Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2308,17 +2308,11 @@ __global__ void merge_multi_chunks_decoder_kernel(
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
}

for (int i = 0; i < vec_size; ++i) {
res_vec[i] = T(0.f);
}

float m;
float d = 1.f;
if constexpr (std::is_same<T, half>::value) {
Expand All @@ -2334,8 +2328,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
offset = (bid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
offset = offset * head_dim + vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1),
Expand Down
161 changes: 161 additions & 0 deletions tests/layers/test_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import json
import os
import shutil
import unittest

import numpy as np
import paddle
import paddle.device.cuda.graphs as graphs

from fastdeploy.config import (
CacheConfig,
FDConfig,
GraphOptimizationConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
)
from fastdeploy.model_executor.layers.quantization.block_wise_fp8 import (
BlockWiseFP8Config,
)
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_MLP
from fastdeploy.scheduler import SchedulerConfig
from fastdeploy.worker.worker_process import init_distributed_environment

paddle.set_default_dtype("bfloat16")


class FFNWrapper(paddle.nn.Layer):
def __init__(self, model_config: ModelConfig):
super().__init__()
self.model_config = model_config

self.intermediate_size = 3584
self.hidden_size = self.model_config.hidden_size
self.prefix = "hahahha"
self.fd_config = FDConfig(
model_config=self.model_config,
parallel_config=ParallelConfig(
{
"tensor_parallel_size": 1,
"expert_parallel_size": 1,
"expert_parallel_rank": 0,
"data_parallel_size": 1,
}
),
quant_config=BlockWiseFP8Config(weight_block_size=[128, 128]),
# quant_config = WINT8Config({}),
scheduler_config=SchedulerConfig({}),
cache_config=CacheConfig({}),
graph_opt_config=GraphOptimizationConfig({}),
load_config=LoadConfig({}),
ips="0.0.0.0",
)
self.fd_config.parallel_config.tp_group = None
self.fd_config.parallel_config.tensor_parallel_rank = 0
self.fd_config.parallel_config.tensor_parallel_size = 1

self.ffn = Ernie4_5_MLP(
fd_config=self.fd_config,
intermediate_size=self.intermediate_size,
prefix=self.prefix,
)

up_gate_proj_weight_shape = [self.hidden_size, self.intermediate_size * 2]
down_proj_weight_shape = [self.intermediate_size, self.hidden_size]

up_gate_proj_weight = paddle.randn(up_gate_proj_weight_shape, paddle.bfloat16)
down_proj_weight = paddle.randn(down_proj_weight_shape, paddle.bfloat16)

state_dict = {
f"{self.prefix}.up_gate_proj.weight": up_gate_proj_weight,
f"{self.prefix}.down_proj.weight": down_proj_weight,
}

self.ffn.load_state_dict(state_dict)


class TestFusedMoE(unittest.TestCase):
def setUp(self) -> None:
self.architectures = ["Ernie4_5_MoeForCausalLM"]
self.hidden_size = 7168
self.moe_intermediate_size = 1
self.moe_num_experts = 1
self.moe_k = 1
self.hidden_act = "silu"
self.num_attention_heads = 64
self.model_config = self.build_model_config()

def build_model_config(self) -> ModelConfig:
model_name_or_path = self.build_config_json()
return ModelConfig(
{
"model": model_name_or_path,
"max_model_len": 2048,
}
)

def build_config_json(self) -> str:
config_dict = {
"architectures": self.architectures,
"hidden_size": self.hidden_size,
"moe_intermediate_size": self.moe_intermediate_size,
"moe_num_experts": self.moe_num_experts,
"moe_k": self.moe_k,
"hidden_act": self.hidden_act,
"num_attention_heads": self.num_attention_heads,
"dtype": "bfloat16",
}

tmp_dir = f"./tmpefef{paddle.distributed.get_rank()}"
os.makedirs(tmp_dir, exist_ok=True)
with open(f"./{tmp_dir}/config.json", "w") as f:
json.dump(config_dict, f)
self.model_name_or_path = os.path.join(os.getcwd(), tmp_dir)
return self.model_name_or_path

def test_ffn(self):
init_distributed_environment()

ffn = FFNWrapper(self.model_config)

# (ZKK): disable this test,
# CI machine does not support deepgemm blockwise_fp8, compilation error.
return

moe_cuda_graphs = [None] * 100
cache_hidden_states = [None] * 100
for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]):

cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16)

moe_cuda_graphs[idx] = graphs.CUDAGraph()
moe_cuda_graphs[idx].capture_begin()

num_layers = 80
for _ in range(num_layers):
out = ffn.ffn(cache_hidden_states[idx])

moe_cuda_graphs[idx].capture_end()

num_tests = 20
start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
start_events[i].record()

moe_cuda_graphs[idx].replay()

end_events[i].record()
paddle.device.cuda.synchronize()

times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
print("num_tokens:", num_tokens)
print(times[-5:])

shutil.rmtree(self.model_name_or_path)
return out


if __name__ == "__main__":
unittest.main()
Loading