Skip to content
10 changes: 9 additions & 1 deletion aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def cmdGenFunc_ck_moe_stage(
activation,
quant_type,
mul_routed_weight_stage,
getattr(w1, "is_shuffled", False),
)
return {
"md_name": md_name,
Expand Down Expand Up @@ -266,6 +267,7 @@ def cmdGenFunc_ck_moe_stage2(
activation,
quant_type,
mul_routed_weight_stage,
getattr(w1, "is_shuffled", False),
)
return {
"md_name": md_name,
Expand Down Expand Up @@ -438,6 +440,7 @@ def get_moe_stage_module(
activation,
quant_type,
mul_routed_weight_stage,
preshuffle_mode=False,
):
if isinstance(activation, int):
activation = ActivationType(activation)
Expand All @@ -448,6 +451,10 @@ def get_moe_stage_module(
Bdtype = dtype2str_dict[weight_dtype]
Cdtype = dtype2str_dict[output_dtype]

preshuffle_str = ""
if preshuffle_mode and weight_dtype == dtypes.fp4x2:
preshuffle_str = "--preshuffle"

quant_type = (
QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type
)
Expand All @@ -459,14 +466,15 @@ def get_moe_stage_module(
"module_moe_ck2stages",
Adtype,
Bdtype,
"preshuffle_on" if preshuffle_mode else "preshuffle_off",
Cdtype,
act,
quant_type,
f"mulWeightStage{mul_routed_weight_stage}",
]
)
blob_gen_cmd = [
f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} -w {{}}"
f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} -w {{}}"
]

return md_name, blob_gen_cmd
Expand Down
4 changes: 3 additions & 1 deletion aiter/ops/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_.view(x_type)
x_ = x_.view(x_type)
x_.is_shuffled = True
return x_


def shuffle_weight_NK(
Expand Down
11 changes: 2 additions & 9 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,13 @@ MoeKernel moe_dispatch(std::string &kernelName, int block_m, int inter_dim, at::
}
std::cout << "[aiter] ck kernel not found: " << kernelName << std::endl;
}

std::string moe_env_value = "0";
if (const char* env = std::getenv("AITER_MXFP4_MOE_SF")) {
moe_env_value = std::string(env);
}
bool use_mxfp4_moe_preshuffle = std::string(moe_env_value) == "1";

if constexpr (stage == 1)
{
return moe_stage1_heuristic_dispatch(block_m, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight, use_mxfp4_moe_preshuffle);
return moe_stage1_heuristic_dispatch(block_m, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight);
}
else
{
return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight, use_mxfp4_moe_preshuffle);
return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight);
}
}

