Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter/configs/model_configs/dsv3_fp4_tuned_fmoe.csv
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cu_num,token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,
256,16384,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,425.3491,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,1027.5233,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_persist_sbm128,0.0%,1452.8724,0,1117.44,1216.29,
256,32768,7168,256,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,128,0,669.5871999999999,moe_ck2stages_gemm1_256x128x128x128_1x4_MulABScaleShuffled_v3_Nswizzle0_Quant3_MulRoutedWeight0_silu_FP4X2_FP4X2_B16,0.0%,2017.8465,flydsl_moe2_afp4_wfp4_bf16_t64x256x256_reduce_sbm128,0.0%,2687.4337,0,1208.21,788.65,
256,1,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,16.6555,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fp4,23.1%,8.2174,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.6%,24.8729,0,7.97,113762.52,
256,2,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,22.244,flydsl_moe1_afp4_wfp4_bf16_t32x128x256_w3_kb4_go_fp4,20.6%,14.059,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.7%,36.303,0,10.92,77944.67,
256,2,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,22.244,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3_kb4_go_fp4,20.6%,14.059,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.7%,36.303,0,10.92,77944.67,
256,4,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,28.5005,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3_fp4,19.5%,19.299,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,47.7995,0,16.58,59198.7,
256,8,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,54.3584,flydsl_moe1_afp4_wfp4_bf16_t32x64x256_w3,0.0%,30.4539,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,3.0%,84.8123,0,18.69,33364.91,
256,16,7168,512,257,9,ActivationType.Silu,torch.bfloat16,torch.float4_e2m1fn_x2,torch.float4_e2m1fn_x2,QuantType.per_1x32,1,0,32,0,96.3987,flydsl_moe1_afp4_wfp4_bf16_t32x32x256_w3,0.0%,51.0459,flydsl_moe2_afp4_wfp4_bf16_t32x128x256_atomic,2.9%,147.4446,0,21.51,19193.15,
Expand Down
324 changes: 233 additions & 91 deletions op_tests/test_moe_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
import aiter
from aiter import dtypes
from aiter.test_common import checkAllclose, benchmark, run_perftest
from aiter.int4_utils import *
from aiter.int4_utils import (
rearrange_4bit_elements,
convert_int8_to_uint32_int4,
)
from aiter.utility import fp4_utils
from aiter.jit.utils.chip_info import get_gfx
from aiter.jit.core import AITER_CONFIGS
from aiter.jit.utils.chip_info import get_gfx, get_cu_num
import argparse
import os
import pandas as pd
Expand Down Expand Up @@ -52,6 +56,7 @@ def test_fmoe(
hidden_pad=0,
intermediate_pad=0,
preshuffle=True,
strict_accuracy=True,
):
if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32:
return
Expand Down Expand Up @@ -281,6 +286,14 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
logging.warning(
f"logits_diff: {logits_diff} is too large, please check the implementation"
)
if strict_accuracy:
assert not (
err != 0 and logits_diff > 0.01
), f"accuracy check failed: checkAllclose err={err}, logits_diff={logits_diff}"
elif err != 0 and logits_diff > 0.01:
logging.warning(
f"accuracy check failed (non-strict): err={err}, logits_diff={logits_diff}"
)

return {"us": us2, "err": err}

Expand Down Expand Up @@ -419,111 +432,240 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
help="""Hidden intermediate pad.
e.g.: -hip 0,0""",
)
parser.add_argument(
"--no-flydsl-csv",
action="store_true",
help="Skip validating flydsl shapes from tuned fmoe CSVs.",
)
parser.add_argument(
"--no-legacy",
action="store_true",
help="Skip the original hardcoded shape sweep and skinny tests.",
)

args = parser.parse_args()


l_quant = [l_quant[args.quant]] if args.quant is not None else l_quant


