Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8014e68
moe mxfp4 block_m = 64/128
xudoyuan Oct 27, 2025
52c55aa
update a4w4_gemm2_kernels_list
xudoyuan Oct 27, 2025
88a72c2
add instance tile_m=32
lalala-sh Oct 27, 2025
049a8d2
tuned configuration
zhiding512 Oct 29, 2025
125d488
Update test_moe_2stage.py
lalala-sh Oct 29, 2025
8793a35
refactor
xudoyuan Oct 29, 2025
b65e405
update v1 pipeline
lalala-sh Nov 3, 2025
b82fe1f
update badcase
lalala-sh Nov 5, 2025
4b2594a
fix fp4 moe tuner
lalala-sh Nov 7, 2025
e485e7f
reformat
lalala-sh Nov 7, 2025
c07d2e4
Merge remote-tracking branch 'origin/main' into moe_mxfp4_ck_64_128
lalala-sh Nov 7, 2025
f0b7911
revert ck update
lalala-sh Nov 7, 2025
55e0b33
update ck
lalala-sh Nov 7, 2025
c1da914
Merge branch 'main' into moe_mxfp4_ck_64_128
lalala-sh Nov 7, 2025
685fd10
Moe mxfp4 ck preshf bns (#1312)
xudoyuan Nov 7, 2025
a62e3db
add AITER_MXFP4_MOE_SF switch for mxfp4 moe
lalala-sh Nov 7, 2025
0b712a2
v3 n128
zhiding512 Nov 7, 2025
9fba0c5
32x32 v1
zhiding512 Nov 7, 2025
f91773a
resolve ck conflict
xudoyuan Nov 7, 2025
9458aa8
Merge branch 'main' into moe_mxfp4_ck_64_128
xudoyuan Nov 7, 2025
f9c9097
rm use_int4=True
xudoyuan Nov 7, 2025
9fe0d84
reformatted op_tests/test_moe_2stage.py
xudoyuan Nov 7, 2025
e4fbdbe
AITER_MXFP4_MOE_SF bugfix
zhiding512 Nov 7, 2025
73128e7
Merge branch 'main' into moe_mxfp4_ck_64_128
xudoyuan Nov 7, 2025
9eea40a
revert torch.int4
xudoyuan Nov 7, 2025
27801bf
Merge branch 'main' into moe_mxfp4_ck_64_128
xudoyuan Nov 7, 2025
b3fe899
Merge branch 'main' into moe_mxfp4_ck_64_128
coderfeli Nov 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import functools
import os
import sys
from dataclasses import dataclass
from typing import Callable, Optional

Expand Down Expand Up @@ -619,6 +618,7 @@ def FinalFunc():
run_1stage = token > 32
elif q_type != QuantType.per_1x32:
run_1stage = token < 256

block_m = (
BLOCK_SIZE_M
if run_1stage
Expand Down
11 changes: 9 additions & 2 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ MoeKernel moe_dispatch(std::string &kernelName, int block_m, int inter_dim, at::
}
std::cout << "[aiter] ck kernel not found: " << kernelName << std::endl;
}

std::string moe_env_value = "0";
if (const char* env = std::getenv("AITER_MXFP4_MOE_SF")) {
moe_env_value = std::string(env);
}
bool use_mxfp4_moe_preshuffle = std::string(moe_env_value) == "1";

if constexpr (stage == 1)
{
return moe_stage1_heuristic_dispatch(block_m, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight);
return moe_stage1_heuristic_dispatch(block_m, x_dtype, w_dtype, y_dtype, act_op, quant_type, mul_routed_weight, use_mxfp4_moe_preshuffle);
}
else
{
return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight);
return moe_stage2_heuristic_dispatch(block_m, inter_dim, x_dtype, w_dtype, y_dtype, 0, quant_type, mul_routed_weight, use_mxfp4_moe_preshuffle);
}
}

Expand Down
34 changes: 30 additions & 4 deletions csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,17 @@ def name(self) -> str:

# gemm1 out:bf16/fp16 A:mxfp4 B:mxfp4
a4w4_gemm1_kernels_list= {
0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 3,),
1: kernelInstanceGEMM1( 256, 64, 128, 128, 1, 4, 3,),
2: kernelInstanceGEMM1( 256, 128, 128, 128, 1, 4, 3,),
# 3: kernelInstanceGEMM1( 256, 256, 128, 128, 2, 2, 3,),
}

# bns gemm1 out:bf16/fp16 A:mxfp4 B:mxfp4
a4w4_bns_gemm1_kernels_list= {
0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 3,),
1: kernelInstanceGEMM1( 256, 64, 64, 128, 2, 2, 3,),
2: kernelInstanceGEMM1( 256, 128, 64, 128, 2, 2, 3,),
# 3: kernelInstanceGEMM1( 256, 256, 128, 128, 2, 2, 3,),
}

