Skip to content

Commit

Permalink
fix tunning code. (#683)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Dec 25, 2024
1 parent d240c6e commit 8b08e1c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
37 changes: 21 additions & 16 deletions test/kernel/fuse_moe_tuning_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,17 +360,22 @@ def tuning_configs(
from lightllm.utils.tuning_utils import mp_tuning
from lightllm.common.fused_moe.moe_kernel_configs import MoeGroupedGemmKernelConfig

# tuning to get deepseekv2 lite configs and store
# tuning to get deepseekv2 lite configs and store tp 1
expert_num = 64
n = 1408 // 2 # up is n * 2
hidden_dim = 2048
topk_num = 6

up_dict = {}
for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192]:
ans = mp_tuning(
tuning_configs,
{
"expert_num": 64,
"expert_num": expert_num,
"m": m,
"n": 1408 // 2,
"k": 2048,
"topk": 6,
"n": n,
"k": hidden_dim,
"topk": topk_num,
"dtype": torch.bfloat16,
"test_count": 20,
"use_fp8_w8a8": False,
Expand All @@ -379,10 +384,10 @@ def tuning_configs(
)
up_dict[m] = ans
MoeGroupedGemmKernelConfig.save_config(
N=1408,
K=2048,
topk_num=6,
expert_num=64,
N=n * 2,
K=hidden_dim,
topk_num=topk_num,
expert_num=expert_num,
mul_routed_weight=False,
use_fp8_w8a8=False,
out_dtype=str(torch.bfloat16),
Expand All @@ -394,11 +399,11 @@ def tuning_configs(
ans = mp_tuning(
tuning_configs,
{
"expert_num": 64,
"expert_num": expert_num,
"m": m,
"n": 1408 // 2,
"k": 2048,
"topk": 6,
"n": n,
"k": hidden_dim,
"topk": topk_num,
"dtype": torch.bfloat16,
"test_count": 20,
"use_fp8_w8a8": False,
Expand All @@ -407,10 +412,10 @@ def tuning_configs(
)
down_dict[m] = ans
MoeGroupedGemmKernelConfig.save_config(
N=2048,
K=1408 // 2,
N=hidden_dim,
K=n,
topk_num=1,
expert_num=64,
expert_num=expert_num,
mul_routed_weight=True,
use_fp8_w8a8=False,
out_dtype=str(torch.bfloat16),
Expand Down
37 changes: 21 additions & 16 deletions test/kernel/fuse_moe_tuning_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,17 +363,22 @@ def tuning_configs(
from lightllm.utils.tuning_utils import mp_tuning
from lightllm.common.fused_moe.moe_kernel_configs import MoeGroupedGemmKernelConfig

# tuning to get deepseekv2 large configs and store in H800
# tuning to get deepseekv2 large configs and store in H800, tp 8
expert_num = 160
n = 192 # up is n * 2
hidden_dim = 5120
topk_num = 6

up_dict = {}
for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192]:
ans = mp_tuning(
tuning_configs,
{
"expert_num": 160,
"expert_num": expert_num,
"m": m,
"n": 192,
"k": 5120,
"topk": 6,
"n": n,
"k": hidden_dim,
"topk": topk_num,
"dtype": torch.bfloat16,
"test_count": 20,
"use_fp8_w8a8": True,
Expand All @@ -382,10 +387,10 @@ def tuning_configs(
)
up_dict[m] = ans
MoeGroupedGemmKernelConfig.save_config(
N=192 * 2,
K=5120,
topk_num=6,
expert_num=160,
N=n * 2,
K=hidden_dim,
topk_num=topk_num,
expert_num=expert_num,
mul_routed_weight=False,
use_fp8_w8a8=True,
out_dtype=str(torch.bfloat16),
Expand All @@ -397,11 +402,11 @@ def tuning_configs(
ans = mp_tuning(
tuning_configs,
{
"expert_num": 160,
"expert_num": expert_num,
"m": m,
"n": 192,
"k": 5120,
"topk": 6,
"n": n,
"k": hidden_dim,
"topk": topk_num,
"dtype": torch.bfloat16,
"test_count": 20,
"use_fp8_w8a8": True,
Expand All @@ -411,10 +416,10 @@ def tuning_configs(
down_dict[m] = ans

MoeGroupedGemmKernelConfig.save_config(
N=5120,
K=192,
N=hidden_dim,
K=n,
topk_num=1,
expert_num=160,
expert_num=expert_num,
mul_routed_weight=True,
use_fp8_w8a8=True,
out_dtype=str(torch.bfloat16),
Expand Down

0 comments on commit 8b08e1c

Please sign in to comment.