df = []
for (
dtype,
(quant_type, aq_dtype, wq_dtype),
(model_dim, inter_dim),
doweight_stage1,
) in itertools.product(args.dtype, l_quant, args.dim, args.doweight_stage1):
if (quant_type, aq_dtype, wq_dtype) == (
aiter.QuantType.per_1x32,
dtypes.bf16,
dtypes.fp4x2,
):
for hidden_pad, intermediate_pad in args.hidden_intermediate_pad:
for m in args.tokenNum:
ret = test_fmoe(
dtype,
m,
model_dim,
inter_dim,
args.expert,
args.topk,
aiter.ActivationType.Swiglu,
quant_type,
aq_dtype,
wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
)
df.append(ret)
elif (quant_type, aq_dtype, wq_dtype) == (
aiter.QuantType.per_1x32,
dtypes.fp8,
dtypes.fp4x2,
):
for hidden_pad, intermediate_pad in args.hidden_intermediate_pad:
for m in args.tokenNum:
ret = test_fmoe(
dtype,
m,
model_dim,
inter_dim,
args.expert,
args.topk,
aiter.ActivationType.Swiglu,
quant_type,
aq_dtype,
wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
)
df.append(ret)
elif (quant_type, aq_dtype, wq_dtype) == (
aiter.QuantType.per_1x32,
dtypes.fp4x2,
dtypes.fp4x2,
# ---------------------------------------------------------------------------
# Both modes (CLI sweep / model-csv) reduce to the same shape:
# yield (test_fmoe_kwargs, extras_for_df)
# A single runner consumes the stream.
# ---------------------------------------------------------------------------
# Only kept for dtypes that may not exist as torch attributes in older builds;
# anything else falls through to getattr(torch, attr).
_DTYPE_STR_FALLBACK = {
"torch.float4_e2m1fn_x2": dtypes.fp4x2,
"torch.float8_e8m0fnu": dtypes.fp8_e8m0,
}


def _str2dtype(s):
s = s.strip()
if s in ("None", "none", ""):
return None
if s.startswith("torch."):
attr = s.split(".", 1)[1]
if hasattr(torch, attr):
return getattr(torch, attr)
if s in _DTYPE_STR_FALLBACK:
return _DTYPE_STR_FALLBACK[s]
raise ValueError(f"unsupported dtype string: {s!r}")


def _str2enum(s, enum_cls):
return getattr(enum_cls, s.strip().split(".")[-1])


def _row_to_kwargs(row):
# csv rows store already-effective dims, so pad defaults to 0.
q_type = _str2enum(row["q_type"], aiter.QuantType)
aq_dtype = _str2dtype(row["q_dtype_a"])
wq_dtype = _str2dtype(row["q_dtype_w"])
act_type = _effective_act_type(
q_type,
aq_dtype,
wq_dtype,
_str2enum(row["act_type"], aiter.ActivationType),
)
return dict(
dtype=_str2dtype(row["dtype"]),
token=int(row["token"]),
model_dim=int(row["model_dim"]),
inter_dim=int(row["inter_dim"]),
E=int(row["expert"]),
topk=int(row["topk"]),
actType=act_type,
qType=q_type,
AQDType=aq_dtype,
WQDType=wq_dtype,
use_g1u1=dtypes.str2bool(str(row["use_g1u1"])),
doweight_stage1=dtypes.str2bool(str(row["doweight_stage1"])),
hidden_pad=0,
intermediate_pad=0,
preshuffle=True,
)


def _iter_csv_cases():
"""Yield (kwargs, extras) for every row of every selected model csv."""
cu = get_cu_num()
merged_csv = AITER_CONFIGS.AITER_CONFIG_FMOE_FILE
df_csv = pd.read_csv(merged_csv)
rows = df_csv[df_csv["cu_num"] == cu]
for _, row in rows.iterrows():
kernel_name1 = str(row.get("kernelName1", "") or "")
kernel_name2 = str(row.get("kernelName2", "") or "")
if "flydsl_" not in kernel_name1 and "flydsl_" not in kernel_name2:
continue
try:
kwargs = _row_to_kwargs(row)
except Exception as e:
aiter.logger.warning(
"skip row token=%s dim=(%s,%s): parse error %s",
row.get("token"),
row.get("model_dim"),
row.get("inter_dim"),
e,
)
continue
kwargs["strict_accuracy"] = True
yield kwargs, {
"kernelName1": kernel_name1,
"kernelName2": kernel_name2,
}


_PER1X32_BF16_FP4 = (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2)
_PER1X32_FP8_FP4 = (aiter.QuantType.per_1x32, dtypes.fp8, dtypes.fp4x2)
_PER1X32_FP4_FP4 = (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2)


def _effective_act_type(quant_type, aq_dtype, wq_dtype, act_type):
if (quant_type, aq_dtype, wq_dtype) in (_PER1X32_BF16_FP4, _PER1X32_FP8_FP4):
return aiter.ActivationType.Swiglu
return act_type


def _iter_legacy_cases():
"""Yield (kwargs, extras) for the original CLI-driven sweep."""
extras = {"model": "legacy"}

def _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
act_type,
**over,
):
for preshuffle in args.preshuffle:
return dict(
dtype=dtype,
token=m,
model_dim=model_dim,
inter_dim=inter_dim,
E=args.expert,
topk=args.topk,
actType=_effective_act_type(quant_type, aq_dtype, wq_dtype, act_type),
qType=quant_type,
AQDType=aq_dtype,
WQDType=wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
strict_accuracy=False,
**over,
)

