diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py index 2da800037e..c8863b74dc 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py @@ -197,6 +197,13 @@ def name(self) -> str: # 3: kernelInstanceGEMM1( 256, 256, 128, 128, 2, 2, 3,), } +# bns gemm1 out:bf16/fp16 A:mxfp4 B:mxfp4 +a4w4_bns_gemm1_kernels_list= { + 0: kernelInstanceGEMM1( 256, 32, 128, 128, 1, 4, 3,), + 1: kernelInstanceGEMM1( 256, 64, 64, 128, 2, 2, 3,), + 2: kernelInstanceGEMM1( 256, 128, 64, 128, 2, 2, 3,), +} + gemm1_kernels_dict = { "a16w16_gfx950": a16w16_gemm1_kernels_list_gfx950, "a16w16": a16w16_gemm1_kernels_list, @@ -205,6 +212,7 @@ def name(self) -> str: "a8w8blkscale": a8w8_gemm1_blockscale_kernels_list, "a8w4": a8w4_gemm1_kernels_list, "a4w4": a4w4_gemm1_kernels_list, + "a4w4_bns": a4w4_bns_gemm1_kernels_list, } @@ -284,6 +292,15 @@ def name(self) -> str: # 6: kernelInstanceGEMM2( 256, 128, 64, 128, 2, 2, 3,), # 7: kernelInstanceGEMM2( 256, 256, 64, 128, 2, 2, 3,), } +# gemm2 out:bf16/fp16 A:fp8 B:in4 +a4w4_bns_gemm2_kernels_list= { + 0: kernelInstanceGEMM2( 64, 32, 32, 128, 1, 1, 1,), + 1: kernelInstanceGEMM2( 64, 64, 64, 128, 1, 1, 1,), + 2: kernelInstanceGEMM2( 64, 128, 128, 128, 1, 1, 1,), + 4: kernelInstanceGEMM2( 256, 32, 128, 128, 1, 4, 3,), + 5: kernelInstanceGEMM2( 256, 64, 64, 128, 2, 2, 3,), + 6: kernelInstanceGEMM2( 256, 128, 64, 128, 2, 2, 3,), +} # fmt: on gemm2_kernels_dict = { @@ -294,6 +311,7 @@ def name(self) -> str: "a8w8blkscale": a8w8_gemm2_blockscale_kernels_list, "a8w4": a8w4_gemm2_kernels_list, "a4w4": a4w4_gemm2_kernels_list, + "a4w4_bns": a4w4_bns_gemm2_kernels_list, } @@ -302,6 +320,7 @@ def name(self) -> str: bit4_list = ["I4", "i4", "FP4X2", "fp4x2"] QuantType_list = [3, 4] +bns_or_preslf = True def get_gemm1_kernels_list( Adtype: str, @@ -312,6 +331,7 @@ def get_gemm1_kernels_list( ActOP: str, MulRoutedWeight: bool, ) -> list: + global bns_or_preslf arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: if arch == "gfx950": @@ -337,7 +357,10 @@ def get_gemm1_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - tag = "a4w4" + if bns_or_preslf: + tag = "a4w4_bns" + else: + tag = "a4w4" else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm1_kernels_dict[tag] @@ -372,6 +395,7 @@ def get_gemm2_kernels_list( QuantType: str, MulRoutedWeight: bool, ) -> list: + global bns_or_preslf arch = get_gfx() if Adtype in bit16_list and Bdtype in bit16_list and Adtype == Adtype: @@ -398,7 +422,10 @@ def get_gemm2_kernels_list( ): tag = "a8w4" elif Adtype in bit4_list and Bdtype in bit4_list: - tag = "a4w4" + if bns_or_preslf: + tag = "a4w4_bns" + else: + tag = "a4w4" else: raise ValueError(f"Unsupported data type combination: {Adtype}, {Bdtype}") kernels_list = gemm2_kernels_dict[tag] diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh new file mode 100644 index 0000000000..63d7b29d33 --- /dev/null +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4_bns.cuh @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "gemm_moe_ck2stages.h" +#include + +template +void ck_moe_stage1_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, + int topk, + void*& hidden_states, // [m, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, + void*& num_valid_ids, // [1] + void*& out, // [max_num_tokens_padded, inter_dim] + std::optional w1_scale, // [e, 1, n], gate(up) scale + std::optional a1_scale // [m, 1], token scale +) +{ + // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things + using A1DataType = E8M0; + using B1DataType = E8M0; + static constexpr ck::index_t ScaleBlockSize = 32; + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + ck::index_t KBatch = 1; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + // using AccDataType = F32; + using CShuffleDataType = F32; + using DsDataType = ck::Tuple; + + using A0Layout = Row; + using B0Layout = Col; + using D0Layout = Row; + using D1Layout = Col; + using ELayout = Row; + using D2Layout = ELayout; + using DsLayout = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); + static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); + // static constexpr ck::index_t NPerBlock = PipelineVer == ck::BlockGemmPipelineVersion::v1 ? 64 + // : 128; + static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = NPerBlock / 2 / NXDLPerWave; // 64 + static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; + static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); + static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; + + using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 32, BLOCKSIZE, + MPerBlock, NPerBlock, 128, + AK1, BK1, + MNPerXDL, MNPerXDL, + MXDLPerWave, NXDLPerWave, + S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + 2, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; // clang-format on + // clang-format on + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + static constexpr auto DStride = PerTensorQuant ? I0 : I1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + hidden_states, + a1_scale.value(), + w1, + w1_scale.value(), + std::array{ + nullptr, nullptr, MulRoutedWeight ? sorted_weights : nullptr}, + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + std::array{DStride, DStride, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + invoker.Run(argument, StreamConfig{stream}); +} + +#define CK_MOE_STAGE1_GEMM_DEFINE( \ + BLOCKSIZE, MPerfBlock, NPerBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage1_gemm(const hipStream_t& stream, \ + int tokens, \ + int sorted_size, \ + int N, \ + int K, \ + int topk, \ + void*& hidden_states, \ + void*& w1, \ + void*& w2, \ + void*& sorted_token_ids, \ + void*& sorted_expert_ids, \ + void*& sorted_weights, \ + void*& num_valid_ids, \ + void*& out, \ + std::optional w1_scale, \ + std::optional a1_scale); + +template +void ck_moe_stage2_gemm(const hipStream_t& stream, + int tokens, + int sorted_size, + int N, + int K, + int topk, + void*& inter_states, // [max_num_tokens_padded, k], input token + void*& w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + void*& w2, // [expert, dim, inter_dim], pre-shuffle([e, nr, kr, w]) + void*& sorted_token_ids, // [max_num_tokens_padded] + void*& sorted_expert_ids, // [max_num_m_blocks] + void*& sorted_weights, // [max_num_tokens_padded] + void*& num_valid_ids, //[1] + void*& out, // [m, out_dim] + std::optional w2_scale, // [e, 1, n], gate(up) scale + std::optional a2_scale // [max_num_tokens_padded, 1], token scale +) +{ + // ~~~~~~~~~~~~~~~~~~~~~~~~following start with ck things + using A1DataType = E8M0; + using B1DataType = E8M0; + static constexpr ck::index_t ScaleBlockSize = 32; + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + ck::index_t KBatch = 1; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + + // using AccDataType = F32; + using CShuffleDataType = F32; + using DsDataType = ck::Tuple; + + using A0Layout = Row; + using B0Layout = Col; + using ELayout = Row; + using D0Layout = Row; + using D1Layout = Col; + using DsLayout = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + + static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + // static constexpr ck::index_t BLOCKSIZE = 256; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + static constexpr ck::index_t MNPerXDL = 16; + static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); + static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); + static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; + static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; + static constexpr ck::index_t CShuffleNLane = + BLOCKSIZE == 64 ? NPerBlock / 2 : NPerBlock / 2 / NXDLPerWave; // 64 + static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; + static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); + static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); + static constexpr ck::index_t EVec = 2; + static constexpr ck::index_t D0Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; + static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + + using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS + // clang-format off +///#####| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///#####| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///#####| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///#####| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///##### RCR + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 32, BLOCKSIZE, + MPerBlock, NPerBlock, 128, + AK1, BK1, + MNPerXDL, MNPerXDL, + MXDLPerWave, NXDLPerWave, + S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + 2, CShuffleNXDLPerWave, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, 0, Nswizzle, false, MulRoutedWeight, ck::index_t, A0DataType>; + + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + static constexpr auto DStride = PerTensorQuant ? I0 : I1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + inter_states, + a2_scale.value(), + w2, + w2_scale.value(), + std::array{nullptr, + nullptr, + MulRoutedWeight ? sorted_weights : nullptr}, + out, + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + std::array{DStride, DStride, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if (!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + invoker.Run(argument, StreamConfig{stream}); +} + +#define CK_MOE_STAGE2_GEMM_DEFINE(BLOCKSIZE, MPerfBlock, NPerfBlock, KPerBlock, MWaves, NWaves, PipelineVer) \ + template void ck_moe_stage2_gemm( \ + const hipStream_t &stream, \ + int tokens, int sorted_size, int N, int K, \ + int topk, \ + void *&inter_states, \ + void *&w1, \ + void *&w2, \ + void *&sorted_token_ids, \ + void *&sorted_expert_ids, \ + void *&sorted_weights, \ + void *&num_valid_ids, \ + void *&out, \ + std::optional w2_scale, \ + std::optional a2_scale); \ No newline at end of file diff --git a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py index 00cd70eb1e..a891797761 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py +++ b/csrc/ck_gemm_moe_2stages_codegen/gen_instances.py @@ -203,6 +203,38 @@ """ +A4W4_bns_gemm1_heuristic_dispatch = """ +#if defined(__Float4_e2m1fn_x2) + if (dtype_checker<{A0DataType}>{{}}(x_dtype) + && dtype_checker<{B0DataType}>{{}}(w_dtype) + && dtype_checker<{EDataType}>{{}}(y_dtype) + && {ActOP} == act_op + && {MulRoutedWeight} == mul_routed_weight_stage + && {Quant} == quant) + {{ + if (block_m == 32) + {{ + return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 32, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 64) + {{ + return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 64, 128/sizeof({A0DataType}), 2, 2, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 128) + {{ + return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 128, 64, 128/sizeof({A0DataType}), 2, 2, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe heuristic dispatch: ", + block_m); + }} + }} +#endif + +""" A8W8_blockscale_gemm1_heuristic_dispatch = """ if (dtype_checker<{A0DataType}>{{}}(x_dtype) @@ -357,6 +389,61 @@ #endif """ +A4W4_bns_gemm2_heuristic_dispatch = """ +#if defined(__Float4_e2m1fn_x2) + if (dtype_checker<{A0DataType}>{{}}(x_dtype) + && dtype_checker<{B0DataType}>{{}}(w_dtype) + && dtype_checker<{EDataType}>{{}}(y_dtype) + && {MulRoutedWeight} == mul_routed_weight_stage + && {Quant} == quant) + {{ + if (inter_dim <= 256) + {{ + if (block_m == 32) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 64, 32, 32, 128/sizeof({A0DataType}), 1, 1, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 64) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 64, 64, 64, 128/sizeof({A0DataType}), 1, 1, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 128) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V1, 64, 128, 128, 128/sizeof({A0DataType}), 1, 1, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe heuristic dispatch: ", + block_m); + }} + }} + else + {{ + if (block_m == 32) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 32, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 64) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 64, 64, 128/sizeof({A0DataType}), 2, 2, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else if (block_m == 128) + {{ + return ck_moe_stage2_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256, 128, 64, 128/sizeof({A0DataType}), 2, 2, {Nswizzle}, {Quant} == static_cast(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>; + }} + else + {{ + TORCH_CHECK( + false, + "Unsupported block_m value for moe heuristic dispatch: ", + block_m); + }} + }} + }} +#endif +""" A8W8_blockscale_gemm2_heuristic_dispatch = """ @@ -410,6 +497,10 @@ A4W4_gemm1_heuristic_dispatch, A4W4_gemm2_heuristic_dispatch, ], + "a4w4_bns": [ + A4W4_bns_gemm1_heuristic_dispatch, + A4W4_bns_gemm2_heuristic_dispatch, + ], } @@ -498,6 +589,9 @@ def generate_instance_and_lookUpTable(self): ) f_lookUpTable = os.path.join(self.working_path, "gemm_moe_ck2stages_lookup.h") + print("tag") + print(tag) + # breakpoint() with open(f_lookUpTable, "a") as f_lookup: for kernel in kernel_list: ## generate instance @@ -510,7 +604,10 @@ def generate_instance_and_lookUpTable(self): if self.quant_type in [4, 5]: quanttype = "_blockscale" elif "FP4" in self.a_dtype: - quanttype = "_mxfp4" + if "bns" in tag: + quanttype = "_mxfp4_bns" + else: + quanttype = "_mxfp4" else: quanttype = "" if not os.path.exists(f_instance): diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 573b12907e..ec318d408c 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -119,6 +119,7 @@ def ck_moe_stage2( ) return out +bns_or_preslf = True @benchmark() def test_fmoe( @@ -135,6 +136,7 @@ def test_fmoe( use_g1u1=False, doweight_stage1=False, ): + global bns_or_preslf if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32: return torch_quant = aiter.get_torch_quant(qType) @@ -248,7 +250,7 @@ def weight_per_128x128_quant(weight, quant_dtype): shuffle_weight(w2_qt_aiter, (16, 16), use_int4=True) ) ) - else: + elif WQDType != dtypes.fp4x2 or not bns_or_preslf: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) # # ######################## ck stage 1 start ###########