Expand Down
12 changes: 6 additions & 6 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def get_gemm1_kernels_list(
QuantType: str,
ActOP: str,
MulRoutedWeight: bool,
preshuffle: bool,
) -> list:
global bns_or_preslf
arch = get_gfx()
if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype:
if arch == "gfx950":
Expand All @@ -356,7 +356,7 @@ def get_gemm1_kernels_list(
):
tag = "a8w4"
elif Adtype in bit4_list and Bdtype in bit4_list:
if int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1:
if preshuffle:
tag = "a4w4"
else:
tag = "a4w4_bns"
Expand All @@ -376,7 +376,7 @@ def get_gemm1_kernels_list(
kernel.CDEElementOp = "MulABScaleWint4"
elif tag == "a8w8blkscale":
kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale"
elif tag == "a8w8" or tag == "a4w4":
elif tag == "a8w8" or tag == "a4w4" or tag == "a4w4_bns":
kernel.CDEElementOp = "MulABScale"
elif tag == "a16w16":
if MulRoutedWeight:
Expand All @@ -393,8 +393,8 @@ def get_gemm2_kernels_list(
Nswizzle: bool,
QuantType: str,
MulRoutedWeight: bool,
preshuffle: bool,
) -> list:
global bns_or_preslf
arch = get_gfx()

if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype:
Expand All @@ -421,7 +421,7 @@ def get_gemm2_kernels_list(
):
tag = "a8w4"
elif Adtype in bit4_list and Bdtype in bit4_list:
if int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1:
if preshuffle:
tag = "a4w4"
else:
tag = "a4w4_bns"
Expand All @@ -440,7 +440,7 @@ def get_gemm2_kernels_list(
kernel.CDEElementOp = "MulABScaleExpertWeightWin4"
elif tag == "a8w8blkscale":
kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale"
elif tag == "a8w8" or tag == "a4w4":
elif tag == "a8w8" or tag == "a4w4" or tag == "a4w4_bns":
kernel.CDEElementOp = "MulABScaleExpertWeight"
elif tag == "a16w16":
if MulRoutedWeight:
Expand Down
43 changes: 34 additions & 9 deletions csrc/ck_gemm_moe_2stages_codegen/gen_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_moe_ck2stages.h"

MoeKernel moe_stage1_heuristic_dispatch(int block_m, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage, bool b_preshuffle=true)
MoeKernel moe_stage1_heuristic_dispatch(int block_m, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage)
{{
"""

Expand All @@ -54,7 +54,7 @@
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_moe_ck2stages.h"

MoeKernel moe_stage2_heuristic_dispatch(int block_m, int inter_dim, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage, bool b_preshuffle=true)
MoeKernel moe_stage2_heuristic_dispatch(int block_m, int inter_dim, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage)
{{
"""

Expand Down Expand Up @@ -177,8 +177,7 @@
&& dtype_checker<{EDataType}>{{}}(y_dtype)
&& {ActOP} == act_op
&& {MulRoutedWeight} == mul_routed_weight_stage
&& {Quant} == quant
&& b_preshuffle == true)
&& {Quant} == quant)
{{
if (block_m == 32)
{{
Expand Down Expand Up @@ -369,8 +368,7 @@
&& dtype_checker<{B0DataType}>{{}}(w_dtype)
&& dtype_checker<{EDataType}>{{}}(y_dtype)
&& {MulRoutedWeight} == mul_routed_weight_stage
&& {Quant} == quant
&& b_preshuffle == true)
&& {Quant} == quant)
{{
if (inter_dim <= 256)
{{
Expand Down Expand Up @@ -591,6 +589,7 @@ def __init__(
quant_type,
activation,
mul_routed_weight_stage,
preshuffle,
):
self.working_path = working_path
self.a_dtype = a_dtype.upper()
Expand All @@ -600,6 +599,7 @@ def __init__(
self.activation = activation
self.mul_routed_weight_stage = mul_routed_weight_stage
self.nswizzle = False
self.preshuffle = preshuffle

def generate_instance_and_lookUpTable(self):
_, gemm1_kernel_list = get_gemm1_kernels_list(
Expand All @@ -610,6 +610,7 @@ def generate_instance_and_lookUpTable(self):
self.quant_type,
self.activation,
self.mul_routed_weight_stage == 1,
self.preshuffle,
)
tag, gemm2_kernel_list = get_gemm2_kernels_list(
self.a_dtype,
Expand All @@ -618,14 +619,14 @@ def generate_instance_and_lookUpTable(self):
self.nswizzle,
self.quant_type,
self.mul_routed_weight_stage == 2,
self.preshuffle,
)
kernel_list = list(gemm1_kernel_list.values()) + list(
gemm2_kernel_list.values()
)

f_lookUpTable = os.path.join(self.working_path, "gemm_moe_ck2stages_lookup.h")

# breakpoint()
with open(f_lookUpTable, "a") as f_lookup:
for kernel in kernel_list:
## generate instance
Expand Down Expand Up @@ -816,6 +817,13 @@ def generate_instance_and_lookUpTable(self):
help="the path where all the blobs are going to be generated",
)

parser.add_argument(
"-p",
"--preshuffle",
action="store_true",
help="enable pre-shuffle weight mode",
)

args = parser.parse_args()
args.quant_type = (
"per_1x128" if args.quant_type == "per_128x128" else args.quant_type
Expand All @@ -838,8 +846,21 @@ def generate_instance_and_lookUpTable(self):
acts = ["silu", "gelu"]
routed_weight_l = [1, 2]
general_quant_l = ["per_tensor", "per_token"]
for b_dtype, c_dtype, act, routed_weight, quant in itertools.product(
b_quant_dtypes, c_dtypes, acts, routed_weight_l, general_quant_l
preshuffle_mode_l = [False]
for (
b_dtype,
c_dtype,
act,
routed_weight,
quant,
preshuffle_mode,
) in itertools.product(
b_quant_dtypes,
c_dtypes,
acts,
routed_weight_l,
general_quant_l,
preshuffle_mode_l,
):
a_dtype = b_dtype if b_dtype != "i4" else "f8"
quant = quant if b_dtype != "fp4x2" else "per_1x32"
Expand All @@ -851,6 +872,7 @@ def generate_instance_and_lookUpTable(self):
quant_dict[quant],
act,
routed_weight,
preshuffle_mode,
)
codegen.generate_instance_and_lookUpTable()

Expand All @@ -867,6 +889,7 @@ def generate_instance_and_lookUpTable(self):
quant_dict[quant],
act,
routed_weight,
preshuffle_mode,
)
codegen.generate_instance_and_lookUpTable()

Expand All @@ -890,6 +913,7 @@ def generate_instance_and_lookUpTable(self):
quant_dict["no"],
act,
routed_weight,
preshuffle_mode,
)
codegen.generate_instance_and_lookUpTable()
else:
Expand All @@ -903,6 +927,7 @@ def generate_instance_and_lookUpTable(self):
quant_dict[args.quant_type],
args.activation,
args.mul_routed_weight_stage,
args.preshuffle,
)
codegen.generate_instance_and_lookUpTable()

Expand Down
55 changes: 53 additions & 2 deletions op_tests/test_moe_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_fmoe(
doweight_stage1=False,
hidden_pad=0,
intermediate_pad=0,
preshuffle=False,
):
if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32:
return
Expand Down Expand Up @@ -175,7 +176,7 @@ def weight_per_128x128_quant(weight, quant_dtype):
w1_scale_aiter = shuffle_scale_a16w4(w1_scale, E, True)
w2_qt_aiter = shuffle_weight_a16w4(w2_qt_aiter, 16, False)
w2_scale_aiter = shuffle_scale_a16w4(w2_scale, E, False)
elif WQDType != dtypes.fp4x2 or int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1:
elif WQDType != dtypes.fp4x2 or preshuffle:
w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16))
w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16))
w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale)
Expand Down Expand Up @@ -257,6 +258,15 @@ def weight_per_128x128_quant(weight, quant_dtype):
msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})",
)