for (
dtype,
(quant_type, aq_dtype, wq_dtype),
(model_dim, inter_dim),
doweight_stage1,
) in itertools.product(args.dtype, l_quant, args.dim, args.doweight_stage1):
triple = (quant_type, aq_dtype, wq_dtype)

if triple in (_PER1X32_BF16_FP4, _PER1X32_FP8_FP4):
for hidden_pad, intermediate_pad in args.hidden_intermediate_pad:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
aiter.ActivationType.Swiglu,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
), extras
elif triple == _PER1X32_FP4_FP4:
for preshuffle in args.preshuffle:
for act_type in args.act:
for m in args.tokenNum:
yield _kw(
dtype,
m,
model_dim,
inter_dim,
quant_type,
aq_dtype,
wq_dtype,
doweight_stage1,
act_type,
preshuffle=preshuffle,
hidden_pad=0,
intermediate_pad=0,
), extras
else:
for act_type in args.act:
for m in args.tokenNum:
ret = test_fmoe(
yield _kw(
dtype,
m,
model_dim,
inter_dim,
args.expert,
args.topk,
act_type,
quant_type,
aq_dtype,
wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
preshuffle=preshuffle,
hidden_pad=0,
intermediate_pad=0,
)
df.append(ret)
else:
for act_type in args.act:
for m in args.tokenNum:
ret = test_fmoe(
dtype,
m,
model_dim,
inter_dim,
args.expert,
args.topk,
act_type,
quant_type,
aq_dtype,
wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
)
df.append(ret)
doweight_stage1,
act_type,
), extras


# ---------------------------------------------------------------------------
# Run
# ---------------------------------------------------------------------------
_case_iters = []
if not args.no_flydsl_csv:
_case_iters.append(_iter_csv_cases())
if not args.no_legacy:
_case_iters.append(_iter_legacy_cases())
case_iter = itertools.chain(*_case_iters)

df = []
seen = 0
for kwargs, extras in case_iter:
seen += 1
ret = test_fmoe(**kwargs)
if ret is None:
continue
ret.update(extras)
df.append(ret)

aiter.logger.info(
"moe_2stage: scanned %d cases, recorded %d results (skipped %d)",
seen,
len(df),
seen - len(df),
)
df = pd.DataFrame(df)
df_md = df.to_markdown(index=False)
aiter.logger.info("moe_2stage summary (markdown):\n%s", df_md)
Loading