gemm1_kernels_dict = {
Expand All @@ -205,6 +212,7 @@ def name(self) -> str:
"a8w8blkscale": a8w8_gemm1_blockscale_kernels_list,
"a8w4": a8w4_gemm1_kernels_list,
"a4w4": a4w4_gemm1_kernels_list,
"a4w4_bns": a4w4_bns_gemm1_kernels_list,
}


Expand Down Expand Up @@ -276,13 +284,22 @@ def name(self) -> str:
}
# gemm2 out:bf16/fp16 A:fp8 B:in4
a4w4_gemm2_kernels_list= {
0: kernelInstanceGEMM2( 256, 32, 128, 128, 1, 4, 3,),
1: kernelInstanceGEMM2( 256, 64, 128, 128, 1, 4, 3,),
2: kernelInstanceGEMM2( 256, 128, 128, 128, 1, 4, 3,),
4: kernelInstanceGEMM2( 64, 32, 32, 128, 1, 1, 1,),
5: kernelInstanceGEMM2( 64, 64, 128, 128, 1, 1, 3,),
6: kernelInstanceGEMM2( 64, 128, 128, 128, 1, 1, 3,),
# 7: kernelInstanceGEMM2( 256, 256, 64, 128, 2, 2, 3,),
}
# gemm2 out:bf16/fp16 A:fp8 B:in4
a4w4_bns_gemm2_kernels_list= {
0: kernelInstanceGEMM2( 64, 32, 32, 128, 1, 1, 1,),
1: kernelInstanceGEMM2( 64, 64, 64, 128, 1, 1, 1,),
2: kernelInstanceGEMM2( 64, 128, 128, 128, 1, 1, 1,),
4: kernelInstanceGEMM2( 256, 32, 128, 128, 1, 4, 3,),
5: kernelInstanceGEMM2( 256, 64, 64, 128, 2, 2, 3,),
6: kernelInstanceGEMM2( 256, 128, 64, 128, 2, 2, 3,),
# 7: kernelInstanceGEMM2( 256, 256, 64, 128, 2, 2, 3,),
}

# fmt: on
Expand All @@ -294,6 +311,7 @@ def name(self) -> str:
"a8w8blkscale": a8w8_gemm2_blockscale_kernels_list,
"a8w4": a8w4_gemm2_kernels_list,
"a4w4": a4w4_gemm2_kernels_list,
"a4w4_bns": a4w4_bns_gemm2_kernels_list,
}


Expand All @@ -312,6 +330,7 @@ def get_gemm1_kernels_list(
ActOP: str,
MulRoutedWeight: bool,
) -> list:
global bns_or_preslf
arch = get_gfx()
if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype:
if arch == "gfx950":
Expand All @@ -337,7 +356,10 @@ def get_gemm1_kernels_list(
):
tag = "a8w4"
elif Adtype in bit4_list and Bdtype in bit4_list:
tag = "a4w4"
if int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1:
tag = "a4w4"
else:
tag = "a4w4_bns"
else:
raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}")
kernels_list = gemm1_kernels_dict[tag]
Expand Down Expand Up @@ -372,6 +394,7 @@ def get_gemm2_kernels_list(
QuantType: str,
MulRoutedWeight: bool,
) -> list:
global bns_or_preslf
arch = get_gfx()

if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype:
Expand All @@ -398,7 +421,10 @@ def get_gemm2_kernels_list(
):
tag = "a8w4"
elif Adtype in bit4_list and Bdtype in bit4_list:
tag = "a4w4"
if int(os.getenv("AITER_MXFP4_MOE_SF", 0)) == 1:
tag = "a4w4"
else:
tag = "a4w4_bns"
else:
raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}")
kernels_list = gemm2_kernels_dict[tag]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp"
// #include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp"
#include "gemm_moe_ck2stages.h"
#include <iostream>

Expand Down Expand Up @@ -89,7 +90,7 @@ void ck_moe_stage1_gemm(const hipStream_t& stream,
static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec;
static constexpr ck::index_t D2Vec = 1;

using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Expand All @@ -104,8 +105,8 @@ void ck_moe_stage1_gemm(const hipStream_t& stream,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<K0_A, K0_M_A, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<K0_B, K0_N_B, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
S<K0_A, K0_M_A, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 1,
S<K0_B, K0_N_B, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 1,
2, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on
// clang-format on
Expand Down Expand Up @@ -278,7 +279,7 @@ void ck_moe_stage2_gemm(const hipStream_t& stream,
static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A;
static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B;

using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle
// clang-format off
///#####| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///#####| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Expand All @@ -293,8 +294,8 @@ void ck_moe_stage2_gemm(const hipStream_t& stream,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<K0_A, K0_M, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<K0_B, K0_N, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
S<K0_A, K0_M, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 1,
S<K0_B, K0_N, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 1,
2, CShuffleNXDLPerWave, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, 0, Nswizzle, false, MulRoutedWeight, ck::index_t, A0DataType>;

Expand Down
Loading