def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim

logits_diff = calc_diff(out2_ref, out2_ck)
assert logits_diff < 1e-3

return {"us": us2, "err": err}


Expand Down Expand Up @@ -289,6 +299,7 @@ def weight_per_128x128_quant(weight, quant_dtype):
l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1]
l_doweight_stage1 = [False, True][:1]
l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)][1:2]
l_preshuffle = [False, True]


parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -382,6 +393,18 @@ def weight_per_128x128_quant(weight, quant_dtype):
e.g.: -k 2""",
)

parser.add_argument(
"-p",
"--preshuffle",
type=dtypes.str2bool,
nargs="?",
const=None,
default=None,
help="""Whether to use pre-shuffle weight mode. Default is [False, True].
-p f # False.
-p t # True.""",
)

args = parser.parse_args()
if args.dtype is None:
l_dtype = [dtypes.d_dtypes[key] for key in l_dtype]
Expand All @@ -402,13 +425,17 @@ def weight_per_128x128_quant(weight, quant_dtype):
if args.doweight_stage1 is not None:
l_doweight_stage1 = [args.doweight_stage1]

if args.preshuffle is not None:
l_preshuffle = [args.preshuffle]

df = []
for (
dtype,
(quant_type, aq_dtype, wq_dtype),
(model_dim, inter_dim),
doweight_stage1,
) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1):
preshuffle,
) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1, l_preshuffle):
if (quant_type, aq_dtype, wq_dtype) == (
aiter.QuantType.per_1x32,
dtypes.bf16,
Expand All @@ -433,6 +460,30 @@ def weight_per_128x128_quant(weight, quant_dtype):
intermediate_pad=intermediate_pad,
)
df.append(ret)
elif (quant_type, aq_dtype, wq_dtype) == (
aiter.QuantType.per_1x32,
dtypes.fp4x2,
dtypes.fp4x2,
):
for preshuffle in l_preshuffle:
for act_type in l_act:
for m in l_tokenNum:
ret = test_fmoe(
dtype,
m,
model_dim,
inter_dim,
args.expert,
args.topk,
act_type,
quant_type,
aq_dtype,
wq_dtype,
use_g1u1=True,
doweight_stage1=doweight_stage1,
preshuffle=preshuffle,
)
df.append(ret)
else:
for act_type in l_act:
for m in l_tokenNum:
Expand Down