From fb60a73ed22cacd520beffe16e2d0f0637406f0b Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Mon, 10 Nov 2025 12:12:00 +0000 Subject: [PATCH 1/9] is_shuffled --- aiter/ops/moe_op.py | 6 ++++++ aiter/ops/shuffle.py | 4 +++- .../gemm_moe_ck2stages_common.py | 15 +++++++-------- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py | 3 +-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index 07fa8ff94c..d81403f706 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -233,6 +233,7 @@ def cmdGenFunc_ck_moe_stage( activation, quant_type, mul_routed_weight_stage, + getattr(w1, 'is_shuffled', False), ) return { "md_name": md_name, @@ -266,6 +267,7 @@ def cmdGenFunc_ck_moe_stage2( activation, quant_type, mul_routed_weight_stage, + getattr(w1, 'is_shuffled', False), ) return { "md_name": md_name, @@ -438,6 +440,7 @@ def get_moe_stage_module( activation, quant_type, mul_routed_weight_stage, + preslf_mode=False, ): if isinstance(activation, int): activation = ActivationType(activation) @@ -447,6 +450,9 @@ def get_moe_stage_module( Adtype = dtype2str_dict[input_dtype] Bdtype = dtype2str_dict[weight_dtype] Cdtype = dtype2str_dict[output_dtype] + + if not preslf_mode and weight_dtype == dtypes.fp4x2: + Bdtype = Bdtype + "_bns" quant_type = ( QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 1ea0e35ac7..311f69c16e 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -22,7 +22,9 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te x_ = x_.permute(0, 1, 3, 4, 2, 5) x_ = x_.contiguous() x_ = x_.view(*x.shape) - return x_.view(x_type) + x_.view(x_type) + x_.is_shuffled = True + return x_ def shuffle_weight_NK( diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index fe8e461b80..1054992e05 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -317,7 +317,7 @@ def name(self) -> str: bit8_list = ["F8", "I8", "f8", "i8"] bit16_list = ["B16", "F16", "b16", "f16"] -bit4_list = ["I4", "i4", "FP4X2", "fp4x2"] +bit4_list = ["I4", "i4", "FP4X2", "fp4x2", "fp4x2_bns"] QuantType_list = [3, 4] @@ -330,7 +330,6 @@ def get_gemm1_kernels_list( ActOP: str, MulRoutedWeight: bool, ) -> list: - global bns_or_preslf arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: if arch == "gfx950": @@ -356,10 +355,10 @@ def get_gemm1_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - if int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1: - tag = "a4w4" - else: + if "bns" in Adtype: tag = "a4w4_bns" + else: + tag = "a4w4" else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm1_kernels_dict[tag] @@ -421,10 +420,10 @@ def get_gemm2_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - if int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1: - tag = "a4w4" - else: + if "bns" in Adtype: tag = "a4w4_bns" + else: + tag = "a4w4" else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm2_kernels_dict[tag] diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index b0343f03ac..240adfe180 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -625,7 +625,6 @@ def generate_instance_and_lookUpTable(self): f_lookUpTable = os.path.join(self.working_path, "gemm_moe_ck2stages_lookup.h") - # breakpoint() with open(f_lookUpTable, "a") as f_lookup: for kernel in kernel_list: ## generate instance @@ -638,7 +637,7 @@ def generate_instance_and_lookUpTable(self): if self.quant_type in [4, 5]: quanttype = "_blockscale" elif "FP4" in self.a_dtype: - if "bns" in tag: + if "bns" in self.a_dtype: quanttype = "_mxfp4_bns" else: quanttype = "_mxfp4" From ce446bfdda3c8566ad6718980c07e98a2b53d25d Mon Sep 17 00:00:00 2001 From: zhimding Date: Tue, 11 Nov 2025 04:37:33 +0000 Subject: [PATCH 2/9] shuffle_weight bugfix --- aiter/ops/moe_op.py | 14 ++++--- .../gemm_moe_ck2stages_common.py | 13 ++++--- .../gen_instances.py | 37 +++++++++++++++++-- 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index d81403f706..ade9d9a0df 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -233,7 +233,7 @@ def cmdGenFunc_ck_moe_stage( activation, quant_type, mul_routed_weight_stage, - getattr(w1, 'is_shuffled', False), + getattr(w1, "is_shuffled", False), ) return { "md_name": md_name, @@ -267,7 +267,7 @@ def cmdGenFunc_ck_moe_stage2( activation, quant_type, mul_routed_weight_stage, - getattr(w1, 'is_shuffled', False), + getattr(w1, "is_shuffled", False), ) return { "md_name": md_name, @@ -450,9 +450,10 @@ def get_moe_stage_module( Adtype = dtype2str_dict[input_dtype] Bdtype = dtype2str_dict[weight_dtype] Cdtype = dtype2str_dict[output_dtype] - - if not preslf_mode and weight_dtype == dtypes.fp4x2: - Bdtype = Bdtype + "_bns" + + preslf_str = "off" + if preslf_mode and weight_dtype == dtypes.fp4x2: + preslf_str = "on" quant_type = ( QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type @@ -465,6 +466,7 @@ def get_moe_stage_module( "module_moe_ck2stages", Adtype, Bdtype, + "slf" + preslf_str, Cdtype, act, quant_type, @@ -472,7 +474,7 @@ def get_moe_stage_module( ] ) blob_gen_cmd = [ - f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} -w {{}}" + f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} -p {preslf_str} -w {{}}" ] return md_name, blob_gen_cmd diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index 1054992e05..8d70f153d6 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -317,7 +317,7 @@ def name(self) -> str: bit8_list = ["F8", "I8", "f8", "i8"] bit16_list = ["B16", "F16", "b16", "f16"] -bit4_list = ["I4", "i4", "FP4X2", "fp4x2", "fp4x2_bns"] +bit4_list = ["I4", "i4", "FP4X2", "fp4x2"] QuantType_list = [3, 4] @@ -329,6 +329,7 @@ def get_gemm1_kernels_list( QuantType: str, ActOP: str, MulRoutedWeight: bool, + preshuffle: str, ) -> list: arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: @@ -355,7 +356,7 @@ def get_gemm1_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - if "bns" in Adtype: + if preshuffle == "off": tag = "a4w4_bns" else: tag = "a4w4" @@ -375,7 +376,7 @@ def get_gemm1_kernels_list( kernel.CDEElementOp = "MulABScaleWint4" elif tag == "a8w8blkscale": kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" - elif tag == "a8w8" or tag == "a4w4": + elif tag == "a8w8" or tag == "a4w4" or tag == "a4w4_bns": kernel.CDEElementOp = "MulABScale" elif tag == "a16w16": if MulRoutedWeight: @@ -392,8 +393,8 @@ def get_gemm2_kernels_list( Nswizzle: bool, QuantType: str, MulRoutedWeight: bool, + preshuffle: str, ) -> list: - global bns_or_preslf arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: @@ -420,7 +421,7 @@ def get_gemm2_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - if "bns" in Adtype: + if preshuffle == "off": tag = "a4w4_bns" else: tag = "a4w4" @@ -439,7 +440,7 @@ def get_gemm2_kernels_list( kernel.CDEElementOp = "MulABScaleExpertWeightWin4" elif tag == "a8w8blkscale": kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" - elif tag == "a8w8" or tag == "a4w4": + elif tag == "a8w8" or tag == "a4w4" or tag == "a4w4_bns": kernel.CDEElementOp = "MulABScaleExpertWeight" elif tag == "a16w16": if MulRoutedWeight: diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 240adfe180..0d174db270 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -591,6 +591,7 @@ def __init__( quant_type, activation, mul_routed_weight_stage, + preshuffle, ): self.working_path = working_path self.a_dtype = a_dtype.upper() @@ -600,6 +601,7 @@ def __init__( self.activation = activation self.mul_routed_weight_stage = mul_routed_weight_stage self.nswizzle = False + self.preshuffle = preshuffle def generate_instance_and_lookUpTable(self): _, gemm1_kernel_list = get_gemm1_kernels_list( @@ -610,6 +612,7 @@ def generate_instance_and_lookUpTable(self): self.quant_type, self.activation, self.mul_routed_weight_stage == 1, + self.preshuffle, ) tag, gemm2_kernel_list = get_gemm2_kernels_list( self.a_dtype, @@ -618,6 +621,7 @@ def generate_instance_and_lookUpTable(self): self.nswizzle, self.quant_type, self.mul_routed_weight_stage == 2, + self.preshuffle, ) kernel_list = list(gemm1_kernel_list.values()) + list( gemm2_kernel_list.values() @@ -637,7 +641,7 @@ def generate_instance_and_lookUpTable(self): if self.quant_type in [4, 5]: quanttype = "_blockscale" elif "FP4" in self.a_dtype: - if "bns" in self.a_dtype: + if "bns" in tag: quanttype = "_mxfp4_bns" else: quanttype = "_mxfp4" @@ -815,6 +819,16 @@ def generate_instance_and_lookUpTable(self): help="the path where all the blobs are going to be generated", ) + parser.add_argument( + "-p", + "--preshuffle", + default="off", + required=False, + type=str, + choices=["off", "on"], + help="Choose the weight mode: bns or pre-shuffle.", + ) + args = parser.parse_args() args.quant_type = ( "per_1x128" if args.quant_type == "per_128x128" else args.quant_type @@ -837,8 +851,21 @@ def generate_instance_and_lookUpTable(self): acts = ["silu", "gelu"] routed_weight_l = [1, 2] general_quant_l = ["per_tensor", "per_token"] - for b_dtype, c_dtype, act, routed_weight, quant in itertools.product( - b_quant_dtypes, c_dtypes, acts, routed_weight_l, general_quant_l + preshuffle_mode_l = ["off"] + for ( + b_dtype, + c_dtype, + act, + routed_weight, + quant, + preshuffle_mode, + ) in itertools.product( + b_quant_dtypes, + c_dtypes, + acts, + routed_weight_l, + general_quant_l, + preshuffle_mode_l, ): a_dtype = b_dtype if b_dtype != "i4" else "f8" quant = quant if b_dtype != "fp4x2" else "per_1x32" @@ -850,6 +877,7 @@ def generate_instance_and_lookUpTable(self): quant_dict[quant], act, routed_weight, + preshuffle_mode, ) codegen.generate_instance_and_lookUpTable() @@ -866,6 +894,7 @@ def generate_instance_and_lookUpTable(self): quant_dict[quant], act, routed_weight, + preshuffle_mode, ) codegen.generate_instance_and_lookUpTable() @@ -889,6 +918,7 @@ def generate_instance_and_lookUpTable(self): quant_dict["no"], act, routed_weight, + preshuffle_mode, ) codegen.generate_instance_and_lookUpTable() else: @@ -902,6 +932,7 @@ def generate_instance_and_lookUpTable(self): quant_dict[args.quant_type], args.activation, args.mul_routed_weight_stage, + args.preshuffle, ) codegen.generate_instance_and_lookUpTable() From 91307976c7e7ad72f5b8629d7c56e9ed2ce8b5e1 Mon Sep 17 00:00:00 2001 From: zhimding Date: Tue, 11 Nov 2025 08:09:07 +0000 Subject: [PATCH 3/9] rm AITER_MXFP4_MOE_SF --- .../ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu | 11 ++--------- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py | 10 ++++------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu index 402409914e..051125faa3 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu @@ -32,20 +32,13 @@ MoeKernel moe_dispatch(std::string &kernelName, int block_m, int inter_dim, at:: } std::cout << "[aiter] ck kernel not found: " << kernelName << std::endl; } - - std::string moe_env_value = "0"; - if (const char* env = std::getenv("AITER_MXFP4_MOE_SF")) { - moe_env_value = std::string(env); - } - bool use_mxfp4_moe_preshuffle = std::string(moe_env_value) == "1"; - if constexpr (stage == 1) { - return moe_stage1_heuristic_dispatch(block_m, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight, use_mxfp4_moe_preshuffle); + return moe_stage1_heuristic_dispatch(block_m, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight); } else { - return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight, use_mxfp4_moe_preshuffle); + return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight); } } diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 0d174db270..1cf67d8481 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -45,7 +45,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_moe_ck2stages.h" -MoeKernel moe_stage1_heuristic_dispatch(int block_m, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage, bool b_preshuffle=true) +MoeKernel moe_stage1_heuristic_dispatch(int block_m, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage) {{ """ @@ -54,7 +54,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_moe_ck2stages.h" -MoeKernel moe_stage2_heuristic_dispatch(int block_m, int inter_dim, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage, bool b_preshuffle=true) +MoeKernel moe_stage2_heuristic_dispatch(int block_m, int inter_dim, at::ScalarType x_dtype, at::ScalarType w_dtype, at::ScalarType y_dtype, int act_op, int quant, bool mul_routed_weight_stage) {{ """ @@ -177,8 +177,7 @@ && dtype_checker<{EDataType}>{{}}(y_dtype) && {ActOP} == act_op && {MulRoutedWeight} == mul_routed_weight_stage - && {Quant} == quant - && b_preshuffle == true) + && {Quant} == quant) {{ if (block_m == 32) {{ @@ -369,8 +368,7 @@ && dtype_checker<{B0DataType}>{{}}(w_dtype) && dtype_checker<{EDataType}>{{}}(y_dtype) && {MulRoutedWeight} == mul_routed_weight_stage - && {Quant} == quant - && b_preshuffle == true) + && {Quant} == quant) {{ if (inter_dim <= 256) {{ From 59e89c88e1b2ecc3d6c2c3ad53c65a75afd5a9a4 Mon Sep 17 00:00:00 2001 From: zhimding Date: Tue, 11 Nov 2025 16:39:18 +0000 Subject: [PATCH 4/9] preshuffle bugfix --- aiter/ops/shuffle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 311f69c16e..a442a16e68 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -22,7 +22,7 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te x_ = x_.permute(0, 1, 3, 4, 2, 5) x_ = x_.contiguous() x_ = x_.view(*x.shape) - x_.view(x_type) + x_ = x_.view(x_type) x_.is_shuffled = True return x_ From ce3a985538e98a68253ba7f3404c5ac675bee5d9 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 12 Nov 2025 03:12:06 +0000 Subject: [PATCH 5/9] refactor --- aiter/ops/moe_op.py | 12 ++++++------ csrc/ck_gemm_moe_2stages_codegen/gen_instances.py | 7 ++----- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index ade9d9a0df..f3c24e043b 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -440,7 +440,7 @@ def get_moe_stage_module( activation, quant_type, mul_routed_weight_stage, - preslf_mode=False, + preshuffle_mode=False, ): if isinstance(activation, int): activation = ActivationType(activation) @@ -451,9 +451,9 @@ def get_moe_stage_module( Bdtype = dtype2str_dict[weight_dtype] Cdtype = dtype2str_dict[output_dtype] - preslf_str = "off" - if preslf_mode and weight_dtype == dtypes.fp4x2: - preslf_str = "on" + preshuffle_str = "" + if preshuffle_mode and weight_dtype == dtypes.fp4x2: + preshuffle_str = "--preshuffle" quant_type = ( QuantType.per_1x128 if quant_type == QuantType.per_128x128 else quant_type @@ -466,7 +466,7 @@ def get_moe_stage_module( "module_moe_ck2stages", Adtype, Bdtype, - "slf" + preslf_str, + "preshuffle_on" if preshuffle_mode else "preshuffle_off", Cdtype, act, quant_type, @@ -474,7 +474,7 @@ def get_moe_stage_module( ] ) blob_gen_cmd = [ - f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} -p {preslf_str} -w {{}}" + f"{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py -a {Adtype} -b {Bdtype} -c {Cdtype} -q {quant_type} -act {act} -m {mul_routed_weight_stage} {preshuffle_str} -w {{}}" ] return md_name, blob_gen_cmd diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 1cf67d8481..7c73cdbbc6 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -820,11 +820,8 @@ def generate_instance_and_lookUpTable(self): parser.add_argument( "-p", "--preshuffle", - default="off", - required=False, - type=str, - choices=["off", "on"], - help="Choose the weight mode: bns or pre-shuffle.", + action="store_true", + help="enable pre-shuffle weight mode", ) args = parser.parse_args() From acdfc983fa6b60580c6abbd7a388cbdd994cf5f3 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 12 Nov 2025 03:18:45 +0000 Subject: [PATCH 6/9] refactor bugfix --- .../gemm_moe_ck2stages_common.py | 16 ++++++++-------- .../ck_gemm_moe_2stages_codegen/gen_instances.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index 8d70f153d6..cc166a962f 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -329,7 +329,7 @@ def get_gemm1_kernels_list( QuantType: str, ActOP: str, MulRoutedWeight: bool, - preshuffle: str, + preshuffle: bool, ) -> list: arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: @@ -356,10 +356,10 @@ def get_gemm1_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - if preshuffle == "off": - tag = "a4w4_bns" - else: + if preshuffle: tag = "a4w4" + else: + tag = "a4w4_bns" else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm1_kernels_dict[tag] @@ -393,7 +393,7 @@ def get_gemm2_kernels_list( Nswizzle: bool, QuantType: str, MulRoutedWeight: bool, - preshuffle: str, + preshuffle: bool, ) -> list: arch = get_gfx() @@ -421,10 +421,10 @@ def get_gemm2_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - if preshuffle == "off": - tag = "a4w4_bns" - else: + if preshuffle: tag = "a4w4" + else: + tag = "a4w4_bns" else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm2_kernels_dict[tag] diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 7c73cdbbc6..0b98c37c7a 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -846,7 +846,7 @@ def generate_instance_and_lookUpTable(self): acts = ["silu", "gelu"] routed_weight_l = [1, 2] general_quant_l = ["per_tensor", "per_token"] - preshuffle_mode_l = ["off"] + preshuffle_mode_l = [False] for ( b_dtype, c_dtype, From 20f701ee96f789fa8fc1a6ab14a44ccf123fbb0d Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Wed, 12 Nov 2025 11:02:11 +0000 Subject: [PATCH 7/9] add bns/preshuffle moe mxfp4 UT tests --- op_tests/test_moe_2stage.py | 46 +++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 73e2624c2f..ab4a262b09 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -49,6 +49,7 @@ def test_fmoe( doweight_stage1=False, hidden_pad=0, intermediate_pad=0, + preshuffle=False, ): if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32: return @@ -175,7 +176,7 @@ def weight_per_128x128_quant(weight, quant_dtype): w1_scale_aiter = shuffle_scale_a16w4(w1_scale, E, True) w2_qt_aiter = shuffle_weight_a16w4(w2_qt_aiter, 16, False) w2_scale_aiter = shuffle_scale_a16w4(w2_scale, E, False) - elif WQDType != dtypes.fp4x2 or int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1: + elif WQDType != dtypes.fp4x2 or preshuffle: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) @@ -289,6 +290,7 @@ def weight_per_128x128_quant(weight, quant_dtype): l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] l_doweight_stage1 = [False, True][:1] l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)][1:2] +l_preshuffle = [False, True] parser = argparse.ArgumentParser( @@ -382,6 +384,18 @@ def weight_per_128x128_quant(weight, quant_dtype): e.g.: -k 2""", ) +parser.add_argument( + "-p", + "--preshuffle", + type=dtypes.str2bool, + nargs="?", + const=None, + default=None, + help="""Whether to use pre-shuffle weight mode. Default is [False, True]. + -p f # False. + -p t # True.""", +) + args = parser.parse_args() if args.dtype is None: l_dtype = [dtypes.d_dtypes[key] for key in l_dtype] @@ -402,13 +416,17 @@ def weight_per_128x128_quant(weight, quant_dtype): if args.doweight_stage1 is not None: l_doweight_stage1 = [args.doweight_stage1] +if args.preshuffle is not None: + l_preshuffle = [args.preshuffle] + df = [] for ( dtype, (quant_type, aq_dtype, wq_dtype), (model_dim, inter_dim), doweight_stage1, -) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1): + preshuffle, +) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1, l_preshuffle): if (quant_type, aq_dtype, wq_dtype) == ( aiter.QuantType.per_1x32, dtypes.bf16, @@ -433,6 +451,30 @@ def weight_per_128x128_quant(weight, quant_dtype): intermediate_pad=intermediate_pad, ) df.append(ret) + elif (quant_type, aq_dtype, wq_dtype) == ( + aiter.QuantType.per_1x32, + dtypes.fp4x2, + dtypes.fp4x2, + ): + for preshuffle in l_preshuffle: + for act_type in l_act: + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + act_type, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + preshuffle=preshuffle, + ) + df.append(ret) else: for act_type in l_act: for m in l_tokenNum: From 56666aa21037d6c2de578f9dd53e2fc152b0fc32 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Thu, 13 Nov 2025 03:38:17 +0000 Subject: [PATCH 8/9] add L2 verification --- op_tests/test_moe_2stage.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index ab4a262b09..87607f5fe4 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -257,6 +257,15 @@ def weight_per_128x128_quant(weight, quant_dtype): out2_ck, msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) + + def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + logits_diff = calc_diff(out2_ref, out2_ck) + assert logits_diff < 1e-3 return {"us": us2, "err": err} From 0624bbf6a13d9d81dfcc3f8576e9759f44310197 Mon Sep 17 00:00:00 2001 From: xudoyuan Date: Thu, 13 Nov 2025 03:40:31 +0000 Subject: [PATCH 9/9] black op_tests/test_moe_2stage.py --- op_tests/test_moe_2stage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 87607f5fe4..c798d1b569 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -257,13 +257,13 @@ def weight_per_128x128_quant(weight, quant_dtype): out2_ck, msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) - + def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return 1 - sim - + logits_diff = calc_diff(out2_ref, out2_ck) assert logits_diff < 1e-3