From 1d4ad9b5cfb6727a20be53fae29809a701b8daed Mon Sep 17 00:00:00 2001 From: solin Date: Tue, 4 Nov 2025 03:49:39 +0000 Subject: [PATCH 01/20] draft of cktile moe --- aiter/jit/optCompilerConfig.json | 31 +- aiter/ops/shuffle.py | 67 ++ .../ck_tile_gemm_moe_2stages/gen_instances.py | 578 ++++++++++++++++++ .../moe_cktile2stages.cu | 184 ++++++ .../moe_cktile2stages.h | 72 +++ .../moe_cktile2stages_common.cuh | 332 ++++++++++ .../moe_cktile2stages_common.py | 448 ++++++++++++++ csrc/include/rocm_ops.hpp | 38 ++ csrc/pybind/moe_cktile_2stages_pybind.cu | 9 + op_tests/test_moe_2stage.py | 446 +++++++++++--- 10 files changed, 2099 insertions(+), 106 deletions(-) create mode 100644 csrc/ck_tile_gemm_moe_2stages/gen_instances.py create mode 100644 csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu create mode 100644 csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h create mode 100644 csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh create mode 100644 csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py create mode 100644 csrc/pybind/moe_cktile_2stages_pybind.cu diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 46494488ac..2de0314dab 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -273,14 +273,16 @@ "srcs": [ "f'{AITER_CSRC_DIR}/pybind/deepgemm_pybind.cu'", "f'{AITER_CSRC_DIR}/ck_deepgemm/deepgemm.cu'" - ], "flags_extra_cc": [], "flags_extra_hip": [], "md_name": "'module_deepgemm'", "extra_ldflags": "None", - "extra_include": ["f'{CK_DIR}/example/ck_tile/18_flatmm'", "f'{AITER_CSRC_DIR}/ck_deepgemm/include'"], - "verbose": "False", + "extra_include": [ + "f'{CK_DIR}/example/ck_tile/18_flatmm'", + "f'{AITER_CSRC_DIR}/ck_deepgemm/include'" + ], + "verbose": "False", "is_python_module": "True", "is_standalone": "False", "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", @@ -392,6 +394,24 @@ "hip_clang_path": "os.environ.get('GEMM_A4W4_BLOCKWISE_HIP_CLANG_PATH')", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py --working_path {{}}'" }, + "module_moe_cktile2stages": { + "srcs": [ + "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh'", + "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages.h'", + "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu'", + "f'{AITER_CSRC_DIR}/pybind/moe_cktile_2stages_pybind.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "md_name": "'module_moe_cktile2stages'", + "extra_ldflags": "None", + "extra_include": [], + "verbose": "False", + "is_python_module": "True", + "is_standalone": "False", + "hip_clang_path": "os.environ.get('FLATMM_HIP_CLANG_PATH')", + "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/gen_instances.py --working_path {{}}'" + }, "module_moe_sorting": { "srcs": [ "f'{AITER_CSRC_DIR}/py_itfs_ck/moe_sorting_kernels.cu'", @@ -966,7 +986,8 @@ "module_mla_reduce": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/mla_reduce_pybind.cu'", - "f'{AITER_CSRC_DIR}/kernels/mla/reduce.cu'"], + "f'{AITER_CSRC_DIR}/kernels/mla/reduce.cu'" + ], "flags_extra_cc": [], "flags_extra_hip": [], "extra_ldflags": "None", @@ -974,4 +995,4 @@ "verbose": "False", "blob_gen_cmd": "''" } -} +} \ No newline at end of file diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 3d10076cd1..6ddde3af0b 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -23,3 +23,70 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te x_ = x_.contiguous() x_ = x_.view(*x.shape) return x_.view(x_type) + + +def shuffle_weight_NK(x: torch.Tensor, inst_N: int, inst_K: int, use_int4=False) -> torch.Tensor: + kPerLane = inst_K // (64 // inst_N) + if(use_int4): + kPerLane *= 2 + assert x.shape[-2] % inst_N == 0, f"{x.shape[-2]} % {inst_N} == {x.shape[-2] % N_WARP_TILE }" + assert x.shape[-1] % inst_K == 0, f"{x.shape[-1]} % {inst_K} == {x.shape[-1] % K_WARP_TILE }" + + x_ = x + x_ = x_.view(-1, x.shape[-2] // inst_N, inst_N, x.shape[-1] // inst_K, 64 // inst_N, kPerLane) + x_ = x_.permute(0, 1, 3, 4, 2, 5).contiguous() + return x_.view(*x.shape) + + +def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: + """ + src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 + Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] + """ + # print("gemm shape:", src.shape) + experts_cnt, N, K_pk = src.shape + if gate_up: + N = N // 2 + KPack = 16 + KLane = 64 // NLane #4 + N0 = N // NLane + K0 = K_pk // (KLane * KPack) + if (gate_up): + src_reshaped = src.view(experts_cnt, 2, N0, NLane, K0, KLane, KPack) # [E,2, N0, NLane ,K0, KLane, KPack] + src_reshaped = src_reshaped.permute(0, 2, 1, 4, 5, 3, 6).contiguous() # [E, N0, 2, K0, KLane, NLane, KPack] + interleaved = src_reshaped.view(*src.shape) + else: + src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) + interleaved = src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) + # print("interleaved shape:", interleaved.shape) + return interleaved.contiguous() + + +def shuffle_mxfp4_scale(src: torch.Tensor, gate_up: bool) -> torch.Tensor: + n_experts, n_, k_ = src.shape + # n_ = n_experts // experts_cnt + # MXFP4 constants + K_Pack = 2 + N_Pack = 2 + N_Lane = 16 + K_Lane = 64 // N_Lane # 4 + + # Basic dimensions + K1 = k_ // K_Pack // K_Lane # k_ // 8 + N1 = n_ // N_Lane // N_Pack # n_ // 32 + real_k =32 * k_ * K_Pack * K_Lane # 1x32 quant + assert real_k >= 256, f"K {real_k} must be larger than Tile_K(256)" + # print("src shape", src.shape) + # Reshape based on moe_kind + if gate_up: + # Reshape to: [E, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane] + shfl_scale = src.view(n_experts, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane) + # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] + shfl_scale = shfl_scale.permute(0, 2, 4, 6, 3, 5, 1).contiguous() + else: + # Reshape to: [E, K1, K_Pack, K_Lane, N1, N_Pack, N_Lane] + shfl_scale = src.view(n_experts, N1, N_Pack, N_Lane, K1, K_Pack, K_Lane) + # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] + shfl_scale = shfl_scale.permute(0, 1, 4, 6, 3, 5, 2).contiguous() + # print("shf_scale shape:", shfl_scale.shape) + return shfl_scale.view((n_experts * n_, k_)).contiguous() diff --git a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py new file mode 100644 index 0000000000..c1373ecae9 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import os +import argparse +from pathlib import Path +import shutil +import re +from moe_cktile2stages_common import ( + kernelInstance, + get_gemm1_kernels_list, + get_gemm2_kernels_list, + get_heuristic_dispatch_template, +) +import sys + +this_dir = os.path.dirname(os.path.abspath(__file__)) +AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") +if os.path.exists(os.path.join(AITER_CORE_DIR, "aiter_meta")): + AITER_CORE_DIR = os.path.join(AITER_CORE_DIR, "aiter/jit/utils") # pip install mode +else: + AITER_CORE_DIR = os.path.abspath( + f"{this_dir}/../../aiter/jit/utils" + ) # develop mode +sys.path.insert(0, AITER_CORE_DIR) + + +class cktile_moe_2stage_gemm_codegen: + def __init__( + self, + working_path, + ab_dtype, + acc_dtype, + c_dtype, + quant_type, + activation, + mul_routed_weight_stage, + istune=False, + ): + self.working_path = working_path + self.impl_path = os.path.join(working_path, "impl") + self.instances_path = os.path.join(working_path, "instances") + self.istune = istune + self.ab_dtype = ab_dtype.lower() + self.acc_dtype = acc_dtype.lower() + self.c_dtype = c_dtype.lower() + self.quant_type = quant_type + self.activation = activation + self.mul_routed_weight_stage = mul_routed_weight_stage + + def get_suffix(self, stage: int) -> str: + return ("_").join( + element + for element in [ + self.quant_type, + "MulRoutedWeight" if self.mul_routed_weight_stage == stage else "", + "" if (stage == 2) else self.activation, + ] + if element != "" + ) + + def gen_instance(self, k: kernelInstance): + INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages_common.cuh" + +template +torch::Tensor +{k.name}( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros = 0, + std::optional k_padded_zeros = 0, + std::optional topk_weight = std::nullopt, + std::optional x_scale = std::nullopt, + std::optional w_scale = std::nullopt, + std::optional exp_bias = std::nullopt) +{{{{ + // The smallest kernel we have available. Works well for memory bound shapes. + int NumTokens = XQ.size(0); + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); + int E = WQ.size(0); + int KBatch = 1; + int stride_A = K; + int stride_B = K; + int stride_C = N / {3 - k.stage}; //gemm1 gate+up need / 2. + void *sorted_weights_ptr = topk_weight.has_value() ? topk_weight.value().data_ptr() : nullptr; + + {{INSTANCE_CONTENT}} + return Y; +}}}} + +""" + # default no quant + scaleGranA = "-1" + scaleGranB = "-1" + biasGran = "-1" + xptr = "nullptr" + wptr = "nullptr" + biasptr = "nullptr" + if k.QuantType == "per_tenser": + scaleGranA = "0" + scaleGranB = "0" + xptr = "static_cast(x_scale.value().data_ptr()[0])" + wptr = "static_cast(w_scale.value().data_ptr()[0])" + elif k.QuantType == "per_token": + scaleGranA = "1" + scaleGranB = "1" + xptr = "static_cast(x_scale.value().data_ptr())" + wptr = "static_cast(w_scale.value().data_ptr())" + elif k.QuantType == "1x32": + scaleGranA = "-1" + scaleGranB = "1, 32" + biasGran = "1" + xptr = "nullptr" + wptr = "static_cast(w_scale.value().data_ptr())" + biasptr = "static_cast(exp_bias.value().data_ptr())" + + INSTANCE_CONTENT = f"""auto per_a_scale_dev_ptr = ck_tile::FlatmmScalePointer<{scaleGranA}>{{{xptr}}}; + auto per_b_scale_dev_ptr = ck_tile::FlatmmScalePointer<{scaleGranB}>{{{wptr}}}; + auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<{biasGran}>{{{biasptr}}}; + ck_tile::MoeFlatmmHostArgs kernel_args{{ + reinterpret_cast(sorted_ids.data_ptr()), + sorted_weights_ptr, + reinterpret_cast(sorted_expert_ids.data_ptr()), + reinterpret_cast(max_token_ids.data_ptr()), + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + reinterpret_cast(Y.data_ptr()), + NumTokens, + E, + topk, + 1, // k_batch + M, + N, + K, + stride_A, + stride_B, + stride_C, + n_padded_zeros.value(), + k_padded_zeros.value(), + per_a_scale_dev_ptr, + per_b_scale_dev_ptr, + exp_bias_dev_ptr + }}; + using TileConfig = MoeFlatmmConfig; + // Run kernel instance. + auto stream_config = ck_stream_config{{at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream()}}; + moe_gemm, + AccDataType, + CDataType, + row_major, + col_major, + ck_tile::tuple<>, + row_major, + {"ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up" if k.stage == 1 else "ck_tile::MoeFlatmmKind::kFFN_gemm2"}, + ck_tile::element_wise::PassThrough + >(kernel_args, stream_config); +""" + + INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT=(INSTANCE_CONTENT)) + + Path(os.path.join(self.impl_path, f"{k.name}.cuh")).write_text( + INSTANCE_IMPL_str + ) + + INSTANCE_template = """// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "{name}.cuh" + +template torch::Tensor +{name}<{dtypes}>( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias); + +""" + + # if self.istune: + # INSTANCE_abI8_dBF16_eBF16 = INSTANCE_template.format( + # name=k.name, dtypes="I8, B16" + # ) + # Path( + # os.path.join(self.instances_path, f"{k.name}_abI8_dB16_eB16.cpp") + # ).write_text(INSTANCE_abI8_dBF16_eBF16) + # else: + def fill_template(name, a_type, b_type, acc_type, c_type): + nonlocal self + intsance = INSTANCE_template.format( + name=name, dtypes=f"{a_type}, {b_type}, {acc_type}, {c_type}" + ) + Path( + os.path.join( + self.instances_path, + f"{name}_a{a_type}_b{b_type}_acc{acc_type}_C{c_type}.cpp", + ) + ).write_text(intsance) + + if (k.QuantType == "1x32") and (self.ab_dtype in ["bf16", "fp16"]): + fill_template(k.name, self.ab_dtype, "pk_fp4", self.acc_dtype, self.c_dtype) + else: + for CDtype in ["bf16", "fp16"]: + for ABDtype in ["fp8"]: # "bf16", "fp16", + for AccDtype in ["float"]: + fill_template(k.name, ABDtype, ABDtype, AccDtype, CDtype) + # intsance = INSTANCE_template.format( + # name=k.name, dtypes=f"{ABDtype}, {AccDtype}, {CDtype}" + # ) + # Path( + # os.path.join( + # self.instances_path, + # f"{k.name}_ab{ABDtype}_acc{AccDtype}_C{CDtype}.cpp", + # ) + # ).write_text(intsance) + + """genarete heuristic dispatch""" + + def gen_heuristic_dispatch(self, tag, kernels_dict): + HEURISTIC_template = get_heuristic_dispatch_template(tag) + # print(HEURISTIC_template) + + def validate_and_format(template: str, mapping: dict) -> str: + # check all format element in dict. + str_mapping = {str(key): value.name for key, value in mapping.items()} + cleaned_template = template.replace("{{", "").replace("}}", "") + placeholders = re.findall(r"\{([^{}]*)\}", cleaned_template) + missing = [p for p in placeholders if p not in str_mapping] + # print(placeholders) + # print(str_mapping) + if missing: + raise KeyError(f"Missing keys in mapping: {missing}") + result = template + # for placeholder in placeholders: + # result = result.replace(placeholder, str_mapping[placeholder]) + # return result + return template.format(**{k: v for k, v in str_mapping.items()}) + + # create heuristic heirarchy + with open( + os.path.join(self.working_path, "moe_cktile2stages_heuristic_dispatch.h"), + "w", + ) as f: + f.write(validate_and_format(HEURISTIC_template, kernels_dict)) + # arch = get_gfx() + # inst_k = "32" if self.quant_type == "1x32" else ("128" if arch == "gfx950" else "64") + # f.write( + # HEURISTIC_template.format( + # inst_k=inst_k, + # suffix1 = self.get_suffix(1), + # suffix2 = self.get_suffix(2) + # ) + # ) + + """generate lookup.h linking MNK/datatype to specific instance""" + + def gen_lookup_dict(self, kernels_dict): + LOOKUP_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +// #ifdef USE_ROCM + +#define GENERATE_LOOKUP_TABLE(ABTYPE, ACCTYPE, CTYPE) \\ + { \\""" + + LOOKUP_template = """ + {{{MNK}, \\ + {kernel_name}}}, \\""" + + LOOKUP_end = """ + } + +// #endif // USE_ROCM +""" + with open( + os.path.join(self.working_path, "moe_cktile2stages_lookup.h"), "w" + ) as f: + f.write(LOOKUP_head) + for mnk, k in kernels_dict.items(): + print(":", k.name) + # if not tunning, tuned mnk = {stage, m, n, k} + if not self.istune and ( + isinstance(mnk, tuple) and (len(mnk) == 4) and mnk[1] > 0 + ): + f.write( + LOOKUP_template.format( + MNK="{" + + (", ").join(map(lambda x: str(x), list(mnk))) + + "}", + kernel_name=k.name, + ) + ) + # if tunning, mnk = -1,-2..... + elif self.istune and isinstance(mnk, int): + f.write(LOOKUP_template.format(MNK=mnk, kernel_name=k.name)) + f.write(LOOKUP_end) + + """generate manifest.h for instance header""" + + def gen_manifest_head(self, kernels_dict): + MAINFEST_head = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// #ifdef USE_ROCM + +#include + +#include +""" + MAINFEST_template = """ +template +torch::Tensor +{kernel_name}( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias); +""" + MAINFEST_end = """ + +// endif // USE_ROCM +""" + + with open( + os.path.join(self.working_path, "moe_cktile2stages_manifest.h"), "w" + ) as f: + f.write(MAINFEST_head) + for mnk, k in kernels_dict.items(): + f.write(MAINFEST_template.format(kernel_name=k.name)) + f.write(MAINFEST_end) + + """generate all instances and headers""" + + def gen_instances(self, tag, kernels_dict): + if os.path.exists(self.impl_path): + shutil.rmtree(self.impl_path) + os.mkdir(self.impl_path) + if os.path.exists(self.instances_path): + shutil.rmtree(self.instances_path) + os.mkdir(self.instances_path) + + for mnk, k in kernels_dict.items(): + self.gen_instance(k) + + self.gen_lookup_dict(kernels_dict) + self.gen_manifest_head(kernels_dict) + self.gen_heuristic_dispatch(tag, kernels_dict) + + +# def get_tune_dict(tune_dict_csv): +# tune_dict = default_kernels_dict +# if os.path.exists(tune_dict_csv): +# tune_df = pd.read_csv(tune_dict_csv) +# if torch.cuda.is_available(): +# gpu = torch.cuda.current_device() +# device_properties = torch.cuda.get_device_properties(gpu) +# cu_num = device_properties.multi_processor_count +# tune_df = tune_df[tune_df["cu_num"] == cu_num].reset_index() +# for i in range(len(tune_df)): +# M = tune_df.loc[i, "M"] +# N = tune_df.loc[i, "N"] +# K = tune_df.loc[i, "K"] +# kid = tune_df.loc[i, "kernelId"] +# tune_dict[(M, N, K)] = kernels_list[kid] +# return tune_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate ck_tile 2stage gemm instance." + ) + + # Add arguments + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated", + ) + + parser.add_argument( + "-f", + "--tune_file", + default="aiter/configs/a8w8_tuned_gemm.csv", + required=False, + help="tune_file include the result after run gemm_a8w8_tune.py", + ) + + parser.add_argument( + "-a", + "--a_dtype", + nargs="*", + required=False, + type=str, + choices=["f8", "i8", "f16", "b16"], + help="select input dtype", + ) + + parser.add_argument( + "-b", + "--b_dtype", + nargs="*", + required=False, + type=str, + choices=["f8", "i8", "f16", "b16", "i4"], + help="select weight dtype", + ) + + parser.add_argument( + "-c", + "--c_dtype", + default="b16", + required=False, + type=str, + choices=["f16", "b16"], + help="select out dtype", + ) + + parser.add_argument( + "-q", + "--quant_type", + default="per_tensor", + required=False, + type=str, + choices=[ + "per_tensor", + "per_token", + "1x32", + "128x128", + "no", + ], + help="select quant_type", + ) + + parser.add_argument( + "-act", + "--activation", + default="silu", + required=False, + type=str, + choices=["silu", "gelu"], + help="select activation", + ) + + parser.add_argument( + "-m", + "--mul_routed_weight_stage", + default=2, + required=False, + type=int, + choices=[1, 2], + help="select quant_type", + ) + + args = parser.parse_args() + + # # build all + # if args.b_dtype is None: + # # quanted moe + # b_quant_dtypes = ["f8"] + # c_dtypes = ["f16", "b16"] + # acts = ["silu"] #, "gelu"] + # general_quant_l = ["per_tensor", "per_token"] + # for b_dtype, c_dtype, act, quant in itertools.product( + # b_quant_dtypes, c_dtypes, acts, general_quant_l + # ): + # a_dtype = b_dtype + # codegen = cktile_moe_2stage_gemm_codegen( + # args.working_path, + # a_dtype, + # b_dtype, + # c_dtype, + # quant, + # act, + # ) + # codegen.generate_instance_and_lookUpTable() + + # # no-quant moe + # b_quant_dtypes = [ + # "f16", + # "b16", + # ] + # for ( + # b_dtype, + # act, + # ) in itertools.product(b_quant_dtypes, acts): + # c_dtype = a_dtype = b_dtype + + # codegen = cktile_moe_2stage_gemm_codegen( + # args.working_path, + # a_dtype, + # b_dtype, + # c_dtype, + # "no", + # act, + # ) + # codegen.generate_instance_and_lookUpTable() + # else: + + # single UT + # a_type = "fp8" + # b_type = "fp8" + # quant_type = "per_token" + + a_type = "bf16" + b_type = "fp4" + quant_type = "1x32" + + acc_type = "float" + c_type = "bf16" + act_type = "silu" + codegen = cktile_moe_2stage_gemm_codegen( + args.working_path, a_type, acc_type, c_type, quant_type, act_type, 2, False + ) + # gen all instances for gemm1 and gemm2 + _, gemm1_kernel_list = get_gemm1_kernels_list( + a_type, + b_type, + quant_type, + act_type, + False, + ) + tag, gemm2_kernel_list = get_gemm2_kernels_list( + a_type, + b_type, + quant_type, + "", + True, + ) + # merge gemm1/gemm2 dict with key = {stage, key} + kernel_dict_merge = { + **{(1, key): value for key, value in gemm1_kernel_list.items()}, + **{(2, key): value for key, value in gemm2_kernel_list.items()}, + } + # print(kernel_dict_merge) + codegen.gen_instances(tag, kernel_dict_merge) diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu new file mode 100644 index 0000000000..c6432e9c17 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages_common.cuh" +#include "moe_cktile2stages_lookup.h" +#include "moe_cktile2stages_manifest.h" +#include "moe_cktile2stages_heuristic_dispatch.h" +#include +#include "py_itfs_common.h" + +template +MoeKernel moe_dispatch(int M, int N, int K, int block_m) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + + // static const auto lookup = [&] + // { + // return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(ABDataType, AccDataType, CDataType)}; + // }(); + + // // First check if this shape(M,N,K) is available in the direct lookup. + // auto it = lookup.find({M, N, K}); + // // If we found an optimal kernel, use it. + // if (it != lookup.end()) + // { + // return it->second; + // } + + // int padded_m = M; + // if (M > 1 && M <= 16) + // { + // padded_m = 16; + // } + // else if (M <= 16384) + // { + // padded_m = nextPow2(M); + // } + // else if (M <= 20480) + // { + // padded_m = 20480; + // } + // // Second check if this shape(padded_m,N,K) is available in the direct lookup. + // it = lookup.find({padded_m, N, K}); + // // If we found an optimal kernel, use it. + // if (it != lookup.end()) + // { + // return it->second; + // } + // Otherwise, use heuristics. + if (stage == 1){ + return moe_gemm1_heuristic_dispatch(M, N, K, block_m); + } + else{ + return moe_gemm2_heuristic_dispatch(M, N, K, block_m); + } +} + + + +torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m) +{ + + TORCH_CHECK(Y.dtype() == at::ScalarType::BFloat16 || Y.dtype() == at::ScalarType::Half, + "Out dtype only support BFloat16/Float16!"); + if (x_scale != std::nullopt && w_scale != std::nullopt){ + TORCH_CHECK(x_scale.value().dtype() == w_scale.value().dtype(), + "Scales should have the same dtype!"); + } + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); + int MPerBlock = block_m.value(); + + // const at::cuda::OptionalCUDAGuard device_guard(device_of(Y)); + // at::cuda::getCurrentCUDAStream().stream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); + at::hip::getCurrentHIPStream(); + // if (!XQ || !WQ || !sorted_ids || !sorted_expert_ids || !max_token_ids || !topk_weight) + // { + // std::cerr << "detect null ptr !" << std::endl; + // return; + // } + + if (XQ.dtype() == torch_fp8) + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + // if (Y.dtype() == at::ScalarType::BFloat16) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + } + else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == at::ScalarType::Byte)) //a16w4 + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + if (Y.dtype() == at::ScalarType::BFloat16) + { + moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias); + } + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} + +torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m) +{ + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); + int MPerBlock = block_m.value(); + + // const at::cuda::OptionalCUDAGuard device_guard(device_of(Y)); + // at::cuda::getCurrentCUDAStream().stream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); + at::hip::getCurrentHIPStream(); + // if (!XQ. || !WQ || !sorted_ids || !sorted_expert_ids || !max_token_ids || !topk_weight) + // { + // std::cerr << "detect null ptr !" << std::endl; + // return; + // } + + if (XQ.dtype() == torch_fp8) + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + // if (Y.dtype() == at::ScalarType::BFloat16) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + } + else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == at::ScalarType::Byte)) //a16w4 + { + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } + if (Y.dtype() == at::ScalarType::BFloat16) + { + moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias); + } + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + return Y; +} \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h new file mode 100644 index 0000000000..7be1db64d6 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h @@ -0,0 +1,72 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +// #include "moe_flatmm.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/moe_flatmm.hpp" +#include "py_itfs_common.h" +// #include +// #include +#include +#include +#include + +#include +#include +#include + +using MoeKernel = std::function< + torch::Tensor(torch::Tensor &, torch::Tensor &, + torch::Tensor &, torch::Tensor &, + torch::Tensor &, torch::Tensor &, + int, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional)>; +using ck_stream_config = ck_tile::stream_config; +using row_major = ck_tile::tensor_layout::gemm::RowMajor; +using col_major = ck_tile::tensor_layout::gemm::ColumnMajor; +using bf16 = ck_tile::bf16_t; +using fp16 = ck_tile::half_t; +using fp8 = ck_tile::fp8_t; +using pk_fp4 = ck_tile::pk_fp4_t; + +__attribute__((visibility("default"))) torch::Tensor +cktile_moe_gemm1(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros = 0, + std::optional k_padded_zeros = 0, + std::optional topk_weight = std::nullopt, + std::optional x_scale = std::nullopt, + std::optional w_scale = std::nullopt, + std::optional exp_bias = std::nullopt, + std::optional block_m = 32); + +__attribute__((visibility("default"))) torch::Tensor +cktile_moe_gemm2(torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros = 0, + std::optional k_padded_zeros = 0, + std::optional topk_weight = std::nullopt, + std::optional x_scale = std::nullopt, + std::optional w_scale = std::nullopt, + std::optional exp_bias = std::nullopt, + std::optional block_m = 32); \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh new file mode 100644 index 0000000000..c5040f126f --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/moe_flatmm.hpp" +#include "moe_cktile2stages.h" +#include +#include +#include + +#include +#include +#include +#include + +// #include +// #include +// #include +#include +#include +#include +#include + +template +struct MoeFlatmmConfig +{ + static constexpr ck_tile::index_t M_Tile = M_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t K_Tile = K_Tile_; + + static constexpr ck_tile::index_t M_Warp = M_Warp_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = kBlockPerCu_; + static constexpr int TileParitionerGroupNum = 1; + static constexpr int TileParitionerM01 = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; + +template +void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) +{ + using CodegenFlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits; // Preshuffle_ + + constexpr bool MXFP4_Pipeline = std::is_same_v; + + if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up) + { + static_assert( + FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0, + "requires NRepeat is multiple of 2 for FFN_gemm1_gate_up"); + } + + using ComputeDataType = ADataType; + static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), + "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); + + using GemmPipelineProblem = ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using CodegenPipelineProblem = + std::conditional_t, + ck_tile::FlatmmPipelineProblem>; + + constexpr int BlockedXDLN_PerWarp = + (MXFP4_Pipeline || (moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)) + ? 2 + : 1; // determined by scale shuffle pattern + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPipeline = std::conditional_t< + MXFP4_Pipeline, + ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1, + ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1>; + + using FusedAct = + std::conditional_t; + + using Kernel = ck_tile::MoeFlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + // if(!Kernel::IsSupportedArgument(kargs)) + // { + // throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // } + + // if(s.log_level_ > 0) + // { + // std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" + // << "Shape: " << CodegenFlatmmShape::GetName() << "\n" + // << "problem: " << CodegenPipelineProblem::GetName() << "\n" + // << "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n" + // << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + // << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + // << std::endl; + // } + // + // if(s.flush_cache_) + // { + // std::cout << "Flushing cache..." << std::endl; + // static constexpr ck_tile::index_t APackedSize = + // std::is_same_v ? 2 : 1; + // static constexpr ck_tile::index_t BPackedSize = + // std::is_same_v ? 2 : 1; + + // ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + // moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK + // : args.NumTokens, + // args.K, + // args.stride_A, + // is_row_major(ALayout{}))); + // ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + // args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{}))); + + // const int outputN = + // moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N; + + // auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + // auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + // ck_tile::RotatingMemWrapper rotating_mem( + // kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + // rotating_mem.Print(); + + // auto run_flush_cache = [&]() { + // // flush icache + // ck_tile::flush_icache(); + // // rotating mem + // rotating_mem.Next(); + // // clear c mem + // if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2) + // hipGetErrorString(hipMemsetAsync( + // args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), + // s.stream_id_)); + // else if(args.k_batch > 1) + // hipGetErrorString( + // hipMemsetAsync(args.e_ptr, + // 0, + // args.NumTokens * args.TopK * outputN * sizeof(CDataType), + // s.stream_id_)); + // }; + // ave_time = ck_tile::launch_kernel_preprocess( + // s, + // run_flush_cache, + // ck_tile::make_kernel( + // Kernel{}, grids, blocks, 0, kargs)); + // } + // else + // { + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + // } + // return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +} \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py new file mode 100644 index 0000000000..f1be74edd8 --- /dev/null +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py @@ -0,0 +1,448 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +from dataclasses import dataclass +import os +import sys + +this_dir = os.path.dirname(os.path.abspath(__file__)) +AITER_CORE_DIR = os.path.abspath(f"{this_dir}/../../../") +if os.path.exists(os.path.join(AITER_CORE_DIR, "aiter_meta")): + AITER_CORE_DIR = os.path.join(AITER_CORE_DIR, "aiter/jit/utils") # pip install mode +else: + AITER_CORE_DIR = os.path.abspath( + f"{this_dir}/../../aiter/jit/utils" + ) # develop mode +sys.path.insert(0, AITER_CORE_DIR) + +from chip_info import get_gfx # noqa: E402 + + +@dataclass +class kernelInstance: + stage: int + BLOCK_SIZE: int + MPerBlock: int + NPerBlock: int + KPerBlock: int + WAVE_TILE_M: int + WAVE_TILE_N: int + WAVE_TILE_K: int + WAVE_MAP_M: int + WAVE_MAP_N: int + Block_Per_CU: int = 1 + MulRoutedWeight: bool = False + ActOP: str = "silu" + QuantType: str = "per_tensor" + + @property + def name(self) -> str: + return ("_").join( + element + for element in [ + f"moe_cktile2stages_gemm{self.stage}", + ("x").join( + map( + lambda x: str(x), + [ + self.BLOCK_SIZE, + self.MPerBlock, + self.NPerBlock, + self.KPerBlock, + ], + ) + ), + ("x").join(map(lambda x: str(x), [self.WAVE_MAP_M, self.WAVE_MAP_N])), + ("x").join( + map( + lambda x: str(x), + [self.WAVE_TILE_M, self.WAVE_TILE_N, self.WAVE_TILE_K], + ) + ), + str(self.Block_Per_CU) + "perCU", + self.QuantType, + "MulRoutedWeight" if self.MulRoutedWeight else "", + "" if (self.stage == 2) else self.ActOP, + ] + if element != "" + ) + + +# fmt: off +# gemm1 out:bf16/fp16 AB:fp8/i8 +a8w8_gemm1_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + # 0: kernelInstance( 1, 256, 32, 64, 256, 16, 16, 128, 1, 4,), + 1: kernelInstance( 1, 256, 32, 128, 128, 16, 16, 128, 1, 4,), + 2: kernelInstance( 1, 256, 64, 128, 128, 16, 16, 128, 1, 4,), + 4: kernelInstance( 1, 256, 64, 128, 256, 16, 16, 128, 1, 4,), + 4: kernelInstance( 1, 256, 128, 128, 128, 16, 16, 128, 1, 4,), + 5: kernelInstance( 1, 256, 128, 128, 128, 16, 16, 128, 1, 4,), + 6: kernelInstance( 1, 256, 256, 128, 128, 16, 16, 128, 1, 4,), +} + +# gemm2 out:bf16/fp16 AB:fp8/i8 +a8w8_gemm2_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + 0: kernelInstance( 2, 256, 32, 128, 256, 16, 16, 128, 1, 4,), + 1: kernelInstance( 2, 256, 64, 128, 256, 16, 16, 128, 1, 4,), + 2: kernelInstance( 2, 256, 128, 128, 128, 16, 16, 128, 1, 4,), + 3: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 128, 1, 4,), + 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 128, 1, 4,), +} + + +#a8w8 +a8w8_gemm1_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + # 0: kernelInstance( 1, 256, 32, 64, 256, 16, 16, 64, 1, 4,), + # 1: kernelInstance( 1, 256, 32, 64, 128, 16, 16, 64, 1, 4,), + # 2: kernelInstance( 1, 256, 64, 64, 256, 16, 16, 64, 2, 2,), + # 3: kernelInstance( 1, 256, 64, 64, 128, 16, 16, 64, 1, 4,), + 3: kernelInstance( 1, 256, 64, 128, 128, 16, 16, 64, 1, 4), + # 4: kernelInstance( 1, 256, 128, 64, 128, 16, 16, 64, 1, 4,), + # 5: kernelInstance( 1, 256, 128, 128, 128, 16, 16, 64, 1, 4,), + # 6: kernelInstance( 1, 256, 256, 128, 128, 16, 16, 64, 1, 4,), +} +a8w8_gemm2_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| + # 0: kernelInstance( 2, 256, 32, 64, 256, 16, 16, 64, 1, 4,), + # 1: kernelInstance( 2, 256, 64, 64, 256, 16, 16, 64, 1, 4,), + # 2: kernelInstance( 2, 256, 128, 64, 128, 16, 16, 64, 1, 4,), + # 3: kernelInstance( 2, 256, 256, 64, 128, 16, 16, 64, 1, 4,), + # 4: kernelInstance( 2, 256, 64, 128, 256, 16, 16, 128, 1, 4,), + # 5: kernelInstance( 2, 256, 128, 128, 128, 16, 16, 64, 1, 4,), + # 6: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 64, 1, 4,), + # 7: kernelInstance( 2, 256, 32, 64, 128, 16, 16, 64, 1, 4,), + 8: kernelInstance( 2, 256, 64, 128, 128, 16, 16, 64, 1, 4,), +} + + +# gemm1 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm1_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N|| BlockPerCU| + 0: kernelInstance( 1, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 1, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 1, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 1, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 1, 256, 128, 256, 256, 16, 16, 32, 1, 4, 1,), +} +# gemm1 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm1_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N|| BlockPerCU| + 0: kernelInstance( 1, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 1, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 1, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 1, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 1, 256, 128, 256, 256, 16, 16, 32, 1, 4, 1,), +} +# gemm2 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm2_kernels_list= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| BlockPerCU| + 0: kernelInstance( 2, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 2, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 2, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 2, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 128, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 256, 256, 256, 16, 16, 32, 1, 4,), + # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), +} +# gemm2 out:bf16/fp16 AB:bf16/fp4 +a16w4_gemm2_kernels_list_gfx950= { + # kernel: stage| BLOCK_SIZE|MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_TILE_K| WAVE_MAP_M| WAVE_MAP_N| BlockPerCU| + 0: kernelInstance( 2, 256, 16, 128, 256, 16, 16, 32, 1, 4, 2,), + # 5: kernelInstance( 2, 256, 16, 512, 256, 16, 16, 32, 1, 4, 4,), + 1: kernelInstance( 2, 256, 32, 256, 256, 16, 16, 32, 1, 4, 2,), + 3: kernelInstance( 2, 256, 64, 256, 256, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 128, 256, 128, 16, 16, 32, 1, 4, 1,), + # 4: kernelInstance( 2, 256, 256, 256, 256, 16, 16, 32, 1, 4,), + # 4: kernelInstance( 2, 256, 256, 128, 128, 16, 16, 32, 1, 4,), +} + +# fmt: on +gemm1_kernels_dict = { + "a8w8_gfx950": a8w8_gemm1_kernels_list_gfx950, + "a8w8": a8w8_gemm1_kernels_list, + "a16w4_gfx950": a16w4_gemm1_kernels_list_gfx950, + "a16w4": a16w4_gemm1_kernels_list, +} + +gemm2_kernels_dict = { + "a8w8_gfx950": a8w8_gemm2_kernels_list_gfx950, + "a8w8": a8w8_gemm2_kernels_list, + "a16w4_gfx950": a16w4_gemm2_kernels_list_gfx950, + "a16w4": a16w4_gemm2_kernels_list, +} + + +a8w8_gfx950_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 32) + {{ + return {(1, 1)}; + }} + else if (block_m == 64) + {{ + return {(1, 2)}; + }} + //else if (block_m == 128) + //{{ + // return {(1, 4)}; + //}} + //else if (block_m == 256) + //{{ + // return {(1, 6)}; + //}} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} +}} + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 32) + {{ + return {(2, 0)}; + }} + else if (block_m == 64) + {{ + return {(2, 1)}; + }} + //else if (block_m == 128) + //{{ + // return {(2, 2)}; + //}} + //else if (block_m == 256) + //{{ + // return {(2, 3)}; + //}} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm1 heuristic dispatch: ", + block_m); + }} +}} +""" + +a16w4_gfx950_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(1, 0)}; + }} + else if (block_m == 32) + {{ + return {(1, 1)}; + }} + else if (block_m == 64) + {{ + return {(1, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} +}} + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(2, 0)}; + }} + else if (block_m == 32) + {{ + return {(2, 1)}; + }} + else if (block_m == 64) + {{ + return {(2, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm2 heuristic dispatch: ", + block_m); + }} +}} +""" + +a16w4_heuristic_dispatch = """#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_cktile2stages.h" + +template +MoeKernel moe_gemm1_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(1, 0)}; + }} + else if (block_m == 32) + {{ + return {(1, 1)}; + }} + else if (block_m == 64) + {{ + return {(1, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_geem1 heuristic dispatch: ", + block_m); + }} +}} + +template +MoeKernel moe_gemm2_heuristic_dispatch(int M, int N, int K, int block_m) +{{ + // Apply shape heuristics to find a suitable kernel implementation. + if (block_m == 16) + {{ + return {(2, 0)}; + }} + else if (block_m == 32) + {{ + return {(2, 1)}; + }} + else if (block_m == 64) + {{ + return {(2, 3)}; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe_gemm2 heuristic dispatch: ", + block_m); + }} +}} +""" + +heuristic_dispatch_dict = { + "a8w8_gfx950": a8w8_gfx950_heuristic_dispatch, + # "a8w8": a8w8_gemm2_kernels_list, + "a16w4_gfx950": a16w4_gfx950_heuristic_dispatch, + "a16w4": a16w4_heuristic_dispatch, +} + + +bit8_list = ["f8", "i8", "fp8"] +bit16_list = ["b16", "f16", "bf16", "fp16"] +bit4_list = ["i4", "fp4x2", "fp4"] +QuantType_list = ["no", "per_tensor", "per_token", "per_1x128", "per_1x32"] + + +def get_gemm1_kernels_list( + Adtype: str, + Bdtype: str, + QuantType: str = "none", + ActOP: str = "silu", + MulRoutedWeight: bool = False, +) -> list: + arch = get_gfx() + if Adtype.lower() in bit8_list and Bdtype.lower() in bit8_list and Adtype == Bdtype: + if arch == "gfx950": + tag = "a8w8_gfx950" + else: + tag = "a8w8" + elif Adtype in bit16_list and Bdtype in bit4_list: + if arch == "gfx950": + tag = "a16w4_gfx950" + else: + tag = "a16w4" + else: + raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") + kernels_list = gemm1_kernels_dict[tag] + for id, kernel in kernels_list.items(): + kernel.MulRoutedWeight = MulRoutedWeight + kernel.ActOP = ActOP + kernel.QuantType = QuantType + # if tag == "a8w4": + # kernel.CDEElementOp = "MulABScaleWint4" + # elif tag == "a8w8blkscale": + # kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" + # elif tag == "a8w8" or tag == "a4w4": + # kernel.CDEElementOp = "MulABScale" + # elif tag == "a16w16": + # if MulRoutedWeight: + # kernel.CDEElementOp = "TypeCastExpertWeight" + # else: + # kernel.CDEElementOp = "TypeCast" + return tag, kernels_list + + +def get_gemm2_kernels_list( + Adtype: str, + Bdtype: str, + QuantType: str = "", + ActOP: str = "", + MulRoutedWeight: bool = True, +) -> list: + arch = get_gfx() + if Adtype in bit8_list and Bdtype in bit8_list and Adtype == Bdtype: + if arch == "gfx950": + tag = "a8w8_gfx950" + else: + tag = "a8w8" + elif Adtype in bit16_list and Bdtype in bit4_list: + if arch == "gfx950": + tag = "a16w4_gfx950" + else: + tag = "a16w4" + else: + raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") + kernels_list = gemm2_kernels_dict[tag] + for id, kernel in kernels_list.items(): + kernel.MulRoutedWeight = MulRoutedWeight + kernel.ActOP = "" + kernel.QuantType = QuantType + # if tag == "a8w4": + # kernel.CDEElementOp = "MulABScaleExpertWeightWin4" + # elif tag == "a8w8blkscale": + # kernel.CDEElementOp = "MulABScaleExpertWeightA8W8blkscale" + # elif tag == "a8w8" or tag == "a4w4": + # kernel.CDEElementOp = "MulABScaleExpertWeight" + # elif tag == "a16w16": + # if MulRoutedWeight: + # kernel.CDEElementOp = "TypeCastExpertWeight" + # else: + # kernel.CDEElementOp = "TypeCast" + return tag, kernels_list + + +def get_heuristic_dispatch_template(tag): + if tag not in heuristic_dispatch_dict.keys(): + raise ValueError(f"Unsupported type for heuristic_dispatch: {tag}") + return heuristic_dispatch_dict[tag] diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 7926085f17..94bb0b0434 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -732,6 +732,44 @@ namespace py = pybind11; py::arg("quant_type") = 0, \ py::arg("activation") = 0); +#define MOE_CKTILE_2STAGES_PYBIND \ + m.def("cktile_moe_gemm1", \ + &cktile_moe_gemm1, \ + "cktile_moe_gemm1", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("Y"), \ + py::arg("sorted_ids"), \ + py::arg("sorted_expert_ids"), \ + py::arg("max_token_ids"), \ + py::arg("topk"), \ + py::arg("n_padded_zeros") = 0, \ + py::arg("k_padded_zeros") = 0, \ + py::arg("topk_weight") = std::nullopt, \ + py::arg("x_scale") = std::nullopt, \ + py::arg("w_scale") = std::nullopt, \ + py::arg("exp_bias") = std::nullopt, \ + py::arg("block_m") = 32); \ + \ + \ + m.def("cktile_moe_gemm2", \ + &cktile_moe_gemm2, \ + "cktile_moe_gemm2", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("Y"), \ + py::arg("sorted_ids"), \ + py::arg("sorted_expert_ids"), \ + py::arg("max_token_ids"), \ + py::arg("topk"), \ + py::arg("n_padded_zeros") = 0, \ + py::arg("k_padded_zeros") = 0, \ + py::arg("topk_weight") = std::nullopt, \ + py::arg("x_scale") = std::nullopt, \ + py::arg("w_scale") = std::nullopt, \ + py::arg("exp_bias") = std::nullopt, \ + py::arg("block_m") = 32); + #define MHA_VARLEN_FWD_PYBIND \ m.def("mha_varlen_fwd", \ &aiter::torch_itfs::mha_varlen_fwd, \ diff --git a/csrc/pybind/moe_cktile_2stages_pybind.cu b/csrc/pybind/moe_cktile_2stages_pybind.cu new file mode 100644 index 0000000000..35bc1ebd04 --- /dev/null +++ b/csrc/pybind/moe_cktile_2stages_pybind.cu @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "rocm_ops.hpp" +#include "moe_cktile2stages.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + MOE_CKTILE_2STAGES_PYBIND; +} diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index eb39528619..02a5d82908 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -11,6 +11,7 @@ from aiter.jit.utils.chip_info import get_gfx import argparse import pandas as pd +import numpy as np from aiter.fused_moe import ( fused_topk, @@ -22,7 +23,7 @@ ) -from aiter.ops.shuffle import shuffle_weight +from aiter.ops.shuffle import shuffle_weight, shuffle_weight_NK from aiter import ActivationType torch.int4 = getattr(torch, "int4", torch.uint32) @@ -95,11 +96,12 @@ def ck_moe_stage2( D = w2.shape[1] # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - out = torch.zeros( + out = torch.empty( (token_num, D), dtype=dtype, device=hidden_states.device, ) + out.fill_(0) aiter.ck_moe_stage2_fwd( hidden_states, w1, @@ -119,6 +121,157 @@ def ck_moe_stage2( ) return out +def cktile_moe_stage1( + hidden_states, + w1, # [E, inter_dim*2, model_dim] + w2, # [E, model_dim, inter_dim] + sorted_token_ids, # [max_num_tokens_padded] + sorted_expert_ids, # [max_num_m_blocks] + num_valid_ids, # [1] + w1_scale, + a1_scale, + exp_bias1, + dtype, + topk, + n_pad_zeros = 0, + k_pad_zeros = 0, + block_size=32, + Activation=ActivationType.Silu, + quant_type=aiter.QuantType.No, + sorted_weights=None, # [max_num_tokens_padded] +): + token_num = hidden_states.shape[0] + _, n1, k1 = w1.shape + _, k2, n2 = w2.shape + D = n2 if k2 == k1 else n2*2 #bit4 format + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + if w1.dtype is torch.uint32: + D = D * 8 + out = torch.empty((token_num, topk, D), dtype=dtype) + # print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0])) + aiter.moe_cktile2stages_gemm1( + hidden_states, + w1, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a1_scale, + w1_scale, + exp_bias1, + block_size, + ) + return out + +def cktile_moe_stage2( + hidden_states, + w1, # [E, inter_dim*2, model_dim] + w2, # [E, model_dim, inter_dim] + sorted_token_ids, # [max_num_tokens_padded] + sorted_expert_ids, # [max_num_m_blocks] + num_valid_ids, # [1] + w2_scale, + a2_scale, + exp_bias2, + dtype, + topk, + n_pad_zeros = 0, + k_pad_zeros = 0, + block_size=32, + Activation=ActivationType.Silu, + quant_type=aiter.QuantType.No, + sorted_weights=None, # [max_num_tokens_padded] + zeros_out = False +): + token_num = hidden_states.shape[0] + D = w2.shape[1] + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + out = torch.empty( + (token_num, D), + dtype=dtype, + device=hidden_states.device, + ) + if zeros_out: + out.fill_(0) + # print("Run cktile_moe_stage2: M=%d, N=%d, K=%d, topk=%d, expert=%d"%(hidden_states.shape[0]*hidden_states.shape[1], w2.shape[1], hidden_states.shape[2], topk, w2.shape[0])) + aiter.moe_cktile2stages_gemm2( + hidden_states, + w2, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a2_scale, + w2_scale, + exp_bias2, + block_size, + ) + return out + + +def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: + """ + src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 + Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] + """ + # print("gemm shape:", src.shape) + experts_cnt, N, K_pk = src.shape + if gate_up: + N = N // 2 + KPack = 16 + KLane = 64 // NLane #4 + N0 = N // NLane + K0 = K_pk // (KLane * KPack) + assert KLane * KPack * K0 == K_pk, f"K({K_pk}) is not a divisble of 64." + assert NLane * N0 == N, f"N({K_pk}) is not a divisble of 16." + if (gate_up): + src_reshaped = src.view(experts_cnt, 2, N0, NLane, K0, KLane, KPack) # [E,2, N0, NLane ,K0, KLane, KPack] + src_reshaped = src_reshaped.permute(0, 2, 1, 4, 5, 3, 6).contiguous() # [E, N0, 2, K0, KLane, NLane, KPack] + interleaved = src_reshaped.view(*src.shape) + else: + src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) + interleaved = src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) + # print("interleaved shape:", interleaved.shape) + return interleaved.contiguous() + +def shuffle_mxfp4_scale(src: torch.Tensor, experts_cnt: int, gate_up: bool) -> torch.Tensor: + n_experts, k_ = src.shape + n_ = n_experts // experts_cnt + # MXFP4 constants + K_Pack = 2 + N_Pack = 2 + N_Lane = 16 + K_Lane = 64 // N_Lane # 4 + + # Basic dimensions + K1 = k_ // K_Pack // K_Lane # k_ // 8 + N1 = n_ // N_Lane // N_Pack # n_ // 32 + real_k =32 * k_ * K_Pack * K_Lane # 1x32 quant + assert K1 * K_Pack * K_Lane == k_, f"K {k_*32} must be divisible of 256" + # print("src shape", src.shape) + # Reshape based on moe_kind + if gate_up: + # Reshape to: [E, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane] + shfl_scale = src.view(experts_cnt, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane) + # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] + shfl_scale = shfl_scale.permute(0, 2, 4, 6, 3, 5, 1).contiguous() + else: + # Reshape to: [E, K1, K_Pack, K_Lane, N1, N_Pack, N_Lane] + shfl_scale = src.view(experts_cnt, N1, N_Pack, N_Lane, K1, K_Pack, K_Lane) + # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] + shfl_scale = shfl_scale.permute(0, 1, 4, 6, 3, 5, 2).contiguous() + # print("shf_scale shape:", shfl_scale.shape) + return shfl_scale.view(*src.shape).contiguous() @benchmark() def test_fmoe( @@ -140,12 +293,24 @@ def test_fmoe( torch_quant = aiter.get_torch_quant(qType) torch_act = aiter.get_torch_act(actType) input = torch.randn((token, model_dim), dtype=dtype) + need_pad = qType == aiter.QuantType.per_1x32 + npad0 = 192 + kpad0 = 128 if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) + if need_pad: + w1[:,:,-kpad0:] = 0 + w1[:,-npad0:,:] = 0 + w1[:,inter_dim-npad0:inter_dim,:] = 0 + exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) + exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) - + if need_pad: + w2[:,:,-kpad0:] = 0 + w2[:,-npad0:,:] = 0 + exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) # rand topk_weights, topk_ids = fused_topk(input, score, topk, True) @@ -155,9 +320,10 @@ def test_fmoe( M, _ = topk_ids.shape - BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) + # BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) + BLOCK_SIZE_M = 32 if M > 1024 else 16 if qType == aiter.QuantType.per_128x128: - BLOCK_SIZE_M = 64 + BLOCK_SIZE_M = 64 if M > 64 else 16 sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M ) @@ -218,25 +384,23 @@ def weight_per_128x128_quant(weight, quant_dtype): ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]): #a16w4 + a1_qt = input.to(AQDType) + a1_scale = None else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) - # w1_scale = w1_scale.fill_(1) - # a1_scale = a1_scale.fill_(1) - out1_ref = torch_moe_stage1( - a1_qt, - w1_qt, - w2_qt, - topk_weights, - topk_ids, - dtype=dtype, - activation=actType, - quant_type=qType, - a1_scale=a1_scale, - w1_scale=w1_scale, - doweight=doweight_stage1, - ) + #bias dtype convert + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + exp_bias1_aiter = exp_bias1.to(dtypes.fp32) + exp_bias2_aiter = exp_bias2.to(dtypes.fp32) + else: + exp_bias1_aiter = exp_bias1 = None + exp_bias2_aiter = exp_bias2 = None + #pre-shuffle + w1_scale_aiter = w1_scale + w2_scale_aiter = w2_scale if WQDType == torch.int4: # int4 w quant w1_qt_aiter = rearrange_4bit_elements( convert_int8_to_uint32_int4( @@ -248,13 +412,39 @@ def weight_per_128x128_quant(weight, quant_dtype): shuffle_weight(w2_qt_aiter, (16, 16), use_int4=True) ) ) + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + w1_qt_aiter = shuffle_mxfp4_weight(w1_qt_aiter, 16, True) + w1_scale_aiter = shuffle_mxfp4_scale(w1_scale, E, True) + w2_qt_aiter = shuffle_mxfp4_weight(w2_qt_aiter, 16, False) + w2_scale_aiter = shuffle_mxfp4_scale(w2_scale, E, False) + elif WQDType != dtypes.fp4x2 and (get_gfx() in ["gfx950"]) and (qType != aiter.QuantType.per_128x128): + inst_K = 128 // w1_qt_aiter.element_size() + w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) + w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) elif WQDType != dtypes.fp4x2: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) + + # # ######################## stage 1 start ########### + out1_ref = torch_moe_stage1( + a1_qt, + w1_qt, + w2_qt, + topk_weights, + topk_ids, + dtype=dtype, + activation=actType, + quant_type=qType, + a1_scale=a1_scale, + w1_scale=w1_scale, + w1_bias=exp_bias1, + doweight=doweight_stage1, + ) + # # ######################## ck stage 1 start ########### - # # a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) - # # out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - # out1_ck, us = run_perftest( + out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) + + # out1_ck, us1 = run_perftest( # ck_moe_stage1, # a1_qt, # w1_qt_aiter, @@ -273,11 +463,39 @@ def weight_per_128x128_quant(weight, quant_dtype): # needTrace=True, # ) + # cktile_2stage + out1_ck, us1 = run_perftest( + cktile_moe_stage1, + a1_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w1_scale_aiter, + a1_scale, + exp_bias1_aiter, + dtype, + topk, + npad0 * 2, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type=qType, + sorted_weights=sorted_weights if doweight_stage1 else None, + # needTrace=True, + # num_iters=2, + # num_warmup=0, + ) # checkAllclose( - # out1_ref, - # out1_ck, - # msg=f"[perf] ck_moe_stage1:{us:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", + # out1_ref[:,:-npad0] if need_pad else out1_ref, + # out1_ck[:,:-npad0] if need_pad else out1_ck, + # msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", # ) + # diff = torch.abs(out1_ref - out1_ck) + # max_value= diff.max() + # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) + # print("max_diff", max_value.item(), ",ref=", out1_ref[multi_index].item(), ",ck=", out1_ck[multi_index].item()) # ######################## stage 1 end ########### # if WQDType != torch.int4: @@ -316,6 +534,9 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ref.view(token, -1, 128), quant_dtype=AQDType ) a2_scale = a2_scale.view(token, topk, -1) + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]): + a2_qt = out1_ref + a2_scale = None else: a2_qt, a2_scale = torch_quant(out1_ref, quant_dtype=AQDType) a2_qt = a2_qt.view(token, topk, -1) @@ -330,6 +551,7 @@ def weight_per_128x128_quant(weight, quant_dtype): quant_type=qType, w2_scale=w2_scale, a2_scale=a2_scale, + w2_bias=exp_bias2, doweight=not doweight_stage1, ) # # out_ref = torch_moe( @@ -343,103 +565,125 @@ def weight_per_128x128_quant(weight, quant_dtype): # # ) # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - # out2_ck, us = run_perftest( - # ck_moe_stage2, - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale, - # a2_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # ) + out2_ck = torch.empty((token, model_dim), dtype=dtype) - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"[perf] ck_moe_stage2:{us:>8.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) + # # cktil2stage + _, us2 = run_perftest( + cktile_moe_stage2, + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale_aiter, + a2_scale, + exp_bias2_aiter, + dtype, + topk, + npad0, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + # needTrace=True, + # num_iters=2, + # num_warmup=0, + ) + out2_ck = cktile_moe_stage2( + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale_aiter, + a2_scale, + exp_bias2_aiter, + dtype, + topk, + npad0, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + True + ) + + checkAllclose( + out1_ref[:,:-npad0] if need_pad else out1_ref, + out1_ck[:,:-npad0] if need_pad else out1_ck, + msg=f"[stage1:perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) + + checkAllclose( + out2_ref, + out2_ck, + msg=f"[stage2:perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) + # diff = torch.abs(out2_ref - out2_ck) + # max_value= diff.max() + # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) + # print("max_diff", max_value.item(), ",ref=", out2_ref[multi_index].item(), ",ck=", out2_ck[multi_index].item()) # ######################## stage 2 end ########### - # # ######################## fused 2 stage ######### - # out2_ck, us = run_perftest( - # ck_moe_2stages, + # # # ######################## fused 2 stage ######### + # us1=0 + # out2_ck, us2 = run_perftest( + # fused_moe, # input, # w1_qt_aiter, # w2_qt_aiter, # topk_weights, # topk_ids, + # w1_scale=fp4_utils.e8m0_shuffle( + # w1_scale + # ), # e8m0_shuffle will do nothing if it's a fp32 + # w2_scale=fp4_utils.e8m0_shuffle(w2_scale), # quant_type=qType, - # fc1_scale=w1_scale, # [expert(local_expert:EP), inter_dim, 1] - # fc2_scale=w2_scale, # [expert(local_expert:EP), model_dim, 1] - # block_size=BLOCK_SIZE_M, # activation=actType, # doweight_stage1=doweight_stage1, # ) # checkAllclose( # out2_ref, # out2_ck, - # msg=f"ck_moe_2stages:{us:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", + # msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", # ) - if dtype == dtypes.bf16: - out2_aiter, us_fuse = run_perftest( - fused_moe, - input, - w1_qt_aiter, - w2_qt_aiter, - topk_weights, - topk_ids, - w1_scale=fp4_utils.e8m0_shuffle( - w1_scale - ), # e8m0_shuffle will do nothing if it's a fp32 - w2_scale=fp4_utils.e8m0_shuffle(w2_scale), - quant_type=qType, - activation=actType, - doweight_stage1=doweight_stage1, - ) - - err = checkAllclose( - out2_ref, - out2_aiter, - msg=f"aiter_all_stages:{us_fuse:>8.2f} us......", - ) - - return {"us": us_fuse, "err": err} - - + return {"gemm1(us)": us1, "gemm2(us)": us2} +# seed = 1 +# torch.manual_seed(seed) +# torch.cuda.manual_seed_all(seed) l_dtype = ["bf16", "fp16"][:1] -l_dim = [(6144, 4096)] +# l_dim = [(6144, 4096)] +l_dim = [(7168, 256)] l_tokenNum = [ - 1, - 3, - 5, - 16, - 32, - 64, - 128, - 256, - 1024, - 4096, - 163840, + # 1, + # 3, + # 5, + 8, + # 16, + # 32, + # 64, + # 128, + # 256, + # 1024, + # 4096, + # 163840, ] l_quant = [ - (aiter.QuantType.No, None, None), # a16w16 - (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 - (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 - (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 - (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 - (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 + # (aiter.QuantType.No, None, None), # a16w16 + # (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 + # (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 + # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 + # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 + # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 + (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] -l_doweight_stage1 = [False, True] +l_doweight_stage1 = [False, True][:1] parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, @@ -491,7 +735,7 @@ def weight_per_128x128_quant(weight, quant_dtype): 4: aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2 # a4w4 5: aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8, # a8w8""", ) - +torch.cuda.manual_seed_all(1) parser.add_argument( "-a", "--act", From 86990e5716666a5eb900955304dcc5d6af5f967a Mon Sep 17 00:00:00 2001 From: solin Date: Tue, 4 Nov 2025 13:31:34 +0000 Subject: [PATCH 02/20] align the interface of main branch to make cktile moe compile pass --- aiter/fused_moe.py | 68 +++++++++++++----- aiter/jit/optCompilerConfig.json | 6 +- aiter/ops/moe_op.py | 72 +++++++++++++++++++ aiter/ops/shuffle.py | 5 +- .../ck_tile_gemm_moe_2stages/gen_instances.py | 14 ++-- .../{ => include}/moe_cktile2stages.h | 28 ++++---- .../moe_cktile2stages_common.cuh | 0 .../moe_cktile2stages.cu | 37 +++++++++- op_tests/test_moe_2stage.py | 15 +++- 9 files changed, 197 insertions(+), 48 deletions(-) rename csrc/ck_tile_gemm_moe_2stages/{ => include}/moe_cktile2stages.h (70%) rename csrc/ck_tile_gemm_moe_2stages/{ => include}/moe_cktile2stages_common.cuh (100%) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 40b265b539..09ea158e2a 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -44,6 +44,9 @@ def moe_sorting( device = topk_ids.device M, topk = topk_ids.shape max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk + if M * topk <= num_experts: + max_num_tokens_padded = M * topk * block_size + max_num_m_blocks = int((max_num_tokens_padded + block_size - 1) // block_size) sorted_ids = torch.empty((max_num_tokens_padded,), dtype=dtypes.i32, device=device) sorted_weights = torch.empty( @@ -586,7 +589,7 @@ def FinalFunc(): in fused_moe_1stage_dict[get_gfx()] ): if q_type == QuantType.per_1x128: - run_1stage = True and (inter_dim % 256 == 0) + run_1stage = True and (inter_dim % 256 == 0) and (token > 31) elif q_type == QuantType.per_Token and q_dtype_w in [dtypes.i8, dtypes.fp8]: run_1stage = token > 32 elif q_type != QuantType.per_1x32: @@ -595,7 +598,7 @@ def FinalFunc(): BLOCK_SIZE_M if run_1stage else ( - 64 + 16 if q_type == QuantType.per_1x128 else get_block_size_M(token, topk, expert, inter_dim) ) @@ -634,6 +637,8 @@ def FinalFunc(): torch.uint32, dtypes.fp4x2, ] + or (q_dtype_w == dtypes.fp8 and q_type == QuantType.per_1x128) + or (q_type == QuantType.per_1x128 and block_m == 16) ): return MOEMetadata( functools.partial( @@ -971,6 +976,15 @@ def torch_moe( return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype) +#temp workaround for swiglu +def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0): + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + def torch_moe_stage1( hidden_states, @@ -984,6 +998,7 @@ def torch_moe_stage1( # following for quant a1_scale=None, # [token, 1] w1_scale=None, # [expert, inter_dim, 1] + w1_bias=None, # [expert, inter_dim, 1] doweight=False, ): quant_type = quant_remap.get(quant_type, quant_type) @@ -994,11 +1009,14 @@ def torch_moe_stage1( E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: from aiter.utility import fp4_utils - - hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) w1 = fp4_utils.mxfp4_to_f32(w1) w1_scale = fp4_utils.e8m0_to_f32(w1_scale) - a1_scale = fp4_utils.e8m0_to_f32(a1_scale) + if a1_scale is not None: #skip a16w4 + hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) + a1_scale = fp4_utils.e8m0_to_f32(a1_scale) + else: #a16w4 + hidden_states = hidden_states.to(ctype) + else: hidden_states = hidden_states.to(ctype) w1 = w1.to(ctype) @@ -1006,8 +1024,8 @@ def torch_moe_stage1( if quant_type in [QuantType.per_Token, QuantType.per_Tensor]: w1 = w1 * w1_scale.view(w1_scale.shape[0], -1, 1) hidden_states = hidden_states * a1_scale - # per_1x128 - elif quant_type == QuantType.per_1x128: + # per_128x128 + elif quant_type in [QuantType.per_128x128, QuantType.per_1x128]: w1_shape = w1.shape w1 = w1.view( w1.shape[0], w1.shape[1] // 128, 128, w1.shape[2] // 128, 128 @@ -1031,9 +1049,10 @@ def torch_moe_stage1( w1 = w1.view(w1_shape) a1_shape = hidden_states.shape - a1_scale = a1_scale[: a1_shape[0]] hidden_states = hidden_states.view(a1_shape[0], a1_shape[1] // 32, 32) - hidden_states = hidden_states * a1_scale.view(a1_shape[0], a1_shape[1] // 32, 1) + if a1_scale is not None: + a1_scale = a1_scale[: a1_shape[0]] + hidden_states = hidden_states * a1_scale.view(a1_shape[0], a1_shape[1] // 32, 1) hidden_states = hidden_states.view(a1_shape) else: assert False, f"Unsupported quant_type: {quant_type}" @@ -1053,11 +1072,17 @@ def torch_moe_stage1( if doweight: act_input = act_input * topk_weight[mask].view(-1, 1) out[mask] = act_input + if w1_bias is not None: + out[mask] = out[mask] + w1_bias[E_id].view(1, -1) use_g1u1 = w1.shape[1] == (2 * inter_dim) + use_swiglu = (a1_scale is None) and (quant_type == QuantType.per_1x32) torch_act = aiter.get_torch_act(activation) if use_g1u1: gate, up = out.split([inter_dim, inter_dim], dim=-1) - out = torch_act(gate) * up + if use_swiglu: + out = swiglu(gate, up) + else: + out = torch_act(gate) * up else: out = torch_act(out) return out.to(dtype) @@ -1073,6 +1098,7 @@ def torch_moe_stage2( quant_type=QuantType.No, w2_scale=None, # [1] a2_scale=None, # [expert]]' + w2_bias=None, doweight=True, ): quant_type = quant_remap.get(quant_type, quant_type) @@ -1081,10 +1107,13 @@ def torch_moe_stage2( if quant_type == QuantType.per_1x32: from aiter.utility import fp4_utils - hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) w2 = fp4_utils.mxfp4_to_f32(w2) w2_scale = fp4_utils.e8m0_to_f32(w2_scale) - a2_scale = fp4_utils.e8m0_to_f32(a2_scale) + if a2_scale is not None: + hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) + a2_scale = fp4_utils.e8m0_to_f32(a2_scale) + else: #a16w4 + hidden_states = hidden_states.to(ctype) else: hidden_states = hidden_states.to(ctype) w2 = w2.to(ctype) @@ -1095,7 +1124,7 @@ def torch_moe_stage2( if quant_type in [QuantType.per_Token, QuantType.per_Tensor]: hidden_states = hidden_states * a2_scale.view(a2_scale.shape[0], -1, 1) w2 = w2 * w2_scale.view(w2_scale.shape[0], -1, 1) - elif quant_type == QuantType.per_1x128: + elif quant_type in [QuantType.per_128x128, QuantType.per_1x128]: a2_scale = a2_scale.view(hidden_states.shape[0], topk, -1, 1) a2_scale = a2_scale.repeat(1, 1, 1, 128).view(hidden_states.shape[0], topk, -1) hidden_states = hidden_states * a2_scale @@ -1109,11 +1138,12 @@ def torch_moe_stage2( w2 = w2.view(w2_shape) elif quant_type == QuantType.per_1x32: a2_shape = hidden_states.shape - a2_scale = a2_scale[: a2_shape[0] * topk] - a2_scale = a2_scale.view(token_num, topk, inter_dim // 32, 1) - hidden_states = ( - hidden_states.view(token_num, topk, inter_dim // 32, 32) * a2_scale - ) + if a2_scale is not None: + a2_scale = a2_scale[: a2_shape[0] * topk] + a2_scale = a2_scale.view(token_num, topk, inter_dim // 32, 1) + hidden_states = ( + hidden_states.view(token_num, topk, inter_dim // 32, 32) * a2_scale + ) hidden_states = hidden_states.view(a2_shape) w2_shape = w2.shape @@ -1133,6 +1163,8 @@ def torch_moe_stage2( sub_tokens = hidden_states[mask] act_input = sub_tokens @ (w2[E_id].transpose(0, 1)) out[mask] = act_input + if w2_bias is not None: + out[mask] = out[mask] + w2_bias[E_id].view(1, -1) if doweight: out = out * topk_weights.view(token_num, -1, 1) return out.sum(1).to(dtype) diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 2de0314dab..1fb118b592 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -396,8 +396,6 @@ }, "module_moe_cktile2stages": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh'", - "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages.h'", "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu'", "f'{AITER_CSRC_DIR}/pybind/moe_cktile_2stages_pybind.cu'" ], @@ -405,7 +403,9 @@ "flags_extra_hip": [], "md_name": "'module_moe_cktile2stages'", "extra_ldflags": "None", - "extra_include": [], + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_tile_gemm_moe_2stages/include'" + ], "verbose": "False", "is_python_module": "True", "is_standalone": "False", diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index f70438b4e0..23b1b196ef 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -312,6 +312,78 @@ def ck_moe_stage2( activation: int = 0, ) -> None: ... +@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1") +def moe_cktile2stages_gemm1_ck( + XQ : Tensor, + WQ : Tensor, + Y : Tensor, + sorted_ids : Tensor, + sorted_expert_ids : Tensor, + max_token_ids : Tensor, + topk : int, + n_padded_zeros : Optional[int] = 0, + k_padded_zeros : Optional[int] = 0, + topk_weight : Optional[Tensor] = None, + x_scale : Optional[Tensor] = None, + w_scale : Optional[Tensor] = None, + exp_bias : Optional[Tensor] = None, + block_m : Optional[int] = 32, +) -> Tensor: ... + +def moe_cktile2stages_gemm1( + XQ : Tensor, + WQ : Tensor, + Y : Tensor, + sorted_ids : Tensor, + sorted_expert_ids : Tensor, + max_token_ids : Tensor, + topk : int, + n_padded_zeros : Optional[int] = 0, + k_padded_zeros : Optional[int] = 0, + topk_weight : Optional[Tensor] = None, + x_scale : Optional[Tensor] = None, + w_scale : Optional[Tensor] = None, + exp_bias : Optional[Tensor] = None, + block_m : Optional[int] = 32, +): + return moe_cktile2stages_gemm1_ck(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias, block_m) + +@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm2") +def moe_cktile2stages_gemm2_ck( + XQ : Tensor, + WQ : Tensor, + Y : Tensor, + sorted_ids : Tensor, + sorted_expert_ids : Tensor, + max_token_ids : Tensor, + topk : int, + n_padded_zeros : Optional[int] = 0, + k_padded_zeros : Optional[int] = 0, + topk_weight : Optional[Tensor] = None, + x_scale : Optional[Tensor] = None, + w_scale : Optional[Tensor] = None, + exp_bias : Optional[Tensor] = None, + block_m : Optional[int] = 32, +) -> Tensor: ... + +def moe_cktile2stages_gemm2( + XQ : Tensor, + WQ : Tensor, + Y : Tensor, + sorted_ids : Tensor, + sorted_expert_ids : Tensor, + max_token_ids : Tensor, + topk : int, + n_padded_zeros : Optional[int] = 0, + k_padded_zeros : Optional[int] = 0, + topk_weight : Optional[Tensor] = None, + x_scale : Optional[Tensor] = None, + w_scale : Optional[Tensor] = None, + exp_bias : Optional[Tensor] = None, + block_m : Optional[int] = 32, +): + return moe_cktile2stages_gemm2_ck(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias, block_m) + dtype2str_dict = { dtypes.fp16: "f16", diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 6ddde3af0b..10f2c055f4 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -44,6 +44,9 @@ def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch. Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] """ # print("gemm shape:", src.shape) + src_type = src.dtype + if hasattr(torch, "float4_e2m1fn_x2") and src_type == torch.float4_e2m1fn_x2: + src = src.view(torch.uint8) experts_cnt, N, K_pk = src.shape if gate_up: N = N // 2 @@ -59,7 +62,7 @@ def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch. src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) interleaved = src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) # print("interleaved shape:", interleaved.shape) - return interleaved.contiguous() + return interleaved.contiguous().view(src_type) def shuffle_mxfp4_scale(src: torch.Tensor, gate_up: bool) -> torch.Tensor: diff --git a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py index c1373ecae9..03d13d1846 100644 --- a/csrc/ck_tile_gemm_moe_2stages/gen_instances.py +++ b/csrc/ck_tile_gemm_moe_2stages/gen_instances.py @@ -73,12 +73,12 @@ def gen_instance(self, k: kernelInstance): torch::Tensor& sorted_expert_ids, torch::Tensor& max_token_ids, int topk, - std::optional n_padded_zeros = 0, - std::optional k_padded_zeros = 0, - std::optional topk_weight = std::nullopt, - std::optional x_scale = std::nullopt, - std::optional w_scale = std::nullopt, - std::optional exp_bias = std::nullopt) + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias) {{{{ // The smallest kernel we have available. Works well for memory bound shapes. int NumTokens = XQ.size(0); @@ -186,7 +186,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "../impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h similarity index 70% rename from csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h rename to csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h index 7be1db64d6..f431dc1653 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.h +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h @@ -47,13 +47,13 @@ cktile_moe_gemm1(torch::Tensor& XQ, torch::Tensor& sorted_expert_ids, torch::Tensor& max_token_ids, int topk, - std::optional n_padded_zeros = 0, - std::optional k_padded_zeros = 0, - std::optional topk_weight = std::nullopt, - std::optional x_scale = std::nullopt, - std::optional w_scale = std::nullopt, - std::optional exp_bias = std::nullopt, - std::optional block_m = 32); + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale , + std::optional w_scale , + std::optional exp_bias, + std::optional block_m ); __attribute__((visibility("default"))) torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, @@ -63,10 +63,10 @@ cktile_moe_gemm2(torch::Tensor& XQ, torch::Tensor& sorted_expert_ids, torch::Tensor& max_token_ids, int topk, - std::optional n_padded_zeros = 0, - std::optional k_padded_zeros = 0, - std::optional topk_weight = std::nullopt, - std::optional x_scale = std::nullopt, - std::optional w_scale = std::nullopt, - std::optional exp_bias = std::nullopt, - std::optional block_m = 32); \ No newline at end of file + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight , + std::optional x_scale , + std::optional w_scale , + std::optional exp_bias, + std::optional block_m ); \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh similarity index 100% rename from csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.cuh rename to csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu index c6432e9c17..04c35dd8d9 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu @@ -73,6 +73,31 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, std::optional exp_bias, std::optional block_m) { + // // ========== 添加调试打印 ========== + // std::cout << "========== cktile_moe_gemm1 Debug Info ==========" << std::endl; + // std::cout << "XQ dtype: " << XQ.dtype() << std::endl; + // std::cout << "WQ dtype: " << WQ.dtype() << std::endl; + // std::cout << "Y dtype: " << Y.dtype() << std::endl; + // std::cout << "x_scale has_value: " << (x_scale.has_value() ? "true" : "false") << std::endl; + // if (x_scale.has_value()) { + // std::cout << "x_scale dtype: " << x_scale.value().dtype() << std::endl; + // } + // std::cout << "w_scale has_value: " << (w_scale.has_value() ? "true" : "false") << std::endl; + // if (w_scale.has_value()) { + // std::cout << "w_scale dtype: " << w_scale.value().dtype() << std::endl; + // } + // std::cout << "exp_bias has_value: " << (exp_bias.has_value() ? "true" : "false") << std::endl; + // if (exp_bias.has_value()) { + // std::cout << "exp_bias dtype: " << exp_bias.value().dtype() << std::endl; + // } + // std::cout << "topk_weight has_value: " << (topk_weight.has_value() ? "true" : "false") << std::endl; + // if (topk_weight.has_value()) { + // std::cout << "topk_weight dtype: " << topk_weight.value().dtype() << std::endl; + // } + // std::cout << "M=" << sorted_ids.size(0) << ", N=" << WQ.size(1) << ", K=" << XQ.size(-1) << std::endl; + // std::cout << "===============================================" << std::endl; + // // ========== 调试打印结束 ========== + TORCH_CHECK(Y.dtype() == at::ScalarType::BFloat16 || Y.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!"); @@ -106,19 +131,27 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } } - else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == at::ScalarType::Byte)) //a16w4 + else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) //a16w4 { + // std::cout << "DEBUG: Entering A16W4 branch" << std::endl; // if (Y.dtype() == at::ScalarType::Half) // { // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } if (Y.dtype() == at::ScalarType::BFloat16) { + // std::cout << "DEBUG: Calling moe_dispatch with BF16 output" << std::endl; moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias); } } else { + // std::cout << "DEBUG: Falling into unsupported dtype branch!" << std::endl; + // std::cout << "DEBUG: XQ.dtype()=" << XQ.dtype() << std::endl; + // std::cout << "DEBUG: WQ.dtype()=" << WQ.dtype() << std::endl; + // std::cout << "DEBUG: at::ScalarType::BFloat16=" << at::ScalarType::BFloat16 << std::endl; + // std::cout << "DEBUG: at::ScalarType::Half=" << at::ScalarType::Half << std::endl; + // std::cout << "DEBUG: at::ScalarType::Byte=" << torch_fp4x2 << std::endl; TORCH_CHECK(false, "Unsupported scales/output dtype!"); } return Y; @@ -165,7 +198,7 @@ torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } } - else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == at::ScalarType::Byte)) //a16w4 + else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) //a16w4 { // if (Y.dtype() == at::ScalarType::Half) // { diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 02a5d82908..731513c816 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -225,6 +225,9 @@ def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch. Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] """ # print("gemm shape:", src.shape) + src_type = src.dtype + if hasattr(torch, "float4_e2m1fn_x2") and src_type == torch.float4_e2m1fn_x2: + src = src.view(torch.uint8) experts_cnt, N, K_pk = src.shape if gate_up: N = N // 2 @@ -242,7 +245,7 @@ def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch. src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) interleaved = src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) # print("interleaved shape:", interleaved.shape) - return interleaved.contiguous() + return interleaved.contiguous().view(src_type) def shuffle_mxfp4_scale(src: torch.Tensor, experts_cnt: int, gate_up: bool) -> torch.Tensor: n_experts, k_ = src.shape @@ -442,7 +445,10 @@ def weight_per_128x128_quant(weight, quant_dtype): ) # # ######################## ck stage 1 start ########### - out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) + if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: + out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) + else: + out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) # out1_ck, us1 = run_perftest( # ck_moe_stage1, @@ -565,7 +571,10 @@ def weight_per_128x128_quant(weight, quant_dtype): # # ) # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - out2_ck = torch.empty((token, model_dim), dtype=dtype) + if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: + out2_ck = torch.zeros((token, model_dim), dtype=dtype) + else: + out2_ck = torch.empty((token, model_dim), dtype=dtype) # # cktil2stage _, us2 = run_perftest( From 5c6d5d67dc434aee5a708a88f7d08b59249c4208 Mon Sep 17 00:00:00 2001 From: solin Date: Wed, 5 Nov 2025 07:06:19 +0000 Subject: [PATCH 03/20] refine code --- aiter/ops/shuffle.py | 96 ++++++----- .../gemm_moe_ck2stages_common.py | 4 +- .../gen_instances.py | 12 +- .../include/moe_cktile2stages.h | 54 ++++--- .../include/moe_cktile2stages_common.cuh | 8 +- .../moe_cktile2stages.cu | 40 +---- op_tests/test_moe_2stage.py | 150 +++++++----------- 7 files changed, 157 insertions(+), 207 deletions(-) diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 10f2c055f4..656044bd70 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -25,71 +25,87 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te return x_.view(x_type) -def shuffle_weight_NK(x: torch.Tensor, inst_N: int, inst_K: int, use_int4=False) -> torch.Tensor: - kPerLane = inst_K // (64 // inst_N) - if(use_int4): +def shuffle_weight_NK( + x: torch.Tensor, inst_N: int, inst_K: int, use_int4=False +) -> torch.Tensor: + kPerLane = inst_K // (64 // inst_N) + if use_int4: kPerLane *= 2 - assert x.shape[-2] % inst_N == 0, f"{x.shape[-2]} % {inst_N} == {x.shape[-2] % N_WARP_TILE }" - assert x.shape[-1] % inst_K == 0, f"{x.shape[-1]} % {inst_K} == {x.shape[-1] % K_WARP_TILE }" + assert ( + x.shape[-2] % inst_N == 0 + ), f"{x.shape[-2]} % {inst_N} == {x.shape[-2] % N_WARP_TILE }" + assert ( + x.shape[-1] % inst_K == 0 + ), f"{x.shape[-1]} % {inst_K} == {x.shape[-1] % K_WARP_TILE }" x_ = x - x_ = x_.view(-1, x.shape[-2] // inst_N, inst_N, x.shape[-1] // inst_K, 64 // inst_N, kPerLane) + x_ = x_.view( + -1, x.shape[-2] // inst_N, inst_N, x.shape[-1] // inst_K, 64 // inst_N, kPerLane + ) x_ = x_.permute(0, 1, 3, 4, 2, 5).contiguous() return x_.view(*x.shape) def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: - """ - src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 - Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] - """ - # print("gemm shape:", src.shape) - src_type = src.dtype - if hasattr(torch, "float4_e2m1fn_x2") and src_type == torch.float4_e2m1fn_x2: - src = src.view(torch.uint8) - experts_cnt, N, K_pk = src.shape - if gate_up: - N = N // 2 - KPack = 16 - KLane = 64 // NLane #4 - N0 = N // NLane - K0 = K_pk // (KLane * KPack) - if (gate_up): - src_reshaped = src.view(experts_cnt, 2, N0, NLane, K0, KLane, KPack) # [E,2, N0, NLane ,K0, KLane, KPack] - src_reshaped = src_reshaped.permute(0, 2, 1, 4, 5, 3, 6).contiguous() # [E, N0, 2, K0, KLane, NLane, KPack] - interleaved = src_reshaped.view(*src.shape) - else: - src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) - interleaved = src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) - # print("interleaved shape:", interleaved.shape) - return interleaved.contiguous().view(src_type) - + """ + src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 + Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] + """ + # print("gemm shape:", src.shape) + src_type = src.dtype + if hasattr(torch, "float4_e2m1fn_x2") and src_type == torch.float4_e2m1fn_x2: + src = src.view(torch.uint8) + experts_cnt, N, K_pk = src.shape + if gate_up: + N = N // 2 + KPack = 16 + KLane = 64 // NLane # 4 + N0 = N // NLane + K0 = K_pk // (KLane * KPack) + if gate_up: + src_reshaped = src.view( + experts_cnt, 2, N0, NLane, K0, KLane, KPack + ) # [E,2, N0, NLane ,K0, KLane, KPack] + src_reshaped = src_reshaped.permute( + 0, 2, 1, 4, 5, 3, 6 + ).contiguous() # [E, N0, 2, K0, KLane, NLane, KPack] + interleaved = src_reshaped.view(*src.shape) + else: + src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) + interleaved = ( + src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) + ) + # print("interleaved shape:", interleaved.shape) + return interleaved.contiguous().view(src_type) -def shuffle_mxfp4_scale(src: torch.Tensor, gate_up: bool) -> torch.Tensor: - n_experts, n_, k_ = src.shape - # n_ = n_experts // experts_cnt + +def shuffle_mxfp4_scale( + src: torch.Tensor, experts_cnt: int, gate_up: bool +) -> torch.Tensor: + n_experts, k_ = src.shape + n_ = n_experts // experts_cnt # MXFP4 constants K_Pack = 2 N_Pack = 2 N_Lane = 16 K_Lane = 64 // N_Lane # 4 - + # Basic dimensions K1 = k_ // K_Pack // K_Lane # k_ // 8 - N1 = n_ // N_Lane // N_Pack # n_ // 32 - real_k =32 * k_ * K_Pack * K_Lane # 1x32 quant + N1 = n_ // N_Lane // N_Pack # n_ // 32 + real_k = 32 * k_ * K_Pack * K_Lane # 1x32 quant assert real_k >= 256, f"K {real_k} must be larger than Tile_K(256)" # print("src shape", src.shape) # Reshape based on moe_kind if gate_up: # Reshape to: [E, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane] - shfl_scale = src.view(n_experts, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane) + shfl_scale = src.view(experts_cnt, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane) # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] shfl_scale = shfl_scale.permute(0, 2, 4, 6, 3, 5, 1).contiguous() else: # Reshape to: [E, K1, K_Pack, K_Lane, N1, N_Pack, N_Lane] - shfl_scale = src.view(n_experts, N1, N_Pack, N_Lane, K1, K_Pack, K_Lane) + shfl_scale = src.view(experts_cnt, N1, N_Pack, N_Lane, K1, K_Pack, K_Lane) # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] shfl_scale = shfl_scale.permute(0, 1, 4, 6, 3, 5, 2).contiguous() # print("shf_scale shape:", shfl_scale.shape) - return shfl_scale.view((n_experts * n_, k_)).contiguous() + return shfl_scale.view(*src.shape).contiguous() diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index cab1d11e9a..d78c57245f 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -173,8 +173,8 @@ def name(self) -> str: } # gemm1 blockscale out:bf16/fp16 AB:fp8/i8 a8w8_gemm1_blockscale_kernels_list= { - #0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 1,), 0: kernelInstanceGEMM1( 256, 64, 128, 128, 1, 4, 3,), + 1: kernelInstanceGEMM1( 256, 16, 128, 256, 1, 4, 1,), #2: kernelInstanceGEMM1( 256, 128, 128, 128, 1, 4, 3,), } @@ -259,7 +259,7 @@ def name(self) -> str: # gemm2 MXDLPerWave out:bf16/fp16 AB:fp8/i8 a8w8_gemm2_blockscale_kernels_list= { - #0: kernelInstanceGEMM2( 256, 32, 128, 128, 1, 4, 1,), + 0: kernelInstanceGEMM2( 256, 16, 128, 256, 1, 4, 1,), 1: kernelInstanceGEMM2( 256, 64, 128, 128, 1, 4, 3,), #2: kernelInstanceGEMM2( 256, 128, 128, 128, 2, 2, 3,), } diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 8cebae2184..8d6b29c1b4 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -212,7 +212,11 @@ && {MulRoutedWeight} == mul_routed_weight_stage && {Quant} == quant) {{ - if (block_m == 64) + if (block_m == 16) + {{ + return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 16, 128, 256/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 64) {{ return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -391,7 +395,11 @@ && {MulRoutedWeight} == mul_routed_weight_stage && {Quant} == quant) {{ - if (block_m == 64) + if (block_m == 16) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 16, 128, 256/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 64) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h index f431dc1653..df9359d7bf 100644 --- a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h @@ -6,8 +6,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/moe_flatmm.hpp" #include "py_itfs_common.h" // #include @@ -20,24 +20,26 @@ #include #include -using MoeKernel = std::function< - torch::Tensor(torch::Tensor &, torch::Tensor &, - torch::Tensor &, torch::Tensor &, - torch::Tensor &, torch::Tensor &, - int, - std::optional, - std::optional, - std::optional, - std::optional, - std::optional, - std::optional)>; -using ck_stream_config = ck_tile::stream_config; -using row_major = ck_tile::tensor_layout::gemm::RowMajor; -using col_major = ck_tile::tensor_layout::gemm::ColumnMajor; -using bf16 = ck_tile::bf16_t; -using fp16 = ck_tile::half_t; -using fp8 = ck_tile::fp8_t; -using pk_fp4 = ck_tile::pk_fp4_t; +using MoeKernel = std::function, + std::optional, + std::optional, + std::optional, + std::optional, + std::optional)>; +using ck_stream_config = ck_tile::stream_config; +using row_major = ck_tile::tensor_layout::gemm::RowMajor; +using col_major = ck_tile::tensor_layout::gemm::ColumnMajor; +using bf16 = ck_tile::bf16_t; +using fp16 = ck_tile::half_t; +using fp8 = ck_tile::fp8_t; +using pk_fp4 = ck_tile::pk_fp4_t; __attribute__((visibility("default"))) torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, @@ -50,10 +52,10 @@ cktile_moe_gemm1(torch::Tensor& XQ, std::optional n_padded_zeros, std::optional k_padded_zeros, std::optional topk_weight, - std::optional x_scale , - std::optional w_scale , + std::optional x_scale, + std::optional w_scale, std::optional exp_bias, - std::optional block_m ); + std::optional block_m); __attribute__((visibility("default"))) torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, @@ -65,8 +67,8 @@ cktile_moe_gemm2(torch::Tensor& XQ, int topk, std::optional n_padded_zeros, std::optional k_padded_zeros, - std::optional topk_weight , - std::optional x_scale , - std::optional w_scale , + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, std::optional exp_bias, - std::optional block_m ); \ No newline at end of file + std::optional block_m); \ No newline at end of file diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh index c5040f126f..cd8d2724fa 100644 --- a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh @@ -18,9 +18,6 @@ #include #include -// #include -// #include -// #include #include #include #include @@ -287,9 +284,8 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) // } // else // { - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); // } // return ave_time; }; diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu index 04c35dd8d9..5ec8807192 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu @@ -72,33 +72,7 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, std::optional w_scale, std::optional exp_bias, std::optional block_m) -{ - // // ========== 添加调试打印 ========== - // std::cout << "========== cktile_moe_gemm1 Debug Info ==========" << std::endl; - // std::cout << "XQ dtype: " << XQ.dtype() << std::endl; - // std::cout << "WQ dtype: " << WQ.dtype() << std::endl; - // std::cout << "Y dtype: " << Y.dtype() << std::endl; - // std::cout << "x_scale has_value: " << (x_scale.has_value() ? "true" : "false") << std::endl; - // if (x_scale.has_value()) { - // std::cout << "x_scale dtype: " << x_scale.value().dtype() << std::endl; - // } - // std::cout << "w_scale has_value: " << (w_scale.has_value() ? "true" : "false") << std::endl; - // if (w_scale.has_value()) { - // std::cout << "w_scale dtype: " << w_scale.value().dtype() << std::endl; - // } - // std::cout << "exp_bias has_value: " << (exp_bias.has_value() ? "true" : "false") << std::endl; - // if (exp_bias.has_value()) { - // std::cout << "exp_bias dtype: " << exp_bias.value().dtype() << std::endl; - // } - // std::cout << "topk_weight has_value: " << (topk_weight.has_value() ? "true" : "false") << std::endl; - // if (topk_weight.has_value()) { - // std::cout << "topk_weight dtype: " << topk_weight.value().dtype() << std::endl; - // } - // std::cout << "M=" << sorted_ids.size(0) << ", N=" << WQ.size(1) << ", K=" << XQ.size(-1) << std::endl; - // std::cout << "===============================================" << std::endl; - // // ========== 调试打印结束 ========== - - +{ TORCH_CHECK(Y.dtype() == at::ScalarType::BFloat16 || Y.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!"); if (x_scale != std::nullopt && w_scale != std::nullopt){ @@ -110,8 +84,6 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, int K = XQ.size(-1); int MPerBlock = block_m.value(); - // const at::cuda::OptionalCUDAGuard device_guard(device_of(Y)); - // at::cuda::getCurrentCUDAStream().stream(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); at::hip::getCurrentHIPStream(); // if (!XQ || !WQ || !sorted_ids || !sorted_expert_ids || !max_token_ids || !topk_weight) @@ -133,25 +105,17 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, } else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) //a16w4 { - // std::cout << "DEBUG: Entering A16W4 branch" << std::endl; // if (Y.dtype() == at::ScalarType::Half) // { // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } if (Y.dtype() == at::ScalarType::BFloat16) { - // std::cout << "DEBUG: Calling moe_dispatch with BF16 output" << std::endl; moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias); } } else { - // std::cout << "DEBUG: Falling into unsupported dtype branch!" << std::endl; - // std::cout << "DEBUG: XQ.dtype()=" << XQ.dtype() << std::endl; - // std::cout << "DEBUG: WQ.dtype()=" << WQ.dtype() << std::endl; - // std::cout << "DEBUG: at::ScalarType::BFloat16=" << at::ScalarType::BFloat16 << std::endl; - // std::cout << "DEBUG: at::ScalarType::Half=" << at::ScalarType::Half << std::endl; - // std::cout << "DEBUG: at::ScalarType::Byte=" << torch_fp4x2 << std::endl; TORCH_CHECK(false, "Unsupported scales/output dtype!"); } return Y; @@ -177,8 +141,6 @@ torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, int K = XQ.size(-1); int MPerBlock = block_m.value(); - // const at::cuda::OptionalCUDAGuard device_guard(device_of(Y)); - // at::cuda::getCurrentCUDAStream().stream(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); at::hip::getCurrentHIPStream(); // if (!XQ. || !WQ || !sorted_ids || !sorted_expert_ids || !max_token_ids || !topk_weight) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 731513c816..a13c6787c7 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -23,7 +23,12 @@ ) -from aiter.ops.shuffle import shuffle_weight, shuffle_weight_NK +from aiter.ops.shuffle import ( + shuffle_weight, + shuffle_mxfp4_weight, + shuffle_mxfp4_scale, + shuffle_weight_NK, +) from aiter import ActivationType torch.int4 = getattr(torch, "int4", torch.uint32) @@ -121,6 +126,7 @@ def ck_moe_stage2( ) return out + def cktile_moe_stage1( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -133,8 +139,8 @@ def cktile_moe_stage1( exp_bias1, dtype, topk, - n_pad_zeros = 0, - k_pad_zeros = 0, + n_pad_zeros=0, + k_pad_zeros=0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, @@ -143,7 +149,7 @@ def cktile_moe_stage1( token_num = hidden_states.shape[0] _, n1, k1 = w1.shape _, k2, n2 = w2.shape - D = n2 if k2 == k1 else n2*2 #bit4 format + D = n2 if k2 == k1 else n2 * 2 # bit4 format # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size if w1.dtype is torch.uint32: @@ -168,6 +174,7 @@ def cktile_moe_stage1( ) return out + def cktile_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -180,13 +187,13 @@ def cktile_moe_stage2( exp_bias2, dtype, topk, - n_pad_zeros = 0, - k_pad_zeros = 0, + n_pad_zeros=0, + k_pad_zeros=0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, sorted_weights=None, # [max_num_tokens_padded] - zeros_out = False + zeros_out=False, ): token_num = hidden_states.shape[0] D = w2.shape[1] @@ -219,63 +226,6 @@ def cktile_moe_stage2( return out -def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: - """ - src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 - Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] - """ - # print("gemm shape:", src.shape) - src_type = src.dtype - if hasattr(torch, "float4_e2m1fn_x2") and src_type == torch.float4_e2m1fn_x2: - src = src.view(torch.uint8) - experts_cnt, N, K_pk = src.shape - if gate_up: - N = N // 2 - KPack = 16 - KLane = 64 // NLane #4 - N0 = N // NLane - K0 = K_pk // (KLane * KPack) - assert KLane * KPack * K0 == K_pk, f"K({K_pk}) is not a divisble of 64." - assert NLane * N0 == N, f"N({K_pk}) is not a divisble of 16." - if (gate_up): - src_reshaped = src.view(experts_cnt, 2, N0, NLane, K0, KLane, KPack) # [E,2, N0, NLane ,K0, KLane, KPack] - src_reshaped = src_reshaped.permute(0, 2, 1, 4, 5, 3, 6).contiguous() # [E, N0, 2, K0, KLane, NLane, KPack] - interleaved = src_reshaped.view(*src.shape) - else: - src_reshaped = src.view(experts_cnt, N0, NLane, K0, KLane, KPack) - interleaved = src_reshaped.permute(0, 1, 3, 4, 2, 5).contiguous().view(*src.shape) - # print("interleaved shape:", interleaved.shape) - return interleaved.contiguous().view(src_type) - -def shuffle_mxfp4_scale(src: torch.Tensor, experts_cnt: int, gate_up: bool) -> torch.Tensor: - n_experts, k_ = src.shape - n_ = n_experts // experts_cnt - # MXFP4 constants - K_Pack = 2 - N_Pack = 2 - N_Lane = 16 - K_Lane = 64 // N_Lane # 4 - - # Basic dimensions - K1 = k_ // K_Pack // K_Lane # k_ // 8 - N1 = n_ // N_Lane // N_Pack # n_ // 32 - real_k =32 * k_ * K_Pack * K_Lane # 1x32 quant - assert K1 * K_Pack * K_Lane == k_, f"K {k_*32} must be divisible of 256" - # print("src shape", src.shape) - # Reshape based on moe_kind - if gate_up: - # Reshape to: [E, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane] - shfl_scale = src.view(experts_cnt, N_Pack, N1, N_Lane, K1, K_Pack, K_Lane) - # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] - shfl_scale = shfl_scale.permute(0, 2, 4, 6, 3, 5, 1).contiguous() - else: - # Reshape to: [E, K1, K_Pack, K_Lane, N1, N_Pack, N_Lane] - shfl_scale = src.view(experts_cnt, N1, N_Pack, N_Lane, K1, K_Pack, K_Lane) - # Permute to: [E, N1, K1, K_Lane, N_Lane, K_Pack, N_Pack] - shfl_scale = shfl_scale.permute(0, 1, 4, 6, 3, 5, 2).contiguous() - # print("shf_scale shape:", shfl_scale.shape) - return shfl_scale.view(*src.shape).contiguous() - @benchmark() def test_fmoe( dtype, @@ -302,17 +252,17 @@ def test_fmoe( if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) if need_pad: - w1[:,:,-kpad0:] = 0 - w1[:,-npad0:,:] = 0 - w1[:,inter_dim-npad0:inter_dim,:] = 0 + w1[:, :, -kpad0:] = 0 + w1[:, -npad0:, :] = 0 + w1[:, inter_dim - npad0 : inter_dim, :] = 0 exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) if need_pad: - w2[:,:,-kpad0:] = 0 - w2[:,-npad0:,:] = 0 + w2[:, :, -kpad0:] = 0 + w2[:, -npad0:, :] = 0 exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) # rand @@ -387,21 +337,27 @@ def weight_per_128x128_quant(weight, quant_dtype): ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]): #a16w4 + elif qType == aiter.QuantType.per_1x32 and ( + AQDType in [dtypes.bf16, dtypes.fp16] + ): # a16w4 a1_qt = input.to(AQDType) a1_scale = None else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) - #bias dtype convert - if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + # bias dtype convert + if ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 exp_bias1_aiter = exp_bias1.to(dtypes.fp32) exp_bias2_aiter = exp_bias2.to(dtypes.fp32) else: exp_bias1_aiter = exp_bias1 = None exp_bias2_aiter = exp_bias2 = None - #pre-shuffle + # pre-shuffle w1_scale_aiter = w1_scale w2_scale_aiter = w2_scale if WQDType == torch.int4: # int4 w quant @@ -415,12 +371,20 @@ def weight_per_128x128_quant(weight, quant_dtype): shuffle_weight(w2_qt_aiter, (16, 16), use_int4=True) ) ) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 w1_qt_aiter = shuffle_mxfp4_weight(w1_qt_aiter, 16, True) w1_scale_aiter = shuffle_mxfp4_scale(w1_scale, E, True) w2_qt_aiter = shuffle_mxfp4_weight(w2_qt_aiter, 16, False) w2_scale_aiter = shuffle_mxfp4_scale(w2_scale, E, False) - elif WQDType != dtypes.fp4x2 and (get_gfx() in ["gfx950"]) and (qType != aiter.QuantType.per_128x128): + elif ( + WQDType != dtypes.fp4x2 + and (get_gfx() in ["gfx950"]) + and (qType != aiter.QuantType.per_128x128) + ): inst_K = 128 // w1_qt_aiter.element_size() w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) @@ -560,20 +524,20 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_bias=exp_bias2, doweight=not doweight_stage1, ) - # # out_ref = torch_moe( - # # input, - # # w1_qt, - # # w2_qt, - # # topk_weights, - # # topk_ids, - # # fc1_scale=w1_scale, - # # fc2_scale=w2_scale, - # # ) - # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") + # out_ref = torch_moe( + # input, + # w1_qt, + # w2_qt, + # topk_weights, + # topk_ids, + # fc1_scale=w1_scale, + # fc2_scale=w2_scale, + # ) + # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: - out2_ck = torch.zeros((token, model_dim), dtype=dtype) - else: + out2_ck = torch.zeros((token, model_dim), dtype=dtype) + else: out2_ck = torch.empty((token, model_dim), dtype=dtype) # # cktil2stage @@ -618,15 +582,15 @@ def weight_per_128x128_quant(weight, quant_dtype): actType, quant_type, sorted_weights if not doweight_stage1 else None, - True + True, ) checkAllclose( - out1_ref[:,:-npad0] if need_pad else out1_ref, - out1_ck[:,:-npad0] if need_pad else out1_ck, + out1_ref[:, :-npad0] if need_pad else out1_ref, + out1_ck[:, :-npad0] if need_pad else out1_ck, msg=f"[stage1:perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) - + checkAllclose( out2_ref, out2_ck, @@ -638,7 +602,7 @@ def weight_per_128x128_quant(weight, quant_dtype): # print("max_diff", max_value.item(), ",ref=", out2_ref[multi_index].item(), ",ck=", out2_ck[multi_index].item()) # ######################## stage 2 end ########### - # # # ######################## fused 2 stage ######### + # # ######################## fused 2 stage ######### # us1=0 # out2_ck, us2 = run_perftest( # fused_moe, @@ -662,6 +626,8 @@ def weight_per_128x128_quant(weight, quant_dtype): # ) return {"gemm1(us)": us1, "gemm2(us)": us2} + + # seed = 1 # torch.manual_seed(seed) # torch.cuda.manual_seed_all(seed) From 7c673b1bbe6716c706d13067f076db49cf890eee Mon Sep 17 00:00:00 2001 From: Oscar Xu Date: Wed, 5 Nov 2025 03:01:34 -0600 Subject: [PATCH 04/20] refine ck moe --- .../gemm_moe_ck2stages_common_blockscale.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh index c417b72f58..dcd6d096cc 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh @@ -92,10 +92,10 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MNPerXDL, MNPerXDL, - 4, 2, + MXDLPerWave, NXDLPerWave, S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + MXDLPerWave, NXDLPerWave, S<1, K0_M_A, 1, K0_A>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; // clang-format on @@ -245,13 +245,13 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - MPerBlock, 128, 128, + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MNPerXDL, MNPerXDL, - 4, 2, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + MXDLPerWave, NXDLPerWave, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + MXDLPerWave, NXDLPerWave, S<1, K0_M, 1, K0_A>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; From 2637cd991c7d6f81f33bbd086805ed77ea27d31e Mon Sep 17 00:00:00 2001 From: solin Date: Wed, 5 Nov 2025 13:05:46 +0000 Subject: [PATCH 05/20] fix CI build fail about code style --- aiter/fused_moe.py | 16 +- aiter/ops/moe_op.py | 154 ++++---- .../gemm_moe_ck2stages_common.cuh | 301 +++++++++------- .../gemm_moe_ck2stages_common_blockscale.cuh | 331 ++++++++++-------- .../moe_cktile2stages.cu | 237 +++++++------ csrc/include/rocm_ops.hpp | 135 ++++--- csrc/pybind/moe_ck_2stages_pybind.cu | 7 +- csrc/pybind/moe_cktile_2stages_pybind.cu | 7 +- 8 files changed, 660 insertions(+), 528 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 09ea158e2a..e2843f089d 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -976,7 +976,8 @@ def torch_moe( return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype) -#temp workaround for swiglu + +# temp workaround for swiglu def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0): # Clamp the input values x_glu = x_glu.clamp(min=None, max=limit) @@ -1009,12 +1010,13 @@ def torch_moe_stage1( E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: from aiter.utility import fp4_utils + w1 = fp4_utils.mxfp4_to_f32(w1) w1_scale = fp4_utils.e8m0_to_f32(w1_scale) - if a1_scale is not None: #skip a16w4 + if a1_scale is not None: # skip a16w4 hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) a1_scale = fp4_utils.e8m0_to_f32(a1_scale) - else: #a16w4 + else: # a16w4 hidden_states = hidden_states.to(ctype) else: @@ -1052,7 +1054,9 @@ def torch_moe_stage1( hidden_states = hidden_states.view(a1_shape[0], a1_shape[1] // 32, 32) if a1_scale is not None: a1_scale = a1_scale[: a1_shape[0]] - hidden_states = hidden_states * a1_scale.view(a1_shape[0], a1_shape[1] // 32, 1) + hidden_states = hidden_states * a1_scale.view( + a1_shape[0], a1_shape[1] // 32, 1 + ) hidden_states = hidden_states.view(a1_shape) else: assert False, f"Unsupported quant_type: {quant_type}" @@ -1098,7 +1102,7 @@ def torch_moe_stage2( quant_type=QuantType.No, w2_scale=None, # [1] a2_scale=None, # [expert]]' - w2_bias=None, + w2_bias=None, doweight=True, ): quant_type = quant_remap.get(quant_type, quant_type) @@ -1112,7 +1116,7 @@ def torch_moe_stage2( if a2_scale is not None: hidden_states = fp4_utils.mxfp4_to_f32(hidden_states) a2_scale = fp4_utils.e8m0_to_f32(a2_scale) - else: #a16w4 + else: # a16w4 hidden_states = hidden_states.to(ctype) else: hidden_states = hidden_states.to(ctype) diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index 23b1b196ef..07fa8ff94c 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -312,77 +312,111 @@ def ck_moe_stage2( activation: int = 0, ) -> None: ... -@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1") + +@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm1") def moe_cktile2stages_gemm1_ck( - XQ : Tensor, - WQ : Tensor, - Y : Tensor, - sorted_ids : Tensor, - sorted_expert_ids : Tensor, - max_token_ids : Tensor, - topk : int, - n_padded_zeros : Optional[int] = 0, - k_padded_zeros : Optional[int] = 0, - topk_weight : Optional[Tensor] = None, - x_scale : Optional[Tensor] = None, - w_scale : Optional[Tensor] = None, - exp_bias : Optional[Tensor] = None, - block_m : Optional[int] = 32, + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, ) -> Tensor: ... + def moe_cktile2stages_gemm1( - XQ : Tensor, - WQ : Tensor, - Y : Tensor, - sorted_ids : Tensor, - sorted_expert_ids : Tensor, - max_token_ids : Tensor, - topk : int, - n_padded_zeros : Optional[int] = 0, - k_padded_zeros : Optional[int] = 0, - topk_weight : Optional[Tensor] = None, - x_scale : Optional[Tensor] = None, - w_scale : Optional[Tensor] = None, - exp_bias : Optional[Tensor] = None, - block_m : Optional[int] = 32, + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, ): - return moe_cktile2stages_gemm1_ck(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias, block_m) + return moe_cktile2stages_gemm1_ck( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + block_m, + ) -@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm2") + +@compile_ops("module_moe_cktile2stages", fc_name="cktile_moe_gemm2") def moe_cktile2stages_gemm2_ck( - XQ : Tensor, - WQ : Tensor, - Y : Tensor, - sorted_ids : Tensor, - sorted_expert_ids : Tensor, - max_token_ids : Tensor, - topk : int, - n_padded_zeros : Optional[int] = 0, - k_padded_zeros : Optional[int] = 0, - topk_weight : Optional[Tensor] = None, - x_scale : Optional[Tensor] = None, - w_scale : Optional[Tensor] = None, - exp_bias : Optional[Tensor] = None, - block_m : Optional[int] = 32, + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, ) -> Tensor: ... + def moe_cktile2stages_gemm2( - XQ : Tensor, - WQ : Tensor, - Y : Tensor, - sorted_ids : Tensor, - sorted_expert_ids : Tensor, - max_token_ids : Tensor, - topk : int, - n_padded_zeros : Optional[int] = 0, - k_padded_zeros : Optional[int] = 0, - topk_weight : Optional[Tensor] = None, - x_scale : Optional[Tensor] = None, - w_scale : Optional[Tensor] = None, - exp_bias : Optional[Tensor] = None, - block_m : Optional[int] = 32, + XQ: Tensor, + WQ: Tensor, + Y: Tensor, + sorted_ids: Tensor, + sorted_expert_ids: Tensor, + max_token_ids: Tensor, + topk: int, + n_padded_zeros: Optional[int] = 0, + k_padded_zeros: Optional[int] = 0, + topk_weight: Optional[Tensor] = None, + x_scale: Optional[Tensor] = None, + w_scale: Optional[Tensor] = None, + exp_bias: Optional[Tensor] = None, + block_m: Optional[int] = 32, ): - return moe_cktile2stages_gemm2_ck(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias, block_m) + return moe_cktile2stages_gemm2_ck( + XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias, + block_m, + ) dtype2str_dict = { diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh index 4c04cbf614..c1a98e6d74 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh @@ -1,40 +1,43 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "gemm_moe_ck2stages.h" -#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "gemm_moe_ck2stages.h" #include -template < - typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP> -void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, +template +void ck_moe_stage1_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, int topk, - void *&hidden_states, // [m, k], input token - void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void *&sorted_token_ids, // [max_num_tokens_padded] - void *&sorted_expert_ids, // [max_num_m_blocks] - void *&sorted_weights, - void *&num_valid_ids, // [1] - void *&out, // [max_num_tokens_padded, inter_dim] - std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + void*& hidden_states, // [m, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, + void*& num_valid_ids, // [1] + void*& out, // [max_num_tokens_padded, inter_dim] + std::optional w1_scale, // [e, 1, n], gate(up) scale + std::optional a1_scale // [m, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -42,41 +45,44 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; // using AccDataType = F32; using CShuffleDataType = F32; - using DsDataType = ck::Tuple; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using D0Layout = Row; using D1Layout = Col; - using ELayout = Row; + using ELayout = Row; using D2Layout = ELayout; using DsLayout = ck::Tuple; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr ck::index_t MNPerXDL = 16; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 : 128; - static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 1 : NXDLPerWave; + // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 + // : 128; + static constexpr ck::index_t CShuffleMXDLPerWave = + ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = + ck::is_same_v ? 1 : NXDLPerWave; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 16 / sizeof(EDataType); - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -97,45 +103,45 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, !PerTensorQuant, ck::index_t, A0DataType>; // clang-format on - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - constexpr auto I1 = ck::Number<1>{}; + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; static constexpr auto DStride = PerTensorQuant ? I0 : I1; // do GEMM auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - hidden_states, - w1, - std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, - w1_scale.has_value() ? w1_scale.value() : nullptr, - MulRoutedWeight ? sorted_weights : nullptr}, - out, - tokens, - topk, - sorted_size, - N, - K, - StrideA, - StrideB, - std::array{DStride, DStride, I0}, - StrideE, - KBatch, - a_element_op, - b_element_op, - cde_element_op); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + hidden_states, + w1, + std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, + w1_scale.has_value() ? w1_scale.value() : nullptr, + MulRoutedWeight ? sorted_weights : nullptr}, + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + std::array{DStride, DStride, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); - if (!device_op.IsSupportedArgument(argument)) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -145,51 +151,72 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE1_GEMM_DEFINE(BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ - template void ck_moe_stage1_gemm( \ - const hipStream_t &stream, \ - int tokens, int sorted_size, int N, int K, \ - int topk, \ - void *&hidden_states, \ - void *&w1, \ - void *&w2, \ - void *&sorted_token_ids, \ - void *&sorted_expert_ids, \ - void *&sorted_weights, \ - void *&num_valid_ids, \ - void *&out, \ - std::optional w1_scale, \ - std::optional a1_scale); +#define CK_MOE_STAGE1_GEMM_DEFINE( \ + BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage1_gemm(const hipStream_t& stream, \ + int tokens, \ + int sorted_size, \ + int N, \ + int K, \ + int topk, \ + void*& hidden_states, \ + void*& w1, \ + void*& w2, \ + void*& sorted_token_ids, \ + void*& sorted_expert_ids, \ + void*& sorted_weights, \ + void*& num_valid_ids, \ + void*& out, \ + std::optional w1_scale, \ + std::optional a1_scale); -template < - typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP = 0> -void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, +template +void ck_moe_stage2_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, int topk, - void *&inter_states, // [max_num_tokens_padded, k], input token - void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void *&sorted_token_ids, // [max_num_tokens_padded] - void *&sorted_expert_ids, // [max_num_m_blocks] - void *&sorted_weights, // [max_num_tokens_padded] - void *&num_valid_ids, //[1] - void *&out, // [m, out_dim] - std::optional w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + void*& inter_states, // [max_num_tokens_padded, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, // [max_num_tokens_padded] + void*& num_valid_ids, //[1] + void*& out, // [m, out_dim] + std::optional w2_scale, // [e, 1, n], gate(up) scale + std::optional a2_scale // [max_num_tokens_padded, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -197,43 +224,47 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; // using AccDataType = F32; using CShuffleDataType = F32; - using DsDataType = ck::Tuple; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using DsLayout = ck::Tuple; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr ck::index_t BLOCKSIZE = 256; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; - static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 2 : NXDLPerWave; - static constexpr ck::index_t CShuffleNLane = ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 + static constexpr ck::index_t CShuffleMXDLPerWave = + ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = + ck::is_same_v ? 2 : NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = + ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; - static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 2; + static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); + static constexpr ck::index_t BK1 = + ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; static constexpr ck::index_t D2Vec = 1; - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh index dcd6d096cc..e8a6c1283e 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh @@ -1,85 +1,91 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "gemm_moe_ck2stages.h" #include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "gemm_moe_ck2stages.h" #include -template < - typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP> -void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, +template +void ck_moe_stage1_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, int topk, - void *&hidden_states, // [m, k], input token - void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void *&sorted_token_ids, // [max_num_tokens_padded] - void *&sorted_expert_ids, // [max_num_m_blocks] - void *&sorted_weights, - void *&num_valid_ids, // [1] - void *&out, // [max_num_tokens_padded, inter_dim] - std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + void*& hidden_states, // [m, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, + void*& num_valid_ids, // [1] + void*& out, // [max_num_tokens_padded, inter_dim] + std::optional w1_scale, // [e, 1, n], gate(up) scale + std::optional a1_scale // [m, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things - using A1DataType = F32; - using B1DataType = F32; + using A1DataType = F32; + using B1DataType = F32; using CShuffleDataType = F32; - using D2DataType = F32; - using DsDataType = ck::Tuple; + using D2DataType = F32; + using DsDataType = ck::Tuple; ck::index_t StrideA = K; ck::index_t StrideB = K; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; using DsLayout = ck::Tuple; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0}; + constexpr auto StrideDs = std::array{0}; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr ck::index_t MNPerXDL = 16; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 : 128; - static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 1 : NXDLPerWave; + // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 + // : 128; + static constexpr ck::index_t CShuffleMXDLPerWave = + ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = + ck::is_same_v ? 1 : NXDLPerWave; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 16 / sizeof(EDataType); - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; @@ -100,40 +106,40 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, // clang-format on - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; // do GEMM - auto device_op = DeviceOpInstance{}; - const void *a1_scale_ptr = *a1_scale; - const void *w1_scale_ptr = *w1_scale; + auto device_op = DeviceOpInstance{}; + const void* a1_scale_ptr = *a1_scale; + const void* w1_scale_ptr = *w1_scale; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - hidden_states, - w1, - std::array{MulRoutedWeight ? sorted_weights : nullptr}, - out, - tokens, - topk, - sorted_size, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - a1_scale_ptr, - w1_scale_ptr, - KBatch, - a_element_op, - b_element_op, - cde_element_op); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + hidden_states, + w1, + std::array{MulRoutedWeight ? sorted_weights : nullptr}, + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_scale_ptr, + w1_scale_ptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op); - if (!device_op.IsSupportedArgument(argument)) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -143,99 +149,124 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE1_GEMM_DEFINE(BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ - template void ck_moe_stage1_gemm( \ - const hipStream_t &stream, \ - int tokens, int sorted_size, int N, int K, \ - int topk, \ - void *&hidden_states, \ - void *&w1, \ - void *&w2, \ - void *&sorted_token_ids, \ - void *&sorted_expert_ids, \ - void *&sorted_weights, \ - void *&num_valid_ids, \ - void *&out, \ - std::optional w1_scale, \ - std::optional a1_scale); +#define CK_MOE_STAGE1_GEMM_DEFINE( \ + BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage1_gemm(const hipStream_t& stream, \ + int tokens, \ + int sorted_size, \ + int N, \ + int K, \ + int topk, \ + void*& hidden_states, \ + void*& w1, \ + void*& w2, \ + void*& sorted_token_ids, \ + void*& sorted_expert_ids, \ + void*& sorted_weights, \ + void*& num_valid_ids, \ + void*& out, \ + std::optional w1_scale, \ + std::optional a1_scale); -template < - typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP = 0> -void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, +template +void ck_moe_stage2_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, int topk, - void *&inter_states, // [max_num_tokens_padded, k], input token - void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void *&sorted_token_ids, // [max_num_tokens_padded] - void *&sorted_expert_ids, // [max_num_m_blocks] - void *&sorted_weights, // [max_num_tokens_padded] - void *&num_valid_ids, //[1] - void *&out, // [m, out_dim] - std::optional w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + void*& inter_states, // [max_num_tokens_padded, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, // [max_num_tokens_padded] + void*& num_valid_ids, //[1] + void*& out, // [m, out_dim] + std::optional w2_scale, // [e, 1, n], gate(up) scale + std::optional a2_scale // [max_num_tokens_padded, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things - using A1DataType = F32; // input scale - using B1DataType = F32; // input scale + using A1DataType = F32; // input scale + using B1DataType = F32; // input scale ck::index_t StrideA = K; ck::index_t StrideB = K; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; using CShuffleDataType = F32; - using D2DataType = F32; - using DsDataType = ck::Tuple; + using D2DataType = F32; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; using DsLayout = ck::Tuple; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0}; + constexpr auto StrideDs = std::array{0}; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr ck::index_t BLOCKSIZE = 256; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; - static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 2 : NXDLPerWave; - static constexpr ck::index_t CShuffleNLane = ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; + static constexpr ck::index_t CShuffleMXDLPerWave = + ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = + ck::is_same_v ? 2 : NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = + ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; - static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 2; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); + static constexpr ck::index_t BK1 = + ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 2; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu index 5ec8807192..73674ed146 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages.cu @@ -3,85 +3,92 @@ #include "moe_cktile2stages_common.cuh" #include "moe_cktile2stages_lookup.h" #include "moe_cktile2stages_manifest.h" +#include "py_itfs_common.h" #include "moe_cktile2stages_heuristic_dispatch.h" #include -#include "py_itfs_common.h" -template +template MoeKernel moe_dispatch(int M, int N, int K, int block_m) { - // For a given shape, either find the best kernel via lookup or heuristic. - // For many small M shapes, we bucket them to the next largest kernel. - // This is fine since kernels are padded anyway. + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. - // static const auto lookup = [&] - // { - // return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(ABDataType, AccDataType, CDataType)}; - // }(); + // static const auto lookup = [&] + // { + // return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(ABDataType, AccDataType, CDataType)}; + // }(); - // // First check if this shape(M,N,K) is available in the direct lookup. - // auto it = lookup.find({M, N, K}); - // // If we found an optimal kernel, use it. - // if (it != lookup.end()) - // { - // return it->second; - // } + // // First check if this shape(M,N,K) is available in the direct lookup. + // auto it = lookup.find({M, N, K}); + // // If we found an optimal kernel, use it. + // if (it != lookup.end()) + // { + // return it->second; + // } - // int padded_m = M; - // if (M > 1 && M <= 16) - // { - // padded_m = 16; - // } - // else if (M <= 16384) - // { - // padded_m = nextPow2(M); - // } - // else if (M <= 20480) - // { - // padded_m = 20480; - // } - // // Second check if this shape(padded_m,N,K) is available in the direct lookup. - // it = lookup.find({padded_m, N, K}); - // // If we found an optimal kernel, use it. - // if (it != lookup.end()) - // { - // return it->second; - // } - // Otherwise, use heuristics. - if (stage == 1){ - return moe_gemm1_heuristic_dispatch(M, N, K, block_m); - } - else{ - return moe_gemm2_heuristic_dispatch(M, N, K, block_m); - } + // int padded_m = M; + // if (M > 1 && M <= 16) + // { + // padded_m = 16; + // } + // else if (M <= 16384) + // { + // padded_m = nextPow2(M); + // } + // else if (M <= 20480) + // { + // padded_m = 20480; + // } + // // Second check if this shape(padded_m,N,K) is available in the direct lookup. + // it = lookup.find({padded_m, N, K}); + // // If we found an optimal kernel, use it. + // if (it != lookup.end()) + // { + // return it->second; + // } + // Otherwise, use heuristics. + if(stage == 1) + { + return moe_gemm1_heuristic_dispatch( + M, N, K, block_m); + } + else + { + return moe_gemm2_heuristic_dispatch( + M, N, K, block_m); + } } - - torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, - torch::Tensor& WQ, - torch::Tensor& Y, - torch::Tensor& sorted_ids, - torch::Tensor& sorted_expert_ids, - torch::Tensor& max_token_ids, - int topk, - std::optional n_padded_zeros, - std::optional k_padded_zeros, - std::optional topk_weight, - std::optional x_scale, - std::optional w_scale, - std::optional exp_bias, - std::optional block_m) -{ + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m) +{ TORCH_CHECK(Y.dtype() == at::ScalarType::BFloat16 || Y.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!"); - if (x_scale != std::nullopt && w_scale != std::nullopt){ + if(x_scale != std::nullopt && w_scale != std::nullopt) + { TORCH_CHECK(x_scale.value().dtype() == w_scale.value().dtype(), "Scales should have the same dtype!"); } - int M = sorted_ids.size(0); - int N = WQ.size(1); - int K = XQ.size(-1); + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); int MPerBlock = block_m.value(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); @@ -92,26 +99,42 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, // return; // } - if (XQ.dtype() == torch_fp8) + if(XQ.dtype() == torch_fp8) { - // if (Y.dtype() == at::ScalarType::Half) - // { - // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); - // } + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } // if (Y.dtype() == at::ScalarType::BFloat16) // { - // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } } - else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) //a16w4 + else if((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && + (WQ.dtype() == torch_fp4x2)) // a16w4 { // if (Y.dtype() == at::ScalarType::Half) // { - // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } - if (Y.dtype() == at::ScalarType::BFloat16) + if(Y.dtype() == at::ScalarType::BFloat16) { - moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias); + moe_dispatch(M, N, K, MPerBlock)(XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias); } } else @@ -122,23 +145,23 @@ torch::Tensor cktile_moe_gemm1(torch::Tensor& XQ, } torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, - torch::Tensor& WQ, - torch::Tensor& Y, - torch::Tensor& sorted_ids, - torch::Tensor& sorted_expert_ids, - torch::Tensor& max_token_ids, - int topk, - std::optional n_padded_zeros, - std::optional k_padded_zeros, - std::optional topk_weight, - std::optional x_scale, - std::optional w_scale, - std::optional exp_bias, - std::optional block_m) + torch::Tensor& WQ, + torch::Tensor& Y, + torch::Tensor& sorted_ids, + torch::Tensor& sorted_expert_ids, + torch::Tensor& max_token_ids, + int topk, + std::optional n_padded_zeros, + std::optional k_padded_zeros, + std::optional topk_weight, + std::optional x_scale, + std::optional w_scale, + std::optional exp_bias, + std::optional block_m) { - int M = sorted_ids.size(0); - int N = WQ.size(1); - int K = XQ.size(-1); + int M = sorted_ids.size(0); + int N = WQ.size(1); + int K = XQ.size(-1); int MPerBlock = block_m.value(); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Y)); @@ -149,26 +172,42 @@ torch::Tensor cktile_moe_gemm2(torch::Tensor& XQ, // return; // } - if (XQ.dtype() == torch_fp8) + if(XQ.dtype() == torch_fp8) { - // if (Y.dtype() == at::ScalarType::Half) - // { - // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); - // } + // if (Y.dtype() == at::ScalarType::Half) + // { + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // } // if (Y.dtype() == at::ScalarType::BFloat16) // { - // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } } - else if ((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && (WQ.dtype() == torch_fp4x2)) //a16w4 + else if((XQ.dtype() == at::ScalarType::BFloat16 || XQ.dtype() == at::ScalarType::Half) && + (WQ.dtype() == torch_fp4x2)) // a16w4 { // if (Y.dtype() == at::ScalarType::Half) // { - // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); + // moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, + // sorted_expert_ids, max_token_ids, topk, topk_weight, x_scale, w_scale, exp_bias); // } - if (Y.dtype() == at::ScalarType::BFloat16) + if(Y.dtype() == at::ScalarType::BFloat16) { - moe_dispatch(M, N, K, MPerBlock)(XQ, WQ, Y, sorted_ids, sorted_expert_ids, max_token_ids, topk, n_padded_zeros, k_padded_zeros, topk_weight, x_scale, w_scale, exp_bias); + moe_dispatch(M, N, K, MPerBlock)(XQ, + WQ, + Y, + sorted_ids, + sorted_expert_ids, + max_token_ids, + topk, + n_padded_zeros, + k_padded_zeros, + topk_weight, + x_scale, + w_scale, + exp_bias); } } else diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 94bb0b0434..45a1b441a1 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -732,43 +732,42 @@ namespace py = pybind11; py::arg("quant_type") = 0, \ py::arg("activation") = 0); -#define MOE_CKTILE_2STAGES_PYBIND \ - m.def("cktile_moe_gemm1", \ - &cktile_moe_gemm1, \ - "cktile_moe_gemm1", \ - py::arg("XQ"), \ - py::arg("WQ"), \ - py::arg("Y"), \ - py::arg("sorted_ids"), \ - py::arg("sorted_expert_ids"), \ - py::arg("max_token_ids"), \ - py::arg("topk"), \ - py::arg("n_padded_zeros") = 0, \ - py::arg("k_padded_zeros") = 0, \ - py::arg("topk_weight") = std::nullopt, \ - py::arg("x_scale") = std::nullopt, \ - py::arg("w_scale") = std::nullopt, \ - py::arg("exp_bias") = std::nullopt, \ - py::arg("block_m") = 32); \ - \ - \ - m.def("cktile_moe_gemm2", \ - &cktile_moe_gemm2, \ - "cktile_moe_gemm2", \ - py::arg("XQ"), \ - py::arg("WQ"), \ - py::arg("Y"), \ - py::arg("sorted_ids"), \ - py::arg("sorted_expert_ids"), \ - py::arg("max_token_ids"), \ - py::arg("topk"), \ - py::arg("n_padded_zeros") = 0, \ - py::arg("k_padded_zeros") = 0, \ - py::arg("topk_weight") = std::nullopt, \ - py::arg("x_scale") = std::nullopt, \ - py::arg("w_scale") = std::nullopt, \ - py::arg("exp_bias") = std::nullopt, \ - py::arg("block_m") = 32); +#define MOE_CKTILE_2STAGES_PYBIND \ + m.def("cktile_moe_gemm1", \ + &cktile_moe_gemm1, \ + "cktile_moe_gemm1", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("Y"), \ + py::arg("sorted_ids"), \ + py::arg("sorted_expert_ids"), \ + py::arg("max_token_ids"), \ + py::arg("topk"), \ + py::arg("n_padded_zeros") = 0, \ + py::arg("k_padded_zeros") = 0, \ + py::arg("topk_weight") = std::nullopt, \ + py::arg("x_scale") = std::nullopt, \ + py::arg("w_scale") = std::nullopt, \ + py::arg("exp_bias") = std::nullopt, \ + py::arg("block_m") = 32); \ + \ + m.def("cktile_moe_gemm2", \ + &cktile_moe_gemm2, \ + "cktile_moe_gemm2", \ + py::arg("XQ"), \ + py::arg("WQ"), \ + py::arg("Y"), \ + py::arg("sorted_ids"), \ + py::arg("sorted_expert_ids"), \ + py::arg("max_token_ids"), \ + py::arg("topk"), \ + py::arg("n_padded_zeros") = 0, \ + py::arg("k_padded_zeros") = 0, \ + py::arg("topk_weight") = std::nullopt, \ + py::arg("x_scale") = std::nullopt, \ + py::arg("w_scale") = std::nullopt, \ + py::arg("exp_bias") = std::nullopt, \ + py::arg("block_m") = 32); #define MHA_VARLEN_FWD_PYBIND \ m.def("mha_varlen_fwd", \ @@ -1347,36 +1346,36 @@ namespace py = pybind11; py::arg("stride0"), \ py::arg("stride1")); -#define MLA_METADATA_PYBIND \ - m.def("get_mla_metadata_v1", \ - &get_mla_metadata_v1, \ - "get_mla_metadata_v1", \ - py::arg("seqlens_qo_indptr"), \ - py::arg("seqlens_kv_indptr"), \ - py::arg("num_heads_per_head_k"), \ - py::arg("num_heads_k"), \ - py::arg("is_causal"), \ - py::arg("work_metadata_ptrs"), \ - py::arg("work_info_set"), \ - py::arg("work_indptr"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("kv_granularity") = 16, \ - py::arg("max_seqlen_qo") = -1, \ - py::arg("uni_seqlen_qo") = -1, \ - py::arg("fast_mode") = true, \ - py::arg("topk") = -1); \ +#define MLA_METADATA_PYBIND \ + m.def("get_mla_metadata_v1", \ + &get_mla_metadata_v1, \ + "get_mla_metadata_v1", \ + py::arg("seqlens_qo_indptr"), \ + py::arg("seqlens_kv_indptr"), \ + py::arg("num_heads_per_head_k"), \ + py::arg("num_heads_k"), \ + py::arg("is_causal"), \ + py::arg("work_metadata_ptrs"), \ + py::arg("work_info_set"), \ + py::arg("work_indptr"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("kv_granularity") = 16, \ + py::arg("max_seqlen_qo") = -1, \ + py::arg("uni_seqlen_qo") = -1, \ + py::arg("fast_mode") = true, \ + py::arg("topk") = -1); \ m.def("get_mla_metadata_v1_no_redundant", &get_mla_metadata_v1_no_redundant); -#define MLA_REDUCE_PYBIND \ - m.def("mla_reduce_v1", \ - &mla_reduce_v1, \ - "mla_reduce_v1", \ - py::arg("partial_output"), \ - py::arg("partial_lse"), \ - py::arg("reduce_indptr"), \ - py::arg("reduce_final_map"), \ - py::arg("reduce_partial_map"), \ - py::arg("final_output"), \ - py::arg("final_lse") = std::nullopt); +#define MLA_REDUCE_PYBIND \ + m.def("mla_reduce_v1", \ + &mla_reduce_v1, \ + "mla_reduce_v1", \ + py::arg("partial_output"), \ + py::arg("partial_lse"), \ + py::arg("reduce_indptr"), \ + py::arg("reduce_final_map"), \ + py::arg("reduce_partial_map"), \ + py::arg("final_output"), \ + py::arg("final_lse") = std::nullopt); diff --git a/csrc/pybind/moe_ck_2stages_pybind.cu b/csrc/pybind/moe_ck_2stages_pybind.cu index 6b237b1898..e720771df2 100644 --- a/csrc/pybind/moe_ck_2stages_pybind.cu +++ b/csrc/pybind/moe_ck_2stages_pybind.cu @@ -1,9 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "rocm_ops.hpp" #include "moe_ck.h" +#include "rocm_ops.hpp" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - MOE_CK_2STAGES_PYBIND; -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { MOE_CK_2STAGES_PYBIND; } diff --git a/csrc/pybind/moe_cktile_2stages_pybind.cu b/csrc/pybind/moe_cktile_2stages_pybind.cu index 35bc1ebd04..82947422ce 100644 --- a/csrc/pybind/moe_cktile_2stages_pybind.cu +++ b/csrc/pybind/moe_cktile_2stages_pybind.cu @@ -1,9 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "rocm_ops.hpp" #include "moe_cktile2stages.h" +#include "rocm_ops.hpp" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - MOE_CKTILE_2STAGES_PYBIND; -} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { MOE_CKTILE_2STAGES_PYBIND; } From ad4f177a1642d776de2ab77cf1a1f3b8620f2e3c Mon Sep 17 00:00:00 2001 From: solin Date: Wed, 5 Nov 2025 14:10:26 +0000 Subject: [PATCH 06/20] remove ck blockscale moe modification --- .../gemm_moe_ck2stages_common.cuh | 201 ++++++-------- .../gemm_moe_ck2stages_common.py | 4 +- .../gemm_moe_ck2stages_common_blockscale.cuh | 245 ++++++++---------- .../gen_instances.py | 12 +- 4 files changed, 196 insertions(+), 266 deletions(-) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh index c1a98e6d74..a130b8c5bd 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh @@ -1,12 +1,13 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" #include "gemm_moe_ck2stages.h" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include -template -void ck_moe_stage1_gemm(const hipStream_t& stream, - int tokens, - int sorted_size, - int N, - int K, +void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, - void*& hidden_states, // [m, k], input token - void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void*& sorted_token_ids, // [max_num_tokens_padded] - void*& sorted_expert_ids, // [max_num_m_blocks] - void*& sorted_weights, - void*& num_valid_ids, // [1] - void*& out, // [max_num_tokens_padded, inter_dim] - std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + void *&hidden_states, // [m, k], input token + void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void *&sorted_token_ids, // [max_num_tokens_padded] + void *&sorted_expert_ids, // [max_num_m_blocks] + void *&sorted_weights, + void *&num_valid_ids, // [1] + void *&out, // [max_num_tokens_padded, inter_dim] + std::optional w1_scale, // [e, 1, n], gate(up) scale + std::optional a1_scale // [m, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -45,44 +42,41 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; // using AccDataType = F32; using CShuffleDataType = F32; - using DsDataType = ck::Tuple; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using D0Layout = Row; using D1Layout = Col; - using ELayout = Row; + using ELayout = Row; using D2Layout = ELayout; using DsLayout = ck::Tuple; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr ck::index_t MNPerXDL = 16; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 - // : 128; - static constexpr ck::index_t CShuffleMXDLPerWave = - ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = - ck::is_same_v ? 1 : NXDLPerWave; + // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 : 128; + static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 1 : NXDLPerWave; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 16 / sizeof(EDataType); - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off @@ -103,27 +97,27 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, !PerTensorQuant, ck::index_t, A0DataType>; // clang-format on - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - constexpr auto I1 = ck::Number<1>{}; + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; static constexpr auto DStride = PerTensorQuant ? I0 : I1; // do GEMM auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument( - sorted_token_ids, + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids, sorted_expert_ids, num_valid_ids, hidden_states, w1, - std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, + std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, w1_scale.has_value() ? w1_scale.value() : nullptr, MulRoutedWeight ? sorted_weights : nullptr}, out, @@ -141,7 +135,7 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument)) + if (!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -151,41 +145,24 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE1_GEMM_DEFINE( \ - BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ - template void ck_moe_stage1_gemm(const hipStream_t& stream, \ - int tokens, \ - int sorted_size, \ - int N, \ - int K, \ - int topk, \ - void*& hidden_states, \ - void*& w1, \ - void*& w2, \ - void*& sorted_token_ids, \ - void*& sorted_expert_ids, \ - void*& sorted_weights, \ - void*& num_valid_ids, \ - void*& out, \ - std::optional w1_scale, \ - std::optional a1_scale); +#define CK_MOE_STAGE1_GEMM_DEFINE(BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage1_gemm( \ + const hipStream_t &stream, \ + int tokens, int sorted_size, int N, int K, \ + int topk, \ + void *&hidden_states, \ + void *&w1, \ + void *&w2, \ + void *&sorted_token_ids, \ + void *&sorted_expert_ids, \ + void *&sorted_weights, \ + void *&num_valid_ids, \ + void *&out, \ + std::optional w1_scale, \ + std::optional a1_scale); -template -void ck_moe_stage2_gemm(const hipStream_t& stream, - int tokens, - int sorted_size, - int N, - int K, +void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, - void*& inter_states, // [max_num_tokens_padded, k], input token - void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void*& sorted_token_ids, // [max_num_tokens_padded] - void*& sorted_expert_ids, // [max_num_m_blocks] - void*& sorted_weights, // [max_num_tokens_padded] - void*& num_valid_ids, //[1] - void*& out, // [m, out_dim] - std::optional w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + void *&inter_states, // [max_num_tokens_padded, k], input token + void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void *&sorted_token_ids, // [max_num_tokens_padded] + void *&sorted_expert_ids, // [max_num_m_blocks] + void *&sorted_weights, // [max_num_tokens_padded] + void *&num_valid_ids, //[1] + void *&out, // [m, out_dim] + std::optional w2_scale, // [e, 1, n], gate(up) scale + std::optional a2_scale // [max_num_tokens_padded, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things @@ -224,47 +197,43 @@ void ck_moe_stage2_gemm(const hipStream_t& stream, ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; // using AccDataType = F32; using CShuffleDataType = F32; - using DsDataType = ck::Tuple; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using DsLayout = ck::Tuple; using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr ck::index_t BLOCKSIZE = 256; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; - static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - static constexpr ck::index_t CShuffleMXDLPerWave = - ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = - ck::is_same_v ? 2 : NXDLPerWave; - static constexpr ck::index_t CShuffleNLane = - ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 + static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 2 : NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; - static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = - ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 2; + static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); + static constexpr ck::index_t BK1 = ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; static constexpr ck::index_t D2Vec = 1; - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index d78c57245f..cab1d11e9a 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -173,8 +173,8 @@ def name(self) -> str: } # gemm1 blockscale out:bf16/fp16 AB:fp8/i8 a8w8_gemm1_blockscale_kernels_list= { + #0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 1,), 0: kernelInstanceGEMM1( 256, 64, 128, 128, 1, 4, 3,), - 1: kernelInstanceGEMM1( 256, 16, 128, 256, 1, 4, 1,), #2: kernelInstanceGEMM1( 256, 128, 128, 128, 1, 4, 3,), } @@ -259,7 +259,7 @@ def name(self) -> str: # gemm2 MXDLPerWave out:bf16/fp16 AB:fp8/i8 a8w8_gemm2_blockscale_kernels_list= { - 0: kernelInstanceGEMM2( 256, 16, 128, 256, 1, 4, 1,), + #0: kernelInstanceGEMM2( 256, 32, 128, 128, 1, 4, 1,), 1: kernelInstanceGEMM2( 256, 64, 128, 128, 1, 4, 3,), #2: kernelInstanceGEMM2( 256, 128, 128, 128, 2, 2, 3,), } diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh index e8a6c1283e..8bf72721ef 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" #include "gemm_moe_ck2stages.h" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" #include -template -void ck_moe_stage1_gemm(const hipStream_t& stream, - int tokens, - int sorted_size, - int N, - int K, +void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, - void*& hidden_states, // [m, k], input token - void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void*& sorted_token_ids, // [max_num_tokens_padded] - void*& sorted_expert_ids, // [max_num_m_blocks] - void*& sorted_weights, - void*& num_valid_ids, // [1] - void*& out, // [max_num_tokens_padded, inter_dim] - std::optional w1_scale, // [e, 1, n], gate(up) scale - std::optional a1_scale // [m, 1], token scale + void *&hidden_states, // [m, k], input token + void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void *&sorted_token_ids, // [max_num_tokens_padded] + void *&sorted_expert_ids, // [max_num_m_blocks] + void *&sorted_weights, + void *&num_valid_ids, // [1] + void *&out, // [max_num_tokens_padded, inter_dim] + std::optional w1_scale, // [e, 1, n], gate(up) scale + std::optional a1_scale // [m, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things - using A1DataType = F32; - using B1DataType = F32; + using A1DataType = F32; + using B1DataType = F32; using CShuffleDataType = F32; - using D2DataType = F32; - using DsDataType = ck::Tuple; + using D2DataType = F32; + using DsDataType = ck::Tuple; ck::index_t StrideA = K; ck::index_t StrideB = K; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; using DsLayout = ck::Tuple; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0}; + constexpr auto StrideDs = std::array{0}; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr ck::index_t MNPerXDL = 16; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 - // : 128; - static constexpr ck::index_t CShuffleMXDLPerWave = - ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = - ck::is_same_v ? 1 : NXDLPerWave; + // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 : 128; + static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 1 : NXDLPerWave; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 16 / sizeof(EDataType); - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; @@ -98,30 +92,30 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, + 4, 2, S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - MXDLPerWave, NXDLPerWave, S<1, K0_M_A, 1, K0_A>, S<2, 1, 1, 1>, + 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; // clang-format on - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; // do GEMM - auto device_op = DeviceOpInstance{}; - const void* a1_scale_ptr = *a1_scale; - const void* w1_scale_ptr = *w1_scale; + auto device_op = DeviceOpInstance{}; + const void *a1_scale_ptr = *a1_scale; + const void *w1_scale_ptr = *w1_scale; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument( - sorted_token_ids, + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids, sorted_expert_ids, num_valid_ids, hidden_states, w1, - std::array{MulRoutedWeight ? sorted_weights : nullptr}, + std::array{MulRoutedWeight ? sorted_weights : nullptr}, out, tokens, topk, @@ -139,7 +133,7 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument)) + if (!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " @@ -149,41 +143,24 @@ void ck_moe_stage1_gemm(const hipStream_t& stream, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE1_GEMM_DEFINE( \ - BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ - template void ck_moe_stage1_gemm(const hipStream_t& stream, \ - int tokens, \ - int sorted_size, \ - int N, \ - int K, \ - int topk, \ - void*& hidden_states, \ - void*& w1, \ - void*& w2, \ - void*& sorted_token_ids, \ - void*& sorted_expert_ids, \ - void*& sorted_weights, \ - void*& num_valid_ids, \ - void*& out, \ - std::optional w1_scale, \ - std::optional a1_scale); +#define CK_MOE_STAGE1_GEMM_DEFINE(BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage1_gemm( \ + const hipStream_t &stream, \ + int tokens, int sorted_size, int N, int K, \ + int topk, \ + void *&hidden_states, \ + void *&w1, \ + void *&w2, \ + void *&sorted_token_ids, \ + void *&sorted_expert_ids, \ + void *&sorted_weights, \ + void *&num_valid_ids, \ + void *&out, \ + std::optional w1_scale, \ + std::optional a1_scale); -template -void ck_moe_stage2_gemm(const hipStream_t& stream, - int tokens, - int sorted_size, - int N, - int K, +void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, - void*& inter_states, // [max_num_tokens_padded, k], input token - void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) - void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) - void*& sorted_token_ids, // [max_num_tokens_padded] - void*& sorted_expert_ids, // [max_num_m_blocks] - void*& sorted_weights, // [max_num_tokens_padded] - void*& num_valid_ids, //[1] - void*& out, // [m, out_dim] - std::optional w2_scale, // [e, 1, n], gate(up) scale - std::optional a2_scale // [max_num_tokens_padded, 1], token scale + void *&inter_states, // [max_num_tokens_padded, k], input token + void *&w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void *&w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void *&sorted_token_ids, // [max_num_tokens_padded] + void *&sorted_expert_ids, // [max_num_m_blocks] + void *&sorted_weights, // [max_num_tokens_padded] + void *&num_valid_ids, //[1] + void *&out, // [m, out_dim] + std::optional w2_scale, // [e, 1, n], gate(up) scale + std::optional a2_scale // [max_num_tokens_padded, 1], token scale ) { // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things - using A1DataType = F32; // input scale - using B1DataType = F32; // input scale + using A1DataType = F32; // input scale + using B1DataType = F32; // input scale ck::index_t StrideA = K; ck::index_t StrideB = K; ck::index_t StrideE = N; - ck::index_t KBatch = 1; + ck::index_t KBatch = 1; using CShuffleDataType = F32; - using D2DataType = F32; - using DsDataType = ck::Tuple; + using D2DataType = F32; + using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; - using ELayout = Row; + using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; using DsLayout = ck::Tuple; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using AElementOp = PassThrough; - using BElementOp = PassThrough; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0}; + constexpr auto StrideDs = std::array{0}; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr ck::index_t BLOCKSIZE = 256; - static constexpr ck::index_t WAVES = BLOCKSIZE / 64; - static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); - static constexpr ck::index_t CShuffleMXDLPerWave = - ck::is_same_v ? 2 : MXDLPerWave; - static constexpr ck::index_t CShuffleNXDLPerWave = - ck::is_same_v ? 2 : NXDLPerWave; - static constexpr ck::index_t CShuffleNLane = - ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; + static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 2 : MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = ck::is_same_v ? 2 : NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; - static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = - ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); - static constexpr ck::index_t EVec = 2; - static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - static constexpr ck::index_t D2Vec = 1; - static constexpr ck::index_t K0_A = KPerBlock / AK1; - static constexpr ck::index_t K0_B = KPerBlock / BK1; - static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; - static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); + static constexpr ck::index_t BK1 = ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 2; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; @@ -276,13 +245,13 @@ void ck_moe_stage2_gemm(const hipStream_t& stream, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - MPerBlock, NPerBlock, KPerBlock, + MPerBlock, 128, 128, AK1, BK1, MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, - S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - MXDLPerWave, NXDLPerWave, S<1, K0_M, 1, K0_A>, S<2, 1, 1, 1>, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 8d6b29c1b4..8cebae2184 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -212,11 +212,7 @@ && {MulRoutedWeight} == mul_routed_weight_stage && {Quant} == quant) {{ - if (block_m == 16) - {{ - return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 16, 128, 256/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; - }} - else if (block_m == 64) + if (block_m == 64) {{ return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} @@ -395,11 +391,7 @@ && {MulRoutedWeight} == mul_routed_weight_stage && {Quant} == quant) {{ - if (block_m == 16) - {{ - return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 256, 16, 128, 256/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; - }} - else if (block_m == 64) + if (block_m == 64) {{ return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; }} From 14ccac206db197c9470c5135b7f088cb88b8c210 Mon Sep 17 00:00:00 2001 From: solin Date: Wed, 5 Nov 2025 14:15:56 +0000 Subject: [PATCH 07/20] refine code --- .../gemm_moe_ck2stages_common.cuh | 100 +++++++++--------- .../gemm_moe_ck2stages_common_blockscale.cuh | 100 +++++++++--------- 2 files changed, 100 insertions(+), 100 deletions(-) diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh index a130b8c5bd..4c04cbf614 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh @@ -8,21 +8,21 @@ template < typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP> + typename B0DataType, + typename AccDataType, + typename EDataType, + typename CDEElementOp, + PipelineVersion PipelineVer, + int BLOCKSIZE, + int MPerBlock, + int NPerBlock, + int KPerBlock, + int MWaves, + int NWaves, + bool Nswizzle, + bool PerTensorQuant, + bool MulRoutedWeight, + int ActOP> void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&hidden_states, // [m, k], input token @@ -113,27 +113,27 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - hidden_states, - w1, + sorted_expert_ids, + num_valid_ids, + hidden_states, + w1, std::array{a1_scale.has_value() ? a1_scale.value() : nullptr, - w1_scale.has_value() ? w1_scale.value() : nullptr, - MulRoutedWeight ? sorted_weights : nullptr}, - out, - tokens, - topk, - sorted_size, - N, - K, - StrideA, - StrideB, - std::array{DStride, DStride, I0}, - StrideE, - KBatch, - a_element_op, - b_element_op, - cde_element_op); + w1_scale.has_value() ? w1_scale.value() : nullptr, + MulRoutedWeight ? sorted_weights : nullptr}, + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + std::array{DStride, DStride, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); if (!device_op.IsSupportedArgument(argument)) { @@ -163,21 +163,21 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, template < typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP = 0> + typename B0DataType, + typename AccDataType, + typename EDataType, + typename CDEElementOp, + PipelineVersion PipelineVer, + int BLOCKSIZE, + int MPerBlock, + int NPerBlock, + int KPerBlock, + int MWaves, + int NWaves, + bool Nswizzle, + bool PerTensorQuant, + bool MulRoutedWeight, + int ActOP = 0> void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&inter_states, // [max_num_tokens_padded, k], input token diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh index 8bf72721ef..c417b72f58 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh @@ -7,21 +7,21 @@ template < typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP> + typename B0DataType, + typename AccDataType, + typename EDataType, + typename CDEElementOp, + PipelineVersion PipelineVer, + int BLOCKSIZE, + int MPerBlock, + int NPerBlock, + int KPerBlock, + int MWaves, + int NWaves, + bool Nswizzle, + bool PerTensorQuant, + bool MulRoutedWeight, + int ActOP> void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&hidden_states, // [m, k], input token @@ -111,27 +111,27 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - hidden_states, - w1, + sorted_expert_ids, + num_valid_ids, + hidden_states, + w1, std::array{MulRoutedWeight ? sorted_weights : nullptr}, - out, - tokens, - topk, - sorted_size, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - a1_scale_ptr, - w1_scale_ptr, - KBatch, - a_element_op, - b_element_op, - cde_element_op); + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_scale_ptr, + w1_scale_ptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op); if (!device_op.IsSupportedArgument(argument)) { @@ -161,21 +161,21 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, template < typename A0DataType, - typename B0DataType, - typename AccDataType, - typename EDataType, - typename CDEElementOp, - PipelineVersion PipelineVer, - int BLOCKSIZE, - int MPerBlock, - int NPerBlock, - int KPerBlock, - int MWaves, - int NWaves, - bool Nswizzle, - bool PerTensorQuant, - bool MulRoutedWeight, - int ActOP = 0> + typename B0DataType, + typename AccDataType, + typename EDataType, + typename CDEElementOp, + PipelineVersion PipelineVer, + int BLOCKSIZE, + int MPerBlock, + int NPerBlock, + int KPerBlock, + int MWaves, + int NWaves, + bool Nswizzle, + bool PerTensorQuant, + bool MulRoutedWeight, + int ActOP = 0> void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&inter_states, // [max_num_tokens_padded, k], input token From 384c4c9ccdf7bd3831582711353f225b4818b673 Mon Sep 17 00:00:00 2001 From: solin Date: Thu, 6 Nov 2025 07:26:55 +0000 Subject: [PATCH 08/20] fix CI build fail of unsupport block_m=16 --- aiter/fused_moe.py | 6 +- op_tests/test_moe_2stage.py | 233 ++++++++++++++++++------------------ 2 files changed, 120 insertions(+), 119 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index e2843f089d..b356e0b235 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -589,7 +589,7 @@ def FinalFunc(): in fused_moe_1stage_dict[get_gfx()] ): if q_type == QuantType.per_1x128: - run_1stage = True and (inter_dim % 256 == 0) and (token > 31) + run_1stage = True and (inter_dim % 256 == 0) elif q_type == QuantType.per_Token and q_dtype_w in [dtypes.i8, dtypes.fp8]: run_1stage = token > 32 elif q_type != QuantType.per_1x32: @@ -598,7 +598,7 @@ def FinalFunc(): BLOCK_SIZE_M if run_1stage else ( - 16 + 64 if q_type == QuantType.per_1x128 else get_block_size_M(token, topk, expert, inter_dim) ) @@ -637,8 +637,6 @@ def FinalFunc(): torch.uint32, dtypes.fp4x2, ] - or (q_dtype_w == dtypes.fp8 and q_type == QuantType.per_1x128) - or (q_type == QuantType.per_1x128 and block_m == 16) ): return MOEMetadata( functools.partial( diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index a13c6787c7..7b13b9370b 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -409,10 +409,7 @@ def weight_per_128x128_quant(weight, quant_dtype): ) # # ######################## ck stage 1 start ########### - if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: - out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) - else: - out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) + out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) # out1_ck, us1 = run_perftest( # ck_moe_stage1, @@ -434,29 +431,29 @@ def weight_per_128x128_quant(weight, quant_dtype): # ) # cktile_2stage - out1_ck, us1 = run_perftest( - cktile_moe_stage1, - a1_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w1_scale_aiter, - a1_scale, - exp_bias1_aiter, - dtype, - topk, - npad0 * 2, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type=qType, - sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # num_iters=2, - # num_warmup=0, - ) + # out1_ck, us1 = run_perftest( + # cktile_moe_stage1, + # a1_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w1_scale_aiter, + # a1_scale, + # exp_bias1_aiter, + # dtype, + # topk, + # npad0 * 2, + # kpad0, + # BLOCK_SIZE_M, + # actType, + # quant_type=qType, + # sorted_weights=sorted_weights if doweight_stage1 else None, + # # needTrace=True, + # # num_iters=2, + # # num_warmup=0, + # ) # checkAllclose( # out1_ref[:,:-npad0] if need_pad else out1_ref, # out1_ck[:,:-npad0] if need_pad else out1_ck, @@ -535,67 +532,75 @@ def weight_per_128x128_quant(weight, quant_dtype): # ) # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: - out2_ck = torch.zeros((token, model_dim), dtype=dtype) - else: - out2_ck = torch.empty((token, model_dim), dtype=dtype) + out2_ck = torch.empty((token, model_dim), dtype=dtype) + # out2_ck, us2 = run_perftest( + # ck_moe_stage2, + # a2_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w2_scale, + # a2_scale, + # dtype, + # topk, + # BLOCK_SIZE_M, + # actType, + # quant_type, + # sorted_weights if not doweight_stage1 else None, + # ) # # cktil2stage - _, us2 = run_perftest( - cktile_moe_stage2, - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale_aiter, - a2_scale, - exp_bias2_aiter, - dtype, - topk, - npad0, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - # needTrace=True, - # num_iters=2, - # num_warmup=0, - ) - out2_ck = cktile_moe_stage2( - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale_aiter, - a2_scale, - exp_bias2_aiter, - dtype, - topk, - npad0, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - True, - ) - - checkAllclose( - out1_ref[:, :-npad0] if need_pad else out1_ref, - out1_ck[:, :-npad0] if need_pad else out1_ck, - msg=f"[stage1:perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", - ) + # _, us2 = run_perftest( + # cktile_moe_stage2, + # a2_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w2_scale_aiter, + # a2_scale, + # exp_bias2_aiter, + # dtype, + # topk, + # npad0, + # kpad0, + # BLOCK_SIZE_M, + # actType, + # quant_type, + # sorted_weights if not doweight_stage1 else None, + # # needTrace=True, + # # num_iters=2, + # # num_warmup=0, + # ) + # out2_ck = cktile_moe_stage2( + # a2_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w2_scale_aiter, + # a2_scale, + # exp_bias2_aiter, + # dtype, + # topk, + # npad0, + # kpad0, + # BLOCK_SIZE_M, + # actType, + # quant_type, + # sorted_weights if not doweight_stage1 else None, + # True + # ) - checkAllclose( - out2_ref, - out2_ck, - msg=f"[stage2:perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", - ) + # checkAllclose( + # out2_ref, + # out2_ck, + # msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + # ) # diff = torch.abs(out2_ref - out2_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) @@ -603,34 +608,32 @@ def weight_per_128x128_quant(weight, quant_dtype): # ######################## stage 2 end ########### # # ######################## fused 2 stage ######### - # us1=0 - # out2_ck, us2 = run_perftest( - # fused_moe, - # input, - # w1_qt_aiter, - # w2_qt_aiter, - # topk_weights, - # topk_ids, - # w1_scale=fp4_utils.e8m0_shuffle( - # w1_scale - # ), # e8m0_shuffle will do nothing if it's a fp32 - # w2_scale=fp4_utils.e8m0_shuffle(w2_scale), - # quant_type=qType, - # activation=actType, - # doweight_stage1=doweight_stage1, - # ) - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) + us1=0 + out2_ck, us2 = run_perftest( + fused_moe, + input, + w1_qt_aiter, + w2_qt_aiter, + topk_weights, + topk_ids, + w1_scale=fp4_utils.e8m0_shuffle( + w1_scale + ), # e8m0_shuffle will do nothing if it's a fp32 + w2_scale=fp4_utils.e8m0_shuffle(w2_scale), + quant_type=qType, + activation=actType, + doweight_stage1=doweight_stage1, + ) + checkAllclose( + out2_ref, + out2_ck, + msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) return {"gemm1(us)": us1, "gemm2(us)": us2} - - -# seed = 1 -# torch.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) +seed = 1 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) l_dtype = ["bf16", "fp16"][:1] # l_dim = [(6144, 4096)] l_dim = [(7168, 256)] @@ -654,8 +657,8 @@ def weight_per_128x128_quant(weight, quant_dtype): # (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 - # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 - (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 + (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 + # (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] l_doweight_stage1 = [False, True][:1] @@ -737,7 +740,7 @@ def weight_per_128x128_quant(weight, quant_dtype): "-e", "--expert", type=int, - default=8, + default=256, help="""Number of experts. e.g.: -e 8""", ) @@ -746,7 +749,7 @@ def weight_per_128x128_quant(weight, quant_dtype): "-k", "--topk", type=int, - default=2, + default=8, help="""Number of top experts. e.g.: -k 2""", ) From 38dd2be551cb5b800003a659d612af66642e218b Mon Sep 17 00:00:00 2001 From: solin Date: Thu, 6 Nov 2025 07:47:47 +0000 Subject: [PATCH 09/20] refine format --- op_tests/test_moe_2stage.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 7b13b9370b..0e41a46374 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -608,7 +608,7 @@ def weight_per_128x128_quant(weight, quant_dtype): # ######################## stage 2 end ########### # # ######################## fused 2 stage ######### - us1=0 + us1 = 0 out2_ck, us2 = run_perftest( fused_moe, input, @@ -631,9 +631,8 @@ def weight_per_128x128_quant(weight, quant_dtype): ) return {"gemm1(us)": us1, "gemm2(us)": us2} -seed = 1 -torch.manual_seed(seed) -torch.cuda.manual_seed_all(seed) + + l_dtype = ["bf16", "fp16"][:1] # l_dim = [(6144, 4096)] l_dim = [(7168, 256)] @@ -713,7 +712,7 @@ def weight_per_128x128_quant(weight, quant_dtype): 4: aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2 # a4w4 5: aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8, # a8w8""", ) -torch.cuda.manual_seed_all(1) + parser.add_argument( "-a", "--act", From 4d8f48116ff2d694c03f51c37280ea921f88296a Mon Sep 17 00:00:00 2001 From: solin Date: Thu, 6 Nov 2025 13:28:40 +0000 Subject: [PATCH 10/20] fix conflict --- op_tests/test_moe_2stage.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index f5814dafaa..d66133b56b 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -101,12 +101,11 @@ def ck_moe_stage2( D = w2.shape[1] # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - out = torch.empty( + out = torch.zeros( (token_num, D), dtype=dtype, device=hidden_states.device, ) - out.fill_(0) aiter.ck_moe_stage2_fwd( hidden_states, w1, @@ -624,23 +623,23 @@ def weight_per_128x128_quant(weight, quant_dtype): activation=actType, doweight_stage1=doweight_stage1, ) - checkAllclose( + err = checkAllclose( out2_ref, out2_ck, msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) - return {"gemm1(us)": us1, "gemm2(us)": us2} - def calc_diff(x: torch.Tensor, y: torch.Tensor): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim + # return {"gemm1(us)": us1, "gemm2(us)": us2} + def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim - logits_diff = calc_diff(out2_ref, out2_aiter) - assert logits_diff < 1e-3 + logits_diff = calc_diff(out2_ref, out2_ck) + assert logits_diff < 1e-3 - return {"us": us_fuse, "err": err} + return {"us": us2, "err": err} l_dtype = ["bf16", "fp16"][:1] @@ -749,7 +748,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-e", "--expert", type=int, - default=256, + default=8, help="""Number of experts. e.g.: -e 8""", ) @@ -758,7 +757,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-k", "--topk", type=int, - default=8, + default=2, help="""Number of top experts. e.g.: -k 2""", ) From d36fc8ac1ffc2913c770b13cfc174b7ecfa0e363 Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 01:51:51 +0000 Subject: [PATCH 11/20] update --- aiter/fused_moe.py | 176 +++++++++++++- csrc/include/aiter_enum.h | 3 +- csrc/include/rocm_ops.hpp | 1 + op_tests/test_moe_2stage.py | 473 +++++++++++++++++++----------------- 4 files changed, 422 insertions(+), 231 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index e2843f089d..36c14ac976 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -107,7 +107,12 @@ def fused_moe( num_local_tokens: Optional[torch.tensor] = None, moe_sorting_dispatch_policy=0, dtype=None, -): + # following for cktile support + hidden_pad=0, + intermediate_pad=0, + bias1=None, + bias2=None, +): if not block_size_M: block_size_M = -1 return fused_moe_( @@ -128,6 +133,10 @@ def fused_moe( num_local_tokens=num_local_tokens, moe_sorting_dispatch_policy=moe_sorting_dispatch_policy, dtype=dtype, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -181,6 +190,10 @@ def fused_moe_( num_local_tokens: Optional[torch.Tensor] = None, moe_sorting_dispatch_policy: bool = 0, dtype: Optional[torch.dtype] = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: Optional[torch.Tensor] = None, + bias2: Optional[torch.Tensor] = None, ) -> torch.Tensor: # We do such convert since custom_op schema restriction on block_size_M, and Enum type activation = ActivationType(activation) @@ -223,6 +236,10 @@ def fused_moe_( isG1U1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) block_size_M = metadata.block_m if block_size_M is None else block_size_M @@ -255,6 +272,8 @@ def fused_moe_( moe_buf, isG1U1, block_size_M, + # activation=activation, + # quant_type=quant_type, q_dtype_a=q_dtype_a, q_dtype_w=q_dtype_w, w1_scale=w1_scale, @@ -286,6 +305,11 @@ def fused_moe_( a1_scale=a1_scale, a2_scale=a2_scale, num_local_tokens=num_local_tokens, + # following for cktile support + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -494,6 +518,10 @@ def get_2stage_cfgs( use_g1u1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ): def get_cfg_2stages(tune_file): import pandas as pd @@ -501,7 +529,6 @@ def get_cfg_2stages(tune_file): cfg_2stages = pd.read_csv(tune_file) cfg_2stages = cfg_2stages.set_index( [ - "cu_num", "token", "model_dim", "inter_dim", @@ -548,7 +575,6 @@ def MainFunc(): f.write( "token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1" ) - q_dtype_ws = q_dtype_w if q_dtype_w != torch.uint32 else "torch.int4" f.write( f"\n{token},{model_dim},{inter_dim},{expert},{topk},{activation},{dtype},{q_dtype_a},{q_dtype_ws},{q_type},{int(use_g1u1)},{int(doweight_stage1)}" @@ -627,6 +653,24 @@ def FinalFunc(): ksplit, run_1stage, ) + if dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu: + return MOEMetadata( + functools.partial( + cktile_moe_stage1, + n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1), + k_pad_zeros=hidden_pad // 128 * 128, + bias1=bias1, + ), + functools.partial( + cktile_moe_stage2, + n_pad_zeros=hidden_pad // 64 * 64, + k_pad_zeros=intermediate_pad // 128 * 128, + bias2=bias2, + ), + 16 if token < 2048 else 32, + ksplit, + False, + ) if ( "ck2stages" in kernelName1 or (q_type == QuantType.per_1x128 and doweight_stage1) @@ -706,6 +750,11 @@ def fused_moe_2stages( a1_scale=None, # [expert(local_expert:EP), 1, model_dim] a2_scale=None, # [expert(local_expert:EP), 1, inter_dim] num_local_tokens: Optional[torch.tensor] = None, + # following for cktile support + hidden_pad=0, + intermediate_pad=0, + bias1=None, + bias2=None, ): quant_func = get_quant(quant_type) @@ -727,9 +776,18 @@ def fused_moe_2stages( isG1U1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) - - if quant_type == QuantType.per_1x32: + if quant_type == QuantType.per_1x32 \ + and dtype in [dtypes.bf16, dtypes.fp16] \ + and w1.dtype == dtypes.fp4x2 \ + and activation == ActivationType.Swiglu: + a1 = hidden_states.to(dtype) + a1_scale = None + elif quant_type == QuantType.per_1x32: a1, a1_scale = quant_func( hidden_states, scale=a1_scale, @@ -770,7 +828,7 @@ def fused_moe_2stages( dtype=dtype, device=device, ) - + a2 = metadata.stage1( a1, w1, @@ -786,7 +844,11 @@ def fused_moe_2stages( sorted_weights=sorted_weights if doweight_stage1 else None, ) - if quant_type == QuantType.per_1x32: + if quant_type == QuantType.per_1x32 \ + and dtype in [dtypes.bf16, dtypes.fp16] \ + and w1.dtype == dtypes.fp4x2: + a2_scale = None + elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) a2, a2_scale = quant_func( a2, @@ -977,7 +1039,7 @@ def torch_moe( return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype) -# temp workaround for swiglu +#temp workaround for swiglu def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0): # Clamp the input values x_glu = x_glu.clamp(min=None, max=limit) @@ -1105,7 +1167,6 @@ def torch_moe_stage2( w2_bias=None, doweight=True, ): - quant_type = quant_remap.get(quant_type, quant_type) ctype = dtypes.fp32 # compute type E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: @@ -1174,6 +1235,101 @@ def torch_moe_stage2( return out.sum(1).to(dtype) +def cktile_moe_stage1( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + block_m, + a1_scale, + w1_scale, + sorted_weights=None, + n_pad_zeros=0, + k_pad_zeros=0, + bias1=None, +): + token_num = hidden_states.shape[0] + _, n1, k1 = w1.shape + _, k2, n2 = w2.shape + D = n2 if k2 == k1 else n2*2 #bit4 format + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + if w1.dtype is torch.uint32: + D = D * 8 + out = torch.empty((token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device) + # print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0])) + aiter.moe_cktile2stages_gemm1( + hidden_states, + w1, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a1_scale, + w1_scale, + bias1, + block_m, + ) + return out + + +def cktile_moe_stage2( + a2, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w2_scale, + a2_scale, + block_m, + sorted_weights=None, + zeros_out=False, + n_pad_zeros=0, + k_pad_zeros=0, + bias2=None, +): + token_num = a2.shape[0] + D = w2.shape[1] + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + # out = torch.empty( + # (token_num, D), + # dtype=a2.dtype, + # device=a2.device, + # ) + # if zeros_out: + # out.fill_(0) + # print("Run cktile_moe_stage2: M=%d, N=%d, K=%d, topk=%d, expert=%d"%(a2.shape[0]*a2.shape[1], w2.shape[1], a2.shape[2], topk, w2.shape[0])) + aiter.moe_cktile2stages_gemm2( + a2, + w2, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a2_scale, + w2_scale, + bias2, + block_m, + ) + return out + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -1235,4 +1391,4 @@ def fused_topk( # if renormalize: # topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights, topk_ids \ No newline at end of file diff --git a/csrc/include/aiter_enum.h b/csrc/include/aiter_enum.h index 0c35e8158f..15126c8cf6 100644 --- a/csrc/include/aiter_enum.h +++ b/csrc/include/aiter_enum.h @@ -6,7 +6,8 @@ enum class ActivationType : int { No = -1, Silu = 0, - Gelu + Gelu = 1, + Swiglu = 2, }; enum class QuantType : int { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 45a1b441a1..6b2863d338 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1320,6 +1320,7 @@ namespace py = pybind11; .value("No", ActivationType::No) \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ + .value("Swiglu", ActivationType::Swiglu) \ .export_values(); \ pybind11::implicitly_convertible(); \ pybind11::implicitly_convertible(); diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index a13c6787c7..c01f10945a 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -25,8 +25,8 @@ from aiter.ops.shuffle import ( shuffle_weight, - shuffle_mxfp4_weight, shuffle_mxfp4_scale, + shuffle_mxfp4_weight, shuffle_weight_NK, ) from aiter import ActivationType @@ -80,7 +80,6 @@ def ck_moe_stage1( return out - def ck_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -126,7 +125,6 @@ def ck_moe_stage2( ) return out - def cktile_moe_stage1( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -139,8 +137,8 @@ def cktile_moe_stage1( exp_bias1, dtype, topk, - n_pad_zeros=0, - k_pad_zeros=0, + n_pad_zeros = 0, + k_pad_zeros = 0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, @@ -149,7 +147,7 @@ def cktile_moe_stage1( token_num = hidden_states.shape[0] _, n1, k1 = w1.shape _, k2, n2 = w2.shape - D = n2 if k2 == k1 else n2 * 2 # bit4 format + D = n2 if k2 == k1 else n2 * 2 #bit4 format # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size if w1.dtype is torch.uint32: @@ -174,7 +172,6 @@ def cktile_moe_stage1( ) return out - def cktile_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -187,13 +184,13 @@ def cktile_moe_stage2( exp_bias2, dtype, topk, - n_pad_zeros=0, - k_pad_zeros=0, + n_pad_zeros = 0, + k_pad_zeros = 0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, sorted_weights=None, # [max_num_tokens_padded] - zeros_out=False, + zeros_out = False ): token_num = hidden_states.shape[0] D = w2.shape[1] @@ -240,32 +237,30 @@ def test_fmoe( WQDType, use_g1u1=False, doweight_stage1=False, + hidden_pad=0, + intermediate_pad=0, ): if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32: return torch_quant = aiter.get_torch_quant(qType) torch_act = aiter.get_torch_act(actType) input = torch.randn((token, model_dim), dtype=dtype) - need_pad = qType == aiter.QuantType.per_1x32 - npad0 = 192 - kpad0 = 128 if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) - if need_pad: - w1[:, :, -kpad0:] = 0 - w1[:, -npad0:, :] = 0 - w1[:, inter_dim - npad0 : inter_dim, :] = 0 + if (hidden_pad != 0 and intermediate_pad != 0): + w1[:,:,-hidden_pad:] = 0 + w1[:,-intermediate_pad:,:] = 0 + w1[:,inter_dim-intermediate_pad:inter_dim,:] = 0 exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) - if need_pad: - w2[:, :, -kpad0:] = 0 - w2[:, -npad0:, :] = 0 + if (hidden_pad != 0 and intermediate_pad != 0): + w2[:,:,-intermediate_pad:] = 0 + w2[:,-hidden_pad:,:] = 0 exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) - # rand topk_weights, topk_ids = fused_topk(input, score, topk, True) # sequence # topk_ids_list = [[((i * topk) + j)% E for j in range(topk)] for i in range(token)] @@ -273,10 +268,10 @@ def test_fmoe( M, _ = topk_ids.shape - # BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) - BLOCK_SIZE_M = 32 if M > 1024 else 16 - if qType == aiter.QuantType.per_128x128: - BLOCK_SIZE_M = 64 if M > 64 else 16 + BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) + # BLOCK_SIZE_M = 32 if M > 1024 else 16 + # if qType == aiter.QuantType.per_128x128: + # BLOCK_SIZE_M = 64 if M > 64 else 16 sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M ) @@ -326,31 +321,25 @@ def weight_per_128x128_quant(weight, quant_dtype): if qType != aiter.QuantType.per_1x32: w1_qt = w1_qt_aiter = w1_qt.view(w1.shape) w2_qt = w2_qt_aiter = w2_qt.view(w2.shape) - else: w1_qt = w1_qt_aiter = w1_qt.view(w1.shape[0], w1.shape[1], w1.shape[2] // 2) w2_qt = w2_qt_aiter = w2_qt.view(w2.shape[0], w2.shape[1], w2.shape[2] // 2) + # Quant-ing a if qType == aiter.QuantType.per_128x128: a1_qt, a1_scale = aiter.pertoken_quant( input.view(token, -1, 128), quant_dtype=AQDType ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) - elif qType == aiter.QuantType.per_1x32 and ( - AQDType in [dtypes.bf16, dtypes.fp16] - ): # a16w4 + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and WQDType == dtypes.fp4x2: #a16w4 a1_qt = input.to(AQDType) a1_scale = None else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) # bias dtype convert - if ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 exp_bias1_aiter = exp_bias1.to(dtypes.fp32) exp_bias2_aiter = exp_bias2.to(dtypes.fp32) else: @@ -371,26 +360,22 @@ def weight_per_128x128_quant(weight, quant_dtype): shuffle_weight(w2_qt_aiter, (16, 16), use_int4=True) ) ) - elif ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 w1_qt_aiter = shuffle_mxfp4_weight(w1_qt_aiter, 16, True) w1_scale_aiter = shuffle_mxfp4_scale(w1_scale, E, True) w2_qt_aiter = shuffle_mxfp4_weight(w2_qt_aiter, 16, False) w2_scale_aiter = shuffle_mxfp4_scale(w2_scale, E, False) - elif ( - WQDType != dtypes.fp4x2 - and (get_gfx() in ["gfx950"]) - and (qType != aiter.QuantType.per_128x128) - ): - inst_K = 128 // w1_qt_aiter.element_size() - w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) - w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) + # elif WQDType != dtypes.fp4x2 and (get_gfx() in ["gfx950"]): + # inst_K = 128 // w1_qt_aiter.element_size() + # w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) + # w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) elif WQDType != dtypes.fp4x2: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) # # ######################## stage 1 start ########### out1_ref = torch_moe_stage1( @@ -413,55 +398,58 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) else: out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - - # out1_ck, us1 = run_perftest( - # ck_moe_stage1, - # a1_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w1_scale, - # a1_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type=qType, - # sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # ) - - # cktile_2stage - out1_ck, us1 = run_perftest( - cktile_moe_stage1, - a1_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w1_scale_aiter, - a1_scale, - exp_bias1_aiter, - dtype, - topk, - npad0 * 2, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type=qType, - sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # num_iters=2, - # num_warmup=0, + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4: + npad0 = intermediate_pad // 64 * 64 + kpad0 = hidden_pad // 128 * 128 + out1_ck, us1 = run_perftest( + cktile_moe_stage1, + a1_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w1_scale_aiter, + a1_scale, + exp_bias1_aiter, + dtype, + topk, + npad0 * 2, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type=qType, + sorted_weights=sorted_weights if doweight_stage1 else None, + # needTrace=True, + # num_iters=2, + # num_warmup=0, + ) + else: + out1_ck, us1 = run_perftest( + ck_moe_stage1, + a1_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w1_scale, + a1_scale, + dtype, + topk, + BLOCK_SIZE_M, + actType, + quant_type=qType, + sorted_weights=sorted_weights if doweight_stage1 else None, + needTrace=True, ) - # checkAllclose( - # out1_ref[:,:-npad0] if need_pad else out1_ref, - # out1_ck[:,:-npad0] if need_pad else out1_ck, - # msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) + + checkAllclose( + out1_ref, + out1_ck, + msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) + # diff = torch.abs(out1_ref - out1_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) @@ -504,7 +492,7 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ref.view(token, -1, 128), quant_dtype=AQDType ) a2_scale = a2_scale.view(token, topk, -1) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]): + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 a2_qt = out1_ref a2_scale = None else: @@ -524,120 +512,137 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_bias=exp_bias2, doweight=not doweight_stage1, ) - # out_ref = torch_moe( - # input, - # w1_qt, - # w2_qt, - # topk_weights, - # topk_ids, - # fc1_scale=w1_scale, - # fc2_scale=w2_scale, - # ) - # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - - if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: - out2_ck = torch.zeros((token, model_dim), dtype=dtype) + # # out_ref = torch_moe( + # # input, + # # w1_qt, + # # w2_qt, + # # topk_weights, + # # topk_ids, + # # fc1_scale=w1_scale, + # # fc2_scale=w2_scale, + # # ) + # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") + + out2_ck = torch.empty((token, model_dim), dtype=dtype) + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + npad0 = hidden_pad // 64 * 64 + kpad0 = intermediate_pad // 128 * 128 + _, us2 = run_perftest( + cktile_moe_stage2, + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale_aiter, + a2_scale, + exp_bias2_aiter, + dtype, + topk, + npad0, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + # needTrace=True, + # num_iters=2, + # num_warmup=0, + ) + out2_ck = cktile_moe_stage2( + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale_aiter, + a2_scale, + exp_bias2_aiter, + dtype, + topk, + npad0, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + True + ) else: - out2_ck = torch.empty((token, model_dim), dtype=dtype) - - # # cktil2stage - _, us2 = run_perftest( - cktile_moe_stage2, - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale_aiter, - a2_scale, - exp_bias2_aiter, - dtype, - topk, - npad0, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - # needTrace=True, - # num_iters=2, - # num_warmup=0, - ) - out2_ck = cktile_moe_stage2( - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale_aiter, - a2_scale, - exp_bias2_aiter, - dtype, - topk, - npad0, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - True, - ) - - checkAllclose( - out1_ref[:, :-npad0] if need_pad else out1_ref, - out1_ck[:, :-npad0] if need_pad else out1_ck, - msg=f"[stage1:perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", - ) + out2_ck, us2 = run_perftest( + ck_moe_stage2, + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale, + a2_scale, + dtype, + topk, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + ) checkAllclose( out2_ref, out2_ck, - msg=f"[stage2:perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) + # diff = torch.abs(out2_ref - out2_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) # print("max_diff", max_value.item(), ",ref=", out2_ref[multi_index].item(), ",ck=", out2_ck[multi_index].item()) # ######################## stage 2 end ########### - # # ######################## fused 2 stage ######### - # us1=0 - # out2_ck, us2 = run_perftest( - # fused_moe, - # input, - # w1_qt_aiter, - # w2_qt_aiter, - # topk_weights, - # topk_ids, - # w1_scale=fp4_utils.e8m0_shuffle( - # w1_scale - # ), # e8m0_shuffle will do nothing if it's a fp32 - # w2_scale=fp4_utils.e8m0_shuffle(w2_scale), - # quant_type=qType, - # activation=actType, - # doweight_stage1=doweight_stage1, - # ) - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - - return {"gemm1(us)": us1, "gemm2(us)": us2} - - -# seed = 1 -# torch.manual_seed(seed) -# torch.cuda.manual_seed_all(seed) + out2_ck = fused_moe( + input, + w1_qt_aiter, + w2_qt_aiter, + topk_weights, + topk_ids, + w1_scale=w1_scale_aiter, + w2_scale=w2_scale_aiter, + quant_type=qType, + activation=actType, + doweight_stage1=doweight_stage1, + intermediate_pad=intermediate_pad, + hidden_pad=hidden_pad, + bias1=exp_bias1_aiter, + bias2=exp_bias2_aiter, + ) + err = checkAllclose( + out2_ref, + out2_ck, + msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) + + # return {"gemm1(us)": us1, "gemm2(us)": us2} + def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + logits_diff = calc_diff(out2_ref, out2_ck) + assert logits_diff < 1e-3 + + return {"us": us2, "err": err} + l_dtype = ["bf16", "fp16"][:1] # l_dim = [(6144, 4096)] l_dim = [(7168, 256)] +# l_dim = [(3072, 3072)] l_tokenNum = [ # 1, - # 3, - # 5, + # 2, + # 4, 8, # 16, # 32, @@ -645,20 +650,26 @@ def weight_per_128x128_quant(weight, quant_dtype): # 128, # 256, # 1024, + # 2048, + # 3072, # 4096, + # 8192, # 163840, ] +l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu] l_quant = [ - # (aiter.QuantType.No, None, None), # a16w16 + # (aiter.QuantType.No, None, None), # a16w16 # (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 # (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 - (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 + (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 + ] -l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] l_doweight_stage1 = [False, True][:1] +l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)] + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, @@ -737,7 +748,7 @@ def weight_per_128x128_quant(weight, quant_dtype): "-e", "--expert", type=int, - default=8, + default=256, help="""Number of experts. e.g.: -e 8""", ) @@ -746,7 +757,7 @@ def weight_per_128x128_quant(weight, quant_dtype): "-k", "--topk", type=int, - default=2, + default=8, help="""Number of top experts. e.g.: -k 2""", ) @@ -770,30 +781,52 @@ def weight_per_128x128_quant(weight, quant_dtype): if args.doweight_stage1 is not None: l_doweight_stage1 = [args.doweight_stage1] - + +df = [] for ( dtype, - act_type, (quant_type, aq_dtype, wq_dtype), (model_dim, inter_dim), doweight_stage1, -) in itertools.product(l_dtype, l_act, l_quant, l_dim, l_doweight_stage1): - df = [] - for m in l_tokenNum: - ret = test_fmoe( - dtype, - m, - model_dim, - inter_dim, - args.expert, - args.topk, - act_type, - quant_type, - aq_dtype, - wq_dtype, - use_g1u1=True, - doweight_stage1=doweight_stage1, - ) - df.append(ret) - df = pd.DataFrame(df) - aiter.logger.info(f"summary:\n{df}") +) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1): + if (quant_type, aq_dtype, wq_dtype) == (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2): + for (hidden_pad, intermediate_pad) in l_hidden_intermediate_pad: + print(f"hidden_pad={hidden_pad}, intermediate_pad={intermediate_pad}") + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + aiter.ActivationType.Swiglu, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + ) + df.append(ret) + else: + for act_type in l_act: + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + act_type, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + ) + df.append(ret) +df = pd.DataFrame(df) +aiter.logger.info(f"summary:\n{df}") \ No newline at end of file From 26ebcd323b25b714b9b6cafd5b0496a0bf34b143 Mon Sep 17 00:00:00 2001 From: zhimding Date: Thu, 6 Nov 2025 21:21:35 -0600 Subject: [PATCH 12/20] format --- aiter/fused_moe.py | 72 ++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index b9f6c80161..6acbb89173 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -112,7 +112,7 @@ def fused_moe( intermediate_pad=0, bias1=None, bias2=None, -): +): if not block_size_M: block_size_M = -1 return fused_moe_( @@ -653,24 +653,28 @@ def FinalFunc(): ksplit, run_1stage, ) - if dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu: + if ( + dtype in [dtypes.bf16, dtypes.fp16] + and q_type == QuantType.per_1x32 + and activation == ActivationType.Swiglu + ): return MOEMetadata( - functools.partial( - cktile_moe_stage1, - n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1), - k_pad_zeros=hidden_pad // 128 * 128, - bias1=bias1, - ), - functools.partial( - cktile_moe_stage2, - n_pad_zeros=hidden_pad // 64 * 64, - k_pad_zeros=intermediate_pad // 128 * 128, - bias2=bias2, - ), - 16 if token < 2048 else 32, - ksplit, - False, - ) + functools.partial( + cktile_moe_stage1, + n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1), + k_pad_zeros=hidden_pad // 128 * 128, + bias1=bias1, + ), + functools.partial( + cktile_moe_stage2, + n_pad_zeros=hidden_pad // 64 * 64, + k_pad_zeros=intermediate_pad // 128 * 128, + bias2=bias2, + ), + 16 if token < 2048 else 32, + ksplit, + False, + ) if ( "ck2stages" in kernelName1 or (q_type == QuantType.per_1x128 and doweight_stage1) @@ -779,10 +783,12 @@ def fused_moe_2stages( bias1, bias2, ) - if quant_type == QuantType.per_1x32 \ - and dtype in [dtypes.bf16, dtypes.fp16] \ - and w1.dtype == dtypes.fp4x2 \ - and activation == ActivationType.Swiglu: + if ( + quant_type == QuantType.per_1x32 + and dtype in [dtypes.bf16, dtypes.fp16] + and w1.dtype == dtypes.fp4x2 + and activation == ActivationType.Swiglu + ): a1 = hidden_states.to(dtype) a1_scale = None elif quant_type == QuantType.per_1x32: @@ -826,7 +832,7 @@ def fused_moe_2stages( dtype=dtype, device=device, ) - + a2 = metadata.stage1( a1, w1, @@ -842,9 +848,11 @@ def fused_moe_2stages( sorted_weights=sorted_weights if doweight_stage1 else None, ) - if quant_type == QuantType.per_1x32 \ - and dtype in [dtypes.bf16, dtypes.fp16] \ - and w1.dtype == dtypes.fp4x2: + if ( + quant_type == QuantType.per_1x32 + and dtype in [dtypes.bf16, dtypes.fp16] + and w1.dtype == dtypes.fp4x2 + ): a2_scale = None elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) @@ -1037,7 +1045,7 @@ def torch_moe( return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype) -#temp workaround for swiglu +# temp workaround for swiglu def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0): # Clamp the input values x_glu = x_glu.clamp(min=None, max=limit) @@ -1253,12 +1261,14 @@ def cktile_moe_stage1( token_num = hidden_states.shape[0] _, n1, k1 = w1.shape _, k2, n2 = w2.shape - D = n2 if k2 == k1 else n2*2 #bit4 format + D = n2 if k2 == k1 else n2 * 2 # bit4 format # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size if w1.dtype is torch.uint32: D = D * 8 - out = torch.empty((token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device) + out = torch.empty( + (token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device + ) # print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0])) aiter.moe_cktile2stages_gemm1( hidden_states, @@ -1277,7 +1287,7 @@ def cktile_moe_stage1( block_m, ) return out - + def cktile_moe_stage2( a2, @@ -1389,4 +1399,4 @@ def fused_topk( # if renormalize: # topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids \ No newline at end of file + return topk_weights, topk_ids From 850a8baa849794daa2d5e98bca6f21a3d24849d1 Mon Sep 17 00:00:00 2001 From: solin Date: Fri, 7 Nov 2025 03:28:46 +0000 Subject: [PATCH 13/20] fix format --- op_tests/test_moe_2stage.py | 88 +++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 29 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index cbe2a80fa6..e68b0d537d 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -80,6 +80,7 @@ def ck_moe_stage1( return out + def ck_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -124,6 +125,7 @@ def ck_moe_stage2( ) return out + def cktile_moe_stage1( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -136,8 +138,8 @@ def cktile_moe_stage1( exp_bias1, dtype, topk, - n_pad_zeros = 0, - k_pad_zeros = 0, + n_pad_zeros=0, + k_pad_zeros=0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, @@ -146,7 +148,7 @@ def cktile_moe_stage1( token_num = hidden_states.shape[0] _, n1, k1 = w1.shape _, k2, n2 = w2.shape - D = n2 if k2 == k1 else n2 * 2 #bit4 format + D = n2 if k2 == k1 else n2 * 2 # bit4 format # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size if w1.dtype is torch.uint32: @@ -171,6 +173,7 @@ def cktile_moe_stage1( ) return out + def cktile_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -183,13 +186,13 @@ def cktile_moe_stage2( exp_bias2, dtype, topk, - n_pad_zeros = 0, - k_pad_zeros = 0, + n_pad_zeros=0, + k_pad_zeros=0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, sorted_weights=None, # [max_num_tokens_padded] - zeros_out = False + zeros_out=False, ): token_num = hidden_states.shape[0] D = w2.shape[1] @@ -246,18 +249,18 @@ def test_fmoe( input = torch.randn((token, model_dim), dtype=dtype) if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) - if (hidden_pad != 0 and intermediate_pad != 0): - w1[:,:,-hidden_pad:] = 0 - w1[:,-intermediate_pad:,:] = 0 - w1[:,inter_dim-intermediate_pad:inter_dim,:] = 0 + if hidden_pad != 0 and intermediate_pad != 0: + w1[:, :, -hidden_pad:] = 0 + w1[:, -intermediate_pad:, :] = 0 + w1[:, inter_dim - intermediate_pad : inter_dim, :] = 0 exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) - if (hidden_pad != 0 and intermediate_pad != 0): - w2[:,:,-intermediate_pad:] = 0 - w2[:,-hidden_pad:,:] = 0 + if hidden_pad != 0 and intermediate_pad != 0: + w2[:, :, -intermediate_pad:] = 0 + w2[:, -hidden_pad:, :] = 0 exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) topk_weights, topk_ids = fused_topk(input, score, topk, True) @@ -331,14 +334,22 @@ def weight_per_128x128_quant(weight, quant_dtype): ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and WQDType == dtypes.fp4x2: #a16w4 + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and WQDType == dtypes.fp4x2 + ): # a16w4 a1_qt = input.to(AQDType) a1_scale = None else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) # bias dtype convert - if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + if ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 exp_bias1_aiter = exp_bias1.to(dtypes.fp32) exp_bias2_aiter = exp_bias2.to(dtypes.fp32) else: @@ -361,7 +372,11 @@ def weight_per_128x128_quant(weight, quant_dtype): ) w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 w1_qt_aiter = shuffle_mxfp4_weight(w1_qt_aiter, 16, True) w1_scale_aiter = shuffle_mxfp4_scale(w1_scale, E, True) w2_qt_aiter = shuffle_mxfp4_weight(w2_qt_aiter, 16, False) @@ -397,7 +412,11 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) else: out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4: + if ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4: npad0 = intermediate_pad // 64 * 64 kpad0 = hidden_pad // 128 * 128 out1_ck, us1 = run_perftest( @@ -441,14 +460,14 @@ def weight_per_128x128_quant(weight, quant_dtype): quant_type=qType, sorted_weights=sorted_weights if doweight_stage1 else None, needTrace=True, - ) - + ) + checkAllclose( out1_ref, out1_ck, msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) - + # diff = torch.abs(out1_ref - out1_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) @@ -491,7 +510,11 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ref.view(token, -1, 128), quant_dtype=AQDType ) a2_scale = a2_scale.view(token, topk, -1) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + elif ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 a2_qt = out1_ref a2_scale = None else: @@ -523,7 +546,11 @@ def weight_per_128x128_quant(weight, quant_dtype): # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") out2_ck = torch.empty((token, model_dim), dtype=dtype) - if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + if ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 npad0 = hidden_pad // 64 * 64 kpad0 = intermediate_pad // 128 * 128 _, us2 = run_perftest( @@ -567,7 +594,7 @@ def weight_per_128x128_quant(weight, quant_dtype): actType, quant_type, sorted_weights if not doweight_stage1 else None, - True + True, ) else: out2_ck, us2 = run_perftest( @@ -664,8 +691,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 - (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 - + (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] l_doweight_stage1 = [False, True][:1] l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)] @@ -781,7 +807,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): if args.doweight_stage1 is not None: l_doweight_stage1 = [args.doweight_stage1] - + df = [] for ( dtype, @@ -789,8 +815,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): (model_dim, inter_dim), doweight_stage1, ) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1): - if (quant_type, aq_dtype, wq_dtype) == (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2): - for (hidden_pad, intermediate_pad) in l_hidden_intermediate_pad: + if (quant_type, aq_dtype, wq_dtype) == ( + aiter.QuantType.per_1x32, + dtypes.bf16, + dtypes.fp4x2, + ): + for hidden_pad, intermediate_pad in l_hidden_intermediate_pad: print(f"hidden_pad={hidden_pad}, intermediate_pad={intermediate_pad}") for m in l_tokenNum: ret = test_fmoe( @@ -829,4 +859,4 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): ) df.append(ret) df = pd.DataFrame(df) -aiter.logger.info(f"summary:\n{df}") \ No newline at end of file +aiter.logger.info(f"summary:\n{df}") From 25f2dd507c1bba590edf64a8d8e65eddeb8a5bbc Mon Sep 17 00:00:00 2001 From: zhimding Date: Thu, 6 Nov 2025 22:37:37 -0600 Subject: [PATCH 14/20] update --- aiter/ops/shuffle.py | 4 +- op_tests/test_moe_2stage.py | 348 ++++++++++++++++++------------------ 2 files changed, 179 insertions(+), 173 deletions(-) diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index 656044bd70..1ea0e35ac7 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -46,7 +46,7 @@ def shuffle_weight_NK( return x_.view(*x.shape) -def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: +def shuffle_weight_a16w4(src: torch.Tensor, NLane: int, gate_up: bool) -> torch.Tensor: """ src: shape [experts_cnt, N, K_pk], where K_pk = K // 2 Returns: shuffled tensor of shape [experts_cnt, N0*2, K0, KLane, NLane, KPack] @@ -79,7 +79,7 @@ def shuffle_mxfp4_weight(src: torch.Tensor, NLane: int, gate_up: bool) -> torch. return interleaved.contiguous().view(src_type) -def shuffle_mxfp4_scale( +def shuffle_scale_a16w4( src: torch.Tensor, experts_cnt: int, gate_up: bool ) -> torch.Tensor: n_experts, k_ = src.shape diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index e68b0d537d..e3b6d5da8d 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -25,8 +25,8 @@ from aiter.ops.shuffle import ( shuffle_weight, - shuffle_mxfp4_scale, - shuffle_mxfp4_weight, + shuffle_scale_a16w4, + shuffle_weight_a16w4, shuffle_weight_NK, ) from aiter import ActivationType @@ -271,9 +271,14 @@ def test_fmoe( M, _ = topk_ids.shape BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) - # BLOCK_SIZE_M = 32 if M > 1024 else 16 - # if qType == aiter.QuantType.per_128x128: - # BLOCK_SIZE_M = 64 if M > 64 else 16 + if ( + qType == aiter.QuantType.per_1x32 + and (AQDType in [dtypes.bf16, dtypes.fp16]) + and (WQDType == dtypes.fp4x2) + ): # a16w4 + BLOCK_SIZE_M = 32 if M > 1024 else 16 + if qType == aiter.QuantType.per_128x128: + BLOCK_SIZE_M = 64 if M > 64 else 16 sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M ) @@ -377,10 +382,10 @@ def weight_per_128x128_quant(weight, quant_dtype): and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2) ): # a16w4 - w1_qt_aiter = shuffle_mxfp4_weight(w1_qt_aiter, 16, True) - w1_scale_aiter = shuffle_mxfp4_scale(w1_scale, E, True) - w2_qt_aiter = shuffle_mxfp4_weight(w2_qt_aiter, 16, False) - w2_scale_aiter = shuffle_mxfp4_scale(w2_scale, E, False) + w1_qt_aiter = shuffle_weight_a16w4(w1_qt_aiter, 16, True) + w1_scale_aiter = shuffle_scale_a16w4(w1_scale, E, True) + w2_qt_aiter = shuffle_weight_a16w4(w2_qt_aiter, 16, False) + w2_scale_aiter = shuffle_scale_a16w4(w2_scale, E, False) # elif WQDType != dtypes.fp4x2 and (get_gfx() in ["gfx950"]): # inst_K = 128 // w1_qt_aiter.element_size() # w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) @@ -408,65 +413,65 @@ def weight_per_128x128_quant(weight, quant_dtype): ) # # ######################## ck stage 1 start ########### - if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: - out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) - else: - out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - if ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4: - npad0 = intermediate_pad // 64 * 64 - kpad0 = hidden_pad // 128 * 128 - out1_ck, us1 = run_perftest( - cktile_moe_stage1, - a1_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w1_scale_aiter, - a1_scale, - exp_bias1_aiter, - dtype, - topk, - npad0 * 2, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type=qType, - sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # num_iters=2, - # num_warmup=0, - ) - else: - out1_ck, us1 = run_perftest( - ck_moe_stage1, - a1_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w1_scale, - a1_scale, - dtype, - topk, - BLOCK_SIZE_M, - actType, - quant_type=qType, - sorted_weights=sorted_weights if doweight_stage1 else None, - needTrace=True, - ) + # if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: + # out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) + # else: + # out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) + # if ( + # qType == aiter.QuantType.per_1x32 + # and (AQDType in [dtypes.bf16, dtypes.fp16]) + # and (WQDType == dtypes.fp4x2) + # ): # a16w4: + # npad0 = intermediate_pad // 64 * 64 + # kpad0 = hidden_pad // 128 * 128 + # out1_ck, us1 = run_perftest( + # cktile_moe_stage1, + # a1_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w1_scale_aiter, + # a1_scale, + # exp_bias1_aiter, + # dtype, + # topk, + # npad0 * 2, + # kpad0, + # BLOCK_SIZE_M, + # actType, + # quant_type=qType, + # sorted_weights=sorted_weights if doweight_stage1 else None, + # # needTrace=True, + # # num_iters=2, + # # num_warmup=0, + # ) + # else: + # out1_ck, us1 = run_perftest( + # ck_moe_stage1, + # a1_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w1_scale, + # a1_scale, + # dtype, + # topk, + # BLOCK_SIZE_M, + # actType, + # quant_type=qType, + # sorted_weights=sorted_weights if doweight_stage1 else None, + # needTrace=True, + # ) - checkAllclose( - out1_ref, - out1_ck, - msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", - ) + # checkAllclose( + # out1_ref, + # out1_ck, + # msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", + # ) # diff = torch.abs(out1_ref - out1_ck) # max_value= diff.max() @@ -545,89 +550,90 @@ def weight_per_128x128_quant(weight, quant_dtype): # # ) # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - out2_ck = torch.empty((token, model_dim), dtype=dtype) - if ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 - npad0 = hidden_pad // 64 * 64 - kpad0 = intermediate_pad // 128 * 128 - _, us2 = run_perftest( - cktile_moe_stage2, - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale_aiter, - a2_scale, - exp_bias2_aiter, - dtype, - topk, - npad0, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - # needTrace=True, - # num_iters=2, - # num_warmup=0, - ) - out2_ck = cktile_moe_stage2( - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale_aiter, - a2_scale, - exp_bias2_aiter, - dtype, - topk, - npad0, - kpad0, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - True, - ) - else: - out2_ck, us2 = run_perftest( - ck_moe_stage2, - a2_qt, - w1_qt_aiter, - w2_qt_aiter, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w2_scale, - a2_scale, - dtype, - topk, - BLOCK_SIZE_M, - actType, - quant_type, - sorted_weights if not doweight_stage1 else None, - ) + # out2_ck = torch.empty((token, model_dim), dtype=dtype) + # if ( + # qType == aiter.QuantType.per_1x32 + # and (AQDType in [dtypes.bf16, dtypes.fp16]) + # and (WQDType == dtypes.fp4x2) + # ): # a16w4 + # npad0 = hidden_pad // 64 * 64 + # kpad0 = intermediate_pad // 128 * 128 + # _, us2 = run_perftest( + # cktile_moe_stage2, + # a2_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w2_scale_aiter, + # a2_scale, + # exp_bias2_aiter, + # dtype, + # topk, + # npad0, + # kpad0, + # BLOCK_SIZE_M, + # actType, + # quant_type, + # sorted_weights if not doweight_stage1 else None, + # # needTrace=True, + # # num_iters=2, + # # num_warmup=0, + # ) + # out2_ck = cktile_moe_stage2( + # a2_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w2_scale_aiter, + # a2_scale, + # exp_bias2_aiter, + # dtype, + # topk, + # npad0, + # kpad0, + # BLOCK_SIZE_M, + # actType, + # quant_type, + # sorted_weights if not doweight_stage1 else None, + # True, + # ) + # else: + # out2_ck, us2 = run_perftest( + # ck_moe_stage2, + # a2_qt, + # w1_qt_aiter, + # w2_qt_aiter, + # sorted_ids, + # sorted_expert_ids, + # num_valid_ids, + # w2_scale, + # a2_scale, + # dtype, + # topk, + # BLOCK_SIZE_M, + # actType, + # quant_type, + # sorted_weights if not doweight_stage1 else None, + # ) - checkAllclose( - out2_ref, - out2_ck, - msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", - ) + # checkAllclose( + # out2_ref, + # out2_ck, + # msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + # ) # diff = torch.abs(out2_ref - out2_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) # print("max_diff", max_value.item(), ",ref=", out2_ref[multi_index].item(), ",ck=", out2_ck[multi_index].item()) # ######################## stage 2 end ########### - - out2_ck = fused_moe( + us1 = 0 + out2_ck, us2 = run_perftest( + fused_moe, input, w1_qt_aiter, w2_qt_aiter, @@ -650,15 +656,16 @@ def weight_per_128x128_quant(weight, quant_dtype): ) # return {"gemm1(us)": us1, "gemm2(us)": us2} - def calc_diff(x: torch.Tensor, y: torch.Tensor): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim + # def calc_diff(x: torch.Tensor, y: torch.Tensor): + # x, y = x.double(), y.double() + # denominator = (x * x + y * y).sum() + # sim = 2 * (x * y).sum() / denominator + # return 1 - sim - logits_diff = calc_diff(out2_ref, out2_ck) - assert logits_diff < 1e-3 + # logits_diff = calc_diff(out2_ref, out2_ck) + # assert logits_diff < 1e-3 + # return {"gemm1(us)": us1, "gemm2(us)": us2, "err": err} return {"us": us2, "err": err} @@ -667,30 +674,30 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): l_dim = [(7168, 256)] # l_dim = [(3072, 3072)] l_tokenNum = [ - # 1, - # 2, - # 4, + 1, + 2, + 4, 8, - # 16, - # 32, - # 64, - # 128, - # 256, - # 1024, - # 2048, - # 3072, - # 4096, - # 8192, - # 163840, + 16, + 32, + 64, + 128, + 256, + 1024, + 2048, + 3072, + 4096, + 8192, + 163840, ] l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu] l_quant = [ - # (aiter.QuantType.No, None, None), # a16w16 - # (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 - # (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 - # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 + (aiter.QuantType.No, None, None), # a16w16 + (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 + (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 + (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 - # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 + (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] l_doweight_stage1 = [False, True][:1] @@ -821,8 +828,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): dtypes.fp4x2, ): for hidden_pad, intermediate_pad in l_hidden_intermediate_pad: - print(f"hidden_pad={hidden_pad}, intermediate_pad={intermediate_pad}") - for m in l_tokenNum: + for m in l_tokenNum: ret = test_fmoe( dtype, m, From d883f0e4f0b03d6a5999f70e1df1a2975972a6ca Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 02:11:15 -0600 Subject: [PATCH 15/20] update --- aiter/fused_moe.py | 4 ++-- op_tests/test_moe_2stage.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 6acbb89173..818e150e8a 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -764,7 +764,6 @@ def fused_moe_2stages( E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) dtype = moe_out.dtype device = hidden_states.device - metadata = get_2stage_cfgs( get_padded_M(token_num), # consider token_num > 1024 as prefill model_dim, @@ -852,6 +851,7 @@ def fused_moe_2stages( quant_type == QuantType.per_1x32 and dtype in [dtypes.bf16, dtypes.fp16] and w1.dtype == dtypes.fp4x2 + and activation == ActivationType.Swiglu ): a2_scale = None elif quant_type == QuantType.per_1x32: @@ -890,7 +890,7 @@ def fused_moe_2stages( num_rows_factor=topk, ) a2 = a2.view(token_num, topk, inter_dim) - + metadata.stage2( a2, w1, diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index e3b6d5da8d..1ba1bb6216 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -27,7 +27,6 @@ shuffle_weight, shuffle_scale_a16w4, shuffle_weight_a16w4, - shuffle_weight_NK, ) from aiter import ActivationType @@ -245,7 +244,6 @@ def test_fmoe( if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32: return torch_quant = aiter.get_torch_quant(qType) - torch_act = aiter.get_torch_act(actType) input = torch.randn((token, model_dim), dtype=dtype) if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) @@ -264,9 +262,6 @@ def test_fmoe( exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) topk_weights, topk_ids = fused_topk(input, score, topk, True) - # sequence - # topk_ids_list = [[((i * topk) + j)% E for j in range(topk)] for i in range(token)] - # topk_ids = torch.tensor(topk_ids_list, device=topk_ids.device, dtype=topk_ids.dtype) M, _ = topk_ids.shape @@ -395,6 +390,9 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + else: + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) # # ######################## stage 1 start ########### out1_ref = torch_moe_stage1( @@ -648,6 +646,8 @@ def weight_per_128x128_quant(weight, quant_dtype): hidden_pad=hidden_pad, bias1=exp_bias1_aiter, bias2=exp_bias2_aiter, + num_iters=5, + num_warmup=2, ) err = checkAllclose( out2_ref, @@ -696,12 +696,12 @@ def weight_per_128x128_quant(weight, quant_dtype): (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 - # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 + (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] l_doweight_stage1 = [False, True][:1] -l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)] +l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)][1:2] parser = argparse.ArgumentParser( From f2019b49b43083f15315d9879f0ce621827d761c Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 03:35:15 -0600 Subject: [PATCH 16/20] update --- aiter/fused_moe.py | 1 + op_tests/test_moe_2stage.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 818e150e8a..cc892810ac 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -529,6 +529,7 @@ def get_cfg_2stages(tune_file): cfg_2stages = pd.read_csv(tune_file) cfg_2stages = cfg_2stages.set_index( [ + "cu_num", "token", "model_dim", "inter_dim", diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 1ba1bb6216..ca6ef05c59 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -381,10 +381,6 @@ def weight_per_128x128_quant(weight, quant_dtype): w1_scale_aiter = shuffle_scale_a16w4(w1_scale, E, True) w2_qt_aiter = shuffle_weight_a16w4(w2_qt_aiter, 16, False) w2_scale_aiter = shuffle_scale_a16w4(w2_scale, E, False) - # elif WQDType != dtypes.fp4x2 and (get_gfx() in ["gfx950"]): - # inst_K = 128 // w1_qt_aiter.element_size() - # w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) - # w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) elif WQDType != dtypes.fp4x2: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) From 3c1fe2d2d0ead4f4d482d0297641bdd802f2afc3 Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 03:47:28 -0600 Subject: [PATCH 17/20] format --- op_tests/test_moe_2stage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index ca6ef05c59..f9aa466089 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -688,7 +688,7 @@ def weight_per_128x128_quant(weight, quant_dtype): ] l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu] l_quant = [ - (aiter.QuantType.No, None, None), # a16w16 + (aiter.QuantType.No, None, None), # a16w16 (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 @@ -824,7 +824,7 @@ def weight_per_128x128_quant(weight, quant_dtype): dtypes.fp4x2, ): for hidden_pad, intermediate_pad in l_hidden_intermediate_pad: - for m in l_tokenNum: + for m in l_tokenNum: ret = test_fmoe( dtype, m, From 942896abf70885d531d02b5c91a79738f840a706 Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 03:49:30 -0600 Subject: [PATCH 18/20] format --- aiter/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index cc892810ac..2a291b6250 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -891,7 +891,7 @@ def fused_moe_2stages( num_rows_factor=topk, ) a2 = a2.view(token_num, topk, inter_dim) - + metadata.stage2( a2, w1, From 18f9a53217b6e8d1c7708b9585a2d6621d075b07 Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 07:18:00 -0600 Subject: [PATCH 19/20] remove useless --- op_tests/test_moe_2stage.py | 422 +----------------------------------- 1 file changed, 6 insertions(+), 416 deletions(-) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index f9aa466089..6adb89fa11 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -15,11 +15,9 @@ from aiter.fused_moe import ( fused_topk, - moe_sorting, fused_moe, torch_moe_stage1, torch_moe_stage2, - get_block_size_M, ) @@ -34,196 +32,6 @@ torch.set_default_device("cuda") -def ck_moe_stage1( - hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w1_scale, - a1_scale, - dtype, - topk, - block_size=32, - Activation=ActivationType.Gelu, - quant_type=aiter.QuantType.No, - sorted_weights=None, # [max_num_tokens_padded] -): - token_num = hidden_states.shape[0] - D = w2.shape[-1] - # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - - if w1.dtype is torch.uint32: - D = D * 8 - - out = torch.empty((token_num, topk, D), dtype=dtype) - - aiter.ck_moe_stage1_fwd( - hidden_states, - w1, - w2, - sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - out, - topk, - "", - w1_scale, - a1_scale, - block_size, - sorted_weights, - quant_type, - Activation, - ) - - return out - - -def ck_moe_stage2( - hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w2_scale, - a2_scale, - dtype, - topk, - block_size=32, - Activation=ActivationType.Gelu, - quant_type=aiter.QuantType.No, - sorted_weights=None, # [max_num_tokens_padded] -): - token_num = hidden_states.shape[0] - D = w2.shape[1] - # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - - out = torch.zeros( - (token_num, D), - dtype=dtype, - device=hidden_states.device, - ) - aiter.ck_moe_stage2_fwd( - hidden_states, - w1, - w2, - sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - out, - topk, - "", - w2_scale, - a2_scale, - block_size, - sorted_weights, - quant_type, - Activation, - ) - return out - - -def cktile_moe_stage1( - hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w1_scale, - a1_scale, - exp_bias1, - dtype, - topk, - n_pad_zeros=0, - k_pad_zeros=0, - block_size=32, - Activation=ActivationType.Silu, - quant_type=aiter.QuantType.No, - sorted_weights=None, # [max_num_tokens_padded] -): - token_num = hidden_states.shape[0] - _, n1, k1 = w1.shape - _, k2, n2 = w2.shape - D = n2 if k2 == k1 else n2 * 2 # bit4 format - # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - - if w1.dtype is torch.uint32: - D = D * 8 - out = torch.empty((token_num, topk, D), dtype=dtype) - # print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0])) - aiter.moe_cktile2stages_gemm1( - hidden_states, - w1, - out, - sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - topk, - n_pad_zeros, - k_pad_zeros, - sorted_weights, - a1_scale, - w1_scale, - exp_bias1, - block_size, - ) - return out - - -def cktile_moe_stage2( - hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w2_scale, - a2_scale, - exp_bias2, - dtype, - topk, - n_pad_zeros=0, - k_pad_zeros=0, - block_size=32, - Activation=ActivationType.Silu, - quant_type=aiter.QuantType.No, - sorted_weights=None, # [max_num_tokens_padded] - zeros_out=False, -): - token_num = hidden_states.shape[0] - D = w2.shape[1] - # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - - out = torch.empty( - (token_num, D), - dtype=dtype, - device=hidden_states.device, - ) - if zeros_out: - out.fill_(0) - # print("Run cktile_moe_stage2: M=%d, N=%d, K=%d, topk=%d, expert=%d"%(hidden_states.shape[0]*hidden_states.shape[1], w2.shape[1], hidden_states.shape[2], topk, w2.shape[0])) - aiter.moe_cktile2stages_gemm2( - hidden_states, - w2, - out, - sorted_token_ids, - sorted_expert_ids, - num_valid_ids, - topk, - n_pad_zeros, - k_pad_zeros, - sorted_weights, - a2_scale, - w2_scale, - exp_bias2, - block_size, - ) - return out - - @benchmark() def test_fmoe( dtype, @@ -263,21 +71,6 @@ def test_fmoe( score = torch.randn((token, E), dtype=dtype) topk_weights, topk_ids = fused_topk(input, score, topk, True) - M, _ = topk_ids.shape - - BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) - if ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 - BLOCK_SIZE_M = 32 if M > 1024 else 16 - if qType == aiter.QuantType.per_128x128: - BLOCK_SIZE_M = 64 if M > 64 else 16 - sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( - topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M - ) - if qType == aiter.QuantType.per_Tensor: w1_qt, w1_scale = aiter.pertoken_quant(w1.view(E, -1), quant_dtype=WQDType) w2_qt, w2_scale = aiter.pertoken_quant(w2.view(E, -1), quant_dtype=WQDType) @@ -406,103 +199,6 @@ def weight_per_128x128_quant(weight, quant_dtype): doweight=doweight_stage1, ) - # # ######################## ck stage 1 start ########### - # if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: - # out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) - # else: - # out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - # if ( - # qType == aiter.QuantType.per_1x32 - # and (AQDType in [dtypes.bf16, dtypes.fp16]) - # and (WQDType == dtypes.fp4x2) - # ): # a16w4: - # npad0 = intermediate_pad // 64 * 64 - # kpad0 = hidden_pad // 128 * 128 - # out1_ck, us1 = run_perftest( - # cktile_moe_stage1, - # a1_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w1_scale_aiter, - # a1_scale, - # exp_bias1_aiter, - # dtype, - # topk, - # npad0 * 2, - # kpad0, - # BLOCK_SIZE_M, - # actType, - # quant_type=qType, - # sorted_weights=sorted_weights if doweight_stage1 else None, - # # needTrace=True, - # # num_iters=2, - # # num_warmup=0, - # ) - # else: - # out1_ck, us1 = run_perftest( - # ck_moe_stage1, - # a1_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w1_scale, - # a1_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type=qType, - # sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # ) - - # checkAllclose( - # out1_ref, - # out1_ck, - # msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - - # diff = torch.abs(out1_ref - out1_ck) - # max_value= diff.max() - # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) - # print("max_diff", max_value.item(), ",ref=", out1_ref[multi_index].item(), ",ck=", out1_ck[multi_index].item()) - # ######################## stage 1 end ########### - - # if WQDType != torch.int4: - # # asm int4 2 stage not support yet - # if qType == aiter.QuantType.per_Tensor: - # a1_scale = a1_scale.view(1).repeat(token) - # w1_scale = w1_scale.view(E, 1).repeat(1, w1.shape[-2]) - - # out1_asm = torch.empty((token, topk, inter_dim), dtype=dtype) - # _, us = run_perftest( - # asm_stage1, - # a1_qt, - # shuffle_weight(w1_qt, (16, 16)), - # shuffle_weight(w2_qt, (16, 16)), - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # out1_asm, - # topk, - # kernelName="fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2", - # w1_scale=w1_scale, - # a1_scale=a1_scale, - # activation=actType, - # quant_type=qType, - # block_m=BLOCK_SIZE_M, - # ) - # checkAllclose( - # out1_ref, - # out1_asm, - # msg=f"[perf] asm_moe_stage1:{us:>8.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - # ######################## stage 2 start ########### if qType == aiter.QuantType.per_128x128: a2_qt, a2_scale = aiter.pertoken_quant( @@ -533,99 +229,8 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_bias=exp_bias2, doweight=not doweight_stage1, ) - # # out_ref = torch_moe( - # # input, - # # w1_qt, - # # w2_qt, - # # topk_weights, - # # topk_ids, - # # fc1_scale=w1_scale, - # # fc2_scale=w2_scale, - # # ) - # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") - - # out2_ck = torch.empty((token, model_dim), dtype=dtype) - # if ( - # qType == aiter.QuantType.per_1x32 - # and (AQDType in [dtypes.bf16, dtypes.fp16]) - # and (WQDType == dtypes.fp4x2) - # ): # a16w4 - # npad0 = hidden_pad // 64 * 64 - # kpad0 = intermediate_pad // 128 * 128 - # _, us2 = run_perftest( - # cktile_moe_stage2, - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale_aiter, - # a2_scale, - # exp_bias2_aiter, - # dtype, - # topk, - # npad0, - # kpad0, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # # needTrace=True, - # # num_iters=2, - # # num_warmup=0, - # ) - # out2_ck = cktile_moe_stage2( - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale_aiter, - # a2_scale, - # exp_bias2_aiter, - # dtype, - # topk, - # npad0, - # kpad0, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # True, - # ) - # else: - # out2_ck, us2 = run_perftest( - # ck_moe_stage2, - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale, - # a2_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # ) - - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) - - # diff = torch.abs(out2_ref - out2_ck) - # max_value= diff.max() - # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) - # print("max_diff", max_value.item(), ",ref=", out2_ref[multi_index].item(), ",ck=", out2_ck[multi_index].item()) + # ######################## stage 2 end ########### - us1 = 0 out2_ck, us2 = run_perftest( fused_moe, input, @@ -651,17 +256,6 @@ def weight_per_128x128_quant(weight, quant_dtype): msg=f"ck_moe_2stages:{us2:>8.2f} us, {token*model_dim*inter_dim*3*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", ) - # return {"gemm1(us)": us1, "gemm2(us)": us2} - # def calc_diff(x: torch.Tensor, y: torch.Tensor): - # x, y = x.double(), y.double() - # denominator = (x * x + y * y).sum() - # sim = 2 * (x * y).sum() / denominator - # return 1 - sim - - # logits_diff = calc_diff(out2_ref, out2_ck) - # assert logits_diff < 1e-3 - - # return {"gemm1(us)": us1, "gemm2(us)": us2, "err": err} return {"us": us2, "err": err} @@ -671,22 +265,17 @@ def weight_per_128x128_quant(weight, quant_dtype): # l_dim = [(3072, 3072)] l_tokenNum = [ 1, - 2, - 4, - 8, + 3, + 5, 16, 32, 64, 128, 256, 1024, - 2048, - 3072, 4096, - 8192, 163840, ] -l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu] l_quant = [ (aiter.QuantType.No, None, None), # a16w16 (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 @@ -696,6 +285,7 @@ def weight_per_128x128_quant(weight, quant_dtype): (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 ] +l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] l_doweight_stage1 = [False, True][:1] l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)][1:2] @@ -777,7 +367,7 @@ def weight_per_128x128_quant(weight, quant_dtype): "-e", "--expert", type=int, - default=256, + default=8, help="""Number of experts. e.g.: -e 8""", ) @@ -786,7 +376,7 @@ def weight_per_128x128_quant(weight, quant_dtype): "-k", "--topk", type=int, - default=8, + default=2, help="""Number of top experts. e.g.: -k 2""", ) From af608e90772cf5851f928257b245ef340a4a540c Mon Sep 17 00:00:00 2001 From: zhimding Date: Fri, 7 Nov 2025 13:50:09 +0000 Subject: [PATCH 20/20] fix sorting --- aiter/fused_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 2a291b6250..e7919043ce 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -44,8 +44,6 @@ def moe_sorting( device = topk_ids.device M, topk = topk_ids.shape max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk - if M * topk <= num_experts: - max_num_tokens_padded = M * topk * block_size max_num_m_blocks = int((max_num_tokens_padded + block_size - 1) // block_size) sorted_ids = torch.empty((max_num_tokens_padded,), dtype=dtypes.i32, device=device)