diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..51cbe64bd4a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "csrc/third_party/catlass"] + path = csrc/third_party/catlass + url = https://gitcode.com/cann/catlass.git + branch = catlass-v1-stable diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index f789635977e..b2c4d68fc49 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -3,6 +3,8 @@ ROOT_DIR=$1 SOC_VERSION=$2 +git config --global --add safe.directory "$ROOT_DIR" + if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then # ASCEND310P series # currently, no custom aclnn ops for ASCEND310 series @@ -11,11 +13,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then exit 0 elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then # ASCEND910B (A2) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine" SOC_ARG="ascend910_93" else # others @@ -23,6 +25,30 @@ else exit 0 fi +git submodule init +git submodule update + + +# For the compatibility of CANN8.5 and CANN8.3: copy and modify moe_distribute_base.h +file_path=$(find /usr/local/Ascend/ascend-toolkit -name "moe_distribute_base.h" 2>/dev/null | head -n1) +if [ -z "$file_path" ]; then + echo "cannot find moe_distribute_base.h file in CANN env" + exit 1 +fi + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/" +TARGET_FILE="$TARGET_DIR/$(basename "$file_path")" + +echo "*************************************" +echo $file_path +echo "$TARGET_DIR" +cp "$file_path" "$TARGET_DIR" + +sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE" +sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE" + + # build custom ops cd csrc rm -rf build output diff --git a/csrc/cmake/func.cmake b/csrc/cmake/func.cmake index f2bebf75639..e8ce57564fc 100644 --- a/csrc/cmake/func.cmake +++ b/csrc/cmake/func.cmake @@ -282,7 +282,7 @@ function(add_ops_src_copy) set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done) add_custom_command(OUTPUT ${_BUILD_FLAG} COMMAND mkdir -p ${SRC_COPY_DST} - COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST} + COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/* ${SRC_COPY_DST} COMMAND touch ${_BUILD_FLAG} ) diff --git a/csrc/dispatch_ffn_combine/op_host/CMakeLists.txt b/csrc/dispatch_ffn_combine/op_host/CMakeLists.txt new file mode 100644 index 00000000000..9101b185393 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/CMakeLists.txt @@ -0,0 +1,66 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +set(_DISPATCH_FFN_INC_OPTS) +if (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include) +elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include) +elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include) +endif() +if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include) + list(APPEND _DISPATCH_FFN_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include) +endif() + +add_ops_compile_options( + OP_NAME DispatchFFNCombine + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror + -DHCCL_COMM + ${_DISPATCH_FFN_INC_OPTS} +) + +target_sources(op_host_aclnnInner PRIVATE + dispatch_ffn_combine_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_dispatch_ffn_combine.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + aclnn_dispatch_ffn_combine.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + aclnn_dispatch_ffn_combine.cpp + ) +endif () + +target_sources(optiling PRIVATE + dispatch_ffn_combine_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../op_kernel +) + +target_sources(opsproto PRIVATE + dispatch_ffn_combine_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_ffn_combine.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp new file mode 100644 index 00000000000..0206fa5cfdb --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp @@ -0,0 +1,84 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "aclnn_dispatch_ffn_combine.h" +#include +// #include "aclnn_kernels/common/op_error_check.h" +// #include "opdev/op_log.h" +// #include "opdev/common_types.h" +// #include "opdev/platform.h" +// #include "ophost/matmul_util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../op_host/error_log.h" +// using namespace op; + +// using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +static constexpr size_t TWO_DIMS = 2; +static constexpr int64_t KVALUE_MIN = 256; +static constexpr int64_t KVALUE_MAX = 65535; +static constexpr size_t HCCL_GROUP_NAME_MAX = 128U; +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; + +extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2, + const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2, + const aclTensor* probs, + const char* group, int64_t maxOutputSize, + bool transB, bool weightNz, + const aclTensor* out, + uint64_t* workspaceSize, aclOpExecutor** executor); +extern aclnnStatus aclnnInnerDispatchFFNCombine(void *workspace, uint64_t workspaceSize, + aclOpExecutor *executor, aclrtStream stream); +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + + + +aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2, + const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2, + const aclTensor* probs, + const char* group, int64_t maxOutputSize, + const aclTensor* out, + uint64_t* workspaceSize, aclOpExecutor** executor) +{ + bool transB = false; + bool weightNz = true; + + aclnnStatus ret = aclnnInnerDispatchFFNCombineGetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group, + maxOutputSize, transB, weightNz, + out, workspaceSize, executor); + return ret; +} + +aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + aclnnStatus ret = aclnnInnerDispatchFFNCombine(workspace, workspaceSize, executor, stream); + return ret; +} +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h new file mode 100644 index 00000000000..b0063d66905 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h @@ -0,0 +1,61 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_DISPATCH_FFN_COMBINE_ +#define OP_API_INC_DISPATCH_FFN_COMBINE_ + +#include + +#include "aclnn/aclnn_base.h" +#include "hccl/hccl.h" +#include "hccl/hccl_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * 算子功能:实现分布式MoE从InitRouting到Unpermute全部算子的融合 + * @brief aclnnDispatchFFNCombine的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * @param [in] a: matmul左矩阵,数据类型支持:float16, bf16。 + * @param [in] b: matmul右矩阵,数据类型支持:float16, bf16。 + * @param [in] bias: 偏置,数据类型支持:float16, bf16。 + * @param [in] group: 标识通信域名称的字符串。 + * @param [in] worldsize: 通信域size,支持2/4/8卡。 + * @param [in] epRankId: ep本卡Id。取值范围[0, worldSize),各卡的rankId不能重复 + * @param [out] c: 计算+通信的结果,数据类型:同输入。 + * @param [out] workspaceSize: 返回需要在npu device侧申请的workspace大小。 + * @param [out] executor: 返回op执行器,包含了算子计算流程。 + * @return aclnnStatus: 返回状态码 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2, + const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2, + const aclTensor* probs, + const char* group, int64_t maxOutputSize, + const aclTensor* out, + uint64_t* workspaceSize, aclOpExecutor** executor); + +/** + * @brief aclnnDispatchGmmCombine的第二段接口,用于执行计算。 + * @param [in] workspace: 在npu device侧申请的workspace内存起址。 + * @param [in] workspace_size: 在npu device侧申请的workspace大小,由第一段接口aclnnDispatchFFNCombineGetWorkspaceSize获取。 + * @param [in] exector: op执行器,包含了算子计算流程。 + * @param [in] stream: acl stream流。 + * @return aclnnStatus: 返回状态码 + */ +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif // OP_API_INC_GMM_ALLTOALLV_ \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp new file mode 100644 index 00000000000..d487c453a97 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_combine_def.cpp + * \brief + */ +#include "register/op_def_registry.h" + +namespace ops { +class DispatchFFNCombine : public OpDef { + public: + explicit DispatchFFNCombine(const char *name) : OpDef(name) { + this->Input("a") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("w1") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) + .IgnoreContiguous(); + this->Input("w2") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) + .IgnoreContiguous(); + this->Input("expertIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("scale1") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("scale2") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("probs") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + // 输出 + this->Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND}); + + this->Attr("group").AttrType(REQUIRED).String(); + this->Attr("M").AttrType(OPTIONAL).Int(); + this->Attr("transB").AttrType(OPTIONAL).Bool(false); + this->Attr("weightNz").AttrType(OPTIONAL).Bool(false); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->MC2().HcclGroup("group"); + } +}; + +OP_ADD(DispatchFFNCombine); +} // namespace ops diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_proto.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_proto.cpp new file mode 100644 index 00000000000..a55d4441dd0 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_proto.cpp @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_proto.cpp + * \brief + */ +#include +#include +// #include "../../common/ophost/op_util.h" +// #include "../../common/ophost/hcom_topo_info.h" +// #include "log/ops_log.h" + +using namespace ge; +namespace ops { +const size_t ATTR_GROUP = 0; +const size_t ATTR_RANK_SIZE = 1; +const size_t SUPPORT_DIM_SIZE = 2; + +static ge::graphStatus InferShapeDispatchFFNCombine(gert::InferShapeContext* context) { + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeDispatchFFNCombine(gert::InferDataTypeContext* context) { + // auto d_type = context->GetInputDataType(0); + // context->SetOutputDataType(0, d_type); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(DispatchFFNCombine) + .InferShape(InferShapeDispatchFFNCombine) + .InferDataType(InferDataTypeDispatchFFNCombine); +} // namespace ops diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp new file mode 100644 index 00000000000..a7f5f7ed601 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp @@ -0,0 +1,265 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +/*! + * \file dispatch_ffn_tiling.cpp + * \brief + */ +#include "vector" +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "error_log.h" +#include "hcom_topo_info.h" +#include "register/op_def_registry.h" +#include "dispatch_ffn_combine_tiling.h" +#include +#include +#include +#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" + +using namespace AscendC; +using namespace ge; + +namespace { + // 1. 常量定义 + const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug"; + constexpr uint32_t ATTR_GROUP_INDEX = 0; + constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1; + constexpr uint32_t ATTR_IS_TRANS_B = 2; + constexpr uint32_t ATTR_WEIGHT_NZ = 3; + constexpr uint64_t INIT_TILINGKEY = 1000000; + constexpr uint64_t TILINGKEY_TRANS_B = 1U; + constexpr uint64_t TILINGKEY_WEIGHT_NZ = 10; + constexpr uint32_t X_INDEX = 0; + constexpr uint32_t WEIGHT_INDEX = 1; + constexpr uint32_t WEIGHT2_INDEX = 2; + constexpr uint32_t EXPERTID_INDEX = 3; + constexpr uint32_t BLOCK_NUM = 20; + constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + +static int32_t CeilDev(int32_t num, int32_t div) +{ + if (div == 0) { + return 0; + } + return (num + div - 1) / div; +} + +// 解析并校验 rankId, group, worldSize, isTransB 属性值 +static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info) +{ + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED); + + // todo:Attr相关tilingdata的设置、校验、打印 + auto groupPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_INDEX)); + auto maxOutputSizePtr = attrs->GetAttrPointer(ATTR_MAX_OUTPUT_SIZE_INDEX); + auto is_trans_b = attrs->GetAttrPointer(ATTR_IS_TRANS_B); + auto weight_nz = attrs->GetAttrPointer(ATTR_WEIGHT_NZ); + OP_TILING_CHECK(groupPtr == nullptr || strlen(groupPtr) == 0, + OP_LOGE(K_INNER_DEBUG, "group is invalid."), return GRAPH_FAILED); + + OP_TILING_CHECK(is_trans_b == nullptr, + OP_LOGE(K_INNER_DEBUG, "is_trans_b is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(weight_nz == nullptr, + OP_LOGE(K_INNER_DEBUG, "weight_nz is invalid."), return GRAPH_FAILED); + + info.maxOutputSize = *maxOutputSizePtr; + info.isTransposeB = *is_trans_b; + info.isWeightNz = *weight_nz; + + int64_t rankSize; + (void)ge::HcomTopoInfo::Instance().GetGroupRankSize(groupPtr, rankSize); + info.worldSize = rankSize; + + OP_LOGD(K_INNER_DEBUG, "maxOutputSize=%d ", info.maxOutputSize); + OP_LOGD(K_INNER_DEBUG, "rankSize=%d ", info.worldSize); + + return ge::GRAPH_SUCCESS; +} + +// 提取输入张量 A 和 B 的形状,计算出 M、K、N 值 +static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info) +{ + const char *nodeName = context->GetNodeName(); + // OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling."); + + const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX); + const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX); + const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX); + uint32_t M = aStorageShape->GetStorageShape().GetDim(0); + uint32_t K = aStorageShape->GetStorageShape().GetDim(1); + uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0); + uint32_t N = bStorageShape->GetStorageShape().GetDim(2); + uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1); + + info.M = M; + info.N = N; + info.K = K; + info.expertPerRank = expertPerRank; + info.topK = topK; + OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M); + OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K); + OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N); + OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank); + OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK); + + return ge::GRAPH_SUCCESS; +} + +// 获取当前芯片平台的 AI Core 数目、UB 容量等硬件信息。 +static ge::graphStatus DispatchFFNCombineGetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0U; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + info.aivNum = aivNum; + info.totalUbSize = ubSize; + + OP_LOGD(K_INNER_DEBUG, "aivNum=%d", info.aivNum); + OP_LOGD(K_INNER_DEBUG, "ubSize=%lu", info.totalUbSize); + + return ge::GRAPH_SUCCESS; +} + +void SetTilingData(CoCTiling &cocTilingData, DispatchFFNCombineInfo &info) +{ + cocTilingData.m0 = 128; + cocTilingData.k0 = 256; + cocTilingData.n0 = 256; + cocTilingData.swizzleDirect = 1; + cocTilingData.swizzleOffset = 7; + cocTilingData.ubMoveNum = 16 * 1024; + cocTilingData.pValue = 1; + cocTilingData.commNpuSplit = info.worldSize; + cocTilingData.commDataSplit = 1; + cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2; +} + +// 主调度函数: +// 获取 tilingData ➝ 检查 Attr ➝ 检查 Shape ➝ 获取平台信息 +// ➝ 调用 SetTilingData(根据rank数目) ➝ 设置 blockDim ➝ 设置 tilingKey ➝ 设置 workspace ➝ 配置通信参数 + +static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + OP_LOGI(nodeName, "Enter DispatchFFNCombine tiling func."); + + // 1. tilingData + DispatchFFNCombineTilingData *tilingData = context->GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), + return ge::GRAPH_FAILED); + OP_LOGI(nodeName, "DispatchFFNCombine get tilingData."); + DispatchFFNCombineInfo& info = tilingData->dispatchFFNCombineInfo; + OP_LOGI(nodeName, "DispatchFFNCombine get tilingData info."); + + OP_TILING_CHECK(DispatchFFNCombineCheckAttrAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OP_LOGE(context->GetNodeName(), "DispatchFFNCombine CheckAttrAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(DispatchFFNCombineCheckShapeAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OP_LOGE(context->GetNodeName(), "DispatchFFNCombine CheckShapeAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(DispatchFFNCombineGetPlatformInfoAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OP_LOGE(context->GetNodeName(), "DispatchFFNCombine GetPlatformInfoAndSetTiling Failed"), + return ge::GRAPH_FAILED); + + SetTilingData(tilingData->cocTiling, info); + + // 2. set blockDim + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + auto aicNum = ascendcPlatform.GetCoreNumAic(); + auto aivNum = ascendcPlatform.GetCoreNumAiv(); + blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, aicNum, aivNum); + context->SetBlockDim(blockDim); + + // 3. set tilingKey + uint64_t tilingKey = INIT_TILINGKEY; + tilingKey += info.isTransposeB ? TILINGKEY_TRANS_B : 0; + tilingKey += info.isWeightNz ? TILINGKEY_WEIGHT_NZ : 0; + context->SetTilingKey(tilingKey); + + OP_LOGD(K_INNER_DEBUG, "tilingKey=%d", tilingKey); + + optiling::MoeInitRoutingQuantV2TilingBase moeInitRoutingQuantV2TilingBase; + int64_t inuptXDtypeSize = sizeof(int16_t); + int64_t scaleDim0 = 0; + int64_t ubSize = 196352; + int64_t expertCapacity = 0; + int64_t expertNum = info.expertPerRank * info.worldSize; + int64_t activeNum = 0; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 2; + bool expertTokensBeforeCapacityFlag = false; + int64_t quantMode = 1; + uint32_t aivNumInitRouting = 2 * BLOCK_NUM; + moeInitRoutingQuantV2TilingBase.DoTiling(info.M, info.K, info.topK, expertCapacity, expertNum, activeNum, dropPadMode, + expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0, aivNumInitRouting, ubSize); + uint64_t initRoutingQuantTilingKey = moeInitRoutingQuantV2TilingBase.tilingKey_; + size_t initRoutingWorkspace = moeInitRoutingQuantV2TilingBase.workspaceSize_; + + tilingData->cocTiling.moeInitRoutingQuantV2TilingData = moeInitRoutingQuantV2TilingBase.quantTilingData; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vbsComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.vbsComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.vmsMiddleComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.vmsMiddleComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.sortOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.sortOutComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.srcToDstComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.srcToDstCapacityComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.srcToDstCapacityComputeParamsOp; + tilingData->cocTiling.moeInitRoutingQuantV2TilingData.gatherOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.gatherOutComputeParamsOp; + tilingData->cocTiling.initRoutingQuantTilingKey = initRoutingQuantTilingKey; + + // 4. workspace + size_t *workSpaces = context->GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), + return ge::GRAPH_FAILED); + + uint32_t n2 = info.K; + uint32_t k2 = info.N / 2; + + uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) + + info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 + + info.maxOutputSize * sizeof(float) * 2 + + std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) + + std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t)); + + workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace); + + + // 5. communication + auto attrs = context->GetAttrs(); + auto group = attrs->GetAttrPointer(static_cast(ATTR_GROUP_INDEX)); + uint32_t opType = 8U; + std::string algConfig = "AlltoAll=level0:fullmesh;level1:pairwise"; + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig); + mc2CcTilingConfig.GetTiling(tilingData->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tilingData->mc2CcTiling); + + OP_LOGI(nodeName, "Leave DispatchFFNCombine tiling func."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchFFNCombineTilingFunc(gert::TilingContext* context) +{ + return DispatchFFNCombineTilingFuncImpl(context); +} + +struct DispatchFFNCombineCompileInfo {}; +ge::graphStatus TilingParseForDispatchFFNCombine(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(DispatchFFNCombine) + .Tiling(DispatchFFNCombineTilingFunc) + .TilingParse(TilingParseForDispatchFFNCombine); +} // namespace optiling \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_host/error_log.h b/csrc/dispatch_ffn_combine/op_host/error_log.h new file mode 100644 index 00000000000..4ef02cd4379 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/error_log.h @@ -0,0 +1,47 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +#define OP_TILING_CHECK(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/dispatch_ffn_combine/op_host/hcom_topo_info.h b/csrc/dispatch_ffn_combine/op_host/hcom_topo_info.h new file mode 100644 index 00000000000..7bc4b835bb1 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/hcom_topo_info.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ +#define METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ + +#include +#include + +using Status = int32_t; + +namespace ge { +static constexpr uint32_t COMM_MESH = 0b1U; +static constexpr uint32_t COMM_SWITCH = (COMM_MESH << 1U); +static constexpr uint32_t COMM_RING = (COMM_MESH << 2U); +static constexpr uint32_t COMM_PAIRWISE = (COMM_MESH << 3U); +class HcomTopoInfo { + public: + enum class TopoLevel { + L0 = 0, + L1, + MAX, + }; + struct TopoLevelDesc { + uint32_t comm_sets; + uint32_t rank_size; + }; + using TopoDescs = TopoLevelDesc[static_cast(TopoLevel::MAX)]; + struct TopoInfo { + int64_t rank_size; + void *notify_handle; + TopoDescs topo_level_descs; + }; + static HcomTopoInfo &Instance(); + bool TopoInfoHasBeenSet(const char_t *group); + bool TryGetGroupTopoInfo(const char_t *group, TopoInfo &info); + Status SetGroupTopoInfo(const char_t *group, const TopoInfo &info); + Status GetGroupRankSize(const char_t *group, int64_t &rank_size); + TopoDescs *GetGroupTopoDesc(const char_t *group); + Status GetGroupNotifyHandle(const char_t *group, void *¬ify_handle); + void UnsetGroupTopoInfo(const char_t *group) { + const std::lock_guard lock(mutex_); + (void) rank_info_.erase(group); + } + + Status SetGroupOrderedStream(const char_t *group, void *stream); + Status GetGroupOrderedStream(const char_t *group, void *&stream); + void UnsetGroupOrderedStream(const char_t *group) { + const std::lock_guard lock(mutex_); + (void) group_to_ordered_stream_.erase(group); + }; + + Status SetGroupOrderedStream(const int32_t device_id, const char_t *group, void *stream); + Status GetGroupOrderedStream(const int32_t device_id, const char_t *group, void *&stream); + void UnsetGroupOrderedStream(const int32_t device_id, const char_t *group); + private: + HcomTopoInfo() = default; + ~HcomTopoInfo() = default; + std::unordered_map rank_info_; + std::mutex mutex_; + std::unordered_map group_to_ordered_stream_; // 通信域保序流 + std::unordered_map> device_id_to_group_to_ordered_stream_; // 通信域保序流 +}; +} + +#endif // METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ diff --git a/csrc/dispatch_ffn_combine/op_host/tiling_args.h b/csrc/dispatch_ffn_combine/op_host/tiling_args.h new file mode 100644 index 00000000000..950cbe9047a --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_host/tiling_args.h @@ -0,0 +1,9 @@ +#ifndef TILING_ARGS_H +#define TILING_ARGS_H +#include + +namespace Moe { +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL; +constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL; +} // namespace Moe +#endif // TILING_ARGS_H diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp new file mode 100644 index 00000000000..db3cf771fd0 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file dispatch_ffn_combine.cpp + * \brief + */ +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "dispatch_ffn_combine_tiling.h" +#include "dispatch_ffn_combine.h" + +using namespace AscendC; +using namespace DispatchFFNCombineImpl; +extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1, GM_ADDR w2, GM_ADDR expertId, GM_ADDR scale1, GM_ADDR scale2, GM_ADDR probs, + GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM) +{ + REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); + if (TILING_KEY_IS(1000000)) { + KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); + DispatchFFNCombine op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(1000001)) { + KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); + DispatchFFNCombine op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(1000010)) { + KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); + DispatchFFNCombine op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } else if (TILING_KEY_IS(1000011)) { + KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2); + GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); + DispatchFFNCombine op; + op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); + op.Process(); + } +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h new file mode 100644 index 00000000000..eb19ede9fca --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -0,0 +1,276 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_combine.h + * \brief + */ + +#ifndef DISPATCH_FFN_COMBINE_H +#define DISPATCH_FFN_COMBINE_H + +using namespace AscendC; + +#include "kernel_operator.h" + +#include "utils/moe_distribute_base.h" + +#include "dispatch_ffn_combine_tiling.h" + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/block/block_epilogue.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/epilogue/tile/tile_elemwise_add.hpp" +#include "catlass/epilogue/tile/tile_elemwise_muls.hpp" +#include "catlass/gemm/block/block_mmad.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/kernel/matmul_epilogue.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/layout/layout.hpp" + +#include "utils/select_helper.hpp" +#include "utils/const_args.hpp" +#include "dispatch_ffn_combine_kernel.hpp" +#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" + +using namespace Catlass; + +namespace DispatchFFNCombineImpl { +#define TemplateMMA2AClass typename AType_, typename BType_, typename CType_, bool TB_, bool Nz_ +#define TemplateMMA2ACFunc AType_, BType_, CType_, TB_, Nz_ + +using namespace AscendC; +template +class DispatchFFNCombine { +public: + __aicore__ inline DispatchFFNCombine() {}; + __aicore__ inline void Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, + GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM); + __aicore__ inline void Process(); + + +private: + GM_ADDR xGM_; + GM_ADDR weight1GM_; + GM_ADDR weight2GM_; + GM_ADDR expertIdGM_; + GM_ADDR scale1GM_; + GM_ADDR scale2GM_; + GM_ADDR probs_; + GM_ADDR outGM_; + GM_ADDR workspaceGM_; + + GM_ADDR moeInitRoutingQuantV2Scale = nullptr; + GM_ADDR moeInitRoutingQuantV2Offset = nullptr; + GM_ADDR expertTokensBeforeCapacity = nullptr; + + + TBuf uBuf_; + + int32_t rank; + int32_t rankSize; + int32_t aivNum; + + int32_t m0; + int32_t k0; + int32_t n0; + int32_t swizzlOffset; + int32_t swizzlDirect; + int32_t ubMoveNum; + int32_t pValue; + + int32_t commNpuSplit; + int32_t commDataSplit; + int32_t lenPerLoop; + + int32_t m; + int32_t k; + int32_t n; + int32_t topK; + int32_t expertPerRank; + int32_t maxOutputSize; + int32_t EP; + + optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData; + uint64_t initRoutingQuantTilingKey; + + // Hccl hccl_; + +}; + + +template +__aicore__ inline void DispatchFFNCombine::Init(GM_ADDR xGM, GM_ADDR weight1GM, GM_ADDR weight2GM, GM_ADDR expertIdGM, GM_ADDR scale1GM, GM_ADDR scale2GM, + GM_ADDR probs, GM_ADDR outGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) +{ + REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); + auto tiling = (__gm__ DispatchFFNCombineTilingData*)tilingGM; + GET_TILING_DATA(tilingData, tilingGM); + + xGM_ = xGM; + weight1GM_ = weight1GM; + weight2GM_ = weight2GM; + expertIdGM_ = expertIdGM; + scale1GM_ = scale1GM; + scale2GM_ = scale2GM; + probs_ = probs; + + outGM_ = outGM; + + workspaceGM_ = workspaceGM; + + aivNum = tilingData.dispatchFFNCombineInfo.aivNum; + + m = tilingData.dispatchFFNCombineInfo.M; + k = tilingData.dispatchFFNCombineInfo.K; + n = tilingData.dispatchFFNCombineInfo.N; + EP = tilingData.dispatchFFNCombineInfo.worldSize; + topK = tilingData.dispatchFFNCombineInfo.topK; + expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank; + maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize; + + m0 = tilingData.cocTiling.m0; + k0 = tilingData.cocTiling.k0; + n0 = tilingData.cocTiling.n0; + swizzlDirect = tilingData.cocTiling.swizzleDirect; + swizzlOffset = tilingData.cocTiling.swizzleOffset; + ubMoveNum = tilingData.cocTiling.ubMoveNum; + pValue = tilingData.cocTiling.pValue; + commNpuSplit = tilingData.cocTiling.commNpuSplit; + commDataSplit = tilingData.cocTiling.commDataSplit; + lenPerLoop = tilingData.cocTiling.lenPerLoop; + moeInitRoutingQuantV2TilingData = tilingData.cocTiling.moeInitRoutingQuantV2TilingData; + initRoutingQuantTilingKey = tilingData.cocTiling.initRoutingQuantTilingKey; + + auto contextGM0 = AscendC::GetHcclContext(); + __gm__ HcclOpResParamCustom *WinContext_{nullptr}; + WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; + + rank = WinContext_->localUsrRankId; + rankSize = WinContext_->rankSize; +} + +template +__aicore__ inline void DispatchFFNCombine::Process() +{ + // Define ArchTag + using ArchTag = Arch::AtlasA2; + constexpr bool enableUnitFlag = false; + constexpr bool enableShuffleK = true; + + uint32_t k2 = n/2; + uint32_t n2 = k; + + int64_t activeNum = 0; + int64_t expertCapacity = 0; + int64_t expertNum = expertPerRank * EP; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 2; + bool expertTokensBeforeCapacityFlag = false; + int64_t quantMode = 1; + + using LayoutA = layout::RowMajor; + using LayoutB = typename std::conditional< + Nz_, + layout::zN, + typename std::conditional::type + >::type; + + LayoutB layoutB1 = LayoutBInitializer::create(k, n); + LayoutB layoutB2 = LayoutBInitializer::create(k2, n2); + using LayoutC = layout::RowMajor; + using L1TileShape = GemmShape<128, 256, 512>; // M, N, K + + constexpr uint32_t workspaceStages = 2; + constexpr uint32_t preloadStages = 1; + constexpr uint32_t l1Stages = 2; + constexpr uint32_t l0AStages = 2; + constexpr uint32_t l0BStages = 2; + constexpr uint32_t l0CStages = 1; + + using DispatchPolicy = Gemm::MmadAtlasA2PreloadAsyncFixpipe< + preloadStages, + l1Stages, l0AStages, l0BStages, l0CStages, + enableUnitFlag, enableShuffleK + >; + + using L0TileShape = GemmShape<128, 256, 128>; + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + using D1Type = Gemm::GemmType; + + using D2Type = typename std::conditional< + std::is_same_v, + Gemm::GemmType, + Gemm::GemmType + >::type; + + using BlockMmad = Gemm::Block::BlockMmad; + constexpr uint32_t ubStages = 2; + + using EpilogueDispatchPolicy1 = Epilogue::EpilogueAtlasA2PerTokenDequantSwigluQuant; + + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using ElementMulType = Gemm::GemmType; + using TileElemWiseMuls = Epilogue::Tile::TileElemWiseMuls; + + using TileCopy1 = Epilogue::Tile::TileCopy; + using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; + + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant; + using TileCopy2 = Epilogue::Tile::TileCopy; + using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle<9, 1>; + using ElementGroupList = int64_t; + using MatmulKernel = Gemm::Kernel::DispatchFFNCombineKernel; + + LayoutA layoutA1{static_cast(m), static_cast(k)}; + LayoutA layoutA2{static_cast(m), static_cast(k2)}; + layout::VectorLayout layoutScale1{static_cast(n)}; + layout::VectorLayout layoutScale2{static_cast(n2)}; + layout::RowMajor layoutD1{static_cast(maxOutputSize), static_cast(k2)}; + layout::RowMajor layoutD2{static_cast(m*topK), static_cast(n2)}; + // Prepare params + + GemmCoord problemShape{static_cast(m), static_cast(n), static_cast(k)}; + + uint32_t epilogueCoreNum = aivNum / 2; + uint32_t epilogueGranularity = expertPerRank - 1; + + typename MatmulKernel::Params params{ + problemShape, static_cast(EP), static_cast(expertPerRank), static_cast(maxOutputSize), + static_cast(rank), static_cast(rankSize), + static_cast(topK), initRoutingQuantTilingKey, + epilogueCoreNum, epilogueGranularity, + xGM_, layoutA1, layoutA2, + weight1GM_, layoutB1, + weight2GM_, layoutB2, + scale1GM_, layoutScale1, + scale2GM_, layoutScale2, + outGM_, layoutD1, layoutD2, + expertIdGM_, moeInitRoutingQuantV2Scale, moeInitRoutingQuantV2Offset, + expertTokensBeforeCapacity, probs_, + workspaceGM_, ubMoveNum, moeInitRoutingQuantV2TilingData}; + //Call kernel + MatmulKernel kernel(params); + kernel(params); +} + +} // DispatchFFNCombineImpl +#endif // DISPATCH_FFN_COMBINE_H diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp new file mode 100644 index 00000000000..311e26088e2 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -0,0 +1,814 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DISPATH_FFN_COMBINE_KERNEL_HPP +#define DISPATH_FFN_COMBINE_KERNEL_HPP + +#include "kernel_operator.h" + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +#include "utils/block_mmad_preload_async_fixpipe_quant.hpp" +#include "utils/copy_gm_to_l1_custom.hpp" +#include "utils/copy_l0c_to_gm_custom.hpp" +#include "utils/block_epilogue_pertoken_row.hpp" +#include "utils/block_epilogue_pertoken_swiglu.hpp" +#include "utils/hccl_shmem.hpp" +#include "utils/const_args.hpp" +#include "utils/layout3d.hpp" + +#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" +#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" +#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" +#include "unpermute/moe_token_unpermute.h" + + +using namespace AscendC; + +namespace Catlass::Gemm::Kernel { + +template < + class BlockMmad_, + class BlockScheduler_, + class ElementGroupList_, + class BlockEpilogue1_, + class BlockEpilogue2_ +> +class DispatchFFNCombineKernel { +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + using ElementScale = uint64_t; + using LayoutScale = typename layout::VectorLayout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = typename layout::VectorLayout; + using BlockScheduler = BlockScheduler_; + using BlockEpilogue1 = BlockEpilogue1_; + using BlockEpilogue2 = BlockEpilogue2_; + using ElementD1 = typename BlockEpilogue1::ElementD; + using LayoutD1 = typename BlockEpilogue1::LayoutD; + using ElementD2 = typename BlockEpilogue2::ElementD; + using LayoutD2 = typename BlockEpilogue2::LayoutD; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + __gm__ ElementA *ptrA; + LayoutA layoutA; + LayoutA layoutA2; + __gm__ ElementB *ptrB1; + LayoutB layoutB1; + __gm__ ElementB *ptrB2; + LayoutB layoutB2; + __gm__ ElementScale *ptrScale1; + LayoutScale layoutScale1; + __gm__ ElementScale *ptrScale2; + LayoutScale layoutScale2; + __gm__ ElementD2 *ptrOutput; + LayoutD1 layoutD1; + LayoutD2 layoutD2; + GM_ADDR ptrWorkspace; + int32_t EP; + int32_t expertPerRank; + uint32_t maxOutputSize; + uint32_t rank; + uint32_t rankSize; + int32_t ubMoveNum; + //-------------- + GM_ADDR expertIdx; + GM_ADDR moeInitRoutingQuantV2Scale; + GM_ADDR moeInitRoutingQuantV2Offset; + GM_ADDR expandedX; + GM_ADDR expandedRowIdx; + GM_ADDR expertTokensCountOrCumsum; + GM_ADDR expertTokensBeforeCapacity; + GM_ADDR dynamicQuantScale; + GM_ADDR probs; + int64_t topK; + uint64_t initRoutingQuantTilingKey; + uint32_t epilogueCoreNum; + uint32_t epilogueGranularity; + optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData; + //-------------- + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params( + GemmCoord problemShape_, + uint32_t EP_, uint32_t expertPerRank_, uint32_t maxOutputSize_, + uint32_t rank_, uint32_t rankSize_, int64_t topK_, + uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_, + GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_, + GM_ADDR ptrB1_, LayoutB layoutB1_, + GM_ADDR ptrB2_, LayoutB layoutB2_, + GM_ADDR ptrScale1_, LayoutScale layoutScale1_, + GM_ADDR ptrScale2_, LayoutScale layoutScale2_, + GM_ADDR ptrOutput_, LayoutD2 layoutD1_, LayoutD2 layoutD2_, + GM_ADDR expertIdx_, GM_ADDR moeInitRoutingQuantV2Scale_, + GM_ADDR moeInitRoutingQuantV2Offset_, + GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_, + GM_ADDR ptrWorkspace_, int32_t ubMoveNum_, + optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_ + ) : problemShape(problemShape_), + EP(EP_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_), + rank(rank_), rankSize(rankSize_), topK(topK_), + initRoutingQuantTilingKey(initRoutingQuantTilingKey_), + epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_), + ptrB1(reinterpret_cast<__gm__ ElementB *>(ptrB1_)), layoutB1(layoutB1_), + ptrB2(reinterpret_cast<__gm__ ElementB *>(ptrB2_)), layoutB2(layoutB2_), + ptrScale1(reinterpret_cast<__gm__ ElementScale *>(ptrScale1_)), layoutScale1(layoutScale1_), + ptrScale2(reinterpret_cast<__gm__ ElementScale *>(ptrScale2_)), layoutScale2(layoutScale2_), + ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_), + expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_), + moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_), + expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_), + ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_), + moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_) + { + } + }; + + // Methods + CATLASS_DEVICE + DispatchFFNCombineKernel(Params const ¶ms) + { + if ASCEND_IS_AIC { + coreIdx = AscendC::GetBlockIdx(); + coreNum = AscendC::GetBlockNum(); + } + + if ASCEND_IS_AIV { + coreIdx = get_block_idx() + get_subblockid() * get_block_num(); + coreNum = get_block_num() * get_subblockdim(); + } + + initBuffer(params); + } + + CATLASS_DEVICE + ~DispatchFFNCombineKernel() + { + } + + template + CATLASS_DEVICE + void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE + void operator()(Params const ¶ms) + { + GMM1(params); + + AscendC::CrossCoreWaitFlag<0x2>(2); + + GMM2(params); + } + + + template <> + CATLASS_DEVICE + void operator()(Params const ¶ms) + { + Dispatch(params); + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + + Combine(params); + } + +private: + CATLASS_DEVICE void initBuffer(Params const ¶ms) { + workspaceInfo = WorkspaceInfo(params); + peermemInfo = PeermemInfo(params, shmem); + + cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM)); + + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA)); + gmS.SetGlobalBuffer(params.ptrScale1); + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC)); + + gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken)); + gmS2.SetGlobalBuffer(params.ptrScale2); + gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2)); + + gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale)); + gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2)); + + tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); + + tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank + 8, params.expertPerRank); + } + + template + CATLASS_DEVICE void CopyGMToGM( + AscendC::GlobalTensor dst, + AscendC::GlobalTensor src, + int32_t elemNum, + int32_t ubMoveNum + ) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + + using TType = Gemm::GemmType; + using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; + using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; + CopyGmToUb copyGmToUb; + CopyUbToGm copyUbToGm; + constexpr int32_t BufferNum = 2; + int tmpBufferSize = 32 * 1024 / sizeof(T); // 32 KB + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + tmpBuffer1.SetSize(tmpBufferSize); + int tmpBufferOffset = 96 * 1024; // half of UB + AscendC::LocalTensor tmpBuffer2 = resource.ubBuf.template GetBufferByByte(tmpBufferOffset); + tmpBuffer2.SetSize(tmpBufferSize); + + // [ReduceScatter] 2. Pre Interface Sync + int pingpongId = 0; + auto processCount = CeilDiv(elemNum, ubMoveNum); + for (uint32_t processIndex = 0; processIndex < processCount; ++processIndex) { + uint32_t curProcessNum = (processIndex == processCount - 1) ? elemNum - ubMoveNum * (processCount - 1) : ubMoveNum; + AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2; + auto processOffset = processIndex * ubMoveNum; + + auto inputOffset = processOffset; + auto outputOffset = processOffset; + // [ReduceScatter] 2. Pre Interface Sync + AscendC::WaitFlag(EVENT_ID); + // [ReduceScatter] 3. Start shmem_mte_get_mem_nbi + copyGmToUb(buf, src[inputOffset], layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum}); + AscendC::SetFlag(EVENT_ID); + AscendC::WaitFlag(EVENT_ID); + copyUbToGm(dst[outputOffset], buf, layout::RowMajor{ 1, curProcessNum}, layout::RowMajor{1, curProcessNum}); + + // [ReduceScatter] 4. Post Interface Sync + AscendC::SetFlag(EVENT_ID); + pingpongId = (pingpongId + 1) % BufferNum; + } + // [ReduceScatter] 4. Post Interface Sync + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } + + CATLASS_DEVICE + void GetCumsumForMMAIV(AscendC::GlobalTensor & tokenPerExpert, AscendC::GlobalTensor & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP) + { + int32_t expertPerRankAligned = (expertPerRank + 8 - 1) / 8 * 8; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + AscendC::LocalTensor tmpResult = resource.ubBuf.template GetBufferByByte(EP * expertPerRank * sizeof(int32_t)); + #define U16(x) static_cast(x) + + AscendC::DataCopyPad( + tmpBuffer1, + tokenPerExpert[rankId * expertPerRank], + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank + 8) * sizeof(int32_t)), 0}, + {} + ); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + for (uint32_t i = 1; i < EP; ++i) { + AscendC::Add(tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[i * expertPerRankAligned], tmpBuffer1[(i - 1) * expertPerRankAligned], expertPerRank); + AscendC::PipeBarrier(); + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::DataCopyPad( + result, + tmpBuffer1, + {U16(EP), U16((expertPerRank) * sizeof(int32_t)), 0, 0} + ); + } + + CATLASS_DEVICE + void GMM1(Params const ¶ms){ + icache_preload(8); + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + uint32_t startCoreIdx = 0; + uint32_t syncGroupIdx = 0; + AscendC::CrossCoreWaitFlag<0x2>(0); // 等待aiv计算cumsumformm + int64_t preCurrentmSum = 0; + int32_t syncLoopIdx = -1; + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preCurrentmSum >= params.maxOutputSize) { + currentM = 0; + } else if (preCurrentmSum + currentM >= params.maxOutputSize) { + currentM = params.maxOutputSize - preCurrentmSum; + } + AscendC::GlobalTensor gmB1; + gmB1.SetGlobalBuffer(params.ptrB1); + if (currentM <= L1TileShape::M) { + gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB1 = params.layoutB1; + LayoutScale layoutScale = params.layoutScale1; + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + for(;syncGroupIdx <= groupIdx; syncGroupIdx++) { + AscendC::CrossCoreWaitFlag<0x2>(0); + } + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB1.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // 每个expert一组scale + if (currentM > 0) { + blockMmad( + gmA[gmGroupOffsetA + gmOffsetA], layoutA, + gmB1[gmGroupOffsetB + gmOffsetB], layoutB1, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, + gmS[gmOffsetS], layoutScale, + actualBlockShape + ); + } + } + + if ((groupIdx + 1) == params.epilogueGranularity && (groupIdx < params.expertPerRank - 1)) { + syncLoopIdx ++; + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + blockMmad.Finalize(syncLoopIdx, 1); + } + + preCurrentmSum += currentM; + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + blockMmad.Finalize(syncLoopIdx + 1, 1); + } + + CATLASS_DEVICE + void GMM2(Params const ¶ms) { + icache_preload(8); + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + uint32_t n2 = params.problemShape.k(); + uint32_t k2 = params.problemShape.n() / 2; + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + + uint32_t startCoreIdx = 0; + + AscendC::PipeBarrier(); + + int64_t preCurrentmSum = 0; + int32_t syncLoopIdx = -1; + uint32_t lastDequantExpertNum = params.expertPerRank; + + if (params.epilogueGranularity < params.expertPerRank) { + lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; + } + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preCurrentmSum >= params.maxOutputSize) { + currentM = 0; + } else if (preCurrentmSum + currentM > params.maxOutputSize) { + currentM = params.maxOutputSize - preCurrentmSum; + } + AscendC::GlobalTensor gmB2; + gmB2.SetGlobalBuffer(params.ptrB2); + if (currentM <= L1TileShape::M) { + gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + GemmCoord inGroupProblemShape{currentM, n2, k2}; // M N K + + LayoutA layoutA = params.layoutA2.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB2 = params.layoutB2; + LayoutScale layoutScale = params.layoutScale2; + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { + AscendC::CrossCoreWaitFlag<0x2>(2); + } + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + if (loopIdx + coreNum >= coreLoops) { + syncLoopIdx = groupIdx; + } + + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB2.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // 每个expert一组scale + if (currentM > 0) { + blockMmad( + gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, + gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, + gmC2[gmGroupOffsetC + gmOffsetC], layoutC, + gmS2[gmOffsetS], layoutScale, + actualBlockShape, syncLoopIdx, 3 + ); + } + } + preCurrentmSum += currentM; + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + blockMmad.Finalize(params.expertPerRank - 1, 3); + } + + CATLASS_DEVICE + void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){ + uint64_t flag_offset = (shmem.SegmentSize() - MB_SIZE) / sizeof(int32_t); + __gm__ int32_t* sync_base = shmem.SyncBaseAddr(); + int count = gm_load(sync_base) + 1; + if (coreIdx < params.EP && coreIdx != params.rank) { + AscendC::GlobalTensor srcAddress; + srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); + AscendC::GlobalTensor dstAddress; + __gm__ void* dstPeermemPtr = shmem(localTokenPerExpertOffset, coreIdx); + dstAddress.SetGlobalBuffer((__gm__ int32_t * )dstPeermemPtr); + + AscendC::SetFlag(EVENT_ID0); + using TType = Gemm::GemmType; + using CopyGmToUb = Epilogue::Tile::CopyGm2Ub; + using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; + CopyGmToUb copyGmToUb; + CopyUbToGm copyUbToGm; + AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); + AscendC::WaitFlag(EVENT_ID0); + uint32_t tmp = params.EP * params.expertPerRank; + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, tmp}, + layout::RowMajor{1, tmp}); + + tmpBuffer.SetValue(params.EP * params.expertPerRank, count); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, tmp + 1}, + layout::RowMajor{1, tmp + 1}); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(coreIdx, params.EP, 0); + gm_signal_wait_until_eq_for_barrier(sync_check, count); + } + AscendC::SyncAll(); + gm_store(sync_base, count); + } + + + CATLASS_DEVICE + void Dispatch(Params const ¶ms) { + icache_preload(8); + int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t); + GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // 把通信矩阵全部放到peermem + uint32_t expandedRowIdxOffset = AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t); + + //---initRouting------ + moe_init_routing_quant_v2(reinterpret_cast (params.ptrA), params.expertIdx, + params.moeInitRoutingQuantV2Scale, params.moeInitRoutingQuantV2Offset, shmem() + peermemInfo.offsetA, + workspaceInfo.expandedRowIdx, localTokenPerExpert, params.expertTokensBeforeCapacity, + shmem() + peermemInfo.offsetPeerPerTokenScale, + params.ptrWorkspace + expandedRowIdxOffset, + ¶ms.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey); + + AscendC::SyncAll(); + CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset); + if (coreIdx == 0) { + GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); + } + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); + + uint32_t curGroupOffset = 0; + int32_t prevSumBeforeRank = 0; + int32_t groupIdxDeq = 0; + if (coreIdx < params.EP) { + for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) { + prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i)); + } + m_prevSumBeforeRank = prevSumBeforeRank; + } + int prevSum = prevSumBeforeRank; + uint32_t prevGroupSum1 = 0; + uint32_t dequantSum = 0; + int32_t syncLoopIdx = -1; + BlockEpilogue1 blockEpilogue(resource); + for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + // 第i个core从第i个rank的peermem读数据 + groupIdxDeq = groupIdx - 2; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; + if (rowStart < params.maxOutputSize) { + uint32_t rows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + if (rowStart + rows > params.maxOutputSize) { + rows = params.maxOutputSize - rowStart; + } + uint32_t rowSrc = prevSum; + prevSum += rows; + GM_ADDR otherRankPtr = shmem(0, dstEpIdx); + AscendC::GlobalTensor gmRemoteA; + gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); + AscendC::GlobalTensor gmRemotePerTokenScale; + gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale)); + MatrixCoord offsetA{rowStart, 0}; + MatrixCoord shapeA{rows, params.problemShape.k()}; + MatrixCoord offsetPeer{rowSrc, 0}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); + // 通信Data + CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); + // 通信scale + CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows); + } + } + + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { + syncLoopIdx++; + AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1); + } + AscendC::SyncAll(); + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V通知C当前轮的通信已完成 + + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { + uint32_t rowStartThisCore = 0; + MatrixCoord offsetC{0U, 0}; + uint32_t dequantLen = prevGroupSum1 - dequantSum; + if (dequantLen >= params.maxOutputSize) { + dequantLen = dequantLen - params.maxOutputSize; + } + + MatrixCoord shapeC{dequantLen, params.problemShape.n()}; + LayoutC layoutC{dequantLen, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); + } + prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) { + dequantSum = 0; + } + } + syncLoopIdx ++; + AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx /8 + 1); + AscendC::SyncAll(); + + uint32_t lastDequantExpertNum = params.expertPerRank; + if (params.epilogueGranularity < params.expertPerRank) { + lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; + } + if (lastDequantExpertNum < params.expertPerRank) { + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(2); + } + if (prevGroupSum1 - dequantSum < params.maxOutputSize) { + uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; + MatrixCoord offsetC{rowStartThisCore, 0}; + uint32_t dequantLen = dequantSum; + if (prevGroupSum1 >= params.maxOutputSize) { + dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize); + } + MatrixCoord shapeC{dequantLen, params.problemShape.n()}; + LayoutC layoutC{dequantLen, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + } + blockEpilogue.Finalize(); + } + + CATLASS_DEVICE + void Combine(Params const ¶ms) { + int32_t prevSumBeforeRank = 0; + if (coreIdx < params.EP) { + prevSumBeforeRank = m_prevSumBeforeRank; + } + + int prevSum = prevSumBeforeRank; + uint32_t n2 = params.problemShape.k(); + uint32_t k2 = params.problemShape.n() / 2; + + // TODO 计算tokenperexpert的cumsum + typename BlockEpilogue2::Params epilogueParams{ + static_cast(params.EP), + static_cast(params.expertPerRank), + reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace) + }; + BlockEpilogue2 blockEpilogue(resource, epilogueParams); + int32_t prevGroupSum2 = 0; + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { + AscendC::CrossCoreWaitFlag<0x2>(groupIdx / 8 + 3); + AscendC::SyncAll(); + + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + __gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); + AscendC::GlobalTensor gmRemotePeer; + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); + uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; + if (srcRowOffset < params.maxOutputSize) { + uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + if (srcRowOffset + dataRows > params.maxOutputSize) { + dataRows = params.maxOutputSize - srcRowOffset; + } + uint32_t dstRowOffset = prevSum; + prevSum += dataRows; + MatrixCoord offsetC{srcRowOffset, 0}; + MatrixCoord offsetPeer{dstRowOffset, 0}; + MatrixCoord shapeC{dataRows, n2}; + int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC); + int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer); + if constexpr (std::is_same_v) { + blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]); + } else { + blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); + } + } + } + prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + } + blockEpilogue.Finalize(); + AscendC::SyncAll(); + shmem.CrossRankSync(); + MoeTokenUnpermuteTilingData tilingData; + MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); + KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; + + kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); + kernelMoeTokenUnpermuteOp.Process(); + } + +private: + struct WorkspaceInfo { + GM_ADDR ptrA; + GM_ADDR ptrPerTokenScale; + GM_ADDR ptrcumsumMM; + GM_ADDR ptrC; + GM_ADDR ptrC2; + GM_ADDR ptrPermutedToken; + GM_ADDR ptrPerTokenScale2; + GM_ADDR expandedRowIdx; + GM_ADDR ptrTokenPerExpert; + + CATLASS_DEVICE + WorkspaceInfo(){} + + CATLASS_DEVICE + WorkspaceInfo(const Params & params) { + uint32_t k2 = params.problemShape.n() / 2; + uint32_t n2 = params.problemShape.k(); + int64_t workspaceOffset = 0; + expandedRowIdx = params.ptrWorkspace; + + workspaceOffset += AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t); + ptrcumsumMM = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); + + workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); + ptrPerTokenScale = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale); + ptrPerTokenScale2 = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.maxOutputSize * sizeof(ElementPerTokenScale); + ptrTokenPerExpert = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); + ptrC = params.ptrWorkspace + workspaceOffset; + ptrC2 = ptrC; + + workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC), + params.maxOutputSize * n2 * sizeof(ElementC)); + + ptrA = params.ptrWorkspace + workspaceOffset; + ptrPermutedToken = ptrA; + workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA), + params.maxOutputSize * k2 * sizeof(ElementA)); + } + }; + + struct PeermemInfo { + int64_t offsetA; + int64_t offsetPeerPerTokenScale; + int64_t offsetPeerTokenPerExpert; + int64_t offsetD; + + CATLASS_DEVICE + PeermemInfo(){} + + CATLASS_DEVICE + PeermemInfo(const Params & params, const HcclShmem & shmem) { + offsetA = 0; // 占用1/3的BUFFSIZE + offsetPeerPerTokenScale = offsetA + AlignUp(shmem.SegmentSize() / 3, 512); // 占用1MB + offsetD = offsetPeerPerTokenScale + MB_SIZE; // 占用剩下的 + offsetPeerTokenPerExpert = shmem.SegmentSize() - 2 * MB_SIZE; // 占用最后2MB + } + }; + + Arch::Resource resource; + + uint32_t coreIdx; + uint32_t coreNum; + + Params params; + WorkspaceInfo workspaceInfo; + PeermemInfo peermemInfo; + + int64_t m_prevSumBeforeRank; + + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmC; + AscendC::GlobalTensor gmS; + + AscendC::GlobalTensor gmPermutedToken; + AscendC::GlobalTensor gmS2; + AscendC::GlobalTensor gmC2; + + AscendC::GlobalTensor gmPerTokenScale1; + AscendC::GlobalTensor gmPerTokenScale2; + + AscendC::GlobalTensor tokenPerExpert; + AscendC::GlobalTensor cumsumMM; + Layout3D tokenPerExpertLayout; + HcclShmem shmem; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // DISPATH_FFN_COMBINE_KERNEL_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h new file mode 100644 index 00000000000..de891e9f026 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dispatch_ffn_combine_tiling.h + * \brief + */ + +#include "moe_init_routing_quant_v2/moe_init_routing_v2_tiling.h" +#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" + +#ifndef ASCENDC_DISPATCH_FFN_COMBINE_TILING_H +#define ASCENDC_DISPATCH_FFN_COMBINE_TILING_H +struct DispatchFFNCombineInfo { + uint32_t M; + uint32_t K; + uint32_t N; + uint32_t expertPerRank; + uint32_t maxOutputSize; + uint32_t isTransposeB; + uint32_t isWeightNz; + uint32_t aivNum; + uint32_t totalUbSize; + uint32_t topK; + uint32_t worldSize; +}; + +struct CoCTiling { + int32_t m0 = -1; + int32_t k0 = -1; + int32_t n0 = -1; + int32_t swizzleDirect = -1; + int32_t swizzleOffset = -1; + int32_t ubMoveNum = -1; + int32_t pValue = -1; + int32_t commNpuSplit = -1; + int32_t commDataSplit = -1; + int32_t lenPerLoop = -1; + uint64_t initRoutingQuantTilingKey; + optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData; +}; + +struct DispatchFFNCombineTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + DispatchFFNCombineInfo dispatchFFNCombineInfo; + CoCTiling cocTiling; +}; +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp new file mode 100644 index 00000000000..8453c810d13 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp @@ -0,0 +1,134 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_init_routing_quant_v2.cpp + * \brief + */ +#include "moe_v2_sort_one_core.h" +#include "moe_v2_sort_multi_core.h" +#include "moe_v2_mrgsort_out.h" +#include "moe_v2_mrgsort.h" +#include "moe_v2_expert_token_out.h" +#include "moe_v2_src_to_dst_op.h" +#include "moe_v2_src_to_dst_with_capacity.h" +#include "moe_v2_fullload_quant.h" +#include "moe_v2_fullload_dynamic_quant.h" +#include "moe_v2_gather_quant.h" +#include "moe_v2_gather_dynamic_quant.h" +#include "moe_v2_src_to_dst_and_gather.h" + +using namespace AscendC; +using namespace MoeInitRoutingQuantV2; +using namespace optiling; + +template +__aicore__ inline void moe_init_routing_quant_v2( + GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, GM_ADDR dynamicQuantScale, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, uint64_t tilingKey) { + + if (g_coreType == AIC) { + return; + } + + if (workspace == nullptr) { + return; + } + + if (tilingKey == 20000) { // quant full load + TPipe sortPipe; + MoeV2FullLoadQuant op; + op.Init(x, expertIdx, scale, offset, expandedX, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, &sortPipe); + op.Process(); + sortPipe.Destroy(); + return; + } + + + else if (tilingKey == 21000) { // dynamic quant full load + TPipe sortPipe; + MoeV2FullLoadDynamicQuant op; + op.Init(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, scale, dynamicQuantScale, workspace, tilingData, + &sortPipe); + op.Process(); + sortPipe.Destroy(); + return; + } + + // sort + if (tilingKey == 10000 || tilingKey == 10100 || tilingKey == 11000 || tilingKey == 11100) { + TPipe sortPipe; + MoeV2SortOneCore op; + op.Init(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, workspace, + tilingData, &sortPipe); + op.Process(); + sortPipe.Destroy(); + } else if (tilingKey == 10010 || tilingKey == 10110 || tilingKey == 11010 || tilingKey== 11110) { + TPipe sortPipe; + MoeV2SortMultiCore op; + op.Init(expertIdx, expertTokensCountOrCumsum, expertTokensBeforeCapacity, workspace, + tilingData, &sortPipe); + op.Process(); + sortPipe.Destroy(); + } + + if (tilingKey == 10000 || tilingKey == 10010 || tilingKey ==11000 || tilingKey ==11010) { //没有drop的情况 + if (tilingData->expertTokensCountOrCumsumFlag != EXERPT_TOKENS_NONE) { + TPipe expertTokenOutPipe; + MoeV2ExpertTokenOut expertTokenOutOp; + expertTokenOutOp.Init(expertTokensCountOrCumsum, expertTokensBeforeCapacity, + expandedRowIdx, workspace, tilingData, &expertTokenOutPipe); + expertTokenOutOp.Process(); + expertTokenOutPipe.Destroy(); + } + TPipe srcToDstPipe; + MoeV2SrcToDstOp srcToDstOp; + srcToDstOp.Init(expandedRowIdx, workspace, tilingData, &srcToDstPipe); + srcToDstOp.Process(); + srcToDstPipe.Destroy(); + } else if (tilingKey ==10100 || tilingKey ==10110 || tilingKey ==11100 || tilingKey ==11110) { //有drop的情况 + TPipe expertTokenOutPipe; + MoeV2ExpertTokenOut expertTokenOutOp; + expertTokenOutOp.Init(expertTokensCountOrCumsum, expertTokensBeforeCapacity, + expandedRowIdx, workspace, tilingData, &expertTokenOutPipe); + expertTokenOutOp.Process(); + expertTokenOutPipe.Destroy(); + + if (tilingKey == 10100 || tilingKey == 10110) { + TPipe srcToDstPipe; + MoeV2SrcToDstWithCapacity srcToDstWithCapacityOp; + srcToDstWithCapacityOp.Init(expandedRowIdx, expandedX, workspace, tilingData, &srcToDstPipe); + srcToDstWithCapacityOp.Process(); + srcToDstPipe.Destroy(); + } else { + TPipe srcToDstGatherPipe; + MoeV2SrcToDstAndGather srcToDstAndGatherOp; + srcToDstAndGatherOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &srcToDstGatherPipe); + srcToDstAndGatherOp.Process(); + srcToDstGatherPipe.Destroy(); + return; + } + } + + if (tilingKey == 10000 || tilingKey == 10010 || tilingKey == 10100 || tilingKey == 10110) { + TPipe gatherPipe; + MoeV2GatherQuant gatherQuantOp; + gatherQuantOp.Init(x, scale, offset, expandedRowIdx, expandedX, workspace, tilingData, &gatherPipe); + gatherQuantOp.Process(); + gatherPipe.Destroy(); + } else if (tilingKey == 11000 || tilingKey == 11010) { + TPipe gatherPipe; + MoeV2GatherDynamicQuant gatherDynamicQuantOp; + gatherDynamicQuantOp.Init(x, scale, expandedRowIdx, expandedX, dynamicQuantScale, workspace, tilingData, &gatherPipe); + gatherDynamicQuantOp.Process(); + gatherPipe.Destroy(); + } +} diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h new file mode 100644 index 00000000000..395a789ce3c --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h @@ -0,0 +1,429 @@ +#pragma once +#include "moe_init_routing_v2_tiling.h" + +namespace optiling { + +const static int64_t ATTR_QUANT_MODE = 6; +const static int64_t TILING_KEY_BASE = 10000; +const static int64_t TILING_KEY_PERF_BASE = 20000; +const static int64_t TILING_KEY_QUANT_BASE = 1000; +const static int64_t TILING_KEY_DROP_MODE_BASE = 100; +const static int64_t TILING_KEY_SORT_BASE = 10; +const static int64_t FOUR_BLOCK_BYTE = 128; +const static int64_t MAX_COLS_ONE_LOOP_QUANT = 8192; +const static int64_t INDEX_SCALE = 2; +const static int64_t INDEX_OFFSET = 3; +const static int64_t SMOOTH_NONE = 0; +const static int64_t SMOOTH_1H = 1; +const static int64_t SMOOTH_EH = 2; +const static int64_t MAX_COLS_DYNAMIC_QUANT = 6144; +const static int64_t DYNAMIC_QUANT_SRC_TO_DST_BUFFER = 15; +const static int64_t DYNAMIC_QUANT_COLS_BUFFER = 21; +const static int64_t DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER = 13; +const static int64_t DYNAMIC_QUANT_SCALE_SIZE_64 = 64; +const static int64_t DYNAMIC_QUANT_SCALE_SIZE_128 = 128; +const static int64_t OUTOUT_DYNAMIC_QUANT_SCALE = 4; +const static int64_t FULLLOAD_H_LIMIT = 7168; + + +inline static int64_t AlignOneBlockByte(int64_t x) { + return (x + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; +} + +inline static int64_t AlignOneBlockByteCeil(int64_t x) { + return x / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; +} + +struct MoeInitRoutingQuantV2TilingData { + int64_t coreNum; + int64_t n; + int64_t cols; + int64_t k; + int64_t expertCapacity; + int64_t expertNum; + int64_t dropPadMode; + int64_t expertTokensCountOrCumsumFlag; + int64_t expertTokensBeforeCapacityFlag; + int64_t smoothType; + InnerMoeV2VBSComputeTilingData vbsComputeParamsOp; + InnerMoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp; + InnerMoeV2SortOutComputeTilingData sortOutComputeParamsOp; + InnerMoeV2GatherOutComputeTilingData srcToDstComputeParamsOp; + InnerMoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp; + InnerMoeV2GatherOutComputeTilingData gatherOutComputeParamsOp; +}; + + + +class MoeInitRoutingQuantV2TilingBase : public InnerMoeInitRoutingV2TilingBase { +public: +protected: + + bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override; + uint64_t GetTilingKey() const override; + bool GetWorkspaceSize() override; + bool PostTiling() override; +public: + //bool CheckOutShape() override; + bool IsFullLoadQuant(int64_t space); + bool IsFullLoadDynamicQuant(int64_t space); + bool IsFullLoad() override; + void SetGatherTilingData(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t perCoreRows, int64_t lastCoreRows, + int64_t cols); + void SetGatherTilingDataCols(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t baseMaxCols, int64_t cols); + void SetGatherTilingDataRows(InnerMoeV2GatherOutComputeTilingData* tilingData, int64_t perCoreRows, + int64_t lastCoreRows, int64_t basePerLoopMaxRows); + void Tiling4GatherQuant(); + void Tiling4GatherDynamicQuant(); + void Tiling4SrcToDstCapacityCompute() override; + void Tiling4GatherOutCompute() override; + void CopyGatherOutTiling(InnerMoeV2GatherOutComputeTilingData& dst, InnerMoeV2GatherOutComputeTilingData& src); + void CopyTilingData(); + + + int64_t quantMode; + MoeInitRoutingQuantV2TilingData quantTilingData; +}; + + +bool MoeInitRoutingQuantV2TilingBase::IsFullLoadQuant(int64_t space) { + int64_t perCoreXRows = moeInitRoutingTilingData.n / aivNum; + int64_t remainder = moeInitRoutingTilingData.n % aivNum; + // NUM_TWO is Max xRows need add 2 becauseof the left and right row may be another row. + perCoreXRows = remainder <= 1 ? perCoreXRows + 1 : perCoreXRows + NUM_TWO; + int64_t quantBaseSpace = AlignOneBlockByte(moeInitRoutingTilingData.cols); + int64_t quantSpace = + quantBaseSpace * (inuptXDtypeSize_ + sizeof(int8_t) + sizeof(float) + sizeof(int16_t)) * perCoreXRows; + int64_t remainUbAfterSort = aicoreParams_.ubSize - space - quantSpace; + return remainUbAfterSort > 0; +} + +bool MoeInitRoutingQuantV2TilingBase::IsFullLoadDynamicQuant(int64_t space) { + int64_t quantSpace = AlignOneBlockByte(moeInitRoutingTilingData.cols) * DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER; + int64_t scaleOutSpace = 64; + int64_t remainUbAfterSort = aicoreParams_.ubSize - space - scaleOutSpace - quantSpace; + return remainUbAfterSort > 0; +} + +bool MoeInitRoutingQuantV2TilingBase::IsFullLoad() { + if (totalLength > sortLoopMaxElement || moeInitRoutingTilingData.cols > MAX_COLS_ONE_LOOP_QUANT || + this->dropPadMode == 1) { + return false; + } + int64_t sortSpace = AlignOneBlockByte(this->totalLength) * sizeof(int32_t) * ONE_CORE_SORT_BUFFER; + int64_t otherSpace = AlignOneBlockByte(this->totalLength) * sizeof(int32_t) * NUM_THREE; + int64_t expertSpace = AlignOneBlockByte(this->expertNum * sizeof(int32_t)); + if (quantMode == 0) { + return IsFullLoadQuant(sortSpace + otherSpace + expertSpace); + } else { + return IsFullLoadDynamicQuant(sortSpace + otherSpace + expertSpace); + } +} + +bool MoeInitRoutingQuantV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) { + + InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode, + expertTokensCountOrCumsumFlag, expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0); + this -> quantMode = quantMode; + if (quantMode == 0) { + } else { + if (scaleDim0 > 0) { + quantTilingData.smoothType = ((scaleDim0 == 1) ? SMOOTH_1H : SMOOTH_EH); + } else { + quantTilingData.smoothType = SMOOTH_NONE; + } + } + return true; +} + + +uint64_t MoeInitRoutingQuantV2TilingBase::GetTilingKey() const { + if (isFullLoad) { + return TILING_KEY_PERF_BASE + quantMode * TILING_KEY_QUANT_BASE; + } + return TILING_KEY_BASE + quantMode * TILING_KEY_QUANT_BASE + dropPadMode * TILING_KEY_DROP_MODE_BASE + + (totalLength > sortLoopMaxElement) * TILING_KEY_SORT_BASE; +} + + +bool MoeInitRoutingQuantV2TilingBase::PostTiling() { + CopyTilingData(); + return true; +} +void MoeInitRoutingQuantV2TilingBase::CopyGatherOutTiling(InnerMoeV2GatherOutComputeTilingData& dst, + InnerMoeV2GatherOutComputeTilingData& src) { + dst.needCoreNum = (src.needCoreNum); + dst.activateRows = (src.activateRows); + dst.perCoreRows = (src.perCoreRows); + dst.perCorePerLoopRows = (src.perCorePerLoopRows); + dst.perCoreLastLoopRows = (src.perCoreLastLoopRows); + dst.lastCoreRows = (src.lastCoreRows); + dst.lastCorePerLoopRows = (src.lastCorePerLoopRows); + dst.lastCoreLastLoopRows = (src.lastCoreLastLoopRows); + dst.perCoreLoops = (src.perCoreLoops); + dst.lastCoreLoops = (src.lastCoreLoops); + dst.perLoopCols = (src.perLoopCols); + dst.lastLoopCols = (src.lastLoopCols); + dst.colLoops = (src.colLoops); +} + +void MoeInitRoutingQuantV2TilingBase::CopyTilingData() { + quantTilingData.coreNum = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.coreNum); + quantTilingData.n = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.n); + quantTilingData.cols = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols); + quantTilingData.k = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.k); + quantTilingData.expertCapacity = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertCapacity); + quantTilingData.expertNum = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertNum); + quantTilingData.dropPadMode = (InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.dropPadMode); + quantTilingData.expertTokensCountOrCumsumFlag = ( + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertTokensCountOrCumsumFlag); + quantTilingData.expertTokensBeforeCapacityFlag = ( + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.expertTokensBeforeCapacityFlag); + + auto vbsTilingData = &InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.vbsComputeParamsOp; + quantTilingData.vbsComputeParamsOp.needCoreNum = (vbsTilingData->needCoreNum); + quantTilingData.vbsComputeParamsOp.perCoreElements = (vbsTilingData->perCoreElements); + quantTilingData.vbsComputeParamsOp.perCoreLoops = (vbsTilingData->perCoreLoops); + quantTilingData.vbsComputeParamsOp.perCorePerLoopElements = (vbsTilingData->perCorePerLoopElements); + quantTilingData.vbsComputeParamsOp.perCoreLastLoopElements = (vbsTilingData->perCoreLastLoopElements); + quantTilingData.vbsComputeParamsOp.lastCoreElements = (vbsTilingData->lastCoreElements); + quantTilingData.vbsComputeParamsOp.lastCoreLoops = (vbsTilingData->lastCoreLoops); + quantTilingData.vbsComputeParamsOp.lastCorePerLoopElements = (vbsTilingData->lastCorePerLoopElements); + quantTilingData.vbsComputeParamsOp.lastCoreLastLoopElements = (vbsTilingData->lastCoreLastLoopElements); + quantTilingData.vbsComputeParamsOp.oneLoopMaxElements = (vbsTilingData->oneLoopMaxElements); + + quantTilingData.vmsMiddleComputeParamsOp.needCoreNum = ( + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.vmsMiddleComputeParamsOp.needCoreNum); + quantTilingData.sortOutComputeParamsOp.oneLoopMaxElements = ( + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.sortOutComputeParamsOp.oneLoopMaxElements); + + CopyGatherOutTiling(quantTilingData.srcToDstComputeParamsOp, + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstComputeParamsOp); + CopyGatherOutTiling(quantTilingData.srcToDstCapacityComputeParamsOp, + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp); +} + + +bool MoeInitRoutingQuantV2TilingBase::GetWorkspaceSize() { + InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize(); + bool useCols = + (dropPadMode == 0 && quantTilingData.gatherOutComputeParamsOp.colLoops > 1) || + (dropPadMode == 1 && + InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp.colLoops > 1); + if (quantMode == 1 && useCols) { + workspaceSize_ += aivNum * InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols * sizeof(float); + } + return true; +} + +void MoeInitRoutingQuantV2TilingBase::SetGatherTilingData(InnerMoeV2GatherOutComputeTilingData* tilingData, + int64_t perCoreRows, int64_t lastCoreRows, int64_t cols) { + tilingData->perCorePerLoopRows = perCoreRows; + tilingData->perCoreLastLoopRows = perCoreRows; + tilingData->lastCorePerLoopRows = lastCoreRows; + tilingData->lastCoreLastLoopRows = lastCoreRows; + tilingData->perCoreLoops = 1; + tilingData->lastCoreLoops = 1; + tilingData->perLoopCols = cols; + tilingData->lastLoopCols = cols; + tilingData->colLoops = 1; +} + +void MoeInitRoutingQuantV2TilingBase::SetGatherTilingDataCols(InnerMoeV2GatherOutComputeTilingData* tilingData, + int64_t baseMaxCols, int64_t cols) { + tilingData->perLoopCols = (std::min(baseMaxCols, cols)); + tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols)); + tilingData->colLoops = (baseMaxCols == 0 ? 0 : (cols + baseMaxCols - 1) / baseMaxCols); +} + +void MoeInitRoutingQuantV2TilingBase::SetGatherTilingDataRows(InnerMoeV2GatherOutComputeTilingData* tilingData, + int64_t perCoreRows, int64_t lastCoreRows, + int64_t basePerLoopMaxRows) { + tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLoops = (basePerLoopMaxRows == 0 ? 0 + : (perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLoops = (basePerLoopMaxRows == 0 ? 0 + : (lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); +} + +void MoeInitRoutingQuantV2TilingBase::Tiling4SrcToDstCapacityCompute() { + if (quantMode == 0 || dropPadMode == 0) { + InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute(); + return; + } + + auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp; + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + tilingData->needCoreNum = CeilDiv(totalLength, perCoreRows); + int64_t cols = moeInitRoutingTilingData.cols; + tilingData->perCoreRows = perCoreRows; + int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1); + tilingData->lastCoreRows = lastCoreRows; + + int64_t rowSize = AlignOneBlockByte(perCoreRows * sizeof(int32_t)) * NUM_FOUR; + int64_t colSize = AlignOneBlockByte(cols * sizeof(int8_t)) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER; + int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64; + if (rowSize + colSize + scaleSize < static_cast(aicoreParams_.ubSize)) { + + SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols); + } else { + + int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT; + int64_t totalColSize = AlignOneBlockByte(baseMaxCols * sizeof(int8_t)) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER; + int64_t ubSize = static_cast(aicoreParams_.ubSize); + int64_t basePerLoopMaxRows = + AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR; + if (cols < MAX_COLS_DYNAMIC_QUANT) { + basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_SRC_TO_DST_BUFFER; + } + SetGatherTilingDataCols(tilingData, baseMaxCols, cols); + SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows); + } +} + + +void MoeInitRoutingQuantV2TilingBase::Tiling4GatherQuant() { + auto tilingData = &quantTilingData.gatherOutComputeParamsOp; + tilingData->activateRows = totalLength; + if (dropPadMode == 0 && activateNum > 0) { + tilingData->activateRows = (std::min(activateNum, totalLength)); + } + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + tilingData->needCoreNum = (CeilDiv(totalLength, perCoreRows)); + int64_t cols = moeInitRoutingTilingData.cols; + tilingData->perCoreRows = perCoreRows; + int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1); + tilingData->lastCoreRows = lastCoreRows; + int64_t sizeOfCol = sizeof(int8_t) * NUM_TWO + sizeof(float) + sizeof(int16_t) + inuptXDtypeSize_ * NUM_TWO; + int64_t rowSize = AlignOneBlockByte((perCoreRows * sizeof(int32_t) * NUM_TWO)); + int64_t colSize = AlignOneBlockByte(cols * sizeOfCol); + if (rowSize + colSize < static_cast(aicoreParams_.ubSize) / NUM_TWO) { + SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols); + } else { + int64_t baseMaxCols = MAX_COLS_ONE_LOOP_QUANT; + int64_t baseMaxColsSize = AlignOneBlockByte(baseMaxCols * sizeOfCol); + int64_t ubSize = static_cast(aicoreParams_.ubSize); + int64_t basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - baseMaxColsSize) / NUM_TWO / sizeof(int32_t)); + if (cols < MAX_COLS_ONE_LOOP_QUANT) { + basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize) / NUM_TWO / sizeof(int32_t)); + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = AlignOneBlockByteCeil((ubSize - rowSize) / sizeOfCol); + } + SetGatherTilingDataCols(tilingData, baseMaxCols, cols); + SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows); + } +} + + + +void SetGatherTilingDatawithloop(InnerMoeV2GatherOutComputeTilingData* tilingData, + int64_t perCorePerLoopRows, int64_t lastCorePerLoopRows, int64_t cols, + int64_t perCoreLastLoopRows = 1, int64_t lastCoreLastLoopRows = 1, + int64_t perCoreLoops = 1, int64_t lastCoreLoops = 1) { + tilingData-> perCorePerLoopRows = perCorePerLoopRows; + tilingData-> perCoreLastLoopRows = perCoreLastLoopRows; + tilingData-> lastCorePerLoopRows = lastCorePerLoopRows; + tilingData-> lastCoreLastLoopRows = lastCoreLastLoopRows; + tilingData-> perCoreLoops = perCoreLoops; + tilingData-> lastCoreLoops = lastCoreLoops; + tilingData-> perLoopCols = cols; + tilingData-> lastLoopCols = cols; + tilingData-> colLoops = 1; +} + +void MoeInitRoutingQuantV2TilingBase::Tiling4GatherDynamicQuant() { + + auto tilingData = &quantTilingData.gatherOutComputeParamsOp; + tilingData->activateRows = totalLength; + if (dropPadMode == 0 && activateNum > 0) { + tilingData->activateRows = (std::min(activateNum, totalLength)); + } + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + tilingData->needCoreNum = (CeilDiv(totalLength, perCoreRows)); + + int64_t cols = InnerMoeInitRoutingV2TilingBase::moeInitRoutingTilingData.cols; + + tilingData->perCoreRows = perCoreRows; + int64_t lastCoreRows = totalLength - perCoreRows * (tilingData->needCoreNum - 1); + tilingData->lastCoreRows = lastCoreRows; + + + int64_t rowSize = AlignOneBlockByte(perCoreRows * sizeof(int32_t)) * NUM_FOUR; + int64_t colSize = AlignOneBlockByte(cols * sizeof(int8_t)) * DYNAMIC_QUANT_COLS_BUFFER; + int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64; + int64_t onceRowSize = (static_cast(aicoreParams_.ubSize) - + colSize - scaleSize - + ONE_BLOCK_BYTE * NUM_FOUR * NUM_THREE) / + (sizeof(int32_t) * NUM_FOUR); + int64_t oneBlockNumInt = static_cast(ONE_BLOCK_BYTE) / static_cast(sizeof(int32_t)); + onceRowSize = onceRowSize / oneBlockNumInt * oneBlockNumInt; + bool ifOneLoop = ((static_cast(aicoreParams_.ubSize) > colSize + + scaleSize + ONE_BLOCK_BYTE * NUM_FOUR * NUM_FOUR) && + quantTilingData.smoothType == SMOOTH_NONE && + cols == FULLLOAD_H_LIMIT); + + int64_t perCoreOnceRowSize = ifOneLoop ? std::min(onceRowSize, perCoreRows) : perCoreRows; + int64_t lastCoreOnceRowSize = ifOneLoop ? std::min(onceRowSize, lastCoreRows) : lastCoreRows; + int64_t perCoreLoops = ifOneLoop ? CeilDiv(perCoreRows, perCoreOnceRowSize) : 1; + int64_t lastCoreLoops = ifOneLoop ? CeilDiv(lastCoreRows, lastCoreOnceRowSize) : 1; + int64_t perCoreLastLoopRows = ifOneLoop ? GetPerOrLastValue(perCoreRows, perCoreOnceRowSize) : perCoreRows; + int64_t lastCoreLastLoopRows = ifOneLoop ? GetPerOrLastValue(lastCoreRows, lastCoreOnceRowSize) : lastCoreRows; + + if (rowSize + colSize + scaleSize < static_cast(aicoreParams_.ubSize) || ifOneLoop) { + + SetGatherTilingDatawithloop(tilingData, perCoreOnceRowSize, lastCoreOnceRowSize, cols, + perCoreLastLoopRows, lastCoreLastLoopRows, + perCoreLoops, lastCoreLoops); + } else { + int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT; + int64_t totalColSize = AlignOneBlockByte(baseMaxCols * sizeof(int8_t)) * DYNAMIC_QUANT_COLS_BUFFER; + int64_t ubSize = static_cast(aicoreParams_.ubSize); + int64_t basePerLoopMaxRows = + AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR; + if (cols < MAX_COLS_DYNAMIC_QUANT) { + basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / sizeof(int32_t)) / NUM_FOUR; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_COLS_BUFFER; + } + SetGatherTilingDataCols(tilingData, baseMaxCols, cols); + SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows); + } +} + + +void MoeInitRoutingQuantV2TilingBase::Tiling4GatherOutCompute() { + if (quantMode == 0) { + Tiling4GatherQuant(); + } else { + Tiling4GatherDynamicQuant(); + } +} + + +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_v2_tiling.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_v2_tiling.h new file mode 100644 index 00000000000..37130b17814 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_init_routing_v2_tiling.h @@ -0,0 +1,410 @@ +#pragma once + +#include "tiling_base.h" + + +namespace optiling { +const static int64_t TILING_KEY_DROPLESS_SORT_ONE_CORE = 10001; +const static int64_t TILING_KEY_DROPLESS_SORT_MULTI_CORE = 10002; +const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE = 10011; +const static int64_t TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE = 10012; +const static int64_t TILING_KEY_HIGH_PERFORMANCE = 20000; +const static int64_t NUM_TWO = 2; +const static int64_t NUM_THREE = 3; +const static int64_t NUM_FOUR = 4; +const static int64_t MRG_LIST_NUM = 4; +const static int64_t SORT32_ALIGN_ELEMENT = 32; +const static int64_t ONE_BLOCK_BYTE = 32; +const static size_t DIM_ONE = 1; +const static size_t DIM_TWO = 2; +const static size_t DIM_THREE = 3; +const static int32_t SIZE_16 = 16; +const static int32_t LENGTH_1024 = 1024; +const static int64_t MAX_COLS_ONE_LOOP = 16376; +const static int64_t ASSIST_NUM = 256; +const static int64_t INDEX_INPUT_X = 0; +const static int64_t INDEX_INPUT_EXPERT_IDX = 1; +const static int64_t ATTR_ACTIVE_ROWS = 0; +const static int64_t ATTR_EXPERT_CAPACITY = 1; +const static int64_t ATTR_EXPERT_NUM = 2; +const static int64_t ATTR_DROP_PAD_MODE = 3; +const static int64_t ATTR_EXPERT_TOKENS_COUNT_OR_CUMSUM_FLAG = 4; +const static int64_t ATTR_EXPERT_TOKENS_BEFORE_CAPACITY_FLAG = 5; +const static int64_t OUTOUT_EXPANDED_X = 0; +const static int64_t OUTOUT_EXPANDED_ROW_IDX = 1; +const static int64_t OUTOUT_EXPERT_TOKENS_COUNT_OR_CUMSUM = 2; +const static int64_t OUTOUT_EXPERT_TOKENS_BEFORE_CAPACITY = 3; +const static int64_t KV_FACTOR = 2; +const static int64_t ONE_CORE_SORT_BUFFER = 6; +const static int64_t EXPERT_TOKENS_COUNT = 2; + + +inline static int64_t CeilLog4(int64_t x) { + return static_cast(std::ceil(std::log(x) / std::log(NUM_FOUR))); +} + +inline static int64_t GetPerOrLastValue(int64_t x, int64_t y) { + if (y == 0) { + return 0; + } + return x <= y ? x : x % y; +} + +template +constexpr T CeilDiv(const T dividend, const T divisor) +{ + return (dividend + divisor - 1) / divisor; +} + + +struct InnerMoeV2VBSComputeTilingData { + int64_t needCoreNum = 0; + int64_t perCoreElements = 0; + int64_t perCoreLoops = 0; + int64_t perCorePerLoopElements = 0; + int64_t perCoreLastLoopElements = 0; + int64_t lastCoreElements = 0; + int64_t lastCoreLoops = 0; + int64_t lastCorePerLoopElements = 0; + int64_t lastCoreLastLoopElements = 0; + int64_t oneLoopMaxElements = 0; +}; + +struct InnerMoeV2VMSMiddleComputeTilingData { + int64_t needCoreNum = 0; +}; + +struct InnerMoeV2SortOutComputeTilingData { + int64_t oneLoopMaxElements = 0; +}; + +struct InnerMoeV2GatherOutComputeTilingData { + int64_t needCoreNum = 0; + int64_t activateRows = 0; + int64_t perCoreRows = 0; + int64_t perCorePerLoopRows = 0; + int64_t perCoreLastLoopRows = 0; + int64_t lastCoreRows = 0; + int64_t lastCorePerLoopRows = 0; + int64_t lastCoreLastLoopRows = 0; + int64_t perCoreLoops = 0; + int64_t lastCoreLoops = 0; + int64_t perLoopCols = 0; + int64_t lastLoopCols = 0; + int64_t colLoops = 0; +}; + +struct InnerMoeInitRoutingV2TilingData { + int64_t coreNum; + int64_t n; + int64_t cols; + int64_t k; + int64_t expertCapacity; + int64_t expertNum; + int64_t dropPadMode; + int64_t expertTokensCountOrCumsumFlag; + int64_t expertTokensBeforeCapacityFlag; + InnerMoeV2VBSComputeTilingData vbsComputeParamsOp; + InnerMoeV2VMSMiddleComputeTilingData vmsMiddleComputeParamsOp; + InnerMoeV2SortOutComputeTilingData sortOutComputeParamsOp; + InnerMoeV2GatherOutComputeTilingData srcToDstComputeParamsOp; + InnerMoeV2GatherOutComputeTilingData srcToDstCapacityComputeParamsOp; + InnerMoeV2GatherOutComputeTilingData gatherOutComputeParamsOp; +}; + + +class InnerMoeInitRoutingV2TilingBase : public TilingBaseClass { + +protected: + bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) override; + bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) override; + + bool DoOpTiling() override; + uint64_t GetTilingKey() const override; + bool GetWorkspaceSize() override; + + +protected: + bool CheckTokenCount(int64_t num, const char* tag); + + virtual void Tiling4GatherOutCompute() = 0; + void Tiling4SrcToDstCompute(); + virtual void Tiling4SrcToDstCapacityCompute(); + void Tiling4SortOutCompute(); + void Tiling4VMSMiddleCompute(); + void Tiling4VBSCompute(); + void ShowTilingData(); + void Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData); + void Tiling4VBSOneCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData); + virtual bool IsFullLoad() = 0; + + + + + int64_t aivNum = 0; + int64_t sortLoopMaxElement = 0; + int64_t mrgSortListMaxElement = 2040; + int64_t totalLength = 0; + int64_t activateNum = 0; + int64_t expertCapacity = 0; + int64_t expertNum = 0; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 0; + bool expertTokensBeforeCapacityFlag = false; + int64_t inuptXDtypeSize_ = 0; + bool isFullLoad = false; + + InnerMoeInitRoutingV2TilingData moeInitRoutingTilingData; +}; + + +bool InnerMoeInitRoutingV2TilingBase::DoOpTiling() { + sortLoopMaxElement = + (aicoreParams_.ubSize) / (sizeof(int32_t) * NUM_TWO * NUM_FOUR) / SORT32_ALIGN_ELEMENT * SORT32_ALIGN_ELEMENT; + isFullLoad = IsFullLoad(); + Tiling4VBSCompute(); + Tiling4VMSMiddleCompute(); + Tiling4SortOutCompute(); + Tiling4SrcToDstCompute(); + Tiling4SrcToDstCapacityCompute(); + Tiling4GatherOutCompute(); + return true; +}; + +uint64_t InnerMoeInitRoutingV2TilingBase::GetTilingKey() const { + if (isFullLoad) { + return TILING_KEY_HIGH_PERFORMANCE; + } + if (dropPadMode == 0) { + if (totalLength <= sortLoopMaxElement) { // 排序只用到一个核排序 + return TILING_KEY_DROPLESS_SORT_ONE_CORE; + } else { + return TILING_KEY_DROPLESS_SORT_MULTI_CORE; + } + } else { + if (totalLength <= sortLoopMaxElement) { + return TILING_KEY_DROP_PAD_MODE_SORT_ONE_CORE; + } else { + return TILING_KEY_DROP_PAD_MODE_SORT_MULTI_CORE; + } + } + return tilingKey_; +} + + + +bool InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activateNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) { + + this->activateNum = activateNum; + this->expertCapacity = expertCapacity; + this->expertNum = expertNum; + this->dropPadMode = dropPadMode; + this->expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag; + if (dropPadMode == 1) { + // droppad场景下不输出expertTokensCountOrCumsum + expertTokensCountOrCumsumFlag = 0; + } else { + // dropless场景下不输出expertTokensBeforeCapacity + expertTokensBeforeCapacityFlag = false; + } + moeInitRoutingTilingData.cols = cols; + moeInitRoutingTilingData.n = m; + moeInitRoutingTilingData.k = topK; + moeInitRoutingTilingData.expertCapacity = expertCapacity; + moeInitRoutingTilingData.expertNum = expertNum; + moeInitRoutingTilingData.dropPadMode = dropPadMode; + moeInitRoutingTilingData.expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag; + moeInitRoutingTilingData.expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag; + totalLength = moeInitRoutingTilingData.n * moeInitRoutingTilingData.k; + inuptXDtypeSize_ = inuptXDtypeSize; + return true; +} + +bool InnerMoeInitRoutingV2TilingBase::GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) { + aivNum = aivCoreNum; + aicoreParams_.blockDim = aivCoreNum; + aicoreParams_.ubSize = ubSizePlatForm; + moeInitRoutingTilingData.coreNum = aivCoreNum; + return true; +} + + +bool InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize() { + // 计算workspace大小 + size_t sortWorkspaceSize = totalLength * sizeof(float) * NUM_TWO * NUM_THREE; // 排序需要的空间 + size_t scatterWorkspaceSize = totalLength * sizeof(int32_t) * NUM_TWO; + size_t expertTokenFlagSize = aivNum * 2 * sizeof(int32_t); + workspaceSize_ = sortWorkspaceSize + scatterWorkspaceSize + expertTokenFlagSize + SIZE_16 * LENGTH_1024 * LENGTH_1024; + return true; +} + +void InnerMoeInitRoutingV2TilingBase::Tiling4VBSOneCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) { + tilingData->needCoreNum = 1; + tilingData->perCoreElements = totalLength; + tilingData->perCoreLoops = 1; + tilingData->perCorePerLoopElements = tilingData->perCoreElements; + tilingData->perCoreLastLoopElements = tilingData->perCoreElements; + tilingData->lastCoreElements = tilingData->perCoreElements; + tilingData->lastCoreLoops = 1; + tilingData->lastCorePerLoopElements = tilingData->perCoreElements; + tilingData->lastCoreLastLoopElements = tilingData->perCoreElements; +} + +void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) { + //Tiling4VBSMultiCoreCompute + int64_t needCoreNum = CeilDiv(totalLength, sortLoopMaxElement); // 向上取整 + needCoreNum = static_cast(std::pow(4, CeilLog4(needCoreNum))); + needCoreNum = std::min(needCoreNum, aivNum); // 不能超过物理核数 + if (needCoreNum > 0) { + int64_t perCoreElements = totalLength / needCoreNum; // 每个核处理的元素数 + int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT; + int64_t lastCoreElement = totalLength - (needCoreNum - 1) * alineFloorPerCoreElements; + int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT; + if (lastCoreElement > alineCeilPerCoreElements) { + perCoreElements = alineCeilPerCoreElements; + needCoreNum = CeilDiv(totalLength, perCoreElements); + } else { + perCoreElements = alineFloorPerCoreElements; + } + tilingData->needCoreNum = needCoreNum; + do { + tilingData->perCoreElements = perCoreElements; + tilingData->perCoreLoops = CeilDiv(tilingData->perCoreElements, sortLoopMaxElement); // 每个核处理的loop数 + tilingData->perCorePerLoopElements = std::min(tilingData->perCoreElements, sortLoopMaxElement); + tilingData->perCoreLastLoopElements = tilingData->perCoreElements - (tilingData->perCoreLoops - 1) * tilingData->perCorePerLoopElements; + tilingData->lastCoreElements = totalLength - (tilingData->needCoreNum - 1) * tilingData->perCoreElements; + tilingData->lastCoreLoops = tilingData->perCoreLoops; + int64_t tmp = CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops); + int64_t lastCorePerLoopElements = + CeilDiv(CeilDiv(tilingData->lastCoreElements, tilingData->lastCoreLoops), SORT32_ALIGN_ELEMENT) * + SORT32_ALIGN_ELEMENT; + tilingData->lastCorePerLoopElements = lastCorePerLoopElements; + tilingData->lastCoreLastLoopElements = tilingData-> lastCoreElements - (tilingData->lastCoreLoops - 1) * tilingData->lastCorePerLoopElements; + perCoreElements -= SORT32_ALIGN_ELEMENT; + } while (tilingData->lastCoreLastLoopElements <= 0 && perCoreElements > 0); + } +} + + +void InnerMoeInitRoutingV2TilingBase::Tiling4VBSCompute() { + auto tilingData = &moeInitRoutingTilingData.vbsComputeParamsOp; + tilingData->oneLoopMaxElements = sortLoopMaxElement; + if (totalLength <= sortLoopMaxElement) { // 只用到一个核 + Tiling4VBSOneCoreCompute(tilingData); + return; + } + Tiling4VBSMultiCoreCompute(tilingData); +} + +void InnerMoeInitRoutingV2TilingBase::Tiling4VMSMiddleCompute() { + auto vbsComputeTilingData = &moeInitRoutingTilingData.vbsComputeParamsOp; + auto tilingData = &moeInitRoutingTilingData.vmsMiddleComputeParamsOp; + if (vbsComputeTilingData->needCoreNum <= MRG_LIST_NUM) { // 队列数小于一次vms则没有中间归并 + tilingData->needCoreNum = 0; // 需要的核数 + } else { + int64_t needCoreNum = CeilDiv(vbsComputeTilingData->needCoreNum, MRG_LIST_NUM); + tilingData->needCoreNum = needCoreNum; // 需要的核数 + } +} + +void InnerMoeInitRoutingV2TilingBase::Tiling4SortOutCompute() { + auto tilingData = &moeInitRoutingTilingData.sortOutComputeParamsOp; + tilingData->oneLoopMaxElements = mrgSortListMaxElement; +} + + +void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCompute() { + auto tilingData = &moeInitRoutingTilingData.srcToDstComputeParamsOp; + + int64_t perLoopMaxRows = (aicoreParams_.ubSize - ASSIST_NUM * sizeof(float) - aivNum * SORT32_ALIGN_ELEMENT) / + (SORT32_ALIGN_ELEMENT * NUM_TWO) / NUM_TWO; + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + int64_t needCoreNum = CeilDiv(totalLength, perCoreRows); + tilingData->needCoreNum = needCoreNum; + int64_t lastCoreNum = totalLength - perCoreRows * (tilingData->needCoreNum - 1); + tilingData->perCoreRows = perCoreRows; + if (perLoopMaxRows >= tilingData->perCoreRows) { // 一个loop结束 + tilingData->perCorePerLoopRows = tilingData->perCoreRows; + tilingData->perCoreLastLoopRows = tilingData->perCoreRows; + } else { + tilingData->perCorePerLoopRows = perLoopMaxRows; + tilingData->perCoreLastLoopRows = tilingData->perCoreRows - (CeilDiv(tilingData->perCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows; + } + tilingData->lastCoreRows = lastCoreNum; + if (perLoopMaxRows >= tilingData->lastCoreRows) { + tilingData->lastCorePerLoopRows = tilingData->lastCoreRows; + tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows; + } else { + tilingData->lastCorePerLoopRows = perLoopMaxRows; + tilingData->lastCoreLastLoopRows = tilingData->lastCoreRows - (CeilDiv(tilingData->lastCoreRows, perLoopMaxRows) - 1) * perLoopMaxRows; + } +} + + +void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute() { + auto tilingData = &moeInitRoutingTilingData.srcToDstCapacityComputeParamsOp; + int64_t perCoreRows = CeilDiv(totalLength, aivNum); + + if (perCoreRows <= 0) { + tilingData->needCoreNum = 0; + return; + } + + int64_t needCoreNum = CeilDiv(totalLength, perCoreRows); + tilingData->needCoreNum = needCoreNum; + int64_t cols = moeInitRoutingTilingData.cols; + tilingData->perCoreRows = perCoreRows; + int64_t lastCoreRows = totalLength - perCoreRows * (needCoreNum - 1); + tilingData->lastCoreRows = lastCoreRows; + + + int64_t rowSize = + (perCoreRows * sizeof(int32_t) * 2 + ONE_BLOCK_BYTE + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t colSize = (cols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + + if (rowSize + colSize < static_cast(aicoreParams_.ubSize)) { + tilingData->perCorePerLoopRows = perCoreRows; + tilingData->perCoreLastLoopRows = perCoreRows; + tilingData->lastCorePerLoopRows = lastCoreRows; + tilingData->lastCoreLastLoopRows = lastCoreRows; + tilingData->perCoreLoops = 1; + tilingData->lastCoreLoops = 1; + tilingData->perLoopCols = cols; + tilingData->lastLoopCols = cols; + tilingData->colLoops = 1; + + } else { + int64_t baseMaxCols = MAX_COLS_ONE_LOOP; + int64_t baseMaxColsSize = (baseMaxCols * inuptXDtypeSize_ + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) - baseMaxColsSize - ONE_BLOCK_BYTE) / + static_cast(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + if (cols < MAX_COLS_ONE_LOOP) { + basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) - colSize - ONE_BLOCK_BYTE) / + static_cast(sizeof(int32_t)) / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = + (static_cast(aicoreParams_.ubSize) - rowSize) / inuptXDtypeSize_ / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } + tilingData->perLoopCols = (std::min(baseMaxCols, cols)); + tilingData->lastLoopCols = (GetPerOrLastValue(cols, baseMaxCols)); + tilingData->colLoops = ((cols + baseMaxCols - 1) / baseMaxCols); + tilingData->perCorePerLoopRows = (std::min(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLastLoopRows = (GetPerOrLastValue(perCoreRows, basePerLoopMaxRows)); + tilingData->perCoreLoops = ((perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + tilingData->lastCorePerLoopRows = (std::min(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLastLoopRows = (GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows)); + tilingData->lastCoreLoops = ((lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + } +} + +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h new file mode 100644 index 00000000000..c190033ade8 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h @@ -0,0 +1,94 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_common.h + * \brief + */ +#ifndef INNER_MOE_V2_COMMON_H +#define INNER_MOE_V2_COMMON_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +constexpr int64_t SPLIT_N = 0; +constexpr int64_t SPLIT_K = 1; +constexpr float MIN_FP32 = -3.4e38; +constexpr int64_t ONE_REPEAT_SORT_NUM = 32; +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t INT32_ONE_BLOCK_NUM = 8; + +constexpr int64_t ASSIST_NUM = 256; +constexpr int64_t ASSIST_INDEX_NUM = 32; + +constexpr int64_t MERGE_LIST_TWO = 2; +constexpr int64_t MERGE_LIST_THREE = 3; +constexpr int64_t MERGE_LIST_FOUR = 4; + +constexpr int64_t MERGE_LIST_IDX_TWO = 2; +constexpr int64_t MERGE_LIST_IDX_THREE = 3; + +constexpr int64_t MAX_EXPERT_NUM = 5120; +constexpr int64_t DROPLESS_MODE = 0; +constexpr int64_t DROP_PAD_MODE = 1; +constexpr int64_t EXERPT_TOKENS_COUNT = 2; +constexpr int64_t EXERPT_TOKENS_CUMSUM = 1; +constexpr int64_t EXERPT_TOKENS_NONE = 0; +constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1; + +const __gm__ int32_t assist[256] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, + 4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, + 16, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, + 20, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0, + 24, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0}; + +__aicore__ inline int64_t Ceil(int64_t a, int64_t b) { + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) { + if (bytes == 0) { + return 0; + } + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes; +} + +__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) { + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES; +} + +template +__aicore__ inline T Min(T a, T b) { + return a > b ? b : a; +} + +template +__aicore__ inline T Max(T a, T b) { + return a < b ? b : a; +} + +template +__aicore__ inline void SetWaitFlag(HardEvent evt) { + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + SetFlag(eventId); + WaitFlag(eventId); +} + +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_COMMON_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_expert_token_out.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_expert_token_out.h new file mode 100644 index 00000000000..fe05765623c --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_expert_token_out.h @@ -0,0 +1,310 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_expert_token_out.h + * \brief + */ +#ifndef INNER_MOE_V2_EXPERT_TOKEN_OUT_H +#define INNER_MOE_V2_EXPERT_TOKEN_OUT_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +constexpr int64_t EXPERT_ID_VALUE_NUM = 2; + +class MoeV2ExpertTokenOut { + public: + __aicore__ inline MoeV2ExpertTokenOut(){}; + template + __aicore__ inline void Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, + GM_ADDR expandedRowIdx, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void Compute(int64_t progress); + __aicore__ inline void SyncAll(); + __aicore__ inline void InitLocal(); + __aicore__ inline void GetExpertTokenCount(int32_t curExpertId); + __aicore__ inline void CopyOutTokenGm(); + __aicore__ inline void CopyOutExpertTokensCumsum(bool isTail); + __aicore__ inline void CopyOutExpertTokensCount(bool isTail); + + private: + TPipe* pipe; + TQue copyInQueue; + TQue expertTokenIdxCopyInQueue; + TQue expertTokenIdxCopyOutQueue; + + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedRowIdxGm; + + LocalTensor expertTokenIdxOutLocal; + + const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t expertNum; + int64_t expertNumUbAlign; + int64_t dropPadMode = 0; + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + + int64_t tokenCount = 0; + int64_t expertIdx = 0; + int32_t lastExpertId = -1; + int32_t firstExpertId = -1; + + int32_t expertTokenValue = 0; +}; + +__aicore__ inline void MoeV2ExpertTokenOut::InitLocal() { + LocalTensor tokenIdxLocal = expertTokenIdxCopyOutQueue.AllocTensor(); + Duplicate(tokenIdxLocal, 0, this->expertNumUbAlign); + expertTokenIdxCopyOutQueue.EnQue(tokenIdxLocal); + + // expandedRowIdx initialized to -1, which is used in the src_to_dst_with_capacity step. + // use this step SyncAll to synchronize every core data + if (this->dropPadMode == 0) { + return; + } + LocalTensor outLocal = copyInQueue.AllocTensor(); + int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows; + Duplicate(outLocal, -1, perLoopRows); + SetWaitFlag(HardEvent::V_MTE3); + for (int64_t loop = 0; loop < loops; loop++) { + int64_t copyLength = perLoopRows; + if (loop == loops - 1) { + copyLength = lastLoopRows; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, 0, + 0}; + DataCopyPad(expandedRowIdxGm[this->blockIdx * this->srcToDstTilingData->perCoreRows + loop * perLoopRows], outLocal, + copyParams); + } + SetWaitFlag(HardEvent::MTE3_MTE2); + copyInQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyIn(int64_t progress) { + LocalTensor inLocal = copyInQueue.AllocTensor(); + DataCopy(inLocal, expandedExpertIdxGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t))); + copyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2ExpertTokenOut::GetExpertTokenCount(int32_t curExpertId) { + this->tokenCount++; + if (this->lastExpertId < curExpertId) { + this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount - 1); + this->tokenCount = 1; + this->expertIdx += (curExpertId - this->lastExpertId); + while (curExpertId - this->firstExpertId + 1 > this->expertNumUbAlign) { + SetWaitFlag(HardEvent::S_MTE3); + CopyOutExpertTokensCumsum(false); + CopyOutExpertTokensCount(false); + SetWaitFlag(HardEvent::MTE3_V); + Duplicate(this->expertTokenIdxOutLocal, 0, this->expertNumUbAlign); + SetWaitFlag(HardEvent::V_S); + this->firstExpertId += this->expertNumUbAlign; + this->expertIdx = curExpertId - this->firstExpertId; + } + this->lastExpertId = curExpertId; + } +} + +__aicore__ inline void MoeV2ExpertTokenOut::Compute(int64_t progress) { + LocalTensor inLocal = copyInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = inLocal.GetValue(0); + this->firstExpertId = this->lastExpertId; + } + for (int64_t i = 0; i < currentLoopRows; i++) { + int32_t expertId = inLocal.GetValue(i); + GetExpertTokenCount(expertId); + } + this->expertTokenIdxOutLocal.SetValue(this->expertIdx, this->tokenCount); + copyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCumsum(bool isTail) { + if (this->dropPadMode != DROPLESS_MODE || expertTokensCountOrCumsumFlag != EXERPT_TOKENS_CUMSUM) { + return; + } +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理 + if (this->firstExpertId > expertTokensCountOrCumsumGm.GetSize()) { + return; + } +#endif + int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign; + int64_t end = this->expertNum - this->firstExpertId; + for (int64_t i = 0; i < copyLength; i++) { + this->expertTokenValue += this->expertTokenIdxOutLocal.GetValue(i); + this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue); + } + // if the remianing UB is sufficient, use the UB space to copy + // otherwise, copy the calculated data first, and then copy the last tokenValue to remaining expert position + if (isTail && end <= this->expertNumUbAlign) { + int64_t startAlign = Min(Align(copyLength, sizeof(int32_t)), end); + for (int64_t i = copyLength; i < startAlign; i++) { + this->expertTokenIdxOutLocal.SetValue(i, this->expertTokenValue); + } + if (startAlign < end) { + Duplicate(this->expertTokenIdxOutLocal[startAlign], this->expertTokenValue, end - startAlign); + } + copyLength = end; + SetWaitFlag(HardEvent::V_MTE3); + } + DataCopyExtParams copyParams{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, 0, 0}; + SetAtomicAdd(); +#ifndef __CCE_KT_TEST__ + DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams); +#endif + SetAtomicNone(); + if (isTail && end > this->expertNumUbAlign) { + int64_t remainderLength = end - copyLength; + SetWaitFlag(HardEvent::MTE3_V); + Duplicate(this->expertTokenIdxOutLocal, this->expertTokenValue, this->expertNumUbAlign); + SetWaitFlag(HardEvent::V_MTE3); + int64_t loopTimes = remainderLength / this->expertNumUbAlign + 1; + for (int64_t i = 0; i < loopTimes; i++) { + copyLength = i == loopTimes - 1 ? remainderLength - this->expertNumUbAlign * i : this->expertNumUbAlign; + DataCopyExtParams params{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, 0, 0}; + SetAtomicAdd(); + DataCopyPad(expertTokensCountOrCumsumGm[this->lastExpertId + 1 + this->expertNumUbAlign * i], + this->expertTokenIdxOutLocal, params); + SetAtomicNone(); + } + } +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCount(bool isTail) { + int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign; + DataCopyExtParams copyParams{static_cast(1), static_cast(copyLength * sizeof(int32_t)), 0, 0, 0}; +#ifdef __CCE_KT_TEST__ + // CPU孪生调试不进行输出拷贝 + return; +#endif + SetAtomicAdd(); + if (this->dropPadMode == DROP_PAD_MODE && expertTokensBeforeCapacityFlag > EXERPT_TOKENS_NONE) { + DataCopyPad(expertTokensBeforeCapacityGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams); + } + if (this->dropPadMode == DROPLESS_MODE && expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) { + DataCopyPad(expertTokensCountOrCumsumGm[this->firstExpertId], this->expertTokenIdxOutLocal, copyParams); + } + SetAtomicNone(); +} + +__aicore__ inline void MoeV2ExpertTokenOut::CopyOutTokenGm() { + if (this->dropPadMode == DROPLESS_MODE) { + SetWaitFlag(HardEvent::S_MTE3); + CopyOutExpertTokensCumsum(true); + CopyOutExpertTokensCount(true); + return; + } + this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign, this->lastExpertId); + this->expertTokenIdxOutLocal.SetValue(this->expertNumUbAlign + 1, this->tokenCount); + DataCopyExtParams copyParams{static_cast(1), static_cast(EXPERT_ID_VALUE_NUM * sizeof(int32_t)), + 0, 0, 0}; + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expertIdxValueGm[this->blockIdx * EXPERT_ID_VALUE_NUM], + this->expertTokenIdxOutLocal[this->expertNumUbAlign], copyParams); + CopyOutExpertTokensCount(true); +} + +__aicore__ inline void MoeV2ExpertTokenOut::SyncAll() { + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2ExpertTokenOut::Init(GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, + GM_ADDR expandedRowIdx, GM_ADDR workspace, + const TilingData* tilingData, TPipe* tPipe) { + int64_t blockNum = GetBlockNum(); + this->pipe = tPipe; + //this->blockIdx = GetBlockIdx(); + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp); + this->expertNum = tilingData->expertNum; + this->dropPadMode = tilingData->dropPadMode; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + } + + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, Align(this->totalLength, sizeof(int32_t))); + if (this->dropPadMode == DROPLESS_MODE && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, this->expertNum); + } + if (this->dropPadMode == DROP_PAD_MODE && this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) { + expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensBeforeCapacity, this->expertNum); + } + + expandedExpertIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + this->coreNum * 2); + + this->expertNumUbAlign = Min(Align(this->expertNum, sizeof(int32_t)), MAX_EXPERT_NUM); + pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES); + pipe->InitBuffer(expertTokenIdxCopyInQueue, 1, this->expertNumUbAlign * sizeof(int32_t)); + pipe->InitBuffer(expertTokenIdxCopyOutQueue, 1, (this->expertNumUbAlign + EXPERT_ID_VALUE_NUM) * sizeof(int32_t)); +} + +__aicore__ inline void MoeV2ExpertTokenOut::Process() { + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows; + currentLoopRows = perLoopRows; + InitLocal(); + this->expertTokenIdxOutLocal = expertTokenIdxCopyOutQueue.DeQue(); + for (int64_t loop = 0; loop < loops - 1; loop++) { + CopyIn(loop); + Compute(loop); + } + currentLoopRows = lastLoopRows; + CopyIn(loops - 1); + Compute(loops - 1); + CopyOutTokenGm(); + expertTokenIdxCopyOutQueue.FreeTensor(this->expertTokenIdxOutLocal); + } + this->SyncAll(); +} + +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_EXPERT_TOKEN_OUT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h new file mode 100644 index 00000000000..824e9af303a --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h @@ -0,0 +1,468 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_v2_fullload_dynamic_quant.h + * \brief + */ +#ifndef MOE_V2_FULL_LOAD_DYNAMIC_QUANT_H +#define MOE_V2_FULL_LOAD_DYNAMIC_QUANT_H + +#include "moe_v2_mrgsort.h" +#include "moe_v2_sort_base.h" +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +template +class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { + public: + __aicore__ inline MoeV2FullLoadDynamicQuant(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR quantSmooth, GM_ADDR dynamicQuantScale, + GM_ADDR workspace, const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOutIdx(); + __aicore__ inline void CopyOutEmpty(); + __aicore__ inline void CopyOutXQuant1H(); + __aicore__ inline void CopyOutXQuantEH(); + __aicore__ inline void ComputeExpertTokenCountOrCumsum(); + __aicore__ inline void Compute(LocalTensor& smoothLocal); + + private: + int64_t sortNum_; + const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData_; + int64_t blockIdx_; + int64_t needCoreNum_; + int64_t coreRows_; + int64_t perCoreRows_; + int64_t k_; + int64_t n_; + int64_t cols_; + int64_t activateRows_; + int64_t expertNum; + int64_t expertCapacity; + int64_t smoothType; + int64_t colsAlign; + + TQue xCopyInQueue_; + TQue expandedRowIdxCopyOutQueue_; + TQue expandedExpertIdxCopyOutQueue_; + TQue expandDstToSrcRowQueue_; + TQue expertTokensCopyOutQueue_; + TQue smoothInQueue; + TQue calcQueue; + TQue inputXOutQueue; + TQue scaleOutQueue; + + GlobalTensor xGm_; + GlobalTensor expertIdxGm_; + GlobalTensor quantSmoothGm; + GlobalTensor dynamicQuantScaleGm; + + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expandedExpertIdxGm_; + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + int64_t dropPadMode = 0; + + LocalTensor expandDstToSrcRowLocal; + LocalTensor expandedExpertIdxLocal; +}; + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyIn() { + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(this->totalLength * sizeof(int32_t)), + 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams); + ArithProgression(inLocal[this->sortNum_], 0, 1, this->totalLength); + sortDataCopyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::SortCompute() { + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertIdxLocal = inLocal[0]; + LocalTensor expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast(); + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + pipe_barrier(PIPE_V); + Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + LocalTensor concatLocal; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum_)); + Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + LocalTensor rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast(); + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum_)); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + LocalTensor expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor(); + expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast(); + Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, + this->totalLength); + pipe_barrier(PIPE_V); + Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + LocalTensor expandedExpertIdxLocalInt32; + expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast(); + Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + pipe_barrier(PIPE_V); + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdxLocalInt32); + + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandedRowIdxU32 = expandedRowIdx.ReinterpretCast(); + Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + ArithProgression(inLocal[this->sortNum_], 0, 1, this->totalLength); + pipe_barrier(PIPE_V); + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutIdx() { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); + DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::ComputeExpertTokenCountOrCumsum() { + expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.DeQue(); + LocalTensor expertTokensCount = expertTokensCopyOutQueue_.AllocTensor(); + + int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t)); + Duplicate(expertTokensCount, 0, expertNumAlign); + SetWaitFlag(HardEvent::V_S); + + int32_t lastExpertId = expandedExpertIdxLocal.GetValue(0); + int64_t tokenCount = 0; + int64_t lastExpertCount = 0; + for (int64_t i = 0; i < this->totalLength; i++) { + int32_t curExpertId = expandedExpertIdxLocal.GetValue(i); + tokenCount++; + while (lastExpertId < curExpertId) { + expertTokensCount.SetValue(lastExpertId, tokenCount - 1); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) { + tokenCount = 1; + } + lastExpertId++; + } + } +#ifndef __CCE_KT_TEST__ + expertTokensCount.SetValue(lastExpertId, tokenCount); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) { + lastExpertId++; + while (lastExpertId < this->expertNum) { + expertTokensCount.SetValue(lastExpertId, tokenCount); + lastExpertId++; + } + } + DataCopyExtParams copyParams{static_cast(1), static_cast(this->expertNum * sizeof(int32_t)), 0, 0, + 0}; + if (this->expertTokensCountOrCumsumFlag > 0) { + DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams); + } + expertTokensCopyOutQueue_.FreeTensor(expertTokensCount); +#endif +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutEmpty() { + expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.DeQue(); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& smoothLocal) { + LocalTensor inLocal = xCopyInQueue_.DeQue(); + + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[colsAlign], RoundMode::CAST_NONE, this->cols_); + pipe_barrier(PIPE_V); + } + + if (smoothType != 0) { + Mul(inLocal, inLocal, smoothLocal, this->cols_); + pipe_barrier(PIPE_V); + } + + Abs(tempLocal, inLocal, this->cols_); + pipe_barrier(PIPE_V); + + ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols_); + pipe_barrier(PIPE_V); + + float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f; + + Duplicate(dynamicQuantLocal, maxValue, 8); + Duplicate(tempLocal, maxValue, this->cols_); + pipe_barrier(PIPE_V); + + Div(tempLocal, inLocal, tempLocal, this->cols_); + pipe_barrier(PIPE_V); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_TRUNC, this->cols_); + pipe_barrier(PIPE_V); + + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->cols_); + + calcQueue.FreeTensor(tempLocal); + inputXOutQueue.EnQue(outLocal); + scaleOutQueue.EnQue(dynamicQuantLocal); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { + expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); + expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal); + + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1; + int64_t startXRow = curRowsStart / this->k_; + int64_t endXRow = curRowsEnd / this->k_; + + DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + + LocalTensor smoothLocal; + if (smoothType == 1) { + smoothLocal = smoothInQueue.AllocTensor(); + DataCopyPad(smoothLocal, quantSmoothGm, smoothCopyParams, {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + } + for (int64_t row = startXRow; row <= endXRow; row++) { + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + if constexpr (IsSameType::value) { + DataCopyPad(xLocal, xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } else { + DataCopyPad(xLocal[colsAlign], xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } + + xCopyInQueue_.EnQue(xLocal); + Compute(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + LocalTensor outLocal = inputXOutQueue.DeQue(); + while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) { + int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); + curRowsStart++; + if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) { + continue; + } + DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams); + DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + } + + xCopyInQueue_.FreeTensor(xLocal); + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + + if (smoothType == 1) { + smoothInQueue.FreeTensor(smoothLocal); + } + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuantEH() { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + + Muls(expandDstToSrcRowLocal.ReinterpretCast(), expandDstToSrcRowLocal.ReinterpretCast(), (float)-1, + this->totalLength); + pipe_barrier(PIPE_V); + LocalTensor sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast(); + Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->totalLength); + + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1; + + DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + + for (int64_t row = curRowsStart; row <= curRowsEnd; row++) { + if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) { + break; + } + int32_t srcIdx = sortedRowIdx.GetValue(row); + int32_t expertIdx = expandedExpertIdxLocal.GetValue(row); + + LocalTensor inLocal = xCopyInQueue_.AllocTensor(); + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0}); + xCopyInQueue_.EnQue(inLocal); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + + Compute(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0}); + + LocalTensor outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams); + + xCopyInQueue_.FreeTensor(inLocal); + smoothInQueue.FreeTensor(smoothLocal); + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + + expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); + expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR quantSmooth, GM_ADDR dynamicQuantScale, + GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, + TPipe* tPipe) { + this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + //this->blockIdx_ = GetBlockIdx(); + this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num(); + this->k_ = tilingData->k; + this->n_ = tilingData->n; + this->cols_ = tilingData->cols; + this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; + this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; + this->activateRows_ = this->gatherOutTilingData_->activateRows; + if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) { + this->coreRows_ = this->gatherOutTilingData_->lastCoreRows; + } else { + this->coreRows_ = this->gatherOutTilingData_->perCoreRows; + } + this->expertNum = tilingData->expertNum; + this->dropPadMode = tilingData->dropPadMode; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->smoothType = tilingData->smoothType; + this->colsAlign = Align(this->cols_, sizeof(T)); + this->pipe = tPipe; + + xGm_.SetGlobalBuffer((__gm__ T*)x); + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength); + + expandedXGm_.SetGlobalBuffer((__gm__ int8_t*)expandedX); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength); + if (this->expertTokensCountOrCumsumFlag > 0) { + // dropless + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, + Align(this->expertNum, sizeof(int32_t))); + } + quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth); + dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale); + + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum_ * sizeof(int32_t); + + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t startXRow = curRowsStart / this->k_; + int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_; + + pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize); + pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize); + pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t))); + pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor); + pipe->InitBuffer(tempBuffer, buffSize * kvFactor); + pipe->InitBuffer(sortedBuffer, buffSize * kvFactor); + + if constexpr (IsSameType::value) { + pipe->InitBuffer(xCopyInQueue_, 1, AlignBytes(this->cols_, sizeof(float))); + } else { + pipe->InitBuffer(xCopyInQueue_, 1, 2 * AlignBytes(this->cols_, sizeof(T))); + } + pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float))); + pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float))); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t))); + pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); +} + +template +__aicore__ inline void MoeV2FullLoadDynamicQuant::Process() { + if (this->blockIdx_ < this->needCoreNum_) { + CopyIn(); + SortCompute(); + if (this->blockIdx_ == 0) { + CopyOutIdx(); + } + if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + ComputeExpertTokenCountOrCumsum(); + } else { + CopyOutEmpty(); + } + if (smoothType == 2) { + CopyOutXQuantEH(); + } else { + CopyOutXQuant1H(); + } + } +} +} // namespace MoeInitRoutingQuantV2 +#endif // MOE_V2_DYNAMIC_QUANT_FULL_LOAD_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_quant.h new file mode 100644 index 00000000000..4d83d6642f1 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_quant.h @@ -0,0 +1,155 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_v2_fullload_quant.h + * \brief + */ +#ifndef MOE_V2_FULL_LOAD_QUANT_H +#define MOE_V2_FULL_LOAD_QUANT_H + +#include "moe_v2_fullload_quant_base.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +template +class MoeV2FullLoadQuant : public MoeV2FullLoadQuantBase { + public: + __aicore__ inline MoeV2FullLoadQuant(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void Compute(int64_t xLocalLength); + __aicore__ inline void CopyOutX(); + + private: + TQue floatQueue; + TQue halfQueue; + TQue inputXCopyOutQueue; + + GlobalTensor xGm; + GlobalTensor scaleGm; + GlobalTensor offsetGm; + + float scale; + float offset; +}; + +template +__aicore__ inline void MoeV2FullLoadQuant::Compute(int64_t xLocalLength) { + LocalTensor inLocal = xCopyInQueue.DeQue(); + LocalTensor outLocal = inputXCopyOutQueue.AllocTensor(); + LocalTensor floatLocal = floatQueue.AllocTensor(); + LocalTensor halfLocal = halfQueue.AllocTensor(); + + uint32_t elements = Align(this->cols, sizeof(int8_t)) * xLocalLength; + if constexpr (IsSameType::value) { + Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements); + pipe_barrier(PIPE_V); + Cast(halfLocal, floatLocal, RoundMode::CAST_NONE, elements); + pipe_barrier(PIPE_V); + Muls(halfLocal, halfLocal, static_cast(this->scale), elements); + pipe_barrier(PIPE_V); + Adds(halfLocal, halfLocal, static_cast(this->offset), elements); + pipe_barrier(PIPE_V); + LocalTensor intLocal = floatLocal.ReinterpretCast(); + Cast(intLocal, halfLocal, RoundMode::CAST_RINT, elements); + pipe_barrier(PIPE_V); + SetDeqScale((half)1.000000e+00f); + pipe_barrier(PIPE_V); + Cast(halfLocal, intLocal, RoundMode::CAST_RINT, elements); + pipe_barrier(PIPE_V); + Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements); + } else if constexpr (IsSameType::value) { + Cast(halfLocal, inLocal, RoundMode::CAST_NONE, elements); + pipe_barrier(PIPE_V); + Muls(halfLocal, halfLocal, static_cast(this->scale), elements); + pipe_barrier(PIPE_V); + Adds(halfLocal, halfLocal, static_cast(this->offset), elements); + pipe_barrier(PIPE_V); + Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements); + } else { + Muls(inLocal, inLocal, static_cast(this->scale), elements); + pipe_barrier(PIPE_V); + Adds(inLocal, inLocal, static_cast(this->offset), elements); + pipe_barrier(PIPE_V); + Cast(outLocal, inLocal, RoundMode::CAST_RINT, elements); + } + inputXCopyOutQueue.EnQue(outLocal); + xCopyInQueue.FreeTensor(inLocal); + floatQueue.FreeTensor(floatLocal); + halfQueue.FreeTensor(halfLocal); +} + +template +__aicore__ inline void MoeV2FullLoadQuant::CopyOutX() { + LocalTensor xLocal = xCopyInQueue.AllocTensor(); + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue.DeQue(); + int64_t inFactor = Align(this->cols, sizeof(int8_t)); + int64_t curRowsStart = this->blockIdx * this->perCoreRows; + int64_t startXRow = curRowsStart / this->k; + int64_t endXRow = (curRowsStart + this->coreRows - 1) / this->k; + + uint32_t dstStride = (inFactor * sizeof(T) - AlignBytes(this->cols, sizeof(T))) / BLOCK_BYTES; + DataCopyExtParams dataXCopyParams{static_cast(endXRow - startXRow + 1), + static_cast(this->cols * sizeof(T)), 0, dstStride, 0}; + DataCopyPadExtParams dataXCopyPadParams{false, 0, 0, 0}; + DataCopyPad(xLocal, xGm[startXRow * this->cols], dataXCopyParams, dataXCopyPadParams); + xCopyInQueue.EnQue(xLocal); + Compute(endXRow - startXRow + 1); + LocalTensor outLocal = inputXCopyOutQueue.DeQue(); + int64_t k = 0; + DataCopyExtParams intriParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + for (int64_t i = startXRow; i <= endXRow; i++) { + for (; k < this->perCoreRows && curRowsStart / this->k == i; curRowsStart++, k++) { + int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); + if (outIndex < this->activateRows) { + DataCopyPad(expandedXGm[outIndex * this->cols], outLocal[(i - startXRow) * inFactor], intriParams); + } + } + } + expandedRowIdxCopyOutQueue.FreeTensor(expandedRowIdx); + inputXCopyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2FullLoadQuant::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, + GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe) { + this->InitBase(x, expertIdx, expandedX, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, tPipe); + xGm.SetGlobalBuffer((__gm__ T*)x); + scaleGm.SetGlobalBuffer((__gm__ float*)scale, 1); + offsetGm.SetGlobalBuffer((__gm__ float*)offset, 1); + this->scale = scaleGm.GetValue(0); + this->offset = offsetGm.GetValue(0); + + int64_t curRowsStart = this->blockIdx * this->perCoreRows; + int64_t rowLength = (curRowsStart + this->coreRows - 1) / this->k - curRowsStart / this->k + 1; + int64_t xAlignedCount = Align(this->cols, sizeof(int8_t)); + pipe->InitBuffer(xCopyInQueue, bufferNum, xAlignedCount * sizeof(T) * rowLength); + pipe->InitBuffer(inputXCopyOutQueue, 1, xAlignedCount * sizeof(int8_t) * rowLength); + pipe->InitBuffer(floatQueue, 1, xAlignedCount * sizeof(float) * rowLength); + pipe->InitBuffer(halfQueue, 1, xAlignedCount * sizeof(half) * rowLength); +} + +template +__aicore__ inline void MoeV2FullLoadQuant::Process() { + if (this->blockIdx < this->needCoreNum) { + this->ProcessBase(); + CopyOutX(); + } +} +} // namespace MoeInitRoutingQuantV2 +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_quant_base.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_quant_base.h new file mode 100644 index 00000000000..8e8195c995a --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_quant_base.h @@ -0,0 +1,279 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_v2_fullload_quant_base.h + * \brief + */ +#ifndef MOE_V2_FULL_LOAD_QUANT_BASE_H +#define MOE_V2_FULL_LOAD_QUANT_BASE_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +class MoeV2FullLoadQuantBase { + public: + __aicore__ inline MoeV2FullLoadQuantBase(){}; + + protected: + __aicore__ inline void InitBase(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void ProcessBase(); + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOutIdx(); + __aicore__ inline void CopyOutEmpty(); + __aicore__ inline void ComputeExpertTokenCountOrCumsum(); + + protected: + const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData; + + TPipe* pipe; + int64_t tileLength; + int64_t bufferNum = 1; + int64_t totalLength; + int64_t coreNum; + int64_t sortNum; + int64_t blockIdx; + int64_t needCoreNum; + int64_t coreRows; + int64_t perCoreRows; + int64_t k; + int64_t n; + int64_t cols; + int64_t activateRows; + int64_t expertNum; + int64_t expertCapacity; + + TQue sortDataCopyInQueue; + TBuf tempBuffer; + TBuf sortedBuffer; + TQue xCopyInQueue; + TQue expandedRowIdxCopyOutQueue; + TQue expandedExpertIdxCopyOutQueue; + TQue expandDstToSrcRowQueue; + TQue expertTokensCopyOutQueue; + + GlobalTensor expertIdxGm; + GlobalTensor expandedXGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + int64_t dropPadMode = 0; + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; + static constexpr int64_t FOUR_BLOCK_BYTES = 128; +}; + +__aicore__ inline void MoeV2FullLoadQuantBase::CopyIn() { + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(this->totalLength * sizeof(int32_t)), + 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams); + ArithProgression(inLocal[this->sortNum], 0, 1, this->totalLength); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2FullLoadQuantBase::SortCompute() { + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertIdxLocal = inLocal[0]; + LocalTensor expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast(); + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + pipe_barrier(PIPE_V); + Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + LocalTensor concatLocal; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum)); + Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + LocalTensor rowIdxLocal = inLocal[this->sortNum].template ReinterpretCast(); + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum)); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + LocalTensor expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue.AllocTensor(); + LocalTensor expandDstToSrcRowLocal = expandDstToSrcRowQueue.AllocTensor(); + LocalTensor expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast(); + Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, + this->totalLength); + pipe_barrier(PIPE_V); + Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + LocalTensor expandedExpertIdxLocalInt32; + expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast(); + Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + pipe_barrier(PIPE_V); + expandedExpertIdxCopyOutQueue.EnQue(expandedExpertIdxLocalInt32); + + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue.AllocTensor(); + LocalTensor expandedRowIdxU32 = expandedRowIdx.ReinterpretCast(); + Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + ArithProgression(inLocal[this->sortNum], 0, 1, this->totalLength); + pipe_barrier(PIPE_V); + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + expandedRowIdxCopyOutQueue.EnQue(expandedRowIdx); + sortDataCopyInQueue.FreeTensor(inLocal); + + expandDstToSrcRowQueue.FreeTensor(expandDstToSrcRowLocal); +} + +__aicore__ inline void MoeV2FullLoadQuantBase::CopyOutIdx() { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); + DataCopyPad(expandedRowIdxGm, expandedRowIdx, intriParams); + expandedRowIdxCopyOutQueue.EnQue(expandedRowIdx); +} + +__aicore__ inline void MoeV2FullLoadQuantBase::ComputeExpertTokenCountOrCumsum() { + LocalTensor expandedExpertIdx = expandedExpertIdxCopyOutQueue.DeQue(); + LocalTensor expertTokensCount = expertTokensCopyOutQueue.AllocTensor(); + + int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t)); + Duplicate(expertTokensCount, 0, expertNumAlign); + SetWaitFlag(HardEvent::V_S); + + int32_t lastExpertId = expandedExpertIdx.GetValue(0); + int64_t tokenCount = 0; + int64_t lastExpertCount = 0; + for (int64_t i = 0; i < this->totalLength; i++) { + int32_t curExpertId = expandedExpertIdx.GetValue(i); + tokenCount++; + while (lastExpertId < curExpertId) { + expertTokensCount.SetValue(lastExpertId, tokenCount - 1); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) { + tokenCount = 1; + } + lastExpertId++; + } + } + expertTokensCount.SetValue(lastExpertId, tokenCount); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) { + lastExpertId++; + while (lastExpertId < this->expertNum) { + expertTokensCount.SetValue(lastExpertId, tokenCount); + lastExpertId++; + } + } + DataCopyExtParams copyParams{static_cast(1), static_cast(this->expertNum * sizeof(int32_t)), 0, 0, + 0}; + if (this->expertTokensCountOrCumsumFlag > 0) { + DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams); + } + expertTokensCopyOutQueue.FreeTensor(expertTokensCount); + expandedExpertIdxCopyOutQueue.FreeTensor(expandedExpertIdx); +} + +__aicore__ inline void MoeV2FullLoadQuantBase::CopyOutEmpty() { + LocalTensor outLocal = expandedExpertIdxCopyOutQueue.DeQue(); + expandedExpertIdxCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2FullLoadQuantBase::InitBase(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, + TPipe* tPipe) { + this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp); + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->k = tilingData->k; + this->n = tilingData->n; + this->cols = tilingData->cols; + this->needCoreNum = this->gatherOutTilingData->needCoreNum; + this->perCoreRows = this->gatherOutTilingData->perCoreRows; + this->activateRows = this->gatherOutTilingData->activateRows; + if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) { + this->coreRows = this->gatherOutTilingData->lastCoreRows; + } else { + this->coreRows = this->gatherOutTilingData->perCoreRows; + } + this->expertNum = tilingData->expertNum; + this->dropPadMode = tilingData->dropPadMode; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->pipe = tPipe; + + expertIdxGm.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength); + + expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength); + if (this->expertTokensCountOrCumsumFlag > 0) { + // dropless + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, + Align(this->expertNum, sizeof(int32_t))); + } + + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum * sizeof(int32_t); + + pipe->InitBuffer(expandedRowIdxCopyOutQueue, bufferNum, buffSize); + pipe->InitBuffer(expandedExpertIdxCopyOutQueue, bufferNum, buffSize); + pipe->InitBuffer(expertTokensCopyOutQueue, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t))); + pipe->InitBuffer(expandDstToSrcRowQueue, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor); + pipe->InitBuffer(tempBuffer, buffSize * kvFactor); + pipe->InitBuffer(sortedBuffer, buffSize * kvFactor); +} + +__aicore__ inline void MoeV2FullLoadQuantBase::ProcessBase() { + if (this->blockIdx < this->needCoreNum) { + CopyIn(); + SortCompute(); + if (this->blockIdx == 0) { + CopyOutIdx(); + } + if (this->blockIdx == this->needCoreNum - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + ComputeExpertTokenCountOrCumsum(); + } else { + CopyOutEmpty(); + } + } +} + +} // namespace MoeInitRoutingQuantV2 +#endif // MOE_V2_FULL_LOAD_QUANT_BASE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h new file mode 100644 index 00000000000..9f3ef220525 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h @@ -0,0 +1,568 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_gather_dynamic_quant.h + * \brief + */ +#ifndef MOE_V2_GATHER_DYNAMIC_QUANT_H +#define MOE_V2_GATHER_DYNAMIC_QUANT_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +template +class MoeV2GatherDynamicQuant { + public: + __aicore__ inline MoeV2GatherDynamicQuant(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR dynamicQuantScale, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyInExpandedRowIdx(int64_t progress); + __aicore__ inline void CopyInExpandedExpertIdx(int64_t progress); + __aicore__ inline void CopyOutXQuant1H(int64_t progress); + __aicore__ inline void CopyOutXQuantEH(int64_t progress); + __aicore__ inline void Compute(LocalTensor& smoothLocal); + __aicore__ inline void CopyOutPartialXQuantEH(int64_t progress); + __aicore__ inline void CopyOutPartialXQuant1H(int64_t progress); + __aicore__ inline float ComputeMax(LocalTensor& inLocal, LocalTensor& tempLocal, + LocalTensor& dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx, + int64_t j); + __aicore__ inline void ComputeScale(LocalTensor& inLocal, LocalTensor& tempLocal, float scaleTemp, + int64_t dstIndex, int64_t j); + + private: + TPipe* pipe; + TQue inputXInQueue; + TQue smoothInQueue; + TQue expandRowIdxInQueue; + TQue calcQueue; + TQue inputXOutQueue; + TQue scaleOutQueue; + + GlobalTensor inputXGm; + GlobalTensor expandedXGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor quantSmoothGm; + GlobalTensor dynamicQuantScaleGm; + GlobalTensor quantSrcGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor sortedRowIdxGm; + + const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData; + + int64_t needCoreNum; + int64_t blockIdx; + int64_t cols; + int64_t n; + int64_t k; + int64_t totalLength; + int64_t activateRows; + int64_t currentLoopRows; + int64_t currentLoopRowsAlign; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t colsTileLength; + int64_t perLoopCols; + int64_t perLoopColsAlign; + int64_t lastLoopCols; + int64_t colLoops; + int64_t dropPadMode; + int64_t smoothType; + + int64_t indicesOffset; + int64_t inputOffset; + int64_t outOffset; +}; + +template +__aicore__ inline void MoeV2GatherDynamicQuant::CopyInExpandedRowIdx(int64_t progress) { + this->indicesOffset = progress * this->perLoopRows; + LocalTensor indicesLocal = expandRowIdxInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams); + expandRowIdxInQueue.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::CopyInExpandedExpertIdx(int64_t progress) { + this->indicesOffset = progress * this->perLoopRows; + LocalTensor indicesLocal = expandRowIdxInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, sortedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams); + DataCopyPad(indicesLocal[currentLoopRowsAlign], expandedExpertIdxGm[indicesOffset], dataCopyParams, + dataCopyPadParams); + expandRowIdxInQueue.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& smoothLocal) { + LocalTensor inLocal = inputXInQueue.DeQue(); + + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); + pipe_barrier(PIPE_V); + } + + if (smoothType != 0) { + Mul(inLocal, inLocal, smoothLocal, this->cols); + pipe_barrier(PIPE_V); + } + + Abs(tempLocal, inLocal, this->cols); + pipe_barrier(PIPE_V); + + ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols); + pipe_barrier(PIPE_V); + + float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f; + + Duplicate(dynamicQuantLocal, maxValue, 8); + Duplicate(tempLocal, maxValue, this->cols); + pipe_barrier(PIPE_V); + + Div(tempLocal, inLocal, tempLocal, this->cols); + pipe_barrier(PIPE_V); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_TRUNC, this->cols); + pipe_barrier(PIPE_V); + + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->cols); + + calcQueue.FreeTensor(tempLocal); + inputXOutQueue.EnQue(outLocal); + scaleOutQueue.EnQue(dynamicQuantLocal); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progress) { + LocalTensor indicesLocal = expandRowIdxInQueue.DeQue(); + + int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + int64_t curLoopRow = 0; + int64_t currentLoopStartRow = initialRow / this->k; + int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; + DataCopyExtParams copyInParams{1, static_cast(this->cols * sizeof(T)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; + + LocalTensor smoothLocal; + if (smoothType == 1) { + smoothLocal = smoothInQueue.AllocTensor(); + DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + } + + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor inLocal = inputXInQueue.AllocTensor(); + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, inputXGm[row * this->cols], copyInParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal[perLoopColsAlign], inputXGm[row * this->cols], copyInParams, {false, 0, 0, 0}); + } + + inputXInQueue.EnQue(inLocal); + + // 计算quant + Compute(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + LocalTensor outLocal = inputXOutQueue.DeQue(); + + while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { + continue; + } + DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams); + DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + } + inputXInQueue.FreeTensor(inLocal); + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + if (smoothType == 1) { + smoothInQueue.FreeTensor(smoothLocal); + } + expandRowIdxInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuantEH(int64_t progress) { + LocalTensor indicesLocal = expandRowIdxInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + + DataCopyExtParams copyInParams{1, static_cast(this->perLoopCols * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(this->perLoopCols * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(this->perLoopCols * sizeof(int8_t)), 0, 0, 0}; + + int32_t lastExpertIdx = -1; + LocalTensor inLocal = inputXInQueue.AllocTensor(); + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + for (int64_t i = 0; i < this->currentLoopRows; i++) { + int64_t rowOffset = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + if (this->dropPadMode == DROPLESS_MODE && rowOffset + i >= this->activateRows) { + break; + } + int32_t srcIdx = indicesLocal.GetValue(i); + int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign + i); + + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0}); + } + inputXInQueue.EnQue(inLocal); + + if (expertIdx != lastExpertIdx) { + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + lastExpertIdx = expertIdx; + } + + Compute(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + DataCopyPad(dynamicQuantScaleGm[(rowOffset + i)], quantScaleLocal, {1, 4, 0, 0, 0}); + + LocalTensor outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm[(rowOffset + i) * this->cols], outLocal, copyOutParams); + + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + + inputXInQueue.FreeTensor(inLocal); + smoothInQueue.FreeTensor(smoothLocal); + expandRowIdxInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline float MoeV2GatherDynamicQuant::ComputeMax(LocalTensor& inLocal, + LocalTensor& tempLocal, + LocalTensor& dynamicQuantLocal, int32_t srcIdx, + int32_t expertIdx, int64_t j) { + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + + DataCopyExtParams intriParamsT{1, static_cast(colsTileLength * sizeof(T)), 0, 0, 0}; + DataCopyExtParams intriParamsFp32{1, static_cast(colsTileLength * sizeof(float)), 0, 0, 0}; + + if constexpr (!IsSameType::value) { + DataCopyPad(inLocal.ReinterpretCast()[perLoopColsAlign], inputXGm[srcIdx * this->cols + j * this->perLoopCols], + intriParamsT, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0}); + } + + inputXInQueue.EnQue(inLocal); + inLocal = inputXInQueue.DeQue(); + + if (smoothType != 0) { + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32, + {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + } + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength); + pipe_barrier(PIPE_V); + } + + if (smoothType != 0) { + Mul(inLocal, inLocal, smoothLocal, colsTileLength); + pipe_barrier(PIPE_V); + } + + Abs(tempLocal, inLocal, colsTileLength); + pipe_barrier(PIPE_V); + + ReduceMax(dynamicQuantLocal[8], tempLocal, tempLocal, colsTileLength); + + DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32); + smoothInQueue.FreeTensor(smoothLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); + + return dynamicQuantLocal.GetValue(8); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::ComputeScale(LocalTensor& inLocal, + LocalTensor& tempLocal, float scaleTemp, + int64_t dstIndex, int64_t j) { + DataCopyExtParams copyInParams{1, static_cast(colsTileLength * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(colsTileLength * sizeof(int8_t)), 0, 0, 0}; + + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + + DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0}); + inputXInQueue.EnQue(inLocal); + inLocal = inputXInQueue.DeQue(); + + Duplicate(tempLocal, scaleTemp, colsTileLength); + pipe_barrier(PIPE_V); + + Div(tempLocal, inLocal, tempLocal, colsTileLength); + pipe_barrier(PIPE_V); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength); + pipe_barrier(PIPE_V); + + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, colsTileLength); + + inputXOutQueue.EnQue(outLocal); + outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams); + + inputXOutQueue.FreeTensor(outLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::CopyOutPartialXQuantEH(int64_t progress) { + LocalTensor indicesLocal = expandRowIdxInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + + for (int64_t i = 0; i < this->currentLoopRows; i++) { + int64_t rowOffset = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + if (this->dropPadMode == DROPLESS_MODE && rowOffset + i >= this->activateRows) { + break; + } + int32_t srcIdx = indicesLocal.GetValue(i); + int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign + i); + + LocalTensor inLocal = inputXInQueue.AllocTensor(); + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor quantScaleLocal = scaleOutQueue.AllocTensor(); + + uint32_t tmp = 0xFF7FFFFF; + float reduceMax = *((float*)&tmp); + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j); + reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax; + } + + float scaleTemp = reduceMax / 127.0f; + Duplicate(quantScaleLocal, scaleTemp, 8); + scaleOutQueue.EnQue(quantScaleLocal); + quantScaleLocal = scaleOutQueue.DeQue(); + + DataCopyPad(dynamicQuantScaleGm[(rowOffset + i)], quantScaleLocal, {1, 4, 0, 0, 0}); + + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + + ComputeScale(inLocal, tempLocal, scaleTemp, rowOffset + i, j); + } + + inputXInQueue.FreeTensor(inLocal); + calcQueue.FreeTensor(tempLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + + expandRowIdxInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::CopyOutPartialXQuant1H(int64_t progress) { + LocalTensor indicesLocal = expandRowIdxInQueue.DeQue(); + + int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + int64_t curLoopRow = 0; + + int64_t currentLoopStartRow = initialRow / this->k; + int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; + + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor inLocal = inputXInQueue.AllocTensor(); + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor quantScaleLocal = scaleOutQueue.AllocTensor(); + + uint32_t tmp = 0xFF7FFFFF; + float reduceMax = *((float*)&tmp); + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + + float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, row, 0, j); + reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax; + } + + float scaleTemp = reduceMax / 127.0f; + Duplicate(quantScaleLocal, scaleTemp, 8); + scaleOutQueue.EnQue(quantScaleLocal); + quantScaleLocal = scaleOutQueue.DeQue(); + + while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { + continue; + } + DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + + ComputeScale(inLocal, tempLocal, scaleTemp, outIndex, j); + } + } + inputXInQueue.FreeTensor(inLocal); + calcQueue.FreeTensor(tempLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); + } + + expandRowIdxInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR dynamicQuantScale, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, + TPipe* tPipe) { + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp); + + this->needCoreNum = this->gatherOutTilingData->needCoreNum; + this->activateRows = this->gatherOutTilingData->activateRows; + this->cols = tilingData->cols; + this->n = tilingData->n; + this->k = tilingData->k; + this->totalLength = tilingData->n * tilingData->k; + this->dropPadMode = tilingData->dropPadMode; + this->smoothType = tilingData->smoothType; + + if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) { + this->coreRows = this->gatherOutTilingData->lastCoreRows; + this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->lastCoreLoops; + } else { + this->coreRows = this->gatherOutTilingData->perCoreRows; + this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->perCoreLoops; + } + this->perLoopCols = this->gatherOutTilingData->perLoopCols; + this->lastLoopCols = this->gatherOutTilingData->lastLoopCols; + this->colLoops = this->gatherOutTilingData->colLoops; + this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T)); + + inputXGm.SetGlobalBuffer((__gm__ T*)inputX); + expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX); + + expandedRowIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + + quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth); + dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale); + + expandedExpertIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)workspace + this->blockIdx * this->gatherOutTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + sortedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) + + this->blockIdx * this->gatherOutTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + if (this->cols > 1) { + quantSrcGm.SetGlobalBuffer( + (__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2 + this->blockIdx * this->cols, + this->cols * sizeof(float)); + } + + this->currentLoopRowsAlign = Align(this->perLoopRows, sizeof(int32_t)); + + int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T)); + perLoopColsAlignBytes = + Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES)); + + pipe->InitBuffer(expandRowIdxInQueue, BUFFER_NUM, 2 * AlignBytes(this->perLoopRows, sizeof(int32_t))); + pipe->InitBuffer(inputXInQueue, BUFFER_NUM, perLoopColsAlignBytes); + pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); + pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); +} + +template +__aicore__ inline void MoeV2GatherDynamicQuant::Process() { + if (this->blockIdx < this->needCoreNum) { + currentLoopRows = perLoopRows; + if (colLoops > 1) { // 一行无法全载,需要workspace + if (smoothType == 2) { + for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { + CopyInExpandedExpertIdx(loop); + CopyOutPartialXQuantEH(loop); + } + currentLoopRows = lastLoopRows; + CopyInExpandedExpertIdx(this->rowLoops - 1); + CopyOutPartialXQuantEH(this->rowLoops - 1); + } else { + for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { + CopyInExpandedRowIdx(loop); + CopyOutPartialXQuant1H(loop); + } + currentLoopRows = lastLoopRows; + CopyInExpandedRowIdx(this->rowLoops - 1); + CopyOutPartialXQuant1H(this->rowLoops - 1); + } + } else { // 一行可以全载 + if (smoothType == 2) { + for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { + CopyInExpandedExpertIdx(loop); + CopyOutXQuantEH(loop); + } + currentLoopRows = lastLoopRows; + CopyInExpandedExpertIdx(this->rowLoops - 1); + CopyOutXQuantEH(this->rowLoops - 1); + } else { + for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { + CopyInExpandedRowIdx(loop); + CopyOutXQuant1H(loop); + } + currentLoopRows = lastLoopRows; + CopyInExpandedRowIdx(this->rowLoops - 1); + CopyOutXQuant1H(this->rowLoops - 1); + } + } + } +} +} // namespace MoeInitRoutingQuantV2 +#endif // MOE_V2_GATHER_DYNAMIC_QUANT_H diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_out.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_out.h new file mode 100644 index 00000000000..bbdb1338908 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_out.h @@ -0,0 +1,181 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_gather_out.h + * \brief + */ +#ifndef INNER_MOE_V2_GATHER_OUT_H +#define INNER_MOE_V2_GATHER_OUT_H + +#include "moe_v2_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +constexpr int64_t BUFFER_NUM = 2; + +template +class MoeV2GatherOut { + public: + __aicore__ inline MoeV2GatherOut(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace, + const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyInIndices(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + + private: + TPipe* pipe; + TQueBind inputActivationsCopyInQueue; + TQue expandDstToSrcRowCopyInQueue; + + GlobalTensor inputXGm; + GlobalTensor expandedXGm; + GlobalTensor expandedRowIdxGm; + + const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData; + + int64_t needCoreNum; + int64_t blockIdx; + int64_t cols; + int64_t n; + int64_t k; + int64_t activateRows; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t colsTileLength; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + int64_t dropPadMode; + + int64_t indicesOffset; + int64_t inputOffset; + int64_t outOffset; +}; + +template +__aicore__ inline void MoeV2GatherOut::CopyInIndices(int64_t progress) { + this->indicesOffset = progress * this->perLoopRows; + LocalTensor indicesLocal = expandDstToSrcRowCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams); + + expandDstToSrcRowCopyInQueue.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherOut::CopyOut(int64_t progress) { + LocalTensor indicesLocal = expandDstToSrcRowCopyInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + colsTileLength = this->perLoopCols; + for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) { + int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + int64_t curLoopRow = 0; + if (colsLoop == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + int64_t currentLoopStartRow = initialRow / this->k; + int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor inLocal = inputActivationsCopyInQueue.AllocTensor(); + // input row position + inputOffset = row * this->cols + colsLoop * this->perLoopCols; + DataCopyExtParams dataCopyParams{1, static_cast(this->colsTileLength * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams); + SetWaitFlag(HardEvent::MTE2_MTE3); + + DataCopyExtParams intriParams{1, static_cast(this->colsTileLength * sizeof(T)), 0, 0, 0}; + while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { + continue; + } + outOffset = outIndex * cols + colsLoop * this->perLoopCols; +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理 + if (outOffset > expandedXGm.GetSize()) { + continue; + } +#endif + DataCopyPad(expandedXGm[outOffset], inLocal, intriParams); + } + inputActivationsCopyInQueue.FreeTensor(inLocal); + } + } + expandDstToSrcRowCopyInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherOut::Init(GM_ADDR inputX, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR workspace, const InnerMoeInitRoutingV2TilingData* tilingData, + TPipe* tPipe) { + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp); + + this->needCoreNum = this->gatherOutTilingData->needCoreNum; + this->activateRows = this->gatherOutTilingData->activateRows; + this->cols = tilingData->cols; + this->n = tilingData->n; + this->k = tilingData->k; + this->dropPadMode = tilingData->dropPadMode; + + if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) { + this->coreRows = this->gatherOutTilingData->lastCoreRows; + this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->lastCoreLoops; + } else { + this->coreRows = this->gatherOutTilingData->perCoreRows; + this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->perCoreLoops; + } + this->perLoopCols = this->gatherOutTilingData->perLoopCols; + this->lastLoopCols = this->gatherOutTilingData->lastLoopCols; + this->colLoops = this->gatherOutTilingData->colLoops; + + inputXGm.SetGlobalBuffer((__gm__ T*)inputX, this->coreRows * this->cols); + expandedXGm.SetGlobalBuffer((__gm__ T*)expandedX, tilingData->n * tilingData->k * this->cols); + expandedRowIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + + pipe->InitBuffer(inputActivationsCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T))); + pipe->InitBuffer(expandDstToSrcRowCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t))); +} + +template +__aicore__ inline void MoeV2GatherOut::Process() { + if (this->blockIdx < this->needCoreNum) { + currentLoopRows = perLoopRows; + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyInIndices(loop); + CopyOut(loop); + } + } +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_GATHER_OUT_H diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_quant.h new file mode 100644 index 00000000000..68b7c927afb --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_quant.h @@ -0,0 +1,235 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_gather_quant.h + * \brief + */ +#ifndef MOE_V2_GATHER_QUANT_H +#define MOE_V2_GATHER_QUANT_H + +#include "moe_v2_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +constexpr int64_t BUFFER_NUM = 2; + +template +class MoeV2GatherQuant { + public: + __aicore__ inline MoeV2GatherQuant(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR workspace, const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyInIndices(int64_t progress); + __aicore__ inline void Compute(); + __aicore__ inline void CopyOut(int64_t progress); + + private: + TPipe* pipe; + TQue inputXCopyInQueue; + TQue expandRowIdxCopyInQueue; + TQue inputXCopyOutQueue; + TQue floatQueue; + TQue halfQueue; + + GlobalTensor inputXGm; + GlobalTensor expandedXGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor scaleGm; + GlobalTensor offsetGm; + + const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData; + + int64_t needCoreNum; + int64_t blockIdx; + int64_t cols; + int64_t n; + int64_t k; + int64_t activateRows; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t colsTileLength; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + int64_t dropPadMode; + float scale; + float offset; + + int64_t indicesOffset; + int64_t inputOffset; + int64_t outOffset; +}; + +template +__aicore__ inline void MoeV2GatherQuant::CopyInIndices(int64_t progress) { + this->indicesOffset = progress * this->perLoopRows; + LocalTensor indicesLocal = expandRowIdxCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(this->currentLoopRows * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm[indicesOffset], dataCopyParams, dataCopyPadParams); + expandRowIdxCopyInQueue.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherQuant::Compute() { + LocalTensor inLocal = inputXCopyInQueue.DeQue(); + LocalTensor outLocal = inputXCopyOutQueue.AllocTensor(); + LocalTensor floatLocal = floatQueue.AllocTensor(); + LocalTensor halfLocal = halfQueue.AllocTensor(); + uint32_t elements = Align(this->colsTileLength, sizeof(T)); + if constexpr (IsSameType::value) { + Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements); + pipe_barrier(PIPE_V); + Cast(halfLocal, floatLocal, RoundMode::CAST_NONE, elements); + pipe_barrier(PIPE_V); + Muls(halfLocal, halfLocal, static_cast(this->scale), elements); + pipe_barrier(PIPE_V); + Adds(halfLocal, halfLocal, static_cast(this->offset), elements); + pipe_barrier(PIPE_V); + LocalTensor intLocal = floatLocal.ReinterpretCast(); + Cast(intLocal, halfLocal, RoundMode::CAST_RINT, elements); + pipe_barrier(PIPE_V); + SetDeqScale((half)1.000000e+00f); + pipe_barrier(PIPE_V); + Cast(halfLocal, intLocal, RoundMode::CAST_RINT, elements); + pipe_barrier(PIPE_V); + Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements); + } else if constexpr (IsSameType::value) { + Cast(halfLocal, inLocal, RoundMode::CAST_NONE, elements); + pipe_barrier(PIPE_V); + Muls(halfLocal, halfLocal, static_cast(this->scale), elements); + pipe_barrier(PIPE_V); + Adds(halfLocal, halfLocal, static_cast(this->offset), elements); + pipe_barrier(PIPE_V); + Cast(outLocal, halfLocal, RoundMode::CAST_RINT, elements); + } else { + Muls(inLocal, inLocal, static_cast(this->scale), elements); + pipe_barrier(PIPE_V); + Adds(inLocal, inLocal, static_cast(this->offset), elements); + pipe_barrier(PIPE_V); + Cast(outLocal, inLocal, RoundMode::CAST_RINT, elements); + } + inputXCopyOutQueue.EnQue(outLocal); + floatQueue.FreeTensor(floatLocal); + halfQueue.FreeTensor(halfLocal); +} + +template +__aicore__ inline void MoeV2GatherQuant::CopyOut(int64_t progress) { + LocalTensor indicesLocal = expandRowIdxCopyInQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + colsTileLength = this->perLoopCols; + for (int64_t colsLoop = 0; colsLoop < this->colLoops; colsLoop++) { + int64_t initialRow = this->gatherOutTilingData->perCoreRows * this->blockIdx + this->perLoopRows * progress; + int64_t curLoopRow = 0; + if (colsLoop == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + int64_t currentLoopStartRow = initialRow / this->k; + int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor inLocal = inputXCopyInQueue.AllocTensor(); + // input row position + inputOffset = row * this->cols + colsLoop * this->perLoopCols; + DataCopyExtParams dataCopyParams{1, static_cast(this->colsTileLength * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal, inputXGm[inputOffset], dataCopyParams, dataCopyPadParams); + inputXCopyInQueue.EnQue(inLocal); + Compute(); + LocalTensor outLocal = inputXCopyOutQueue.DeQue(); + DataCopyExtParams intriParams{1, static_cast(this->colsTileLength * sizeof(int8_t)), 0, 0, 0}; + while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { + continue; + } + outOffset = outIndex * cols + colsLoop * this->perLoopCols; + DataCopyPad(expandedXGm[outOffset], outLocal, intriParams); + } + inputXCopyInQueue.FreeTensor(inLocal); + inputXCopyOutQueue.FreeTensor(outLocal); + } + } + expandRowIdxCopyInQueue.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeV2GatherQuant::Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR workspace, + const MoeInitRoutingQuantV2TilingData* tilingData, TPipe* tPipe) { + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->gatherOutTilingData = &(tilingData->gatherOutComputeParamsOp); + + this->needCoreNum = this->gatherOutTilingData->needCoreNum; + this->activateRows = this->gatherOutTilingData->activateRows; + this->cols = tilingData->cols; + this->n = tilingData->n; + this->k = tilingData->k; + this->dropPadMode = tilingData->dropPadMode; + + if (this->blockIdx == this->gatherOutTilingData->needCoreNum - 1) { + this->coreRows = this->gatherOutTilingData->lastCoreRows; + this->perLoopRows = this->gatherOutTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->lastCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->lastCoreLoops; + } else { + this->coreRows = this->gatherOutTilingData->perCoreRows; + this->perLoopRows = this->gatherOutTilingData->perCorePerLoopRows; + this->lastLoopRows = this->gatherOutTilingData->perCoreLastLoopRows; + this->rowLoops = this->gatherOutTilingData->perCoreLoops; + } + this->perLoopCols = this->gatherOutTilingData->perLoopCols; + this->lastLoopCols = this->gatherOutTilingData->lastLoopCols; + this->colLoops = this->gatherOutTilingData->colLoops; + + inputXGm.SetGlobalBuffer((__gm__ T*)inputX); + expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX); + expandedRowIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)expandedRowIdx + this->blockIdx * this->gatherOutTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + scaleGm.SetGlobalBuffer((__gm__ float*)scale, 1); + offsetGm.SetGlobalBuffer((__gm__ float*)offset, 1); + this->scale = scaleGm.GetValue(0); + this->offset = offsetGm.GetValue(0); + + pipe->InitBuffer(inputXCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(T))); + pipe->InitBuffer(inputXCopyOutQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(int8_t))); + pipe->InitBuffer(expandRowIdxCopyInQueue, BUFFER_NUM, AlignBytes(this->perLoopRows, sizeof(int32_t))); + pipe->InitBuffer(floatQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(halfQueue, 1, AlignBytes(this->perLoopCols, sizeof(half))); +} + +template +__aicore__ inline void MoeV2GatherQuant::Process() { + if (this->blockIdx < this->needCoreNum) { + currentLoopRows = perLoopRows; + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyInIndices(loop); + CopyOut(loop); + } + } +} +} // namespace MoeInitRoutingQuantV2 +#endif // MOE_V2_GATHER_QUANT_H diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_init_routing_fullload.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_init_routing_fullload.h new file mode 100644 index 00000000000..539bc69c58f --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_init_routing_fullload.h @@ -0,0 +1,312 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_v2_init_routing_fullload.h + * \brief + */ +#ifndef INNER_MOE_V2_FULL_LOAD_H +#define INNER_MOE_V2_FULL_LOAD_H + +#include "moe_v2_mrgsort.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +template +class MoeV2FullLoad : public MoeV2SortBase { + public: + __aicore__ inline MoeV2FullLoad(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOutIdx(); + __aicore__ inline void CopyOutEmpty(); + __aicore__ inline void CopyOutX(); + __aicore__ inline void ComputeExpertTokenCountOrCumsum(); + + private: + int64_t sortNum_; + const InnerMoeV2GatherOutComputeTilingData* gatherOutTilingData_; + int64_t blockIdx_; + int64_t needCoreNum_; + int64_t coreRows_; + int64_t perCoreRows_; + int64_t k_; + int64_t n_; + int64_t cols_; + int64_t activateRows_; + int64_t expertNum; + int64_t expertCapacity; + + TQue xCopyInQueue_; + TQue expandedRowIdxCopyOutQueue_; + TQue expandedExpertIdxCopyOutQueue_; + TQue expandDstToSrcRowQueue_; + TQue expertTokensCopyOutQueue_; + + GlobalTensor xGm_; + GlobalTensor expertIdxGm_; + + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expandedExpertIdxGm_; + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + int64_t dropPadMode = 0; +}; + +template +__aicore__ inline void MoeV2FullLoad::CopyIn() { + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(this->totalLength * sizeof(int32_t)), + 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams); + ArithProgression(inLocal[this->sortNum_], 0, 1, this->totalLength); + sortDataCopyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::SortCompute() { + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertIdxLocal = inLocal[0]; + LocalTensor expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast(); + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + pipe_barrier(PIPE_V); + Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + LocalTensor concatLocal; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum_)); + Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + LocalTensor rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast(); + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum_)); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + LocalTensor expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast(); + Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, + this->totalLength); + pipe_barrier(PIPE_V); + Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + LocalTensor expandedExpertIdxLocalInt32; + expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast(); + Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, this->totalLength); + pipe_barrier(PIPE_V); + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdxLocalInt32); + + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandedRowIdxU32 = expandedRowIdx.ReinterpretCast(); + Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, this->totalLength); + pipe_barrier(PIPE_V); + ArithProgression(inLocal[this->sortNum_], 0, 1, this->totalLength); + pipe_barrier(PIPE_V); + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Extract(tempTensor, expandedRowIdxU32, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); + sortDataCopyInQueue.FreeTensor(inLocal); + + expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::CopyOutIdx() { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); + DataCopyPad(expandedRowIdxGm_, expandedRowIdx, intriParams); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); +} + +template +__aicore__ inline void MoeV2FullLoad::ComputeExpertTokenCountOrCumsum() { + LocalTensor expandedExpertIdx = expandedExpertIdxCopyOutQueue_.DeQue(); + LocalTensor expertTokensCount = expertTokensCopyOutQueue_.AllocTensor(); + + int64_t expertNumAlign = Align(this->expertNum, sizeof(int32_t)); + Duplicate(expertTokensCount, 0, expertNumAlign); + SetWaitFlag(HardEvent::V_S); + + int32_t lastExpertId = expandedExpertIdx.GetValue(0); + int64_t tokenCount = 0; + int64_t lastExpertCount = 0; + for (int64_t i = 0; i < this->totalLength; i++) { + int32_t curExpertId = expandedExpertIdx.GetValue(i); + tokenCount++; + while (lastExpertId < curExpertId) { + expertTokensCount.SetValue(lastExpertId, tokenCount - 1); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_COUNT) { + tokenCount = 1; + } + lastExpertId++; + } + } + expertTokensCount.SetValue(lastExpertId, tokenCount); + if (this->expertTokensCountOrCumsumFlag == EXERPT_TOKENS_CUMSUM) { + lastExpertId++; + while (lastExpertId < this->expertNum) { + expertTokensCount.SetValue(lastExpertId, tokenCount); + lastExpertId++; + } + } + DataCopyExtParams copyParams{static_cast(1), static_cast(this->expertNum * sizeof(int32_t)), 0, 0, + 0}; + if (this->expertTokensCountOrCumsumFlag > 0) { + DataCopyPad(expertTokensCountOrCumsumGm, expertTokensCount, copyParams); + } + expertTokensCopyOutQueue_.FreeTensor(expertTokensCount); + expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx); +} + +template +__aicore__ inline void MoeV2FullLoad::CopyOutX() { + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->cols_ * sizeof(T); + int64_t inFactor = Align(this->cols_, sizeof(T)); + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t startXRow = curRowsStart / this->k_; + int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_; + + DataCopyExtParams dataXCopyParams{static_cast(endXRow - startXRow + 1), + static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataXCopyPadParams{false, 0, 0, 0}; + DataCopyPad(xLocal, xGm_[startXRow * this->cols_], dataXCopyParams, dataXCopyPadParams); + SetWaitFlag(HardEvent::MTE2_S); + + int64_t k = 0; + for (int64_t i = startXRow; i <= endXRow; i++) { + for (; k < this->perCoreRows_ && curRowsStart / this->k_ == i; curRowsStart++, k++) { + int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); + if (outIndex < this->activateRows_) { + DataCopyPad(expandedXGm_[outIndex * this->cols_], xLocal[(i - startXRow) * inFactor], intriParams); + } + } + } + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + xCopyInQueue_.FreeTensor(xLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::CopyOutEmpty() { + LocalTensor outLocal = expandedExpertIdxCopyOutQueue_.DeQue(); + expandedExpertIdxCopyOutQueue_.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2FullLoad::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const InnerMoeInitRoutingV2TilingData* tilingData, TPipe* tPipe) { + this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + this->blockIdx_ = get_block_idx() + get_subblockid() * get_block_num(); + this->k_ = tilingData->k; + this->n_ = tilingData->n; + this->cols_ = tilingData->cols; + this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; + this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; + this->activateRows_ = this->gatherOutTilingData_->activateRows; + if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) { + this->coreRows_ = this->gatherOutTilingData_->lastCoreRows; + } else { + this->coreRows_ = this->gatherOutTilingData_->perCoreRows; + } + this->expertNum = tilingData->expertNum; + this->dropPadMode = tilingData->dropPadMode; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum_ = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->pipe = tPipe; + + xGm_.SetGlobalBuffer((__gm__ T*)x); + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength); + + expandedXGm_.SetGlobalBuffer((__gm__ T*)expandedX); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, this->tileLength); + if (this->expertTokensCountOrCumsumFlag > 0) { + // dropless + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, + Align(this->expertNum, sizeof(int32_t))); + } + + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum_ * sizeof(int32_t); + + int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; + int64_t startXRow = curRowsStart / this->k_; + int64_t endXRow = (curRowsStart + this->coreRows_ - 1) / this->k_; + pipe->InitBuffer(xCopyInQueue_, bufferNum, AlignBytes(this->cols_, sizeof(T)) * (endXRow - startXRow + 1)); + pipe->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum, buffSize); + pipe->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum, buffSize); + pipe->InitBuffer(expertTokensCopyOutQueue_, bufferNum, AlignBytes(this->expertNum, sizeof(int32_t))); + pipe->InitBuffer(expandDstToSrcRowQueue_, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize * kvFactor); + pipe->InitBuffer(tempBuffer, buffSize * kvFactor); + pipe->InitBuffer(sortedBuffer, buffSize * kvFactor); +} + +template +__aicore__ inline void MoeV2FullLoad::Process() { + if (this->blockIdx_ < this->needCoreNum_) { + CopyIn(); + SortCompute(); + if (this->blockIdx_ == 0) { + CopyOutIdx(); + } + if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + ComputeExpertTokenCountOrCumsum(); + } else { + CopyOutEmpty(); + } + CopyOutX(); + } +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_FULL_LOAD_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_mrgsort.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_mrgsort.h new file mode 100644 index 00000000000..f72b66373c2 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_mrgsort.h @@ -0,0 +1,189 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_mrgsort.h + * \brief + */ +#ifndef INNER_MOE_V2_MRGSORT_H +#define INNER_MOE_V2_MRGSORT_H + +#include "moe_v2_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +struct MoeV2MrgsortParam { + int64_t perListElements; + int64_t lastListElements; + int64_t oneLoopMaxElements; +}; + +class MoeV2Mrgsort { + public: + __aicore__ inline MoeV2Mrgsort(){}; + __aicore__ inline void Init(MoeV2MrgsortParam* param); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor& gmInput, LocalTensor& ubInput); + __aicore__ inline void SetOutput(GlobalTensor& gmOutput, LocalTensor& ubOutput); + + private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + + private: + MoeV2MrgsortParam* param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput; + + LocalTensor ubInputs[4]; + LocalTensor ubOutput; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail{0}; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeV2Mrgsort::ClearCache() { + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeV2Mrgsort::SetInput(GlobalTensor& gmInput, LocalTensor& ubInput) { + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeV2Mrgsort::SetOutput(GlobalTensor& gmOutput, LocalTensor& ubOutput) { + this->gmOutput = gmOutput; + this->ubOutput = ubOutput; +} + +__aicore__ inline void MoeV2Mrgsort::UpdateMrgParam() { + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeV2Mrgsort::CopyIn() { + this->remainListNum = 0; + SetWaitFlag(HardEvent::MTE3_MTE2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeV2Mrgsort::MrgsortCompute() { + SetWaitFlag(HardEvent::MTE2_V); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->ubOutput, this->tmpUbInputs[0], Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeV2Mrgsort::UpdateSortInfo() { + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeV2Mrgsort::CopyOut() { + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = GetSortLen(curLoopSortedNum) * sizeof(float); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(this->gmOutput[outOffset], this->ubOutput, intriParams); + outOffset += GetSortLen(curLoopSortedNum); +} + +__aicore__ inline void MoeV2Mrgsort::Init(MoeV2MrgsortParam* param) { + this->param = param; + this->remainListNum = listNum; + + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + if (i == listNum - 1) { + listRemainElements[i] = param->lastListElements; + } else { + listRemainElements[i] = param->perListElements; + } + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeV2Mrgsort::Process() { + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + CopyOut(); + } + + ClearCache(); +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_MRGSORT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_mrgsort_out.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_mrgsort_out.h new file mode 100644 index 00000000000..f08e56de0db --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_mrgsort_out.h @@ -0,0 +1,213 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_mrgsort_out.h + * \brief + */ +#ifndef INNER_MOE_V2_MRGSORT_OUT_H +#define INNER_MOE_V2_MRGSORT_OUT_H + +#include "moe_v2_mrgsort.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +class MoeV2MrgsortOut { + public: + __aicore__ inline MoeV2MrgsortOut(){}; + __aicore__ inline void Init(MoeV2MrgsortParam* param, TPipe* tPipe); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor& gmInput, LocalTensor& ubInput); + __aicore__ inline void SetOutput(GlobalTensor& gmOutput1, GlobalTensor& gmOutput2, + LocalTensor& ubOutput1, LocalTensor& ubOutput2); + __aicore__ inline void SetBuffer(LocalTensor& tempBuffer); + + private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void Extract(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + + private: + MoeV2MrgsortParam* param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput1; + GlobalTensor gmOutput2; + + LocalTensor ubInputs[4]; + LocalTensor tempBuffer; + + // for extract + LocalTensor ubOutput1; + LocalTensor ubOutput2; + + // for copy out + LocalTensor ubOutputInt1; + LocalTensor ubOutputInt2; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeV2MrgsortOut::ClearCache() { + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeV2MrgsortOut::SetInput(GlobalTensor& gmInput, LocalTensor& ubInput) { + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeV2MrgsortOut::SetOutput(GlobalTensor& gmOutput1, GlobalTensor& gmOutput2, + LocalTensor& ubOutput1, LocalTensor& ubOutput2) { + this->gmOutput1 = gmOutput1; + this->ubOutput1 = ubOutput1; + this->ubOutputInt1 = ubOutput1.ReinterpretCast(); + + this->gmOutput2 = gmOutput2; + this->ubOutput2 = ubOutput2.ReinterpretCast(); + this->ubOutputInt2 = ubOutput2.ReinterpretCast(); +} + +__aicore__ inline void MoeV2MrgsortOut::SetBuffer(LocalTensor& tempBuffer) { + this->tempBuffer = tempBuffer; +} + +__aicore__ inline void MoeV2MrgsortOut::UpdateMrgParam() { + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeV2MrgsortOut::CopyIn() { + this->remainListNum = 0; + SetWaitFlag(HardEvent::MTE3_MTE2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeV2MrgsortOut::MrgsortCompute() { + SetWaitFlag(HardEvent::MTE2_V); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->tempBuffer, this->tmpUbInputs[0], Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeV2MrgsortOut::UpdateSortInfo() { + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeV2MrgsortOut::Extract() { + AscendC::Extract(this->ubOutput1, this->ubOutput2, this->tempBuffer, Ceil(curLoopSortedNum, ONE_REPEAT_SORT_NUM)); + pipe_barrier(PIPE_V); + Muls(this->ubOutput1, this->ubOutput1, (float)-1, Align(curLoopSortedNum, sizeof(float))); + pipe_barrier(PIPE_V); + Cast(this->ubOutputInt1, this->ubOutput1, RoundMode::CAST_ROUND, Align(curLoopSortedNum, sizeof(float))); +} + +__aicore__ inline void MoeV2MrgsortOut::CopyOut() { + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = curLoopSortedNum * sizeof(int32_t); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyPad(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams); + DataCopyPad(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams); + outOffset += curLoopSortedNum; +} + +__aicore__ inline void MoeV2MrgsortOut::Init(MoeV2MrgsortParam* param, TPipe* tPipe) { + this->param = param; + this->allRemainElements = 0; + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + if (i == listNum - 1) { + listRemainElements[i] = param->lastListElements; + } else { + listRemainElements[i] = param->perListElements; + } + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeV2MrgsortOut::Process() { + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + Extract(); + CopyOut(); + } + ClearCache(); +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_MRGSORT_OUT_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_base.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_base.h new file mode 100644 index 00000000000..203afb6a6a1 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_base.h @@ -0,0 +1,70 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_sort_base.h + * \brief + */ +#ifndef INNER_MOE_V2_SORT_BASE_H +#define INNER_MOE_V2_SORT_BASE_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +class MoeV2SortBase { + public: + __aicore__ inline MoeV2SortBase(){}; + + protected: + __aicore__ inline void SyncAll(); + + protected: + TPipe* pipe; + TQue sortDataCopyInQueue; + TQue sortDataCopyOutQueue; + TBuf tempBuffer; + TBuf sortedBuffer; + + GlobalTensor expertIdxGm; + GlobalTensor sortedexpertIdxGm; + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expertTokensCountOrCumsumGm; + GlobalTensor expertTokensBeforeCapacityGm; + + int64_t tileLength; + int64_t bufferNum = 1; + int64_t totalLength; + int64_t coreNum; + int64_t n; + int64_t k; + int64_t existRowIdx; + int64_t expertNum; + int64_t expertTokensCountOrCumsumFlag = 0; + int64_t expertTokensBeforeCapacityFlag = 0; + + static constexpr int64_t SYNC_GM_NUM = 2; + static constexpr int64_t WORK_GM_NUM = 2; + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; +}; + +__aicore__ inline void MoeV2SortBase::SyncAll() { + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_SORT_BASE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_multi_core.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_multi_core.h new file mode 100644 index 00000000000..8484e837a3e --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_multi_core.h @@ -0,0 +1,373 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_sort_multi_core.h + * \brief + */ +#ifndef INNER_MOE_V2_VBS_ONE_CORE_H +#define INNER_MOE_V2_VBS_ONE_CORE_H + +#include "moe_v2_sort_base.h" +#include "moe_v2_mrgsort.h" +#include "moe_v2_mrgsort_out.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +class MoeV2SortMultiCore : public MoeV2SortBase { + public: + __aicore__ inline MoeV2SortMultiCore(){}; + template + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, + GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void VBSProcess(); + __aicore__ inline void UBSortProcess(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void OneCoreVMSProcess(int64_t listNum, int64_t perListElements, int64_t lastListElements); + __aicore__ inline void VMSProcess(); + __aicore__ inline void SortOutProcess(); + __aicore__ inline void VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void UBSortCompute(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void InitMoeMrgSort(MoeV2Mrgsort* sorter, int64_t listNum, int64_t coreOffset, int64_t loopOffset); + __aicore__ inline void InitMoeMrgSortOut(MoeV2MrgsortOut* sorter, int64_t listNum, int64_t coreOffset); + __aicore__ inline void InitExpertTokensGlobalMemory(); + + private: + GlobalTensor workspaceGms[2]; + + const InnerMoeV2VBSComputeTilingData* vbsTilingData; + const InnerMoeV2VMSMiddleComputeTilingData* vmsTilingData; + const InnerMoeV2SortOutComputeTilingData* sortOutTilingData; + + // for MoeMrgsort + MoeV2Mrgsort mrgsorter; + MoeV2MrgsortParam mrgsortParam; + + int64_t coreNum; + int64_t blockIdx; + int64_t srcWsIndex = 0; + + int64_t listNum; + int64_t perListElements; + int64_t lastListElements; + + int64_t sortTotalLength; + int64_t sortCoreLoops; + int64_t sortCoreLoopElements; + int64_t sortCoreLastLoopElements; + + int64_t perCoreExpert; + int64_t needInitExpertCore; + int64_t currentCoreExpert; + + static constexpr int64_t MAX_MRGSORT_LIST = 4; +}; + +__aicore__ inline void MoeV2SortMultiCore::InitExpertTokensGlobalMemory() { + if (this->blockIdx < this->needInitExpertCore) { + if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + InitGlobalMemory(expertTokensCountOrCumsumGm, currentCoreExpert, 0); + } + if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) { + InitGlobalMemory(expertTokensBeforeCapacityGm, currentCoreExpert, 0); + } + } +} + +__aicore__ inline void MoeV2SortMultiCore::VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum) { + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + int64_t inOffset = progress * sortCoreLoopElements; + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(size * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm[inOffset], dataCopyParams, dataCopyPadParams); + + LocalTensor rowIdxLocal = inLocal[sortNum]; + int64_t startValue = this->blockIdx * this->vbsTilingData->perCoreElements + inOffset; + SetWaitFlag(HardEvent::MTE3_S); + ArithProgression(rowIdxLocal, startValue, 1, size); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::UBSortCompute(int64_t progress, int64_t size, int64_t sortNum) { + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertForSourceRowLocal = inLocal[0]; + LocalTensor expertForSourceRowLocalFp32; + + expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, sortNum); + pipe_barrier(PIPE_V); + Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, sortNum); + pipe_barrier(PIPE_V); + + int64_t duplicateNum = size % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = size - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + + LocalTensor concatLocal = expertForSourceRowLocalFp32; + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(sortNum)); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[sortNum].ReinterpretCast(); + Sort(outLocal, concatLocal, sourceRowLocal, sortedLocal, sortNum / ONE_REPEAT_SORT_NUM); + + sortDataCopyOutQueue.EnQue(outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum) { + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + DataCopy(workspaceGms[0][this->blockIdx * GetSortLen(this->vbsTilingData->perCoreElements) + + GetSortLen(progress * sortCoreLoopElements)], + outLocal, Align(GetSortLen(size), sizeof(float))); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSort(MoeV2Mrgsort* sorter, int64_t listNum, int64_t coreOffset, + int64_t loopOffset) { + GlobalTensor srcWsGm = workspaceGms[srcWsIndex][blockIdx * coreOffset + loopOffset]; + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + for (int64_t i = 0; i < listNum; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * i]; + sorter->SetInput(srcWsGm, inLocalT); + } + GlobalTensor dstWsGm = workspaceGms[1 - srcWsIndex][blockIdx * coreOffset + loopOffset]; + sorter->SetOutput(dstWsGm, outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::InitMoeMrgSortOut(MoeV2MrgsortOut* sorter, int64_t listNum, + int64_t coreOffset) { + GlobalTensor srcWsGm = workspaceGms[srcWsIndex]; + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + + for (int64_t i = 0; i < listNum; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * i]; + sorter->SetInput(srcWsGm, inLocalT); + } + + LocalTensor outLocalV = outLocal[this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST]; + sorter->SetOutput(this->sortedexpertIdxGm, this->expandDstToSrcRowGm, outLocal, outLocalV); + + LocalTensor tempBuffer = + sortedBuffer.Get(GetSortLen(this->sortOutTilingData->oneLoopMaxElements) * MAX_MRGSORT_LIST); + sorter->SetBuffer(tempBuffer); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2SortMultiCore::OneCoreVMSProcess(int64_t listNum, int64_t perListElements, + int64_t lastListElements) { + int64_t coreOffset = GetSortLen(this->vbsTilingData->perCoreElements); + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + + for (int64_t i = 0; listNum >= 1; i++) { + int64_t loops = (listNum + MAX_MRGSORT_LIST - 1) / MAX_MRGSORT_LIST; + int64_t remainListNum = listNum - (loops - 1) * MAX_MRGSORT_LIST; + + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = perListElements; + + int64_t loopOffset = GetSortLen(mrgsortParam.perListElements * MAX_MRGSORT_LIST); + for (int64_t loop = 0; loop < loops - 1; loop++) { + InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, loop * loopOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, (loops - 1) * loopOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + + listNum = loops; + lastListElements = perListElements * (remainListNum - 1) + lastListElements; + perListElements = perListElements * MAX_MRGSORT_LIST; + srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM; + + if (loops == 1) { + break; + } + } +} + +__aicore__ inline void MoeV2SortMultiCore::UBSortProcess(int64_t progress, int64_t size, int64_t sortNum) { + VBSCopyIn(progress, size, sortNum); + UBSortCompute(progress, size, sortNum); + VBSCopyOut(progress, size, sortNum); +} + +__aicore__ inline void MoeV2SortMultiCore::VBSProcess() { + if (this->blockIdx < this->vbsTilingData->needCoreNum) { + int64_t sortNum = Ceil(sortCoreLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + for (int64_t loop = 0; loop < sortCoreLoops - 1; loop++) { + UBSortProcess(loop, sortCoreLoopElements, sortNum); + } + + sortNum = Ceil(sortCoreLastLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + UBSortProcess(sortCoreLoops - 1, sortCoreLastLoopElements, sortNum); + if (sortCoreLoops > 1) { + OneCoreVMSProcess(sortCoreLoops, sortCoreLoopElements, sortCoreLastLoopElements); + } + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +__aicore__ inline void MoeV2SortMultiCore::VMSProcess() { + int64_t currentStageNeedCoreNum = this->vmsTilingData->needCoreNum; + perListElements = this->vbsTilingData->perCoreElements; + lastListElements = this->vbsTilingData->lastCoreElements; + listNum = this->vbsTilingData->needCoreNum; + + for (; listNum > MAX_MRGSORT_LIST;) { + currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST); + int64_t coreOffset = GetSortLen(perListElements * MAX_MRGSORT_LIST); + int64_t remainListNum = listNum - (currentStageNeedCoreNum - 1) * MAX_MRGSORT_LIST; + + if (this->blockIdx < currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = perListElements; + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, 0); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } else if (this->blockIdx == currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, 0); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + listNum = currentStageNeedCoreNum; + currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST); + srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM; + + lastListElements = perListElements * (remainListNum - 1) + lastListElements; + perListElements = perListElements * MAX_MRGSORT_LIST; +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif + } +} + +__aicore__ inline void MoeV2SortMultiCore::SortOutProcess() { + if (this->blockIdx < 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + mrgsortParam.oneLoopMaxElements = this->sortOutTilingData->oneLoopMaxElements; + + MoeV2MrgsortOut sorter; + InitMoeMrgSortOut(&sorter, listNum, GetSortLen(perListElements)); + sorter.Init(&mrgsortParam, pipe); + sorter.Process(); + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2SortMultiCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, + const TilingData* tilingData, TPipe* tPipe) { + this->totalLength = tilingData->n * tilingData->k; + this->coreNum = tilingData->coreNum; + this->vbsTilingData = &(tilingData->vbsComputeParamsOp); + this->vmsTilingData = &(tilingData->vmsMiddleComputeParamsOp); + this->sortOutTilingData = &(tilingData->sortOutComputeParamsOp); + + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->tileLength = this->vbsTilingData->perCorePerLoopElements; + this->sortTotalLength = this->vbsTilingData->perCoreElements; + if (this->blockIdx == tilingData->vbsComputeParamsOp.needCoreNum - 1) { + this->tileLength = this->vbsTilingData->lastCorePerLoopElements; + this->sortTotalLength = this->vbsTilingData->lastCoreElements; + } + this->n = tilingData->n; + this->k = tilingData->k; + this->expertNum = tilingData->expertNum; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag; + + // VBS param init + if (this->blockIdx == this->vbsTilingData->needCoreNum - 1) { + sortCoreLoops = this->vbsTilingData->lastCoreLoops; + sortCoreLoopElements = this->vbsTilingData->lastCorePerLoopElements; + sortCoreLastLoopElements = this->vbsTilingData->lastCoreLastLoopElements; + } else { + sortCoreLoops = this->vbsTilingData->perCoreLoops; + sortCoreLoopElements = this->vbsTilingData->perCorePerLoopElements; + sortCoreLastLoopElements = this->vbsTilingData->perCoreLastLoopElements; + } + + this->pipe = tPipe; + expertIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)expertIdx + this->blockIdx * tilingData->vbsComputeParamsOp.perCoreElements, + this->sortTotalLength); + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace), + Align(this->totalLength, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer( + reinterpret_cast<__gm__ int32_t*>(workspace) + Align(this->totalLength, sizeof(int32_t)), + Align(this->totalLength, sizeof(int32_t))); + + this->perCoreExpert = Align((this->expertNum + this->coreNum - 1) / this->coreNum, sizeof(int32_t)); + this->needInitExpertCore = (this->expertNum + this->perCoreExpert - 1) / this->perCoreExpert; + this->currentCoreExpert = this->perCoreExpert; + if (this->blockIdx == needInitExpertCore - 1) { + this->currentCoreExpert = this->expertNum - (this->needInitExpertCore - 1) * this->perCoreExpert; + } + if (this->expertTokensCountOrCumsumFlag > EXERPT_TOKENS_NONE) { + expertTokensCountOrCumsumGm.SetGlobalBuffer( + (__gm__ int32_t*)expertTokensCountOrCumsum + this->blockIdx * this->perCoreExpert, this->currentCoreExpert); + } + if (this->expertTokensBeforeCapacityFlag == EXERPT_TOKENS_BEFORE_CAPACITY) { + expertTokensBeforeCapacityGm.SetGlobalBuffer( + (__gm__ int32_t*)expertTokensBeforeCapacity + this->blockIdx * this->perCoreExpert, this->currentCoreExpert); + } + // key and value + int64_t kvFactor = 2; + workspaceGms[0].SetGlobalBuffer((__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + Align(this->totalLength, sizeof(int32_t)) * kvFactor); + workspaceGms[1].SetGlobalBuffer((__gm__ float*)workspace + Align(this->totalLength, sizeof(int32_t)) * (kvFactor + 2), + Align(this->totalLength, sizeof(int32_t)) * kvFactor); + + int64_t bufferSize = Ceil(Max(this->sortOutTilingData->oneLoopMaxElements * MAX_MRGSORT_LIST, sortCoreLoopElements), + ONE_REPEAT_SORT_NUM) * + ONE_REPEAT_SORT_NUM * sizeof(int32_t) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortedBuffer, bufferSize); +} + +__aicore__ inline void MoeV2SortMultiCore::Process() { + InitExpertTokensGlobalMemory(); + VBSProcess(); + VMSProcess(); + SortOutProcess(); +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_VBS_ONE_CORE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_one_core.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_one_core.h new file mode 100644 index 00000000000..0778308d944 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_sort_one_core.h @@ -0,0 +1,162 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_sort_one_core.h + * \brief + */ +#ifndef INNER_MOE_V2_SORT_ONE_CORE_H +#define INNER_MOE_V2_SORT_ONE_CORE_H + +#include "moe_v2_mrgsort.h" +#include "moe_v2_sort_base.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +class MoeV2SortOneCore : public MoeV2SortBase { + public: + __aicore__ inline MoeV2SortOneCore(){}; + template + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expertTokensBeforeCapacity, + GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOut(); + + private: + int64_t sortNum; + int64_t blockIdx; +}; + +__aicore__ inline void MoeV2SortOneCore::CopyIn() { + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(this->totalLength * sizeof(int32_t)), + 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams); + + LocalTensor rowIdxLocal = inLocal[this->sortNum]; + ArithProgression(rowIdxLocal, 0, 1, this->sortNum); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2SortOneCore::SortCompute() { + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertForSourceRowLocal = inLocal[0]; + LocalTensor expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength); + pipe_barrier(PIPE_V); + Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, this->tileLength); + pipe_barrier(PIPE_V); + + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + pipe_barrier(PIPE_V); + } + + LocalTensor concatLocal; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum)); + Concat(concatLocal, expertForSourceRowLocalFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum)); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[this->sortNum].ReinterpretCast(); + Sort(sortedLocal, concatLocal, sourceRowLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + LocalTensor sortedExpertForSourceRowLocal = outLocal[0]; + LocalTensor expandDstToSrcRowLocal; + expandDstToSrcRowLocal = outLocal[this->sortNum].ReinterpretCast(); + Extract(sortedExpertForSourceRowLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM); + pipe_barrier(PIPE_V); + Muls(sortedExpertForSourceRowLocal, sortedExpertForSourceRowLocal, (float)-1, this->tileLength); + pipe_barrier(PIPE_V); + + LocalTensor expertForSourceRowLocalInt32; + expertForSourceRowLocalInt32 = sortedExpertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalInt32, sortedExpertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength); + sortDataCopyOutQueue.EnQue(outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeV2SortOneCore::CopyOut() { + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); + DataCopyPad(sortedexpertIdxGm, outLocal[0], intriParams); + DataCopyPad(expandDstToSrcRowGm, outLocal[this->sortNum], intriParams); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2SortOneCore::Init(GM_ADDR expertIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expertTokensBeforeCapacity, GM_ADDR workspace, + const TilingData* tilingData, TPipe* tPipe) { + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->coreNum = tilingData->coreNum; + this->pipe = tPipe; + this->n = tilingData->n; + this->k = tilingData->k; + this->expertNum = tilingData->expertNum; + this->expertTokensCountOrCumsumFlag = tilingData->expertTokensCountOrCumsumFlag; + this->expertTokensBeforeCapacityFlag = tilingData->expertTokensBeforeCapacityFlag; + + expertIdxGm.SetGlobalBuffer((__gm__ int32_t*)expertIdx, this->tileLength); + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace), this->tileLength); + expandDstToSrcRowGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspace) + this->tileLength, + this->tileLength); + + if (this->blockIdx == this->coreNum - 1) { + if (this->expertTokensCountOrCumsumFlag > 0) { + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensCountOrCumsum, + Align(this->expertNum, sizeof(int32_t))); + InitGlobalMemory(expertTokensCountOrCumsumGm, this->expertNum, 0); + } + if (this->expertTokensBeforeCapacityFlag == 1) { + expertTokensBeforeCapacityGm.SetGlobalBuffer((__gm__ int32_t*)expertTokensBeforeCapacity, + Align(this->expertNum, sizeof(int32_t))); + InitGlobalMemory(expertTokensBeforeCapacityGm, this->expertNum, 0); + } + } + // key and value + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum * sizeof(int32_t) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, buffSize); + pipe->InitBuffer(tempBuffer, buffSize); + pipe->InitBuffer(sortedBuffer, buffSize); +} + +__aicore__ inline void MoeV2SortOneCore::Process() { + if (get_block_idx() + get_subblockid() * get_block_num() < 1) { + CopyIn(); + SortCompute(); + CopyOut(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_SORT_ONE_CORE_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_and_gather.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_and_gather.h new file mode 100644 index 00000000000..2fb99194ce1 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_and_gather.h @@ -0,0 +1,560 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_src_to_dst_and_gather.h + * \brief + */ +#ifndef MOE_V2_SRC_TO_DST_AND_GATHER_H +#define MOE_V2_SRC_TO_DST_AND_GATHER_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +template +class MoeV2SrcToDstAndGather { + public: + __aicore__ inline MoeV2SrcToDstAndGather(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR dynamicQuantScale, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + __aicore__ inline void CopyOutLoops(int64_t progress); + __aicore__ inline void Compute(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx); + __aicore__ inline float ComputeMax(LocalTensor& inLocal, LocalTensor& tempLocal, + LocalTensor& dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx, + int64_t j); + __aicore__ inline void ComputeScale(LocalTensor& inLocal, LocalTensor& tempLocal, float scaleTemp, + int64_t dstIndex, int64_t j); + __aicore__ inline void ComputeLoops(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx); + + __aicore__ inline void CopyOutRemain(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + + private: + TPipe* pipe; + TQue copyInQueue; + TQue copyOutQueue; + TQue copyOutZeroQueue; + + TQue inputXInQueue; + TQue smoothInQueue; + TQue calcQueue; + TQue inputXOutQueue; + TQue scaleOutQueue; + TQue scaleOutZeroQueue; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expandedXGm; + + GlobalTensor inputXGm; + GlobalTensor quantSmoothGm; + GlobalTensor dynamicQuantScaleGm; + GlobalTensor quantSrcGm; + + LocalTensor outTmpLocal; + LocalTensor scaleOutTmpLocal; + LocalTensor smoothLocal; + + const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t expertCapacity; + int64_t expertNum; + int64_t cols; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + int64_t perLoopColsAlign; + int64_t k; + int64_t colsTileLength; + int64_t smoothType; + + int64_t tokenCount = 0; + int32_t lastExpertId = -1; + int32_t lastCoreExpertId = 0; + int32_t lastCoreExpertIdNum = 0; +}; + +template +__aicore__ inline void MoeV2SrcToDstAndGather::AssistInit() { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + LocalTensor scaleOutLocal = scaleOutZeroQueue.AllocTensor(); + Duplicate(scaleOutLocal, 0.0f, 8); + scaleOutZeroQueue.EnQue(scaleOutLocal); + + if (this->blockIdx != 0) { + this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2); + this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1); + for (int64_t i = this->blockIdx - 2; i >= 0; i--) { + int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2); + if (lastExpertIdx < this->lastCoreExpertId) { + break; + } + int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1); + this->lastCoreExpertIdNum += lastExpertNum; + } + } +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::CopyIn(int64_t progress) { + LocalTensor inLocal = copyInQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length); + DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length); + + copyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::Compute(int32_t srcIdx, int32_t dstIdx, + int32_t expertIdx) { + DataCopyExtParams copyInParams{1, static_cast(this->cols * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + + LocalTensor inLocal = inputXInQueue.AllocTensor(); + + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal.template ReinterpretCast()[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols], + copyInParams, {false, 0, 0, 0}); + } + + if (smoothType == 2) { + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0}); + } + + inputXInQueue.EnQue(inLocal); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + + inLocal = inputXInQueue.DeQue(); + + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.template ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); + pipe_barrier(PIPE_V); + } + + if (smoothType != 0) { + Mul(inLocal, inLocal, smoothLocal, this->cols); + pipe_barrier(PIPE_V); + } + + Abs(tempLocal, inLocal, this->cols); + pipe_barrier(PIPE_V); + + ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols); + pipe_barrier(PIPE_V); + + float maxValue = dynamicQuantLocal.GetValue(0) / 127.0f; + + Duplicate(dynamicQuantLocal, maxValue, 8); + Duplicate(tempLocal, maxValue, this->cols); + pipe_barrier(PIPE_V); + + Div(tempLocal, inLocal, tempLocal, this->cols); + pipe_barrier(PIPE_V); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_TRUNC, this->cols); + pipe_barrier(PIPE_V); + + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->cols); + + calcQueue.FreeTensor(tempLocal); + inputXOutQueue.EnQue(outLocal); + scaleOutQueue.EnQue(dynamicQuantLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, {1, 4, 0, 0, 0}); + + outLocal = inputXOutQueue.DeQue(); +#ifndef __CCE_KT_TEST__ + DataCopyPad(expandedXGm[dstIdx * this->cols], outLocal, copyOutParams); +#endif + inputXInQueue.FreeTensor(inLocal); + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::CopyOut(int64_t progress) { + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + DataCopyExtParams copyParams1{static_cast(1), static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + SetWaitFlag(HardEvent::S_MTE3); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + DataCopyPad(expandedXGm[index * this->cols], this->outTmpLocal, copyParams1); + DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0}); + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + Compute(outOffset, index, expertIdx); + SetWaitFlag(HardEvent::MTE3_S); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline float MoeV2SrcToDstAndGather::ComputeMax(LocalTensor& inLocal, + LocalTensor& tempLocal, + LocalTensor& dynamicQuantLocal, + int32_t srcIdx, int32_t expertIdx, + int64_t j) { + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + + DataCopyExtParams intriParamsT{1, static_cast(colsTileLength * sizeof(T)), 0, 0, 0}; + DataCopyExtParams intriParamsFp32{1, static_cast(colsTileLength * sizeof(float)), 0, 0, 0}; + + if constexpr (!IsSameType::value) { + DataCopyPad(inLocal.ReinterpretCast()[perLoopColsAlign], inputXGm[srcIdx * this->cols + j * this->perLoopCols], + intriParamsT, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0}); + } + + inputXInQueue.EnQue(inLocal); + inLocal = inputXInQueue.DeQue(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength); + pipe_barrier(PIPE_V); + } + + if (smoothType != 0) { + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32, + {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + + Mul(inLocal, inLocal, smoothLocal, colsTileLength); + pipe_barrier(PIPE_V); + } + + Abs(tempLocal, inLocal, colsTileLength); + pipe_barrier(PIPE_V); + + ReduceMax(dynamicQuantLocal[8], tempLocal, tempLocal, colsTileLength); + + DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32); + smoothInQueue.FreeTensor(smoothLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); + + return dynamicQuantLocal.GetValue(8); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::ComputeScale(LocalTensor& inLocal, + LocalTensor& tempLocal, + float scaleTemp, int64_t dstIndex, + int64_t j) { + DataCopyExtParams copyInParams{1, static_cast(colsTileLength * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(colsTileLength * sizeof(int8_t)), 0, 0, 0}; + + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + + DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0}); + inputXInQueue.EnQue(inLocal); + inLocal = inputXInQueue.DeQue(); + + Duplicate(tempLocal, scaleTemp, colsTileLength); + pipe_barrier(PIPE_V); + + Div(tempLocal, inLocal, tempLocal, colsTileLength); + pipe_barrier(PIPE_V); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength); + pipe_barrier(PIPE_V); + + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, colsTileLength); + + inputXOutQueue.EnQue(outLocal); + outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams); + + inputXOutQueue.FreeTensor(outLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::ComputeLoops(int32_t srcIdx, int32_t dstIdx, + int32_t expertIdx) { + LocalTensor inLocal = inputXInQueue.AllocTensor(); + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor quantScaleLocal = scaleOutQueue.AllocTensor(); + + uint32_t tmp = 0xFF7FFFFF; + float reduceMax = *((float*)&tmp); + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j); + reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax; + } + + float scaleTemp = reduceMax / 127.0f; + Duplicate(quantScaleLocal, scaleTemp, 8); + scaleOutQueue.EnQue(quantScaleLocal); + quantScaleLocal = scaleOutQueue.DeQue(); + + DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, {1, 4, 0, 0, 0}); + + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + ComputeScale(inLocal, tempLocal, scaleTemp, dstIdx, j); + } + + inputXInQueue.FreeTensor(inLocal); + calcQueue.FreeTensor(tempLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::CopyOutLoops(int64_t progress) { + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + SetWaitFlag(HardEvent::S_MTE3); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0}); + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams1{static_cast(1), static_cast(col * sizeof(int8_t)), 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams1); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + if (smoothType == 2) { + ComputeLoops(outOffset, index, expertIdx); + } else { + ComputeLoops(outOffset, index, 0); + } + SetWaitFlag(HardEvent::MTE3_S); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::CopyOutRemain() { + if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) { + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal); + return; + } + while (this->lastExpertId < this->expertNum) { + while (this->tokenCount < this->expertCapacity) { + int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, {1, 4, 0, 0, 0}); + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(col * sizeof(int8_t)), 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR dynamicQuantScale, + GM_ADDR workspace, const TilingData* tilingData, + TPipe* tPipe) { + int64_t blockNum = GetBlockNum(); + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp); + this->expertNum = tilingData->expertNum; + this->expertCapacity = tilingData->expertCapacity; + this->cols = tilingData->cols; + this->k = tilingData->k; + this->smoothType = tilingData->smoothType; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->lastCoreLoops; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->perCoreLoops; + } + this->perLoopCols = this->srcToDstTilingData->perLoopCols; + this->lastLoopCols = this->srcToDstTilingData->lastLoopCols; + this->colLoops = this->srcToDstTilingData->colLoops; + this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T)); + + inputXGm.SetGlobalBuffer((__gm__ T*)x); + quantSmoothGm.SetGlobalBuffer((__gm__ float*)scale); + dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale); + + int64_t length = Align(this->totalLength, sizeof(int32_t)); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, length); + expandedXGm.SetGlobalBuffer((__gm__ int8_t*)expandedX, this->expertNum * this->expertCapacity * this->cols); + + expandedExpertIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer( + (__gm__ int32_t*)workspace + length + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + length * 2, this->coreNum * 2); + if (this->colLoops > 1) { + quantSrcGm.SetGlobalBuffer((__gm__ float*)workspace + length * 2 + this->coreNum * 2 + this->blockIdx * this->cols, + this->cols * sizeof(float)); + } + + pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2); + pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t))); + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t))); + + int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T)); + perLoopColsAlignBytes = + Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES)); + + pipe->InitBuffer(inputXInQueue, 1, perLoopColsAlignBytes); + pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); + pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); + pipe->InitBuffer(scaleOutZeroQueue, 1, BLOCK_BYTES); +} + +template +__aicore__ inline void MoeV2SrcToDstAndGather::Process() { + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + AssistInit(); + this->outTmpLocal = copyOutZeroQueue.DeQue(); + this->scaleOutTmpLocal = scaleOutZeroQueue.DeQue(); + currentLoopRows = perLoopRows; + if (colLoops > 1) { + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOutLoops(loop); + } + } else { + smoothLocal = smoothInQueue.AllocTensor(); + if (smoothType == 1) { + DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; + DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0}); + } + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOut(loop); + } + smoothInQueue.FreeTensor(smoothLocal); + } + CopyOutRemain(); + } +} +} // namespace MoeInitRoutingQuantV2 +#endif // MOE_V2_SRC_TO_DST_AND_GATHER_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_op.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_op.h new file mode 100644 index 00000000000..521a032e9b2 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_op.h @@ -0,0 +1,164 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_src_to_dst_op.h + * \brief + */ +#ifndef INNER_MOE_V2_SRC_TO_DST_H +#define INNER_MOE_V2_SRC_TO_DST_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +class MoeV2SrcToDstOp { + public: + __aicore__ inline MoeV2SrcToDstOp(){}; + template + __aicore__ inline void Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void Compute(int64_t progress); + __aicore__ inline void CopyOut(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + + private: + TPipe* pipe; + TQue copyInQueue; + TQue copyOutQueue; + TBuf assistBuffer; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandSrcToDstRowGm; + GlobalTensor assistGm; + + const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; +}; + +__aicore__ inline void MoeV2SrcToDstOp::AssistInit() { +#if defined(ASCENDC_OOM) && ASCENDC_OOM == 1 + OOMCheckAddrRange(assistGm.GetPhyAddr(), ASSIST_NUM * sizeof(int32_t)); +#endif + LocalTensor assistTensor = assistBuffer.Get(ASSIST_NUM); + DataCopy(assistTensor, assistGm, ASSIST_NUM); + SetWaitFlag(HardEvent::MTE2_V); + Adds(assistTensor, assistTensor, (int32_t)(this->blockIdx * this->srcToDstTilingData->perCoreRows), ASSIST_NUM); +} + +__aicore__ inline void MoeV2SrcToDstOp::CopyIn(int64_t progress) { + LocalTensor inLocal = copyInQueue.AllocTensor(); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], Align(currentLoopRows, sizeof(int32_t))); + copyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeV2SrcToDstOp::Compute(int64_t progress) { + LocalTensor outLocal = copyOutQueue.AllocTensor(); + LocalTensor assistTensor = assistBuffer.Get(ASSIST_NUM); + + pipe_barrier(PIPE_V); + int64_t loops = Ceil(currentLoopRows, ASSIST_INDEX_NUM); + for (int64_t i = 0; i < loops; i++) { + Adds(outLocal[i * ASSIST_NUM], assistTensor, + static_cast(this->perLoopRows * progress + i * ASSIST_INDEX_NUM), ASSIST_NUM); + } + pipe_barrier(PIPE_V); + copyOutQueue.EnQue(outLocal); +} + +__aicore__ inline void MoeV2SrcToDstOp::CopyOut() { + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = sizeof(int32_t); + uint32_t outOffset; + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + outOffset = inLocal.GetValue(idx); + DataCopyPad(expandSrcToDstRowGm[outOffset], outLocal[idx * INT32_ONE_BLOCK_NUM], intriParams); + } + + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeV2SrcToDstOp::SyncAll() { + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2SrcToDstOp::Init(GM_ADDR expandSrcToDstRow, GM_ADDR workspace, const TilingData* tilingData, + TPipe* tPipe) { + int64_t blockNum = GetBlockNum(); + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstComputeParamsOp); + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + } + + expandSrcToDstRowGm.SetGlobalBuffer((__gm__ int32_t*)expandSrcToDstRow, Align(this->totalLength, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t*)workspace + Align(this->totalLength, sizeof(int32_t)) + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + assistGm.SetGlobalBuffer((__gm__ int32_t*)assist, ASSIST_NUM); + + pipe->InitBuffer(copyInQueue, 1, this->perLoopRows * BLOCK_BYTES); + pipe->InitBuffer(copyOutQueue, 1, Ceil(this->perLoopRows, ASSIST_NUM) * ASSIST_NUM * BLOCK_BYTES); + pipe->InitBuffer(assistBuffer, ASSIST_NUM * sizeof(int32_t)); +} + +__aicore__ inline void MoeV2SrcToDstOp::Process() { + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + int64_t loops = (coreRows + perLoopRows - 1) / perLoopRows; + currentLoopRows = perLoopRows; + AssistInit(); + for (int64_t loop = 0; loop < loops - 1; loop++) { + CopyIn(loop); + Compute(loop); + CopyOut(); + } + currentLoopRows = lastLoopRows; + CopyIn(loops - 1); + Compute(loops - 1); + CopyOut(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_SRC_TO_DST_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_with_capacity.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_with_capacity.h new file mode 100644 index 00000000000..850e66b06bb --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_src_to_dst_with_capacity.h @@ -0,0 +1,269 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_v2_src_to_dst_with_capacity.h + * \brief + */ +#ifndef INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H +#define INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H + +#include "moe_v2_common.h" + +namespace MoeInitRoutingQuantV2 { +using namespace AscendC; +using namespace optiling; +template +class MoeV2SrcToDstWithCapacity { + public: + __aicore__ inline MoeV2SrcToDstWithCapacity(){}; + __aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace, + const TilingData* tilingData, TPipe* tPipe); + __aicore__ inline void Process(); + + private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + __aicore__ inline void CopyOutRemain(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + + private: + TPipe* pipe; + TQue copyInQueue; + TQue copyOutQueue; + TQue copyOutZeroQueue; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expandedXGm; + + LocalTensor outTmpLocal; + + const InnerMoeV2GatherOutComputeTilingData* srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t expertCapacity; + int64_t expertNum; + int64_t cols; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + + int64_t tokenCount = 0; + int32_t lastExpertId = -1; + int32_t lastCoreExpertId = 0; + int32_t lastCoreExpertIdNum = 0; +}; + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::AssistInit() { + if constexpr (IsSameType::value) { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + } else { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + } + + if (this->blockIdx != 0) { + this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2); + this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1); + for (int64_t i = this->blockIdx - 2; i >= 0; i--) { + int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2); + if (lastExpertIdx < this->lastCoreExpertId) { + break; + } + int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1); + this->lastCoreExpertIdNum += lastExpertNum; + } + } +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::CopyIn(int64_t progress) { + LocalTensor inLocal = copyInQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length); + DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length); + copyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::CopyOut(int64_t progress) { + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + SetWaitFlag(HardEvent::S_MTE3); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } +#ifdef __CCE_KT_TEST__ + // CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理 + if (index * this->cols + i * this->perLoopCols + col * sizeof(T) > expandedXGm.GetSize()) { + continue; + } +#endif + DataCopyExtParams copyParams1{static_cast(1), static_cast(col * sizeof(T)), 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams1); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::CopyOutRemain() { + if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) { + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + return; + } + while (this->lastExpertId < this->expertNum) { + while (this->tokenCount < this->expertCapacity) { + int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(col * sizeof(T)), 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + copyOutZeroQueue.FreeTensor(this->outTmpLocal); +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::SyncAll() { + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR workspace, const TilingData* tilingData, + TPipe* tPipe) { + int64_t blockNum = GetBlockNum(); + this->pipe = tPipe; + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstCapacityComputeParamsOp); + this->expertNum = tilingData->expertNum; + this->expertCapacity = tilingData->expertCapacity; + this->cols = tilingData->cols; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->lastCoreLoops; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->perCoreLoops; + } + this->perLoopCols = this->srcToDstTilingData->perLoopCols; + this->lastLoopCols = this->srcToDstTilingData->lastLoopCols; + this->colLoops = this->srcToDstTilingData->colLoops; + + int64_t length = Align(this->totalLength, sizeof(int32_t)); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t*)expandedRowIdx, length); + expandedXGm.SetGlobalBuffer((__gm__ T*)expandedX, this->expertNum * this->expertCapacity * this->cols); + + expandedExpertIdxGm.SetGlobalBuffer( + (__gm__ int32_t*)workspace + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer( + (__gm__ int32_t*)workspace + length + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expertIdxValueGm.SetGlobalBuffer((__gm__ int32_t*)workspace + length * 2, this->coreNum * 2); + + pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2); + pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t))); + if constexpr (IsSameType::value) { + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t))); + } else { + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(T))); + } +} + +template +__aicore__ inline void MoeV2SrcToDstWithCapacity::Process() { + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + AssistInit(); + this->outTmpLocal = copyOutZeroQueue.DeQue(); + currentLoopRows = perLoopRows; + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOut(loop); + } + CopyOutRemain(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingQuantV2 +#endif // INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/tiling_base.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/tiling_base.h new file mode 100644 index 00000000000..7e6c0e5e3a4 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/tiling_base.h @@ -0,0 +1,66 @@ +#pragma once +namespace optiling { +struct AiCoreParams { + uint64_t ubSize; + uint64_t blockDim; + uint64_t aicNum; + + uint64_t l1Size; + uint64_t l0aSize; + uint64_t l0bSize; + uint64_t l0cSize; +}; + +class TilingBaseClass { +public: + bool DoTiling( + int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0, + int64_t aivCoreNum, int64_t ubSizePlatForm) + { + bool ret = GetShapeAttrsInfo(m, cols, topK, expertCapacity, expertNum, activeNum, dropPadMode, expertTokensCountOrCumsumFlag, + expertTokensBeforeCapacityFlag, inuptXDtypeSize, quantMode, scaleDim0); + + if (!ret){ + return ret; + } + ret = GetPlatformInfo(aivCoreNum, ubSizePlatForm); + if (!ret){ + return ret; + } + ret = DoOpTiling(); + if (!ret){ + return ret; + } + ret = GetWorkspaceSize(); + if (!ret){ + return ret; + } + ret = PostTiling(); + if (!ret){ + return ret; + } + tilingKey_ = GetTilingKey(); + + return true; + } + +//protected: + virtual bool GetPlatformInfo(int64_t aivCoreNum, int64_t ubSizePlatForm) = 0; + virtual bool GetShapeAttrsInfo(int64_t m, int64_t cols, int64_t topK, int64_t expertCapacity, + int64_t expertNum, int64_t activeNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, + bool expertTokensBeforeCapacityFlag, int64_t inuptXDtypeSize, int64_t quantMode, int64_t scaleDim0) = 0; + + virtual bool DoOpTiling() = 0; + virtual bool GetWorkspaceSize() = 0; + virtual bool PostTiling() = 0; + virtual uint64_t GetTilingKey() const = 0; +//protected: + uint32_t blockDim_{0}; + uint64_t workspaceSize_{0}; + uint64_t tilingKey_{0}; + AiCoreParams aicoreParams_{0, 0, 0, 0, 0, 0, 0}; +}; + +} \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h new file mode 100644 index 00000000000..12b35a29d1e --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h @@ -0,0 +1,376 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file moe_token_unpermute.h + * \brief + */ + +#ifndef MOE_TOKEN_UNPERMUTE +#define MOE_TOKEN_UNPERMUTE + +#include "kernel_operator.h" +#include "moe_token_unpermute_tiling.h" +using namespace AscendC; + + +template class KernelMoeTokenUnpermute { +public: + __aicore__ inline KernelMoeTokenUnpermute() + { + } + + __aicore__ inline void Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs, + GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data); + __aicore__ inline void Process(); + +protected: + __aicore__ inline void CalMultiOutToken(const int64_t out_offset, const int64_t out_tokens_number); + __aicore__ inline void CalSingleOutToken(const int64_t start_token, const int64_t out_token_idx); + __aicore__ inline void CalPartOutToken(const int64_t start_token, const int64_t h_index, const int64_t h_length, + const int64_t out_token_index); + __aicore__ inline void CopyTokenIn(const T2 in_token_index, const int64_t h_index, const int64_t h_length); + __aicore__ inline void CalFirstToken(const float prob_value, const int64_t h_length); + __aicore__ inline void CalToken(const float prob_value, const int64_t h_length); + __aicore__ inline void CopyOut(const int64_t out_token_index, const int64_t h_index, const int64_t h_length); + + TPipe pipe; + TQue tokens_inque, indices_inque, probs_inque; + TBuf temp_buffer0, temp_buffer1, temp_buffer2; + TQue outque; + GlobalTensor tokensGM, outGM; + GlobalTensor indicesGM; + GlobalTensor probsGM; + LocalTensor indicesLocal; + LocalTensor token_tensor0, token_tensor1, probs_tensor; + DataCopyPadExtParams extParams1{false, 0, 0, 0}; + DataCopyPadExtParams extParams2{false, 0, 0, 0}; + DataCopyPadExtParams extParams3{false, 0, 0, 0}; + DataCopyExtParams copyParams{1, 0, 0, 0, 0}; + + constexpr static uint32_t BLOCK_SIZE = 32; + constexpr static uint32_t ALIGN_512 = 512; + + int64_t hidden_size; + int64_t top_k; + int64_t num_out_tokens; + int64_t hidden_splited_length; + int64_t hidden_splited_num; + int64_t hidden_splited_remain; + int64_t tokens_core_length; + int64_t tokens_core_remain; + int64_t tokens_splited_length; + int64_t tokens_splited_num; + int64_t tokens_splited_remain; + int32_t blockIdx; + int32_t blockNum; +}; + +template +__aicore__ inline void +KernelMoeTokenUnpermute::Init(GM_ADDR permuted_tokens, GM_ADDR sorted_indices, GM_ADDR probs, + GM_ADDR unpermuted_tokens, + const MoeTokenUnpermuteTilingData *__restrict tiling_data) +{ + this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); + this->blockNum = get_block_num() * get_subblockdim(); + + if (blockIdx >= blockNum) { + return; + } + ASSERT(blockNum != 0 && "block dim can not be zero!"); + // row_input + this->hidden_size = tiling_data->hidden_size; + this->top_k = tiling_data->top_k; + this->num_out_tokens = tiling_data->num_out_tokens; + // hidden_tiling + this->hidden_splited_length = tiling_data->hidden_splited_length; + this->hidden_splited_num = tiling_data->hidden_splited_num; + this->hidden_splited_remain = tiling_data->hidden_splited_remain; + // token_tiling + this->tokens_core_length = tiling_data->tokens_core_length; + this->tokens_core_remain = tiling_data->tokens_core_remain; + this->tokens_splited_length = tiling_data->tokens_splited_length; + this->tokens_splited_num = tiling_data->tokens_splited_num; + this->tokens_splited_remain = tiling_data->tokens_splited_remain; + + // 处理token_by_core尾块 + if (this->tokens_core_remain > 0 && blockIdx < this->tokens_core_remain) { + this->tokens_core_length += 1; + this->tokens_splited_remain += 1; + } + + int64_t hidden_splited_length_align512 = (this->hidden_splited_length + ALIGN_512 - 1) & ~(ALIGN_512 - 1); + + int64_t block_length = this->tokens_core_length * this->top_k; + int64_t block_splited_length = this->tokens_splited_length * this->top_k; + + int64_t block_offset; + if (this->tokens_core_remain > 0) { + if (blockIdx < this->tokens_core_remain) { + block_offset = block_length * blockIdx; + } else { + block_offset = (block_length + this->top_k) * this->tokens_core_remain + + block_length * (blockIdx - this->tokens_core_remain); + } + } else { + block_offset = block_length * blockIdx; + } + + this->tokensGM.SetGlobalBuffer((__gm__ T1 *)permuted_tokens); + this->indicesGM.SetGlobalBuffer((__gm__ T2 *)sorted_indices + block_offset, block_length); + + + int64_t out_block_offset; + if (this->tokens_core_remain > 0) { + if (blockIdx < this->tokens_core_remain) { + out_block_offset = this->tokens_core_length * blockIdx * hidden_size; + } else { + out_block_offset = (this->tokens_core_length + 1) * this->tokens_core_remain + + this->tokens_core_length * (blockIdx - this->tokens_core_remain); + out_block_offset *= this->hidden_size; + } + } else { + out_block_offset = this->tokens_core_length * blockIdx * hidden_size; + } + + this->outGM.SetGlobalBuffer((__gm__ T1 *)unpermuted_tokens + out_block_offset, + this->tokens_core_length * this->hidden_size); + + this->pipe.InitBuffer(tokens_inque, tiling_data->buffer_num, hidden_splited_length_align512 * sizeof(T1)); + this->pipe.InitBuffer(indices_inque, 1, block_splited_length * (sizeof(T2))); + this->pipe.InitBuffer(outque, 1, hidden_splited_length_align512 * sizeof(T1)); + + if constexpr (!IsSameType::value) { + this->pipe.InitBuffer(temp_buffer0, hidden_splited_length_align512 * sizeof(float) + 256); + this->pipe.InitBuffer(temp_buffer1, hidden_splited_length_align512 * sizeof(float)); + this->token_tensor0 = this->temp_buffer0.template Get(); + this->token_tensor1 = this->temp_buffer1.template Get(); + } + + if constexpr (PROBS) { + this->probsGM.SetGlobalBuffer((__gm__ T3 *)probs + block_offset, block_length); + this->pipe.InitBuffer(probs_inque, 1, block_splited_length * (sizeof(T3))); + if constexpr (!IsSameType::value) { + this->pipe.InitBuffer(temp_buffer2, block_splited_length * sizeof(float)); + this->probs_tensor = this->temp_buffer2.template Get(); + } + } +}; + +template +__aicore__ inline void KernelMoeTokenUnpermute::Process() +{ + + if (blockIdx >= blockNum) { + return; + } + for (int64_t i = 0; i < this->tokens_splited_num; ++i) { + CalMultiOutToken(i * this->tokens_splited_length, this->tokens_splited_length); + } + // 处理tokens_num不能均匀分核数的尾块 + if (this->tokens_splited_remain > 0) { + CalMultiOutToken(this->tokens_splited_num * this->tokens_splited_length, this->tokens_splited_remain); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalMultiOutToken(const int64_t out_offset, + const int64_t out_tokens_number) +{ + this->indicesLocal = this->indices_inque.template AllocTensor(); + int64_t in_offset = out_offset * this->top_k; + this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T2); + DataCopyPad(this->indicesLocal, this->indicesGM[in_offset], this->copyParams, this->extParams2); + this->indices_inque.template EnQue(this->indicesLocal); + + if constexpr (PROBS) { + LocalTensor temp_probs_tensor = this->probs_inque.template AllocTensor(); + this->copyParams.blockLen = out_tokens_number * this->top_k * sizeof(T3); + DataCopyPad(temp_probs_tensor, this->probsGM[in_offset], this->copyParams, this->extParams3); + this->probs_inque.template EnQue(temp_probs_tensor); + temp_probs_tensor = this->probs_inque.template DeQue(); + if constexpr (!IsSameType::value) { + Cast(this->probs_tensor, temp_probs_tensor, RoundMode::CAST_NONE, out_tokens_number * this->top_k); + this->probs_inque.FreeTensor(temp_probs_tensor); + PipeBarrier(); + } else { + this->probs_tensor = temp_probs_tensor; + } + } + this->indicesLocal = this->indices_inque.template DeQue(); + + + for (int64_t out_token_idx = 0; out_token_idx < out_tokens_number; ++out_token_idx) { + CalSingleOutToken(out_token_idx * this->top_k, out_offset + out_token_idx); + } + // Free Tensor + this->indices_inque.FreeTensor(this->indicesLocal); + if constexpr (PROBS && IsSameType::value) { + this->probs_inque.FreeTensor(this->probs_tensor); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalSingleOutToken(const int64_t start_token, + const int64_t out_token_idx) +{ + for (int64_t h_index = 0; h_index < this->hidden_splited_num; ++h_index) { + CalPartOutToken(start_token, h_index, this->hidden_splited_length, out_token_idx); + } + // 一次不能完整容纳完整的hidden_size, 处理尾块 + if (this->hidden_splited_remain > 0) { + CalPartOutToken(start_token, this->hidden_splited_num, this->hidden_splited_remain, out_token_idx); + } +} + +template +__aicore__ inline void +KernelMoeTokenUnpermute::CalPartOutToken(const int64_t start_token, const int64_t h_index, + const int64_t h_length, const int64_t out_token_index) +{ + if constexpr (IsSameType::value) { + this->token_tensor0 = this->outque.template AllocTensor(); + } + int64_t end_token = start_token + this->top_k; + T2 cal_token_idx = this->indicesLocal.GetValue(start_token); + + // 处理第一个Token数据 + if (cal_token_idx < this->num_out_tokens) { + float probsValue = 0; + if constexpr (PROBS) { + probsValue = this->probs_tensor.GetValue(start_token); + } + + CopyTokenIn(cal_token_idx, h_index, h_length); + PipeBarrier(); + CalFirstToken(probsValue, h_length); + } else { + PipeBarrier(); + Duplicate(this->token_tensor0, static_cast(0), h_length); + } + + // 处理剩余的Token数据 + for (int64_t token_index = start_token + 1; token_index < end_token; ++token_index) { + cal_token_idx = this->indicesLocal.GetValue(token_index); + if (cal_token_idx < this->num_out_tokens) { + float probsValue = 0; + if constexpr (PROBS) { + probsValue = this->probs_tensor.GetValue(token_index); + } + + CopyTokenIn(cal_token_idx, h_index, h_length); + PipeBarrier(); + CalToken(probsValue, h_length); + } + } + + // 输出计算结果 + CopyOut(out_token_index, h_index, h_length); +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CopyTokenIn(const T2 in_token_index, + const int64_t h_index, + const int64_t h_length) +{ + LocalTensor tokensLocal = this->tokens_inque.template AllocTensor(); + int64_t offset = in_token_index * this->hidden_size + h_index * this->hidden_splited_length; + + if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) { + DataCopy(tokensLocal, this->tokensGM[offset], h_length); + } else { + this->copyParams.blockLen = h_length * sizeof(T1); + DataCopyPad(tokensLocal, this->tokensGM[offset], this->copyParams, this->extParams1); + } + + this->tokens_inque.template EnQue(tokensLocal); +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalFirstToken(const float prob_value, + const int64_t h_length) +{ + LocalTensor tokensLocal = this->tokens_inque.template DeQue(); + + if constexpr (!IsSameType::value) { + Cast(this->token_tensor0, tokensLocal, RoundMode::CAST_NONE, h_length); + } else { + uint64_t byteAlign32 = (h_length * sizeof(float) + BLOCK_SIZE - 1) & ~(BLOCK_SIZE - 1); + DataCopy(this->token_tensor0, tokensLocal, byteAlign32 / sizeof(float)); + } + + this->tokens_inque.FreeTensor(tokensLocal); + + if constexpr (PROBS) { + PipeBarrier(); + Muls(this->token_tensor0, this->token_tensor0, prob_value, h_length); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CalToken(const float prob_value, + const int64_t h_length) +{ + LocalTensor tokensLocal = this->tokens_inque.template DeQue(); + + if constexpr (!IsSameType::value) { + Cast(this->token_tensor1, tokensLocal, RoundMode::CAST_NONE, h_length); + this->tokens_inque.FreeTensor(tokensLocal); + if constexpr (PROBS) { + PipeBarrier(); + Muls(this->token_tensor1, this->token_tensor1, prob_value, h_length); + } + PipeBarrier(); + Add(this->token_tensor0, this->token_tensor0, this->token_tensor1, h_length); + } else { + if constexpr (PROBS) { + Muls(tokensLocal, tokensLocal, prob_value, h_length); + PipeBarrier(); + } + Add(this->token_tensor0, this->token_tensor0, tokensLocal, h_length); + this->tokens_inque.FreeTensor(tokensLocal); + } +} + +template +__aicore__ inline void KernelMoeTokenUnpermute::CopyOut(const int64_t out_token_index, + const int64_t h_index, + const int64_t h_length) +{ + LocalTensor temp_out_tensors; + if constexpr (!IsSameType::value) { + temp_out_tensors = this->outque.template AllocTensor(); + PipeBarrier(); + Cast(temp_out_tensors, this->token_tensor0, RoundMode::CAST_RINT, h_length); + } else { + temp_out_tensors = this->token_tensor0; + } + + this->outque.template EnQue(temp_out_tensors); + temp_out_tensors = this->outque.template DeQue(); + + int64_t offset = out_token_index * this->hidden_size + h_index * this->hidden_splited_length; + if (likely((h_length * sizeof(T1)) % BLOCK_SIZE == 0)) { + DataCopy(this->outGM[offset], temp_out_tensors, h_length); + } else { + this->copyParams.blockLen = h_length * sizeof(T1); + DataCopyPad(this->outGM[offset], temp_out_tensors, this->copyParams); + } + + this->outque.FreeTensor(temp_out_tensors); +} +#endif // MOE_TOKEN_UNPERMUTE diff --git a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute_tiling.h b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute_tiling.h new file mode 100644 index 00000000000..df47f6db8d0 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute_tiling.h @@ -0,0 +1,38 @@ +#ifndef MOE_TOKEN_UNPERMUTE_TILING +#define MOE_TOKEN_UNPERMUTE_TILING + +struct MoeTokenUnpermuteTilingData { + int64_t hidden_size; + int64_t top_k; + int64_t num_out_tokens; + int64_t hidden_splited_length; + int64_t hidden_splited_num; + int64_t hidden_splited_remain; + int64_t tokens_core_length; + int64_t tokens_core_remain; + int64_t tokens_splited_length; + int64_t tokens_splited_num; + int64_t tokens_splited_remain; + int64_t buffer_num; +}; + +__forceinline__ [host, aicore] void +MoeTokenUnpermuteTiling(int32_t m, int32_t n, int32_t topK, MoeTokenUnpermuteTilingData &tilingData, uint32_t coreNum) +{ + #define I64(x) static_cast(x) + tilingData.hidden_size = I64(n); + tilingData.top_k = I64(topK); + tilingData.num_out_tokens = I64(m); + tilingData.hidden_splited_length = tilingData.hidden_size; + tilingData.hidden_splited_num = 1; + tilingData.hidden_splited_remain = 0; + uint32_t outTokens = m / topK; + tilingData.tokens_core_length = I64(outTokens / coreNum); + tilingData.tokens_core_remain = I64(outTokens % coreNum); + tilingData.tokens_splited_length = I64(min(tilingData.tokens_core_length, 600)); + tilingData.tokens_splited_num = I64(tilingData.tokens_core_length / tilingData.tokens_splited_length); + tilingData.tokens_splited_remain = I64(tilingData.tokens_core_length % tilingData.tokens_splited_length); + tilingData.buffer_num = 4; +} + +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp new file mode 100644 index 00000000000..8bdd017d387 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/epilogue/block/block_epilogue.hpp" + +namespace Catlass::Epilogue::Block { + +// float scale, dequant per expert +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequant, + CType_, + Gemm::GemmType, + DType_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert( + std::is_same_v && (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong" + ); + static_assert( + std::is_same_v && + std::is_same_v && std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong" + ); + + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + struct Params { + __gm__ int32_t *ptrTokenPerExpert{nullptr}; + int32_t EP; + int32_t expertPerRank; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 4096; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + constexpr int32_t blockN = 12000; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + ubCFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(float); + } + } + CATLASS_DEVICE + void Finalize() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + CATLASS_DEVICE + ~BlockEpilogue() + { + + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + MatrixCoord const &shapeC, + AscendC::GlobalTensor const &gmPerTokenScale, + AscendC::GlobalTensor const &gmD + ) + { + uint32_t blockM = shapeC.row(); + uint32_t blockN = shapeC.column(); + + uint32_t tileLoops = blockM; + + for (uint32_t loopIdx = 0; loopIdx < tileLoops; loopIdx ++) { + auto gmTileC = gmC[loopIdx * blockN]; + auto &ubC = ubCList[ubListId]; + auto &ubCFp32 = ubCFp32List[ubListId]; + auto &ubMul = ubMulList[ubListId]; + auto &ubD = ubDList[ubListId]; + auto gmTileD = gmD[loopIdx * blockN]; + LayoutC layoutUbC{1, blockN}; + + // 把C从GM workspace搬到UB + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + //在UB上做把C cast成FP32 + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + // 获取pertoken scale值,gmPerTokenScale的第loopIdx行 + ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + // pertoken scale值与FP32的C做Muls乘法 + AscendC::PipeBarrier(); + AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN); + AscendC::PipeBarrier(); + + // 将muls结果转回fp16/bf16 + LayoutD layoutUbD{1, blockN}; + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + + AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, blockN); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32List[UB_STAGES]; + AscendC::LocalTensor ubMulList[UB_STAGES]; + + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_ROW_HPP diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp new file mode 100644 index 00000000000..adca19f6ebd --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp @@ -0,0 +1,316 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +namespace Catlass::Epilogue::Block { + +// float scale, dequant per expert +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileElemWiseMuls_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequantSwigluQuant, + CType_, + Gemm::GemmType, + DType_, + TileElemWiseMuls_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwigluQuant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert( + std::is_same_v && (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong" + ); + static_assert( + std::is_same_v && + std::is_same_v && std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong" + ); + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + using CopyUbToGmDequantScale = Epilogue::Tile::CopyUb2Gm>; + + struct Params { + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_ + ) : ptrPerTokenScale(ptrPerTokenScale_), layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), layoutD(layoutD_) {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + constexpr uint32_t blockN = 4096; + constexpr uint32_t ChunkTileLen = blockN / 2; + constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2; + + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(ElementD); + ubCFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += blockN * sizeof(float); + ubCFp32ChunkNList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += ChunkTileLen * sizeof(float); + ubCFp32ChunkNAbsList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += ChunkTileLen * sizeof(float); + ubCFp32ChunkNMaxList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += HalfChunkTileLen * sizeof(float); + ubQuantS32List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast(); + ubQuantF16List[i] = ubCFp32ChunkNAbsList[i].template ReinterpretCast(); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + + ubPerTokenScaleOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + } + CATLASS_DEVICE + void Finalize() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + CATLASS_DEVICE + ~BlockEpilogue() + { + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + // 每个tile就是1*7168,每个block是一个expert的所有token=[group[i], 7168] + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + MatrixCoord const &shapeC, + AscendC::GlobalTensor const &gmPerTokenScale1, + AscendC::GlobalTensor const &gmD, + AscendC::GlobalTensor const &gmPerTokenScale2, + + uint32_t epilogueCoreNum = 40, + Callback &&callback = Callback{} + ) + { + callback(); + uint32_t blockM = shapeC.row(); + uint32_t blockN = shapeC.column(); + + uint32_t tileLoops = blockM; + uint32_t subblockIdx = get_block_idx() + get_subblockid() * get_block_num(); + + uint32_t subblockNum = get_block_num() * 2; + uint32_t moveDataCoreNum = subblockNum - epilogueCoreNum; + + if (subblockIdx < moveDataCoreNum) { + return; + } + uint32_t epilogueCoreIdx = subblockIdx - moveDataCoreNum; + + uint32_t perCoreData = blockM / epilogueCoreNum; + uint32_t remainderData = blockM % epilogueCoreNum; + + uint32_t tasksForIdx = epilogueCoreIdx < remainderData ? perCoreData + 1 : perCoreData; + uint32_t loopStartIdx = epilogueCoreIdx * perCoreData + (epilogueCoreIdx < remainderData? epilogueCoreIdx : remainderData); + + uint32_t alignedPerCoreData = RoundUp(perCoreData + 1); + + uint32_t ChunkTileLen = blockN / 2; + uint32_t HalfChunkTileLen = ChunkTileLen / 2; + + + for (uint32_t loopIdx = loopStartIdx; loopIdx < loopStartIdx + tasksForIdx; ++loopIdx) { + + auto gmTileC = gmC[loopIdx * blockN]; + + auto &ubC = ubCList[ubListId]; + auto &ubD = ubDList[ubListId]; + + auto &ubCFp32 = ubCFp32List[ubListId]; + auto &ubCFp32ChunkN = ubCFp32ChunkNList[ubListId]; + auto &ubAbs = ubCFp32ChunkNAbsList[ubListId]; + // auto &ubMax = ubCFp32ChunkNMaxList[ubListId]; + auto &ubReduceMax = ubCFp32ChunkNMaxList[ubListId]; + auto &ubOutputTmp = ubAbs; + auto &sharedUbTmpBuffer = ubReduceMax; + auto &ubQuantS32 = ubQuantS32List[ubListId]; + auto &ubQuantF16 = ubQuantF16List[ubListId]; + + auto gmTileD = gmD[loopIdx * ChunkTileLen]; + LayoutC layoutUbC{1, blockN}; + + // 把C从GM workspace搬到UB + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + // 在UB上做把C cast成FP32 + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + // 获取pertoken scale值,gmPerTokenScale的第loopIdx行 + ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + // pertoken scale值与FP32的C做Muls乘法 + AscendC::PipeBarrier(); + AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN); + AscendC::PipeBarrier(); + + //swiglue计算过程 + AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen); + AscendC::PipeBarrier(); + //TODO除的时候是否会对之后的数据有影响; + AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen); + + //quant过程,两种方式区别; + AscendC::PipeBarrier(); + AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::ReduceMax(ubReduceMax, ubAbs, sharedUbTmpBuffer, ChunkTileLen, false); + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + //TODO两种计算方法的效率比较 + ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0); + AscendC::SetFlag(0); + + auto ubPerTokenScaleOutputOffset = loopIdx - loopStartIdx; + ubPerTokenScaleOutput.SetValue(ubPerTokenScaleOutputOffset, GMubDequantScale / 127.f); + + AscendC::WaitFlag(0); + AscendC::Muls(ubOutputTmp, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::Cast(ubQuantS32, ubOutputTmp, AscendC::RoundMode::CAST_RINT, ChunkTileLen); + AscendC::PipeBarrier(); + AscendC::SetDeqScale(static_cast(1.0)); + AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, ChunkTileLen); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + AscendC::Cast(ubD, ubQuantF16, AscendC::RoundMode::CAST_RINT, ChunkTileLen); + // AscendC::Muls(ubD, ubCFp32ChunkN, 127.f / GMubDequantScale, ChunkTileLen); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + LayoutD layoutUbD{1, ChunkTileLen}; + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutUbD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + + if(tasksForIdx > 0){ + LayoutPerTokenScale layoutGmPerTokenScale2{tasksForIdx}; + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + copyUbToGmDequantScale(gmPerTokenScale2[loopStartIdx], ubPerTokenScaleOutput[0], layoutGmPerTokenScale2, layoutGmPerTokenScale2); + } + + + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32List[UB_STAGES]; + AscendC::LocalTensor ubCFp32ChunkNList[UB_STAGES]; + AscendC::LocalTensor ubCFp32ChunkNAbsList[UB_STAGES]; + AscendC::LocalTensor ubCFp32ChunkNMaxList[UB_STAGES]; + AscendC::LocalTensor ubQuantS32List[UB_STAGES]; + AscendC::LocalTensor ubQuantF16List[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleOutput; + + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; + CopyUbToGmDequantScale copyUbToGmDequantScale; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_SWIGLU_HPP diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp new file mode 100644 index 00000000000..c15e11b208a --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -0,0 +1,502 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "dispatch_policy_custom.hpp" + + +namespace Catlass::Gemm::Block { + +template +__aicore__ inline void SyncFlagFunc(int32_t eventID) +{ + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +template < + uint32_t PRELOAD_STAGES_, + uint32_t L1_STAGES_, + uint32_t L0A_STAGES_, + uint32_t L0B_STAGES_, + uint32_t L0C_STAGES_, + bool ENABLE_UNIT_FLAG_, + bool ENABLE_SHUFFLE_K_, + class L1TileShape_, + class L0TileShape_, + class AType_, + class BType_, + class CType_, + class BiasType_, + class TileCopy_, + class TileMmad_ +> +struct BlockMmad < + MmadAtlasA2PreloadAsyncFixpipe< + PRELOAD_STAGES_, + L1_STAGES_, + L0A_STAGES_, + L0B_STAGES_, + L0C_STAGES_, + ENABLE_UNIT_FLAG_, + ENABLE_SHUFFLE_K_ + >, + L1TileShape_, + L0TileShape_, + AType_, + BType_, + CType_, + BiasType_, + TileCopy_, + TileMmad_ +> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PreloadAsyncFixpipe< + PRELOAD_STAGES_, + L1_STAGES_, + L0A_STAGES_, + L0B_STAGES_, + L0C_STAGES_, + ENABLE_UNIT_FLAG_, + ENABLE_SHUFFLE_K_ + >; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyGmToL1S = Gemm::Tile::CopyGmToL1>; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using CopyL0CToGm = typename std::conditional< + std::is_same_v, + Gemm::Tile::CopyL0CToGm, + typename TileCopy_::CopyL0CToGm + >::type; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L1S_TILE_SIZE = L1TileShape::N * sizeof(int64_t); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert( + (std::is_same_v + ? (L1A_TILE_SIZE + L1B_TILE_SIZE + L1S_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE + : (L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE), + "L1TileShape exceeding the L1 space for the given data type" + ); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout( + L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout( + L1TileShape::K, L1TileShape::N); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + syncGroupIdx = 0; + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + CATLASS_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + if constexpr (std::is_same_v) { + AscendC::WaitFlag(0); + } + } + + CATLASS_DEVICE + void operator()( + AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + AscendC::GlobalTensor const &gmBlockS, layout::VectorLayout const &layoutScale, + GemmCoord const &actualShape, int32_t syncLoopIdx = -1, int32_t flag = 0 + ) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? + (startTileIdx + kLoopIdx) : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = (kTileIdx < kTileCount - 1) ? + L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) ? + (l1TileMmadParamsId + preloadCount) : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + l1TileMmadParams.flag = flag; + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.gmBlockS = gmBlockS; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.layoutScale = layoutScale; + l1TileMmadParams.syncLoopIdx = syncLoopIdx; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } + } + + CATLASS_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + + CATLASS_DEVICE + void Finalize(int32_t target, int32_t flag = 0) + { + for(;syncGroupIdx <= target; syncGroupIdx++) { + int32_t flagId = syncGroupIdx / 8 + flag; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } + } +private: + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + AscendC::GlobalTensor gmBlockS; + LayoutC layoutCInGm; + layout::VectorLayout layoutScale; + int32_t syncLoopIdx; + int32_t flag; + + CATLASS_DEVICE + L1TileMmadParams() = default; + }; + + CATLASS_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } + if constexpr (std::is_same_v) { + uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; + l1STensor = resource.l1Buf.template GetBufferByByte(l1SOffset); + AscendC::SetFlag(0); + } + } + + CATLASS_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ? + L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ? + L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ? + L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && + (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + if constexpr (std::is_same_v) { + auto layoutScale = params.layoutScale; + auto layoutTileS = layoutScale.GetTileLayout(MakeCoord(layoutCInGm.shape(1))); + AscendC::WaitFlag(0); + copyGmToL1S(l1STensor, params.gmBlockS, layoutTileS, layoutTileS); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + } + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + if constexpr (std::is_same_v) { + copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + } + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + if constexpr (std::is_same_v) { + copyL0CToGm(params.gmBlockC, l0CTensor, l1STensor, layoutCInGm, layoutCInL0, 0b11); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + if constexpr (std::is_same_v) { + AscendC::SetFlag(0); + } + Finalize(params.syncLoopIdx, params.flag); + } + } + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + AscendC::LocalTensor l1STensor; + int32_t syncGroupIdx; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyGmToL1S copyGmToL1S; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp new file mode 100644 index 00000000000..71b422c924d --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp @@ -0,0 +1,6 @@ + +#ifndef CONST_ARGS_HPP +#define CONST_ARGS_HPP +constexpr static uint64_t MB_SIZE = 1024 * 1024UL; +constexpr static int32_t NUMS_PER_FLAG = 16; +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/copy_gm_to_l1_custom.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/copy_gm_to_l1_custom.hpp new file mode 100644 index 00000000000..84789073d5a --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/copy_gm_to_l1_custom.hpp @@ -0,0 +1,40 @@ +#ifndef COPY_GM_TO_L1_CUSTOM_HPP +#define COPY_GM_TO_L1_CUSTOM_HPP + +namespace Catlass::Gemm::Tile { + /// Partial specialization for nZ in and nZ out. + template < + class ArchTag, + class Element + > + struct CopyGmToL1> { + using LayoutDst = layout::VectorLayout; + using LayoutSrc = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); // int64, 32/8=4 + + // Mehtods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()( + AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = 1; + uint32_t blockLen = CeilDiv(layoutSrc.shape(0)); + + AscendC::DataCopyParams repeatParams; + + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } + }; +} +#endif // COPY_GM_TO_L1_CUSTOM_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/copy_l0c_to_gm_custom.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/copy_l0c_to_gm_custom.hpp new file mode 100644 index 00000000000..ba47798409a --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/copy_l0c_to_gm_custom.hpp @@ -0,0 +1,47 @@ +#ifndef COPY_L0C_TO_GM_CUSTOM_HPP +#define COPY_L0C_TO_GM_CUSTOM_HPP + +namespace Catlass::Gemm::Tile { + template < + class ElementAccumulator_, + class ElementDst_, + bool ReluEnable_ + > + struct CopyL0CToGm, + ScaleGranularity::PER_CHANNEL, + ReluEnable_> + { + using ArchTag = Catlass::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Catlass::layout::zN; + using LayoutDst = Catlass::layout::RowMajor; + static constexpr auto quantPre = CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, AscendC::LocalTensor cbufWorkspace, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(1); + intriParams.mSize = dstLayout.shape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); + intriParams.dstStride = dstLayout.stride(0); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, cbufWorkspace, intriParams); + } + }; +} +#endif // COPY_L0C_TO_GM_CUSTOM_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp new file mode 100644 index 00000000000..31fdbad1c27 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp @@ -0,0 +1,47 @@ +#ifndef DISPATH_POLICY_CUSTOM_HPP +#define DISPATH_POLICY_CUSTOM_HPP + +namespace Catlass::Gemm { + template + struct MmadAtlasA2PreloadFixpipeQuant : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; + }; + + template + struct MmadAtlasA2PreloadAsyncFixpipe : + public MmadAtlasA2PreloadAsync< + PRELOAD_STAGES_, + L1_STAGES_, + L0A_STAGES_, + L0B_STAGES_, + L0C_STAGES_, + ENABLE_UNIT_FLAG_, + ENABLE_SHUFFLE_K_ + > { + }; +} + +namespace Catlass::Epilogue { + + template + struct EpilogueAtlasA2UnQuant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; + + template + struct EpilogueAtlasA2PerTokenDequantQuant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; + + template + struct EpilogueAtlasA2PerTokenDequantSwigluQuant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + }; +} +#endif // DISPATH_POLICY_CUSTOM_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp new file mode 100644 index 00000000000..b66268fbf94 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp @@ -0,0 +1,178 @@ +#ifndef SYNC_UTIL_HPP +#define SYNC_UTIL_HPP + + +#include "kernel_operator.h" +#include "const_args.hpp" + +#include "moe_distribute_base.h" + +#ifndef HCCL_COMM +#include "shmem_api.h" +#endif + +#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ + +template +FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) { + *((__gm__ T *)addr) = val; +} + +template +FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) { + return *((__gm__ T *)cache); +} + +FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { + using namespace AscendC; + GlobalTensor global; + global.SetGlobalBuffer(addr); + + // Important: add hint to avoid dcci being optimized by compiler + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); +} + +FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) { + do { + gm_dcci((__gm__ uint8_t *)sig_addr); + + if (*sig_addr == cmp_val) { + return *sig_addr; + } + + // in case when peer pe enters next barrier + if (*sig_addr == cmp_val + 1) { + return *sig_addr; + } + } while (true); + + // never reach + return -1; +} + + +constexpr int32_t MAX_RANK_SIZE = 32; +class HcclShmem { +public: + #ifdef HCCL_COMM // hccl需要初始化hccl context + __gm__ HcclOpResParamCustom *WinContext_{nullptr}; + Hccl hccl_; + GM_ADDR m_ptrArray[MAX_RANK_SIZE]; + size_t m_segmentSize; + int32_t m_rank; + int32_t m_rankSize; + + FORCE_INLINE_AICORE + HcclShmem(){ + auto contextGM0 = AscendC::GetHcclContext(); + WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; + + m_rank = WinContext_->localUsrRankId; + m_rankSize = WinContext_->rankSize; + m_segmentSize = WinContext_->winSize; + + for (int i = 0; i < m_rankSize; i++) { + m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn : + ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn); + } + + } + + FORCE_INLINE_AICORE + size_t SegmentSize() const { + return m_segmentSize; + } + + FORCE_INLINE_AICORE + int32_t RankSize() const { + return m_rankSize; + } + #endif + + FORCE_INLINE_AICORE + GM_ADDR operator() () const { // 无参数,返回本地peermem + #ifdef HCCL_COMM + return m_ptrArray[m_rank]; + #else + return reinterpret_cast(shmemi_get_state()->heap_base); + #endif + } + + FORCE_INLINE_AICORE + GM_ADDR operator() (int32_t index) const { // 带index参数,返回远端peermem首地址 + #ifdef HCCL_COMM + return m_ptrArray[index]; + #else + return reinterpret_cast(shmem_ptr(shmemi_get_state()->heap_base, index)); + #endif + } + + + + FORCE_INLINE_AICORE + GM_ADDR operator () (int64_t offset, int32_t rankId) const { + #ifdef HCCL_COMM + if (offset < 0 || offset >= m_segmentSize) { + return nullptr; + } + if (rankId < 0 || rankId >= m_rankSize) { + return nullptr; + } + return m_ptrArray[rankId] + offset; + #else + return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId); + #endif + } + + // FORCE_INLINE_AICORE + // GM_ADDR operator () (GM_ADDR ptr, int32_t index) const { // shmem_ptr相同用法 + // #ifdef HCCL_COMM + // size_t offset = ptr - m_ptrArray[m_rank]; + // if (offset < 0 || offset >= m_segmentSize) { + // return nullptr; + // } + // if (index < 0 || index >= m_rankSize) { + // return nullptr; + // } + // return m_ptrArray[index] + offset; + // #else + // return shmem_ptr(ptr, index); + // #endif + // } + + + FORCE_INLINE_AICORE + ~HcclShmem() { + } + + FORCE_INLINE_AICORE + void CrossRankSync() { + uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); + __gm__ int32_t* sync_counter = (__gm__ int32_t*)(*this)() + flag_offset; + __gm__ int32_t* sync_base = (__gm__ int32_t*)(*this)() + flag_offset + 2048; + int count = gm_load(sync_base) + 1; + int vec_id = AscendC::GetBlockIdx(); + int vec_size = AscendC::GetBlockNum() * AscendC::GetTaskRation(); + for(int i = vec_id; i < m_rankSize; i += vec_size) { + __gm__ int32_t* sync_remote = (__gm__ int32_t*)((*this)(i)) + flag_offset + m_rank * 16; + gm_store(sync_remote, count); + gm_dcci((__gm__ uint8_t*)sync_remote); + auto sync_check = sync_counter + i * 16; + gm_signal_wait_until_eq_for_barrier(sync_check, count); + } + + AscendC::SyncAll(); + gm_store(sync_base, count); + } + + FORCE_INLINE_AICORE + __gm__ int32_t* SyncBaseAddr() { + uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); + return (__gm__ int32_t*)(*this)() + flag_offset + 2048; + } +}; + + +#endif diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/layout3d.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/layout3d.hpp new file mode 100644 index 00000000000..7cc3a9c1fce --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/layout3d.hpp @@ -0,0 +1,20 @@ +#ifndef LAYOUT_3D_HPP +#define LAYOUT_3D_HPP +#include "kernel_operator.h" +#include "catlass/catlass.hpp" +class Layout3D { + int64_t strides[2]; + public: + CATLASS_DEVICE + Layout3D() {} + CATLASS_DEVICE + Layout3D(int64_t stride0, int64_t stride1) { + strides[0] = stride0; + strides[1] = stride1; + } + CATLASS_DEVICE + int64_t operator() (int64_t dim0, int64_t dim1, int64_t dim2) { + return dim0 * strides[0] + dim1 * strides[1] + dim2; + } +}; +#endif // LAYOUT_3D_HPP diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/select_helper.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/select_helper.hpp new file mode 100644 index 00000000000..574ab3351c8 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/select_helper.hpp @@ -0,0 +1,25 @@ +#ifndef SELECT_HELPER_HPP +#define SELECT_HELPER_HPP + +#include "catlass/layout/layout.hpp" +using namespace AscendC; +using namespace Catlass; + +template +struct LayoutBInitializer { + CATLASS_DEVICE + static Layout create(uint32_t k, uint32_t n) { + return Layout{k, n}; + } +}; + +template +struct LayoutBInitializer> +> { + CATLASS_DEVICE + static Layout create(uint32_t k, uint32_t n) { + return Layout::template MakeLayout(k, n); + } +}; +#endif // SELECT_HELPER_HPP \ No newline at end of file diff --git a/csrc/third_party/catlass b/csrc/third_party/catlass new file mode 160000 index 00000000000..716fd7baa7f --- /dev/null +++ b/csrc/third_party/catlass @@ -0,0 +1 @@ +Subproject commit 716fd7baa7fb7f6cac0488bb628fd1dd0e875641 diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 68cefc15a44..351c67458e4 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -37,59 +37,60 @@ namespace vllm_ascend { const int64_t INT4_NUMS_IN_INT32 = 8; void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst, - const torch::Tensor& block_mapping, aclrtStream stream) { - torch::Device src_device = src.device(); - torch::Device dst_device = dst.device(); - aclrtMemcpyKind memcpy_type; - - if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) { - TORCH_CHECK(src_device.index() == dst_device.index(), - "src and dst must be on the same npu"); - memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE; - } else if ((!src_device.is_cpu()) && dst_device.is_cpu()) { - memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST; - } else if (src_device.is_cpu() && (!dst_device.is_cpu())) { - memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE; - } else { - TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device); - } - - TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); - - char* src_ptr = static_cast(src.data_ptr()); - char* dst_ptr = static_cast(dst.data_ptr()); - - const int64_t block_size_in_bytes = src.element_size() * src.stride(0); - - const int64_t num_blocks = block_mapping.size(0); - const int64_t max_src_block = src.size(0); - const int64_t max_dst_block = dst.size(0); - for (size_t i = 0; i < num_blocks; i++) { - int64_t src_block_number = block_mapping[i][0].item(); - int64_t dst_block_number = block_mapping[i][1].item(); - TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block, + const torch::Tensor& block_mapping, aclrtStream stream) +{ + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + aclrtMemcpyKind memcpy_type; + + if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same npu"); + memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE; + } else if ((!src_device.is_cpu()) && dst_device.is_cpu()) { + memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST; + } else if (src_device.is_cpu() && (!dst_device.is_cpu())) { + memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE; + } else { + TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device); + } + + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); + + const int64_t block_size_in_bytes = src.element_size() * src.stride(0); + + const int64_t num_blocks = block_mapping.size(0); + const int64_t max_src_block = src.size(0); + const int64_t max_dst_block = dst.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block, "src block index ", src_block_number, " out of range (max: ", max_src_block, ")"); - TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block, + TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block, "dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")"); - - int64_t src_offset = src_block_number * block_size_in_bytes; - int64_t dst_offset = dst_block_number * block_size_in_bytes; + + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; - aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes, - src_ptr + src_offset, block_size_in_bytes, - memcpy_type, stream); - } + aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes, + src_ptr + src_offset, block_size_in_bytes, + memcpy_type, stream); + } } void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z) { - const c10_npu::OptionalNPUGuard npuGuard( + const c10_npu::OptionalNPUGuard npuGuard( (!x.device().is_cpu()) ? x.device() : y.device() ); - aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - swap_blocks_impl(x, y, z, stream); - return; + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + swap_blocks_impl(x, y, z, stream); + return; } AscendType get_dtype_from_torch(at::ScalarType scalarType) @@ -617,7 +618,33 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor }); cmd.Run(); return; +} +at::Tensor& dispatch_ffn_combine( + const at::Tensor& x, + const at::Tensor& weight1, + const at::Tensor& weight2, + const at::Tensor& expert_idx, + const at::Tensor& scale1, + const at::Tensor& scale2, + const at::Tensor& probs, + c10::string_view group, + int64_t max_output_size, + at::Tensor& out +) { + char *group_ep_ptr = const_cast(group.data()); + EXEC_NPU_CMD(aclnnDispatchFFNCombine, + x, + weight1, + weight2, + expert_idx, + scale1, + scale2, + probs, + group_ep_ptr, + max_output_size, + out); + return out; } at::Tensor npu_lightning_indexer( @@ -810,4 +837,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " int sparse_mode=3) -> Tensor" ); ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention); + + ops.def( + "dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx," + " Tensor scale1, Tensor scale2, Tensor probs, str group," + " int max_output_size, Tensor! out) -> Tensor" + ); + ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index b84779e2dab..4a09c8dfc14 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -156,7 +156,21 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor c10::optional quant_mode) { return; +} +at::Tensor& dispatch_ffn_combine_meta( + const at::Tensor& x, + const at::Tensor& weight1, + const at::Tensor& weight2, + const at::Tensor& expert_idx, + const at::Tensor& scale1, + const at::Tensor& scale2, + const at::Tensor& probs, + c10::string_view group, + int64_t max_output_size, + at::Tensor& out +) { + return out; } at::Tensor npu_lightning_indexer_meta( @@ -244,5 +258,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("npu_lightning_indexer", &vllm_ascend::meta::npu_lightning_indexer_meta); // Sparse flash attention ops.impl("npu_sparse_flash_attention", &vllm_ascend::meta::npu_sparse_flash_attention_meta); + // MoE dispatch-ffn-combine + ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta); } } diff --git a/tests/e2e/nightly/ops/test_dispatch_ffn_combine.py b/tests/e2e/nightly/ops/test_dispatch_ffn_combine.py new file mode 100644 index 00000000000..90ce1f07fcf --- /dev/null +++ b/tests/e2e/nightly/ops/test_dispatch_ffn_combine.py @@ -0,0 +1,168 @@ +import random + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch_npu +from torch.distributed.distributed_c10d import _get_default_group + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +class TestDisptachFFNCombine: + + def __init__(self, rank, world_size, port): + self.rank = rank + self.world_size = world_size + self.master_ip = "127.0.0.1" + self.port = port + + def get_hcomm(self, comm_group): + hcomm_info = None + if torch.__version__ > "2.0.1": + hcomm_info = comm_group._get_backend( + torch.device("npu")).get_hccl_comm_name(self.rank) + else: + hcomm_info = comm_group.get_hccl_comm_name(self.rank) + return hcomm_info + + def setup_ep_tp( + self, + rank, + tp_size, + ep_size, + backend_type, + ep_ranks_list=None, + tp_ranks_list=None, + ): + for i in range(tp_size): + if ep_ranks_list: + ep_ranks = ep_ranks_list[i] + else: + ep_ranks = [x + ep_size * i for x in range(ep_size)] + ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks) + if rank in ep_ranks: + ep_group_tmp = ep_group + for i in range(ep_size): + if tp_ranks_list: + tp_ranks = tp_ranks_list[i] + else: + tp_ranks = [x * ep_size + i for x in range(tp_size)] + tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks) + if rank in tp_ranks: + tp_group_tmp = tp_group + return ep_group_tmp, tp_group_tmp + + def generate_hcom(self): + torch_npu.npu.set_device(self.rank) + dist.init_process_group( + backend="hccl", + rank=self.rank, + world_size=self.world_size, + init_method=f"tcp://127.0.0.1:{self.port}", + ) + + ep_size = 0 + tp_size = self.world_size + hcomm_info_dist = { + "default_pg_info": None, + "ep_hcomm_info": None, + "group_ep": None, + "tp_hcomm_info": None, + "group_tp": None, + } + if ep_size and tp_size: + group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size, + "hccl", None, None) + hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep) + hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp) + hcomm_info_dist["group_ep"] = group_ep + hcomm_info_dist["group_tp"] = group_tp + else: + if dist.is_available(): + default_pg = _get_default_group() + hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg) + hcomm_info = hcomm_info_dist["default_pg_info"] + self.hcomm_info = hcomm_info + + def run_npu_out(self) -> bool: + torch_npu.npu.set_device(self.rank) + m = 2 # token-num 32 + k = 4 # hidden_size 7168 + n = 4 # mid-hidden-size 4096 + topk = 2 + e = 2 # expert-num-per-rank 16 + k2 = n // 2 + n2 = k + + torch_npu.npu.config.allow_internal_format = True + x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + weight1 = self.generate_random_tensor((e, k, n), + dtype=torch.int8).npu() + weight1 = torch_npu.npu_format_cast(weight1, 29) + weight2 = self.generate_random_tensor((e, k2, n2), + dtype=torch.int8).npu() + weight2 = torch_npu.npu_format_cast(weight2, 29) + + expert_idx = torch.randint(0, + self.world_size * e, (m, topk), + dtype=torch.int32).npu() + scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu() + scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu() + probs = torch.randn(size=(m, topk), dtype=torch.float32).npu() + out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + + torch.ops._C_ascend.dispatch_ffn_combine( + x=x, + weight1=weight1, + weight2=weight2, + expert_idx=expert_idx, + scale1=scale1, + scale2=scale2, + probs=probs, + group=self.hcomm_info, + max_output_size=512, + out=out, + ) + return True + + def generate_random_tensor(self, size, dtype): + if dtype in [torch.float16, torch.bfloat16, torch.float32]: + return torch.randn(size=size, dtype=dtype) + elif dtype is torch.int8: + return torch.randint(-16, 16, size=size, dtype=dtype) + elif dtype is torch.int32: + return torch.randint(-1024, 1024, size=size, dtype=dtype) + else: + raise ValueError(f"Invalid dtype: {dtype}") + + +def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue): + op = TestDisptachFFNCombine(rank, world_size, port) + op.generate_hcom() + out = op.run_npu_out() + q.put(out) + + +@torch.inference_mode() +def test_dispatch_ffn_combine_kernel(): + world_size = 2 + mp.set_start_method("fork", force=True) + + q = mp.SimpleQueue() + p_list = [] + port = 29501 + random.randint(0, 10000) + + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank, world_size, port, q)) + p.start() + p_list.append(p) + + results = [q.get() for _ in range(world_size)] + + for p in p_list: + p.join() + + assert all(results) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index bd4a3509e26..70780f7492a 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -52,6 +52,7 @@ class MoECommType(Enum): ALLGATHER = 0 MC2 = 1 ALLTOALL = 2 + FUSED_ALLTOALL = 3 @contextmanager diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 945ea19743c..ef7c380c674 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -520,7 +520,7 @@ def forward_impl(self, hidden_states: torch.Tensor, # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \ and not shared_expert_dp_enabled(): shared_out = tensor_model_parallel_all_reduce(shared_out) return shared_out, fused_output diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 802dbe5d66e..138dcddf5da 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -44,6 +44,8 @@ def setup_moe_comm_method(moe_config): _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) + _MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl( + moe_config) class MoECommMethod(ABC): @@ -243,3 +245,69 @@ def _get_token_dispatcher(self): def _get_prepare_finalize(self): return PrepareAndFinalizeWithAll2All(self.moe_config) + + +class FusedAlltoAllCommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_grouped_matmul` is available. + + This implementation uses all-to-all communication to exchange tokens + between data parallel ranks before and after the MLP computation. It should + have better performance than AllGatherCommImpl when DP size > 1. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAll2AllV( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + + def _get_prepare_finalize(self): + return PrepareAndFinalizeWithAll2All(self.moe_config) + + def fused_experts( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + # For TorchAir graph + is_torchair: bool = False, + # For Cube/Vector parallel + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + # For load balance + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + need_trans: bool = False, + dynamic_eplb: bool = False, + mc2_mask: torch.Tensor = None, + pertoken_scale: Optional[torch.Tensor] = None): + out = torch.empty_like(hidden_states) + + torch.ops._C_ascend.dispatch_ffn_combine( + x=hidden_states, + weight1=w1, + weight2=w2, + expert_idx=topk_ids, + scale1=w1_scale, + scale2=w2_scale, + probs=topk_weights.to(torch.float32), + group=self.token_dispatcher.moe_all_to_all_group_name, + max_output_size=65536, + out=out, + ) + return out diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 57f26046072..75035a47a62 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -513,6 +513,11 @@ def __init__(self, **kwargs): self.local_expert_indices[i + 1] - 1), "local_expert_indices must be continuous" + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = torch.distributed.get_rank(group=self.ep_group) + backend = self.ep_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + def token_dispatch(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 03bea55460a..06a52ae1fe2 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -249,8 +249,9 @@ def _maybe_all_reduce_tensor_model_parallel_impl( final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2 - } or forward_context.sp_enabled: + if moe_comm_type in { + MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL + } or forward_context.sp_enabled: return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 2901d17504e..00c42cd87d0 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -24,6 +24,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz @@ -232,13 +233,15 @@ def apply( w2 = [layer.w2_weight] w2_scale = [layer.w2_weight_scale] + fused_flag = get_forward_context( + ).moe_comm_type == MoECommType.FUSED_ALLTOALL return moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, + w1=w1[0] if fused_flag else w1, + w1_scale=layer.fused_w1_scale if fused_flag else w1_scale, + w2=w2[0] if fused_flag else w2, + w2_scale=layer.fused_w2_scale if fused_flag else w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, use_int8_w8a8=True, @@ -270,6 +273,12 @@ def process_weights_after_loading(self, layer): layer.w2_weight_scale.data.shape[0], -1) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1) + + layer.fused_w1_scale = scale_from_float_to_int64( + layer.w13_weight_scale.data) + layer.fused_w2_scale = scale_from_float_to_int64( + layer.w2_weight_scale.data) + if self.dynamic_eplb: layer.w13_weight_list = [ weight.clone() @@ -292,3 +301,11 @@ def process_weights_after_loading(self, layer): del layer.w13_weight_scale_fp32 del layer.w2_weight_scale torch.npu.empty_cache() + + +def scale_from_float_to_int64(scale): + import numpy as np + scale = torch.from_numpy( + np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), + dtype=np.int32).astype(np.int64)).to(scale.device) + return scale diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index b4b8269be2f..ae5944d429c 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -911,6 +911,9 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: "dp": { "hccl_buffer_size": calculate_dp_buffer_size() }, + "ep": { + "hccl_buffer_size": calculate_ep_buffer_size() + }, } return hccl_config_map.get(group_name, get_default_buffer_config()) @@ -932,6 +935,30 @@ def calculate_dp_buffer_size() -> int: return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) +def calculate_ep_buffer_size() -> int: + """ + formula of ep buffer size: + batch_size * hidden_size * topk * 4 + """ + ep_buffer_size = _DEFAULT_BUFFER_SIZE + try: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + hf_config = vllm_config.model_config.hf_config + + hidden_size = hf_config.hidden_size + topk = getattr(hf_config, "num_experts_per_token", 1) + batch_size = vllm_config.scheduler_config.max_num_batched_tokens + int8_size = torch.iinfo(torch.int8).bits // 8 + bf16_size = torch.finfo(torch.bfloat16).bits // 8 + ep_buffer_size = math.ceil( + (batch_size * hidden_size * topk * + (int8_size * 2 + bf16_size)) / (1024 * 1024)) + except Exception: + pass + return max(ep_buffer_size, _DEFAULT_BUFFER_SIZE) + + # Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 # and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and # significantly improve communication performance of MC2 ops dispatch/combine. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8aff73f1d28..3558dd42db5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2217,8 +2217,9 @@ def _select_moe_comm_method(self, return None soc_version = get_ascend_device_type() - quant_type = getattr(self.vllm_config.model_config.hf_config, - 'moe_quantize', None) + quant_type = getattr( + self.vllm_config.model_config.hf_config, 'moe_quantize', + getattr(self.vllm_config.model_config.hf_config, 'quantize', None)) model_type = self.vllm_config.model_config.hf_config.model_type if not self.parallel_config.enable_expert_parallel: @@ -2237,7 +2238,8 @@ def _select_moe_comm_method(self, elif soc_version in {AscendDeviceType._910_93}: moe_comm_type = (MoECommType.MC2 if num_tokens <= self.mc2_tokens_capacity else - MoECommType.ALLTOALL) + MoECommType.FUSED_ALLTOALL if quant_type + == "w8a8_dynamic" else MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}")