diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 905400887f8..f24a9f025d9 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -24,7 +24,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) export CPATH=${ABSOLUTE_CATLASS_PATH}:${CPATH} - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm" + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;matmul_allreduce_add_rmsnorm;moe_init_routing_custom" SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series @@ -69,6 +69,7 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then "moe_dispatch_normal" "dispatch_layout" "notify_dispatch" + "moe_init_routing_custom" ) CUSTOM_OPS=$(IFS=';'; echo "${CUSTOM_OPS_ARRAY[*]}") SOC_ARG="ascend910_93" diff --git a/csrc/moe_init_routing_custom/op_host/CMakeLists.txt b/csrc/moe_init_routing_custom/op_host/CMakeLists.txt new file mode 100644 index 00000000000..3d9e67caffd --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/CMakeLists.txt @@ -0,0 +1,55 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME MoeInitRoutingCustom + OPTIONS --cce-auto-sync=on + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnExc PRIVATE + moe_init_routing_custom_def.cpp +) + +target_sources(opapi PRIVATE + moe_init_routing_custom.cpp + aclnn_moe_init_routing_custom.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + moe_init_routing_custom.cpp + aclnn_moe_init_routing_custom.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + moe_init_routing_custom.cpp + aclnn_moe_init_routing_custom.cpp + ) +endif () + +target_sources(optiling PRIVATE + moe_init_routing_custom_tiling_base.cpp + moe_init_routing_custom_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + moe_init_routing_custom_infershape.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_moe_init_routing_custom.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/aclnn_moe_init_routing_custom.cpp b/csrc/moe_init_routing_custom/op_host/aclnn_moe_init_routing_custom.cpp new file mode 100644 index 00000000000..6564a58e98f --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/aclnn_moe_init_routing_custom.cpp @@ -0,0 +1,143 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include +#include "opdev/make_op_executor.h" +#include "aclnn_kernels/contiguous.h" +#include "opdev/tensor_view_utils.h" +#include "aclnn_kernels/common/op_error_check.h" +#include "opdev/op_log.h" +#include "aclnn_kernels/cast.h" +#include "opdev/common_types.h" +#include "moe_init_routing_custom.h" +#include "aclnn_moe_init_routing_custom.h" + +using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +namespace { + static const int64_t MOE_DIM_2 = 2; + static const int64_t MOE_DIM_1 = 1; +} + +static const std::initializer_list DTYPE_SUPPORT_LIST_X= {DataType::DT_FLOAT16, DataType::DT_BF16, DataType::DT_FLOAT, DataType::DT_INT8}; +static const std::initializer_list DTYPE_SUPPORT_LIST_EXPERT_IDX = {DataType::DT_INT32}; +static const std::initializer_list DTYPE_SUPPORT_LIST_SCALE = {DataType::DT_FLOAT}; +static const std::initializer_list DTYPE_SUPPORT_LIST_OFFSET= {DataType::DT_FLOAT}; +static const std::initializer_list DTYPE_SUPPORT_LIST_EXPANDED_X_OUT = {DataType::DT_FLOAT16, DataType::DT_BF16, DataType::DT_FLOAT, DataType::DT_INT8}; +static const std::initializer_list DTYPE_SUPPORT_LIST_EXPANDED_ROW_IDX_OUT = {DataType::DT_INT32}; +static const std::initializer_list DTYPE_SUPPORT_LIST_EXPERT_TOKENS_COUNT_OR_CUMSUMOUT = {DataType::DT_INT64}; +static const std::initializer_list DTYPE_SUPPORT_LIST_EXPANDED_SCALE_OUT = {DataType::DT_FLOAT}; + +static inline bool CheckNotNull(const aclTensor *x, + const aclTensor *expertIdx, + const aclTensor *expandedXOut, + const aclTensor *expandedRowIdxOut, + const aclTensor *expertTokensCountOrCumsumOut, + const aclTensor *expandedScaleOut) { + OP_CHECK_NULL(x, return false); + OP_CHECK_NULL(expertIdx, return false); + OP_CHECK_NULL(expandedXOut, return false); + OP_CHECK_NULL(expandedRowIdxOut, return false); + OP_CHECK_NULL(expertTokensCountOrCumsumOut, return false); + OP_CHECK_NULL(expandedScaleOut, return false); + + return true; +} + +aclnnStatus aclnnMoeInitRoutingCustomGetWorkspaceSize(const aclTensor *x, + const aclTensor *expertIdx, + const aclTensor *scaleOptional, + const aclTensor *offsetOptional, + int64_t activeNum, + int64_t expertCapacity, + int64_t expertNum, + int64_t dropPadMode, + int64_t expertTokensNumType, + bool expertTokensNumFlag, + int64_t quantMode, + const aclIntArray *activeExpertRangeOptional, + int64_t rowIdxType, + const aclTensor *expandedXOut, + const aclTensor *expandedRowIdxOut, + const aclTensor *expertTokensCountOrCumsumOut, + const aclTensor *expandedScaleOut, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + L2_DFX_PHASE_1(aclnnMoeInitRoutingCustom, + DFX_IN(x, expertIdx, scaleOptional, offsetOptional, + activeNum, expertCapacity, expertNum, dropPadMode, + expertTokensNumType, expertTokensNumFlag, quantMode, activeExpertRangeOptional, rowIdxType), + DFX_OUT(expandedXOut, expandedRowIdxOut, expertTokensCountOrCumsumOut, expandedScaleOut)); + auto ret = CheckNotNull(x, expertIdx, expandedXOut, expandedRowIdxOut, + expertTokensCountOrCumsumOut, expandedScaleOut); + + CHECK_RET(ret, ACLNN_ERR_PARAM_NULLPTR); + + auto uniqueExecutor = CREATE_EXECUTOR(); + CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + + auto xContiguous = l0op::Contiguous(x, uniqueExecutor.get()); + CHECK_RET(xContiguous != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + auto expertIdxContiguous = l0op::Contiguous(expertIdx, uniqueExecutor.get()); + CHECK_RET(expertIdxContiguous != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + + const aclTensor* scaleContiguous = nullptr; + const aclTensor* offsetContiguous = nullptr; + if (scaleOptional != nullptr) { + scaleContiguous = l0op::Contiguous(scaleOptional, uniqueExecutor.get()); + CHECK_RET(scaleContiguous != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + } + + if (offsetOptional != nullptr) { + offsetContiguous = l0op::Contiguous(offsetOptional, uniqueExecutor.get()); + CHECK_RET(offsetContiguous != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); + } + + auto routingResult = std::tuple(nullptr, nullptr, nullptr, nullptr); + routingResult = l0op::MoeInitRoutingCustom(xContiguous, expertIdxContiguous, scaleContiguous, offsetContiguous, + activeNum, expertCapacity, expertNum, dropPadMode, expertTokensNumType, expertTokensNumFlag, + quantMode, activeExpertRangeOptional, rowIdxType, expandedXOut, expandedRowIdxOut, + expertTokensCountOrCumsumOut, expandedScaleOut, uniqueExecutor.get()); + auto [expandedXOut_, expandedRowIdxOut_, expertTokensCountOrCumsumOut_, expandedScaleOut_] = routingResult; + bool hasNullptr = (expandedXOut_ == nullptr) || (expandedRowIdxOut_ == nullptr) || (expertTokensCountOrCumsumOut_ == nullptr) || (expandedScaleOut_ == nullptr); + CHECK_RET(hasNullptr != true, ACLNN_ERR_INNER_NULLPTR); + + auto viewCopyExpandedXOutResult = l0op::ViewCopy(expandedXOut_, expandedXOut, uniqueExecutor.get()); + CHECK_RET(viewCopyExpandedXOutResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + auto viewCopyExpandedRowIdxOutResult = l0op::ViewCopy(expandedRowIdxOut_, expandedRowIdxOut, uniqueExecutor.get()); + CHECK_RET(viewCopyExpandedRowIdxOutResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + + auto viewCopyExpertTokensCountOrCumsumOutResult = l0op::ViewCopy(expertTokensCountOrCumsumOut_, expertTokensCountOrCumsumOut, uniqueExecutor.get()); + CHECK_RET(viewCopyExpertTokensCountOrCumsumOutResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + + auto viewCopyExpandedScaleOutResult = l0op::ViewCopy(expandedScaleOut_, expandedScaleOut, uniqueExecutor.get()); + CHECK_RET(viewCopyExpandedScaleOutResult != nullptr, ACLNN_ERR_INNER_NULLPTR); + + *workspaceSize = uniqueExecutor->GetWorkspaceSize(); + uniqueExecutor.ReleaseTo(executor); + return ACLNN_SUCCESS; +} +aclnnStatus aclnnMoeInitRoutingCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream) +{ + L2_DFX_PHASE_2(aclnnMoeInitRoutingCustom); + return CommonOpExecutorRun(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/aclnn_moe_init_routing_custom.h b/csrc/moe_init_routing_custom/op_host/aclnn_moe_init_routing_custom.h new file mode 100644 index 00000000000..5c7106b5d61 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/aclnn_moe_init_routing_custom.h @@ -0,0 +1,47 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_MOE_INIT_ROUTING_CUSTOM_H_ +#define OP_API_INC_MOE_INIT_ROUTING_CUSTOM_H_ + +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnMoeInitRoutingCustomGetWorkspaceSize(const aclTensor *x, + const aclTensor *expertIdx, + const aclTensor *scaleOptional, + const aclTensor *offsetOptional, + int64_t activeNum, + int64_t expertCapacity, + int64_t expertNum, + int64_t dropPadMode, + int64_t expertTokensNumType, + bool expertTokensNumFlag, + int64_t quantMode, + const aclIntArray *activeExpertRangeOptional, + int64_t rowIdxType, + const aclTensor *expandedXOut, + const aclTensor *expandedRowIdxOut, + const aclTensor *expertTokensCountOrCumsumOut, + const aclTensor *expandedScaleOut, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnMoeInitRoutingCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom.cpp b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom.cpp new file mode 100644 index 00000000000..df36f9d4169 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom.cpp @@ -0,0 +1,50 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include "moe_init_routing_custom.h" +#include "opdev/make_op_executor.h" +#include "opdev/op_def.h" +#include "opdev/op_dfx.h" +#include "opdev/op_executor.h" +#include "opdev/op_log.h" +#include "opdev/shape_utils.h" +#include "aclnn_kernels/common/op_error_check.h" + +using namespace op; + +namespace l0op { +OP_TYPE_REGISTER(MoeInitRoutingCustom); + +std::tuple MoeInitRoutingCustom(const aclTensor *x, const aclTensor *expertIdx, const aclTensor *scale, + const aclTensor *offset, int64_t activeNum, int64_t expertCapacity, + int64_t expertNum, int64_t dropPadMode, int64_t expertTokensNumType, + bool expertTokensNumFlag, int64_t quantMode, const aclIntArray *activeExpertRange, + int64_t rowIdxType, const aclTensor *expandedX, const aclTensor *expandedRowIdx, + const aclTensor *expertTokensCountOrCumsum, const aclTensor *expandedScale, aclOpExecutor *executor) +{ + L0_DFX(MoeInitRoutingCustom, x, expertIdx, scale, offset, activeNum, expertCapacity, expertNum, dropPadMode, expertTokensNumType, expertTokensNumFlag, + quantMode, activeExpertRange, rowIdxType, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale); + + auto expandedXOut = executor->AllocTensor(expandedX->GetViewShape(), expandedX->GetDataType(), Format::FORMAT_ND); + auto expandedRowIdxOut = executor->AllocTensor(expandedRowIdx->GetViewShape(), expandedRowIdx->GetDataType(), Format::FORMAT_ND); + auto expertTokensCountOrCumsumOut = executor->AllocTensor(expertTokensCountOrCumsum->GetViewShape(), expertTokensCountOrCumsum->GetDataType(), Format::FORMAT_ND); + auto expandedScaleOut = executor->AllocTensor(expandedScale->GetViewShape(), expandedScale->GetDataType(), Format::FORMAT_ND); + if (expandedXOut == nullptr || expandedRowIdxOut == nullptr || expertTokensCountOrCumsumOut == nullptr || expandedScaleOut == nullptr) { + OP_LOGE(ACLNN_ERR_INNER_NULLPTR, "alloc expandedXOut or expandedRowIdxOut or expertTokensCountOrCumsumOut or expandedScaleOut tensor failed."); + return std::tuple(nullptr, nullptr, nullptr, nullptr); + } + + ADD_TO_LAUNCHER_LIST_AICORE( + MoeInitRoutingCustom, OP_INPUT(x, expertIdx, scale, offset), OP_OUTPUT(expandedXOut, expandedRowIdxOut, expertTokensCountOrCumsumOut, expandedScaleOut), OP_ATTR(activeNum, expertCapacity, expertNum, dropPadMode, expertTokensNumType, expertTokensNumFlag, quantMode, activeExpertRange, rowIdxType)); + return std::tuple(expandedXOut, expandedRowIdxOut, expertTokensCountOrCumsumOut, expandedScaleOut); //OP_OUTPUT +} + +} // namespace l0op \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom.h b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom.h new file mode 100644 index 00000000000..65da3ff332a --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom.h @@ -0,0 +1,25 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_LEVEL0_MOE_INIT_ROUTING_CUSTOM_H +#define OP_API_INC_LEVEL0_MOE_INIT_ROUTING_CUSTOM_H + +#include +#include "opdev/op_executor.h" + +namespace l0op { +std::tuple MoeInitRoutingCustom(const aclTensor *x, const aclTensor *expertIdx, const aclTensor *scale, + const aclTensor *offset, int64_t activeNum, int64_t expertCapacity, + int64_t expertNum, int64_t dropPadMode, int64_t expertTokensNumType, + bool expertTokensNumFlag, int64_t quantMode, const aclIntArray *activeExpertRange, + int64_t rowIdxType, const aclTensor *expandedX, const aclTensor *expandedRowIdx, + const aclTensor *expertTokensCountOrCumsum, const aclTensor *expandedScale, aclOpExecutor *executor); +} // namespace l0op +#endif // OP_API_INC_LEVEL0_MOE_INIT_ROUTING_CUSTOM_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_def.cpp b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_def.cpp new file mode 100644 index 00000000000..c1d980b9de6 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_def.cpp @@ -0,0 +1,105 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_init_routing_v3_def.cpp + * \brief + */ +#include "register/op_def_registry.h" + +namespace ops { +class MoeInitRoutingCustom : public OpDef { +public: + explicit MoeInitRoutingCustom(const char *name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType( + {ge::DT_INT8, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expert_idx") + .ParamType(REQUIRED) + .DataType( + {ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("scale") + .ParamType(OPTIONAL) + .DataType( + {ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("offset") + .ParamType(OPTIONAL) + .DataType( + {ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Output("expanded_x") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expanded_row_idx") + .ParamType(REQUIRED) + .DataType( + {ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expert_tokens_count_or_cumsum") + .ParamType(REQUIRED) + .DataType( + {ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expanded_scale") + .ParamType(REQUIRED) + .DataType( + {ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("active_num").AttrType(OPTIONAL).Int(-1); + this->Attr("expert_capacity").AttrType(OPTIONAL).Int(-1); + this->Attr("expert_num").AttrType(OPTIONAL).Int(-1); + this->Attr("drop_pad_mode").AttrType(OPTIONAL).Int(0); + this->Attr("expert_tokens_num_type").AttrType(OPTIONAL).Int(0); + this->Attr("expert_tokens_num_flag").AttrType(OPTIONAL).Bool(false); + this->Attr("quant_mode").AttrType(OPTIONAL).Int(-1); + this->Attr("active_expert_range").AttrType(OPTIONAL).ListInt({}); + this->Attr("row_idx_type").AttrType(OPTIONAL).Int(0); + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + + } +}; + +OP_ADD(MoeInitRoutingCustom); +} // namespace ops diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_infershape.cpp b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_infershape.cpp new file mode 100644 index 00000000000..77e3d283352 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_infershape.cpp @@ -0,0 +1,797 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/* ! + * \file moe_init_routing_custom_infershape.cpp + * \brief + */ + +#include +#include +#include +#include "register/op_def_registry.h" +#include "log/ops_log.h" +#include "platform/platform_info.h" + +#define unlikely(x) __builtin_expect((x), 0) +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if (unlikely((ptr) == nullptr)) { \ + const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \ + "nil" : \ + (context)->GetNodeName(); \ + OPS_LOG_E(name, "%s is nullptr!", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +using namespace ge; +namespace ops { +static constexpr size_t DIM_ONE = 1U; +static constexpr size_t DIM_TWO = 2U; +static constexpr size_t DIM_THREE = 3U; +static constexpr int64_t NEG_ONE = static_cast(-1); +static constexpr int64_t NEG_TWO = static_cast(-2); +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_INPUT_X = 0; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_INPUT_EXPERT_IDX = 1; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_INPUT_SCALE = 2; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_INPUT_OFFSET = 3; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_ACTIVE_NUM = 0; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_CAPACITY = 1; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_NUM = 2; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_DROP_PAD_MODE = 3; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_TOKEN_NUM_TYPE = 4; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_TOKEN_NUM_FLAG = 5; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_QUANT_MODE = 6; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_ACTIVE_EXPERT_RANGE = 7; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_ATTR_ROW_IDX_TYPE = 8; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_X = 0; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_ROW_IDX = 1; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPERT_TOKEN_CUMSUM_OR_COUNT = 2; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_SCALE = 3; +static constexpr int64_t MOE_INIT_ROUTING_CUSTOM_EXPERT_END_BOUND = 10240; +static constexpr int64_t KEY_VALUE_MODE_DIM0_NUM = 2; +enum DropPadMode : int8_t { + NO_DROP_PAD = 0, + DROP_PAD = 1, +}; +enum QuantMode : int8_t { + NON_QUANT = -1, + STATIC_QUANT = 0, + DYNAMIC_QUANT = 1 +}; +enum ExpertTokenNumType : int8_t { + CUMSUM = 0, + COUNT = 1, + KEY_VALUE = 2 +}; + +static bool isSameDim(int64_t dim1, int64_t dim2) +{ + if (dim1 <= NEG_ONE || dim2 <= NEG_ONE) { + return true; + } + return dim1 == dim2; +} + +static ge::graphStatus GetAndCheckAttrActiveExpertRange(const gert::RuntimeAttrs *attrs, + gert::InferShapeContext *context, int64_t &expertStart, + int64_t &expertEnd, int64_t &experNum) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckAttrActiveExpertRange."); + // Check if active_expert_range size is 2 and if expert_start < expert_end + auto activeExpertRangePtr = attrs->GetListInt(MOE_INIT_ROUTING_CUSTOM_ATTR_ACTIVE_EXPERT_RANGE); + if (nullptr == activeExpertRangePtr) { + OPS_LOG_E(context->GetNodeName(), "The active_expert_range should be list int. But it is none."); + return ge::GRAPH_FAILED; + } + int64_t activeExpertRangeSize = activeExpertRangePtr->GetSize(); + if (activeExpertRangePtr->GetSize() == DIM_TWO) { + expertStart = activeExpertRangePtr->GetData()[0]; + expertEnd = activeExpertRangePtr->GetData()[1]; + if (expertStart >= expertEnd || expertStart < 0 || expertEnd > MOE_INIT_ROUTING_CUSTOM_EXPERT_END_BOUND) { + OPS_LOG_E(context->GetNodeName(), + "The active_expert_range should be in [0, %ld), but the active_expert_range is [%ld, %ld).", + MOE_INIT_ROUTING_CUSTOM_EXPERT_END_BOUND, expertStart, expertEnd); + return ge::GRAPH_FAILED; + } + } else if (activeExpertRangePtr->GetSize() == 0) { + expertStart = 0; + expertEnd = experNum; + } else { + OPS_LOG_E(context->GetNodeName(), "The active_expert_range size should be 2, but its size is %ld.", activeExpertRangeSize); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrActiveExpertRange."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrActiveNum(const gert::RuntimeAttrs *attrs, gert::InferShapeContext *context, + int64_t &activeNum, int64_t &dropPadMode) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckAttrActiveNum."); + const int64_t *activeNumPtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_ACTIVE_NUM); + if (nullptr == activeNumPtr) { + OPS_LOG_E(context->GetNodeName(), "The active_num should not be none."); + return ge::GRAPH_FAILED; + } + activeNum = *activeNumPtr; + if (dropPadMode == DropPadMode::NO_DROP_PAD && activeNum < -1) { + OPS_LOG_E(context->GetNodeName(), "The active_num should be greater than or equal to 0. But it is %ld.", activeNum); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrActiveNum."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrExpertCapacity(const gert::RuntimeAttrs *attrs, gert::InferShapeContext *context, + const gert::Shape *xShape, int64_t &expertCapacity, + int64_t &dropPadMode) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckAttrExpertCapacity."); + const int64_t *expertCapacityPtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_CAPACITY); + if (nullptr == expertCapacityPtr) { + OPS_LOG_E(context->GetNodeName(), "The expert_capacity should not be none."); + return ge::GRAPH_FAILED; + } + expertCapacity = *expertCapacityPtr; + if (dropPadMode == DropPadMode::DROP_PAD && xShape->GetDim(0) > 0 && expertCapacity > xShape->GetDim(0)) { + OPS_LOG_E(context->GetNodeName(), "The expert_capacity should be between 0 and n. But it is %ld.", expertCapacity); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrExpertCapacity."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrExpertNum(const gert::RuntimeAttrs *attrs, gert::InferShapeContext *context, + int64_t &experNum) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckexperNum."); + const int64_t *experNumPtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_NUM); + if (nullptr == experNumPtr) { + OPS_LOG_E(context->GetNodeName(), "The expert_num should not be none."); + return ge::GRAPH_FAILED; + } + experNum = *experNumPtr; + if (experNum <= 0 || experNum > MOE_INIT_ROUTING_CUSTOM_EXPERT_END_BOUND) { + OPS_LOG_E(context->GetNodeName(), "The expert_num should be greater than 0. But it is %ld.", experNum); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrExpertNum."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrDropPadMode(const gert::RuntimeAttrs *attrs, gert::InferShapeContext *context, + int64_t &dropPadMode) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckAttrDropPadMode."); + const int64_t *dropPadModePtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_DROP_PAD_MODE); + if (nullptr == dropPadModePtr) { + OPS_LOG_E(context->GetNodeName(), "The RuntimeAttrs for drop_pad_mode is none."); + return ge::GRAPH_FAILED; + } + + dropPadMode = *dropPadModePtr; + if (dropPadMode < DropPadMode::NO_DROP_PAD || dropPadMode > DropPadMode::DROP_PAD) { + OPS_LOG_E(context->GetNodeName(), "The drop_pad_mode should be %d or %d. But it is %ld.", DropPadMode::NO_DROP_PAD, + DropPadMode::DROP_PAD, dropPadMode); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrDropPadMode."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrExpertTokenNumType(const gert::RuntimeAttrs *attrs, gert::InferShapeContext* context, + int64_t &experTokenNumType) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckexperTokenNumType."); + const int64_t *experTokenNumTypePtr = + attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_TOKEN_NUM_TYPE); + if (nullptr == experTokenNumTypePtr) { + OPS_LOG_E(context->GetNodeName(), "The expert_token_num_type should not be none."); + return ge::GRAPH_FAILED; + } + experTokenNumType = *experTokenNumTypePtr; + if (experTokenNumType < ExpertTokenNumType::CUMSUM || experTokenNumType > ExpertTokenNumType::KEY_VALUE) { + OPS_LOG_E(context->GetNodeName(), "The expert_token_num_type should be %d, %d or %d. But it is %ld.", + ExpertTokenNumType::CUMSUM, ExpertTokenNumType::COUNT, ExpertTokenNumType::KEY_VALUE, + experTokenNumType); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrExpertTokenNumType."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrExpertTokenNumFlag(const gert::RuntimeAttrs *attrs, + gert::InferShapeContext *context, bool &experTokenNumFlag) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckexperTokenNumType."); + const bool *experTokenNumFlagPtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_EXPERT_TOKEN_NUM_FLAG); + if (nullptr == experTokenNumFlagPtr) { + OPS_LOG_E(context->GetNodeName(), "The expert_token_num_flag should not be none."); + return ge::GRAPH_FAILED; + } + experTokenNumFlag = *experTokenNumFlagPtr; + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrExpertTokenNumType."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrQuantMode(const gert::RuntimeAttrs *attrs, gert::InferShapeContext *context, + int64_t &quantMode) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckQuantMode."); + if (nullptr == attrs) { + OPS_LOG_E(context->GetNodeName(), "The RuntimeAttrs for quant_mode is none."); + return ge::GRAPH_FAILED; + } + const int64_t *quantModePtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_QUANT_MODE); + if (nullptr == quantModePtr) { + OPS_LOG_E(context->GetNodeName(), "The quant_mode should be %d, %d or %d. But it is none.", QuantMode::NON_QUANT, + QuantMode::STATIC_QUANT, QuantMode::DYNAMIC_QUANT); + return ge::GRAPH_FAILED; + } + quantMode = *quantModePtr; + if (quantMode < QuantMode::NON_QUANT || quantMode > QuantMode::DYNAMIC_QUANT) { + OPS_LOG_E(context->GetNodeName(), "The quant_mode should be %d, %d or %d. But it is %ld.", QuantMode::NON_QUANT, + QuantMode::STATIC_QUANT, QuantMode::DYNAMIC_QUANT, quantMode); + return ge::GRAPH_FAILED; + } + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckQuantMode."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAndCheckAttrRowIdxType(const gert::RuntimeAttrs *attrs, gert::InferShapeContext *context, + int64_t &rowIdxType, int64_t &dropPadMode) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do GetAndCheckAttrRowIdxType."); + if (nullptr == attrs) { + OPS_LOG_E(context->GetNodeName(), "The RuntimeAttrs for row_Idx_type is none."); + return ge::GRAPH_FAILED; + } + const int64_t *dropPadModePtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_DROP_PAD_MODE); + dropPadMode = *dropPadModePtr; + + const int64_t *rowIdxTypePtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_ROW_IDX_TYPE); + if (nullptr == rowIdxTypePtr) { + OPS_LOG_E(context->GetNodeName(), "The row_Idx_type should be 0 or 1. But it is none."); + return ge::GRAPH_FAILED; + } + rowIdxType = *rowIdxTypePtr; + if (dropPadMode == DropPadMode::DROP_PAD && rowIdxType != 0) { + OPS_LOG_E(context->GetNodeName(), "The row_Idx_type should be 0 when dropPadMode is equal to 1 But it is %ld.", rowIdxType); + return ge::GRAPH_FAILED; + } + + if (rowIdxType < 0 || rowIdxType > 1) { + OPS_LOG_E(context->GetNodeName(), "The row_Idx_type should be 0 or 1 But it is %ld.", rowIdxType); + return ge::GRAPH_FAILED; + } + + OPS_LOG_D(context->GetNodeName(), "End to do GetAndCheckAttrRowIdxType."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckInputScaleShape(gert::InferShapeContext *context, const gert::Shape *xShape, + const gert::Shape *scaleShape, const int64_t expertStart, + const int64_t expertEnd, const int64_t quantMode) +{ + // When quant_mode is STATIC_QUANT, scale cannot be none. + OP_CHECK((nullptr == scaleShape && QuantMode::STATIC_QUANT == quantMode), + OPS_LOG_E(context->GetNodeName(), "The scale cannot be none when quant_mode is %ld.", quantMode), + return ge::GRAPH_FAILED); + + // When quant_mode is NON_QUANT or DYNAMIC_QUANT, scale can be none. + OP_CHECK((nullptr == scaleShape && (QuantMode::NON_QUANT == quantMode || QuantMode::DYNAMIC_QUANT == quantMode)), + OPS_LOG_I(context->GetNodeName(), "When quant_mode is NON_QUANT or DYNAMIC_QUANT, scale can be none."), + return ge::GRAPH_SUCCESS); + + if (QuantMode::NON_QUANT == quantMode) { + if (scaleShape->GetDimNum() == DIM_ONE) { + OP_CHECK(scaleShape->GetDim(0) < 0 && scaleShape->GetDim(0) != NEG_ONE && scaleShape->GetDim(0) != NEG_TWO, + OPS_LOG_E(context->GetNodeName(), + "When quant_mode is %ld and use scale in dynamic graph, The shape of scale should be (-1) or (-2), current shape is (%s).", + quantMode, ops::Shape2String(*scaleShape).c_str()), + return ge::GRAPH_FAILED); + OP_CHECK(scaleShape->GetDim(0) > 0 && !isSameDim(scaleShape->GetDim(0), xShape->GetDim(0)), + OPS_LOG_E(context->GetNodeName(), + "When quant_mode is %ld and use scale in static graph, The shape of scale should be (%ld,), current shape is (%s).", + quantMode, xShape->GetDim(0), ops::Shape2String(*scaleShape).c_str()), + return ge::GRAPH_FAILED); + } else { + OPS_LOG_E(context->GetNodeName(), "When quant_mode is %ld, The dimNum of scale should be 1, current shape is (%ld).", quantMode, + scaleShape->GetDimNum()); + return ge::GRAPH_FAILED; + } + } else if (QuantMode::STATIC_QUANT == quantMode) { + if (scaleShape->GetDimNum() == DIM_ONE) { + OP_CHECK( + scaleShape->GetDim(0) != NEG_ONE && scaleShape->GetDim(0) != NEG_TWO && + !isSameDim(scaleShape->GetDim(0), DIM_ONE), + OPS_LOG_E( + context->GetNodeName(), + "When quant_mode is %ld, the shape of scale should be (-1) or (-2) or (1,), current shape is (%s).", + quantMode, ops::Shape2String(*scaleShape).c_str()), + return ge::GRAPH_FAILED); + } else { + OPS_LOG_E(context->GetNodeName(), "When quant_mode is %ld, the dimNum of scale should be (1,), current shape is (%ld).", + quantMode, scaleShape->GetDimNum()); + return ge::GRAPH_FAILED; + } + } else if (QuantMode::DYNAMIC_QUANT == quantMode) { + int64_t activeExpertRange = expertEnd - expertStart; + if (scaleShape->GetDimNum() == DIM_ONE) { + OP_CHECK(scaleShape->GetDim(0) != NEG_TWO, + OPS_LOG_E(context->GetNodeName(), + "When quant_mode is %ld and scale dim is 1 in dynamic graph, the first dim of scale should be -2, but " + "its shape is (%ld).", + quantMode, scaleShape->GetDim(0)), + return ge::GRAPH_FAILED); + } else if (scaleShape->GetDimNum() == DIM_TWO) { + if (scaleShape->GetDim(0) > 0) { + OP_CHECK( + !isSameDim(scaleShape->GetDim(0), activeExpertRange) && !isSameDim(scaleShape->GetDim(0), DIM_ONE), + OPS_LOG_E( + context->GetNodeName(), + "When quant_mode is %ld in static graph, the first dim of scale should be 1 or %ld, but its shape is (%ld).", + quantMode, activeExpertRange, scaleShape->GetDim(0)), + return ge::GRAPH_FAILED); + OP_CHECK( + !isSameDim(scaleShape->GetDim(1), xShape->GetDim(1)), + OPS_LOG_E( + context->GetNodeName(), + "When quant_mode is %ld in static graph, the second dim of scale should or %ld, but its shape is (%ld).", + quantMode, xShape->GetDim(1), scaleShape->GetDim(0)), + return ge::GRAPH_FAILED); + } else { + OP_CHECK( + scaleShape->GetDim(0) != NEG_ONE || (scaleShape->GetDim(1) != NEG_ONE && scaleShape->GetDim(1) != xShape->GetDim(1)), + OPS_LOG_E(context->GetNodeName(), + "When quant_mode is %ld and scale dim is 2 in dynamic graph, the shape of scale should be (-1, -1) or (-1, %d), but its shape is (%s).", + quantMode, xShape->GetDim(1), ops::Shape2String(*scaleShape).c_str()), + return ge::GRAPH_FAILED); + } + } else { + OPS_LOG_E( + context->GetNodeName(), + "When quant_mode is %ld, the dimNum of scale should be 1(dynamic graph) or 2, but its shape is (%ld).", + scaleShape->GetDimNum()); + return ge::GRAPH_FAILED; + } + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckInputOffsetShape(gert::InferShapeContext *context, + const gert::Shape *offsetShape, const int64_t expertStart, + const int64_t expertEnd, const int64_t quantMode) +{ + // The shape of offset can be none. + if (quantMode != QuantMode::STATIC_QUANT) { + return ge::GRAPH_SUCCESS; + } else if (nullptr == offsetShape) { + return ge::GRAPH_FAILED; + } + + if (offsetShape->GetDimNum() != DIM_ONE) { + OPS_LOG_E(context->GetNodeName(), "The dimNum of offset should be 1, current shape is (%ld).", offsetShape->GetDimNum()); + return ge::GRAPH_FAILED; + } + if (offsetShape->GetDim(0) != NEG_ONE && offsetShape->GetDim(0) != NEG_TWO && !isSameDim(offsetShape->GetDim(0), DIM_ONE)) { + OPS_LOG_E(context->GetNodeName(), + "The shape of offset should be (1,) in static graph or (-2), (-1,) in dynamic graph, current shape is (%s).", + ops::Shape2String(*offsetShape).c_str()); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckInputShape(gert::InferShapeContext *context, const gert::Shape *xShape, + const gert::Shape *expertIdxShape, const gert::Shape *scaleShape, + const gert::Shape *offsetShape, const int64_t expertStart, + const int64_t expertEnd, const int64_t quantMode) +{ + // Check the shape of input_x + if (xShape->GetDimNum() == DIM_ONE) { + if (xShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) { + OPS_LOG_E(context->GetNodeName(), "The dynamic dim of x should be -2, current shape is %s.", + ops::Shape2String(*xShape).c_str()); + return ge::GRAPH_FAILED; + } + } else if (xShape->GetDimNum() != DIM_TWO) { + OPS_LOG_E(context->GetNodeName(), "The dim of x should be 2 or dynamic, current shape is %s.", + ops::Shape2String(*xShape).c_str()); + return ge::GRAPH_FAILED; + } + + int64_t x_n = xShape->GetDimNum() == DIM_ONE ? NEG_ONE : xShape->GetDim(0); + int64_t cols = xShape->GetDimNum() == DIM_ONE ? NEG_ONE : xShape->GetDim(1); + if (x_n < NEG_ONE || cols < NEG_ONE) { + OPS_LOG_E(context->GetNodeName(), "Invalid x shape, shape is %s.", ops::Shape2String(*xShape).c_str()); + return ge::GRAPH_FAILED; + } + + // Check the shape of expert_idx + if (expertIdxShape->GetDimNum() == DIM_ONE) { + if (expertIdxShape->GetDim(0) != ge::UNKNOWN_DIM_NUM) { + OPS_LOG_E(context->GetNodeName(), "The dynamic dim of expert_idx should be -2, current shape is %s.", + ops::Shape2String(*expertIdxShape).c_str()); + return ge::GRAPH_FAILED; + } + } else if (expertIdxShape->GetDimNum() != DIM_TWO) { + OPS_LOG_E(context->GetNodeName(), "The dim of expert_idx should be 2 or dynamic, current shape is %s.", + ops::Shape2String(*expertIdxShape).c_str()); + return ge::GRAPH_FAILED; + } + + int64_t expert_idx_n = expertIdxShape->GetDimNum() == DIM_ONE ? NEG_ONE : expertIdxShape->GetDim(0); + int64_t expert_idx_k = expertIdxShape->GetDimNum() == DIM_ONE ? NEG_ONE : expertIdxShape->GetDim(1); + if (expert_idx_n < NEG_ONE || expert_idx_k < NEG_ONE) { + OPS_LOG_E(context->GetNodeName(), "Invalid expert_idx shape, shape is %s.", + ops::Shape2String(*expertIdxShape).c_str()); + return ge::GRAPH_FAILED; + } + + if (!isSameDim(x_n, expert_idx_n)) { + OPS_LOG_E(context->GetNodeName(), "The first dim of x and expert_idx should be same."); + return ge::GRAPH_FAILED; + } + // Check the shape of scale + if (CheckInputScaleShape(context, xShape, scaleShape, expertStart, expertEnd, quantMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // Check the shape of offset + if (CheckInputOffsetShape(context, offsetShape, expertStart, expertEnd, quantMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +static void ShowInputShapeAndAttrInfo(gert::InferShapeContext *context, const gert::Shape *xShape, + const gert::Shape *expertIdxShape, const gert::Shape *scaleShape, + const gert::Shape *offsetShape, const int64_t expertStart, + const int64_t expertEnd, const int64_t quantMode, const int64_t rowIdxType) +{ + // input_x and expert_idx are all required. + OPS_LOG_D(context->GetNodeName(), "x shape is: %s.", ops::Shape2String(*xShape).c_str()); + OPS_LOG_D(context->GetNodeName(), "expert_idx shape is: %s.", ops::Shape2String(*expertIdxShape).c_str()); + + // scale is optional and can be none. + if (nullptr == scaleShape) { + OPS_LOG_D(context->GetNodeName(), "scale_shape is: none."); + } else { + OPS_LOG_D(context->GetNodeName(), "scale_shape is: %s.", ops::Shape2String(*scaleShape).c_str()); + } + + // offset is optional and can be none. + OPS_LOG_D(context->GetNodeName(), "Begin print offset_shape."); + if (nullptr == offsetShape) { + OPS_LOG_D(context->GetNodeName(), "offset_shape is: none."); + } else { + OPS_LOG_D(context->GetNodeName(), "offset_shape is: %s.", ops::Shape2String(*offsetShape).c_str()); + } + OPS_LOG_D(context->GetNodeName(), "End print offset_shape."); + + // Attrs are all required. + OPS_LOG_D(context->GetNodeName(), "active_expert_range is: [%ld, %ld).", expertStart, expertEnd); + OPS_LOG_D(context->GetNodeName(), "quant_mode is: %ld.", quantMode); + OPS_LOG_D(context->GetNodeName(), "row_Idx_type is: %ld.", rowIdxType); +} + +static void ShowOutputShapeInfo(gert::InferShapeContext *context, const gert::Shape *expandedXShape, + const gert::Shape *expandedRowIdxShape, + const gert::Shape *expertTokenCumsumOrCountShape, const gert::Shape *expandedScaleShape) +{ + OPS_LOG_D(context->GetNodeName(), "expanded_x shape is: %s after infershape.", + ops::Shape2String(*expandedXShape).c_str()); + OPS_LOG_D(context->GetNodeName(), "expanded_row_idx shape is: %s after infershape.", + ops::Shape2String(*expandedRowIdxShape).c_str()); + OPS_LOG_D(context->GetNodeName(), "expert_token_cumsum_or_count shape is: %s after infershape.", + ops::Shape2String(*expertTokenCumsumOrCountShape).c_str()); + OPS_LOG_D(context->GetNodeName(), "expanded_scale shape is: %s after infershape.", + ops::Shape2String(*expandedScaleShape).c_str()); +} + +static ge::graphStatus InferShape4MoeInitRoutingCustom(gert::InferShapeContext *context) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do MoeInitRoutingCustomInfershape."); + // 1. Get and check input shape + // 1.1 Get and check input_x + const gert::Shape *xShape = context->GetInputShape(MOE_INIT_ROUTING_CUSTOM_INPUT_X); + OP_CHECK_NULL_WITH_CONTEXT(context, xShape); + + // 1.2 Get and check expert_idx + const gert::Shape *expertIdxShape = context->GetInputShape(MOE_INIT_ROUTING_CUSTOM_INPUT_EXPERT_IDX); + OP_CHECK_NULL_WITH_CONTEXT(context, expertIdxShape); + + // 1.3 Get scale shape without checking null, because scale is optional and can be none. + const gert::Shape *scaleShape = context->GetOptionalInputShape(MOE_INIT_ROUTING_CUSTOM_INPUT_SCALE); + + // 1.4 Get offset shape without checking null, because offset is optional and can be none. + const gert::Shape *offsetShape = context->GetOptionalInputShape(MOE_INIT_ROUTING_CUSTOM_INPUT_OFFSET); + // 2. Get and check attrs + const gert::RuntimeAttrs *attrs = context->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context, attrs); + + // 2.1 Get and check expert_num attr + int64_t experNum = static_cast(-1); + if (GetAndCheckAttrExpertNum(attrs, context, experNum) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.2 Get and check active_expert_range attr + int64_t expertStart = static_cast(-1); + int64_t expertEnd = static_cast(-1); + if (GetAndCheckAttrActiveExpertRange(attrs, context, expertStart, expertEnd, experNum) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + if (nullptr == attrs) { + OPS_LOG_E(context->GetNodeName(), "The attrs is none."); + return ge::GRAPH_FAILED; + } + + // 2.3 Get and check drop_pad_mode attr + int64_t dropPadMode = static_cast(-1); + if (GetAndCheckAttrDropPadMode(attrs, context, dropPadMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.4 Get and check active_num attr + int64_t activeNum = static_cast(-1); + if (GetAndCheckAttrActiveNum(attrs, context, activeNum, dropPadMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.5 Get and check expert_capacity attr + int64_t expertCapacity = static_cast(-1); + if (GetAndCheckAttrExpertCapacity(attrs, context, xShape, expertCapacity, dropPadMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.6 Get and check expert_token_num_type attr + int64_t expertTokenNumType = static_cast(-1); + if (GetAndCheckAttrExpertTokenNumType(attrs, context, expertTokenNumType) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.7 Get and check expert_token_num_type attr + bool expertTokenNumFlag = false; + if (GetAndCheckAttrExpertTokenNumFlag(attrs, context, expertTokenNumFlag) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.8 Get and check quant_mode attr + int64_t quantMode = static_cast(-1); + if (GetAndCheckAttrQuantMode(attrs, context, quantMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 2.9 Get and check row_Idx_type attr + int64_t rowIdxType = static_cast(-1); + if (GetAndCheckAttrRowIdxType(attrs, context, rowIdxType, dropPadMode) != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // Check input shape + if (CheckInputShape(context, xShape, expertIdxShape, scaleShape, offsetShape, expertStart, expertEnd, quantMode) != + ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + // 3. Infer output shape + // 3.1 Prepare output shape + gert::Shape *expandedXShape = context->GetOutputShape(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_X); + OP_CHECK_NULL_WITH_CONTEXT(context, expandedXShape); + gert::Shape *expandedRowIdxShape = context->GetOutputShape(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_ROW_IDX); + OP_CHECK_NULL_WITH_CONTEXT(context, expandedRowIdxShape); + gert::Shape *expertTokenCumsumOrCountShape = + context->GetOutputShape(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPERT_TOKEN_CUMSUM_OR_COUNT); + OP_CHECK_NULL_WITH_CONTEXT(context, expertTokenCumsumOrCountShape); + gert::Shape *expandedScaleShape = context->GetOutputShape(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_SCALE); + OP_CHECK_NULL_WITH_CONTEXT(context, expandedScaleShape); + + int64_t x_n = xShape->GetDimNum() == DIM_ONE ? NEG_ONE : xShape->GetDim(0); + int64_t cols = xShape->GetDimNum() == DIM_ONE ? NEG_ONE : xShape->GetDim(1); + + int64_t expert_idx_n = expertIdxShape->GetDimNum() == DIM_ONE ? NEG_ONE : expertIdxShape->GetDim(0); + int64_t k = expertIdxShape->GetDimNum() == DIM_ONE ? NEG_ONE : expertIdxShape->GetDim(1); + int64_t n = x_n > expert_idx_n ? x_n : expert_idx_n; + if (activeNum == 0 || activeNum == -1) { + activeNum = n * k; + } else { + activeNum = std::min(activeNum, n * k); + } + + int64_t xOutDimNum = activeNum < n * k ? activeNum : n * k; + int64_t outNum = (n == NEG_ONE || k == NEG_ONE) ? NEG_ONE : n * k; + int64_t xOutNum = (n == NEG_ONE || k == NEG_ONE) ? NEG_ONE : xOutDimNum; + // 3.2 Set output expanded_x shape + if (dropPadMode == DropPadMode::NO_DROP_PAD) { + expandedXShape->SetDimNum(DIM_TWO); + expandedXShape->SetDim(0U, xOutNum); + expandedXShape->SetDim(DIM_ONE, cols); + } else { + expandedXShape->SetDimNum(DIM_THREE); + expandedXShape->SetDim(0U, experNum); + expandedXShape->SetDim(DIM_ONE, expertCapacity); + expandedXShape->SetDim(DIM_TWO, cols); + } + + // 3.3 Set output expanded_row_idx shape + expandedRowIdxShape->SetDimNum(DIM_ONE); + expandedRowIdxShape->SetDim(0U, outNum); + + // 3.4 Set output expert_token_cumsum_or_count shape + if (expertTokenNumFlag) { + if (expertTokenNumType == ExpertTokenNumType::KEY_VALUE) { + expertTokenCumsumOrCountShape->SetDimNum(DIM_TWO); + expertTokenCumsumOrCountShape->SetDim(0U, experNum); + expertTokenCumsumOrCountShape->SetDim(DIM_ONE, KEY_VALUE_MODE_DIM0_NUM); + } else { + expertTokenCumsumOrCountShape->SetDimNum(DIM_ONE); + expertTokenCumsumOrCountShape->SetDim(0U, expertEnd - expertStart); + } + } + + // 3.5 Set output expanded_scale shape + // When scale_shape=(b*s) and non-quant, or it is dynamic quant mode, the shape of expanded_scale should be (b*s*k) + if (QuantMode::NON_QUANT == quantMode || QuantMode::DYNAMIC_QUANT == quantMode) { + expandedScaleShape->SetDimNum(DIM_ONE); + if (dropPadMode == DropPadMode::NO_DROP_PAD) { + expandedScaleShape->SetDim(0U, xOutNum); + } else { + expandedScaleShape->SetDim(0U, experNum * expertCapacity); + } + } + + ShowOutputShapeInfo(context, expandedXShape, expandedRowIdxShape, expertTokenCumsumOrCountShape, + expandedScaleShape); + OPS_LOG_D(context->GetNodeName(), "End to do MoeInitRoutingCustomInfershape."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataType4MoeInitRoutingCustom(gert::InferDataTypeContext *context) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do MoeInitRoutingCustomInferDataType."); + + // Get and check quant_mode attr + const gert::RuntimeAttrs *attrs = context->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context, attrs); + int64_t quantMode = static_cast(-1); + const int64_t *quantModePtr = attrs->GetAttrPointer(MOE_INIT_ROUTING_CUSTOM_ATTR_QUANT_MODE); + if (nullptr == quantModePtr) { + OPS_LOG_E(context->GetNodeName(), "The quant_mode should be %d, %d or %d. But it is none.", QuantMode::NON_QUANT, + QuantMode::STATIC_QUANT, QuantMode::DYNAMIC_QUANT); + return ge::GRAPH_FAILED; + } + quantMode = *quantModePtr; + // Infer output dtype according quant_mode + auto xDtype = context->GetInputDataType(MOE_INIT_ROUTING_CUSTOM_INPUT_X); + if (QuantMode::NON_QUANT == quantMode) { + context->SetOutputDataType(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_X, xDtype); + } else if (QuantMode::STATIC_QUANT == quantMode || QuantMode::DYNAMIC_QUANT == quantMode) { + if (ge::DT_INT8 == xDtype) { + OPS_LOG_E(context->GetNodeName(), "When quant_mode=%ld, xDtype cannot be int_8.", quantMode); + return ge::GRAPH_FAILED; + } + context->SetOutputDataType(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_X, ge::DT_INT8); + } + context->SetOutputDataType(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_ROW_IDX, ge::DT_INT32); + context->SetOutputDataType(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPERT_TOKEN_CUMSUM_OR_COUNT, ge::DT_INT64); + context->SetOutputDataType(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_SCALE, ge::DT_FLOAT); + OPS_LOG_D(context->GetNodeName(), "End to do MoeInitRoutingCustomInferDataType."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus InferShapeRange4MoeInitRoutingCustom(gert::InferShapeRangeContext *context) +{ + OPS_LOG_D(context->GetNodeName(), "Begin to do MoeInitRoutingCustomInferRange."); + + // Get and check the pointers of all the outputs' shape range object + auto expanded_x = context->GetOutputShapeRange(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_X); + OP_CHECK_NULL_WITH_CONTEXT(context, expanded_x); + auto expanded_row_idx = context->GetOutputShapeRange(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_ROW_IDX); + OP_CHECK_NULL_WITH_CONTEXT(context, expanded_row_idx); + auto count = context->GetOutputShapeRange(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPERT_TOKEN_CUMSUM_OR_COUNT); + OP_CHECK_NULL_WITH_CONTEXT(context, count); + auto expanded_scale = context->GetOutputShapeRange(MOE_INIT_ROUTING_CUSTOM_OUTPUT_EXPANDED_SCALE); + OP_CHECK_NULL_WITH_CONTEXT(context, expanded_scale); + + // Print the shape ranges of the outputs before InferShapeRange + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, expanded_x->GetMin() = %s", + ops::Shape2String(*(expanded_x->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, expanded_x->GetMax() = %s", + ops::Shape2String(*(expanded_x->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, expanded_row_idx->GetMin() = %s", + ops::Shape2String(*(expanded_row_idx->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, expanded_row_idx->GetMax() = %s", + ops::Shape2String(*(expanded_row_idx->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, count->GetMin() = %s", + ops::Shape2String(*(count->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, count->GetMax() = %s", + ops::Shape2String(*(count->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, expanded_scale->GetMin() = %s", + ops::Shape2String(*(expanded_scale->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "Before InferShapeRange, expanded_scale->GetMax() = %s", + ops::Shape2String(*(expanded_scale->GetMax())).c_str()); + + // Set the dim num and dim of the outputs' shape range object + if (expanded_x->GetMin() != nullptr && expanded_x->GetMax() != nullptr) { + expanded_x->GetMin()->SetDimNum(DIM_TWO); + expanded_x->GetMax()->SetDimNum(DIM_TWO); + for (size_t i = 0; i < DIM_TWO; i++) { + expanded_x->GetMin()->SetDim(i, 0); + expanded_x->GetMax()->SetDim(i, -1); + } + } + + if (expanded_row_idx->GetMin() != nullptr && expanded_row_idx->GetMax() != nullptr) { + expanded_row_idx->GetMin()->SetDimNum(DIM_ONE); + expanded_row_idx->GetMax()->SetDimNum(DIM_ONE); + expanded_row_idx->GetMin()->SetDim(0, 0); + expanded_row_idx->GetMax()->SetDim(0, -1); + } + + if (count->GetMin() != nullptr && count->GetMax() != nullptr) { + count->GetMin()->SetDimNum(DIM_ONE); + count->GetMax()->SetDimNum(DIM_ONE); + count->GetMin()->SetDim(0, 0); + count->GetMax()->SetDim(0, -1); + } + + if (expanded_scale->GetMin() != nullptr && expanded_scale->GetMax() != nullptr) { + expanded_scale->GetMin()->SetDimNum(DIM_ONE); + expanded_scale->GetMax()->SetDimNum(DIM_ONE); + expanded_scale->GetMin()->SetDim(0, 0); + expanded_scale->GetMax()->SetDim(0, -1); + } + + // Print the shape ranges of the outputs after InferShapeRange + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, expanded_x->GetMin() = %s", + ops::Shape2String(*(expanded_x->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, expanded_x->GetMax() = %s", + ops::Shape2String(*(expanded_x->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, expanded_row_idx->GetMin() = %s", + ops::Shape2String(*(expanded_row_idx->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, expanded_row_idx->GetMax() = %s", + ops::Shape2String(*(expanded_row_idx->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, count->GetMin() = %s", + ops::Shape2String(*(count->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, count->GetMax() = %s", + ops::Shape2String(*(count->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, expanded_scale->GetMin() = %s", + ops::Shape2String(*(expanded_scale->GetMin())).c_str()); + OPS_LOG_D(context->GetNodeName(), "After InferShapeRange, expanded_scale->GetMax() = %s", + ops::Shape2String(*(expanded_scale->GetMax())).c_str()); + + OPS_LOG_D(context->GetNodeName(), "End to do MoeInitRoutingCustomInferRange."); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_INFERSHAPE(MoeInitRoutingCustom) + .InferShape(InferShape4MoeInitRoutingCustom) + .InferDataType(InferDataType4MoeInitRoutingCustom) + .InferShapeRange(InferShapeRange4MoeInitRoutingCustom); +} // namespace ops \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling.cpp b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling.cpp new file mode 100644 index 00000000000..411ec97f7d3 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling.cpp @@ -0,0 +1,1267 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_init_routing_custom_tiling.cpp + * \brief + */ +#include "moe_init_routing_custom_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_templates_registry.h" + +namespace optiling { +const static int64_t NUM_TWO = 2; +const static int64_t NUM_THREE = 3; +const static int64_t NUM_FOUR = 4; +const static int64_t NUM_FIVE = 5; +const static int64_t MRG_LIST_NUM = 4; +const static int64_t SORT32_ALIGN_ELEMENT = 32; +const static int64_t ONE_BLOCK_BYTE = 32; +const static size_t DIM_ONE = 1; +const static size_t DIM_TWO = 2; +const static int32_t SIZE_16 = 16; +const static int32_t SIZE_31 = 31; +const static int32_t LENGTH_1024 = 1024; +const static int64_t MAX_COLS_ONE_LOOP = 16376; +const static int64_t ASSIST_NUM = 256; +const static int64_t SPLIT_K_THRESHOLD = 512; +const static int64_t KV_FACTOR = 2; +const static int64_t ONE_CORE_SORT_BUFFER = 6; +const static int64_t EXPERT_IDX_MAX = 10240; +const static int64_t KV_MODE_EXPERT_IDX_MAX = EXPERT_IDX_MAX / KV_FACTOR; +const static int64_t ACTIVE_NUM_MIN_VALUE = static_cast(-1); +const static int64_t EXPERT_CAPACITY_MIN_VALUE = static_cast(0); + +const static int64_t INPUT_X_INDEX = 0; +const static int64_t INPUT_EXPERT_IDX_INDEX = 1; +const static int64_t INPUT_SCALE_INDEX = 2; +const static int64_t INPUT_OFFSET_INDEX = 3; +const static int64_t OUTPUT_EXPANDED_X_INDEX = 0; +const static int64_t OUTPUT_EXPANDED_ROW_IDX_INDEX = 1; +const static int64_t OUTPUT_EXPERT_TOKENS_COUNT_INDEX = 2; +const static int64_t OUTPUT_EXPANDED_SCALE_INDEX = 3; +const static int64_t ATTR_ACTIVE_NUM_INDEX = 0; +const static int64_t ATTR_EXPERT_CAPACITY_INDEX = 1; +const static int64_t ATTR_EXPERT_NUM_INDEX = 2; +const static int64_t ATTR_DROP_PAD_MODE_INDEX = 3; +const static int64_t ATTR_EXPERT_TOKEN_NUM_TYPE_INDEX = 4; +const static int64_t ATTR_EXPERT_TOKEN_NUM_FLAG_INDEX = 5; +const static int64_t ATTR_QUANT_MODE_INDEX = 6; +const static int64_t ATTR_EXPERT_RANGE_INDEX = 7; +const static int64_t ATTR_ROW_IDX_TYPE_INDEX = 8; +const static int64_t ATTR_EXPERT_RANGE_DIM = 2; +const static int64_t GATHER = 0; +const static int64_t SCATTER = 1; +const static int64_t UN_QUANT = -1L; +const static int64_t STATIC_QUANT = 0; +const static int64_t DYNAMIC_QUANT = 1; +const static int64_t CUMSUM = 0; +const static int64_t COUNT = 1; +const static int64_t KEY_VALUE = 2; +const static int64_t DROP_LESS = 0; +const static int64_t DROP_PAD = 1; +const static int64_t DYNAMIC_QUANT_COLS_BUFFER = 21; +const static int64_t DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER = 13; +const static int64_t STATIC_QUANT_FULLLOAD_COLS_BUFFER = 11; + +const static int64_t DYNAMIC_QUANT_SRC_TO_DST_BUFFER = 15; +const static int64_t DYNAMIC_QUANT_SCALE_SIZE_64 = 64; +const static int64_t MAX_COLS_DYNAMIC_QUANT = 6144; +const static int64_t SIZE_INT32 = 4; +const static int64_t SIZE_INT16 = 2; +const static int64_t SIZE_INT8 = 1; +const static int64_t SIZE_FP32 = 4; + +const static uint64_t TILINGKEY_BASE = 1000000; +const static uint64_t SORT_CORE_TILINGKEY_BASE = 100000; +const static uint64_t QUANT_MODE_TILINGKEY_BASE = 10000; +const static uint64_t ROWIDX_TYPE_TILINGKEY_BASE = 1000; +const static uint64_t DROP_MODE_TILINGKEY_BASE = 100; + +// Tiling Key for performance puncturing +const static uint64_t PERFORMANCE_TILINGKEY_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168 = 2000000; +const static uint64_t UNQUANTIZED_FULLLOAD_TILINGKEY = 2100000; +const static uint64_t STATIC_QUANT_FULLLOAD_TILINGKEY = 2200000; +const static uint64_t DYNAMIC_QUANT_FULLLOAD_TILINGKEY = 2300000; +const static uint64_t DYNAMIC_QUANT_EPFULLLOAD_TILINGKEY = 10000; +const static uint64_t DYNAMIC_QUANT_SMOOTHTYPE_FULLLOAD_TILINGKEY = 1000; + +const static int64_t PERFORMANCE_MODE_TOP_K = 8; +const static int64_t PERFORMANCE_MODE_BS_MIN = 384; +const static int64_t PERFORMANCE_MODE_BS_MAX = 8192; +const static int64_t PERFORMANCE_MODE_RANGE_MAX = 32; +const static int64_t PERFORMANCE_MODE_MAX_BATCH_SIZE_TOP_K = PERFORMANCE_MODE_BS_MAX * PERFORMANCE_MODE_TOP_K; +const static int64_t PERFORMANCE_MODE_MAX_ONE_CORE_GATHER = 21845; + +const static int64_t gatherFirstN = 100; +const static int64_t gatherFirstScale = 8; +const static int64_t scale1H = 1; +const static int64_t scaleEH = 2; +const static int64_t ONE_REPEAT_SORT_NUM = 32; + +enum class PerformanceMode : int32_t { + COMMON = 0, + ONE_CORE_GATHER_SORT = 1, + MULTI_CORE_GATHER_SORT = 2, +}; + +static constexpr int64_t KEY_VALUE_MODE_DIM0_NUM = 2; + +#define unlikely(x) __builtin_expect((x), 0) + +#define CHECK_FAIL(context, cond, ...) \ + do { \ + if (cond) { \ + OPS_LOG_E(context->GetNodeName(), ##__VA_ARGS__); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if (unlikely((ptr) == nullptr)) { \ + const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \ + "nil" : \ + (context)->GetNodeName(); \ + OPS_LOG_E(name, "%s is nullptr!", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +template +static T1 CeilDiv(T1 a, T2 b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +template +typename std::enable_if ::value, T>::type CeilAlign(T x, T align) { + return CeilDiv(x, align) * align; +} + +inline static int64_t CeilLog4(int64_t x) +{ + return static_cast(std::ceil(std::log(x) / std::log(NUM_FOUR))); +} + +inline static int64_t Align(int64_t elementNum, int64_t bytes) +{ + if (bytes == 0) { + return 0; + } + return (elementNum * bytes + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE / bytes; +} + +inline static int64_t AlignBytes(int64_t elementNum, int64_t bytes) +{ + return (elementNum * bytes + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; +} + +inline static int64_t GetPerOrLastValue(int64_t x, int64_t y) +{ + if (y == 0) { + return 0; + } + return x <= y ? x : x % y; +} + +inline static int64_t AlignOneBlockByteCeil(int64_t x) +{ + return x / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; +} + +class MoeInitRountingCustomTilingBase : public TilingBaseClass { +public: + explicit MoeInitRountingCustomTilingBase(gert::TilingContext *context) : TilingBaseClass(context) + { + Reset(); + } + ~MoeInitRountingCustomTilingBase() override = default; + + void Reset(gert::TilingContext *context) override + { + TilingBaseClass::Reset(context); + Reset(); + } + +protected: + bool IsCapable() override + { + return true; + } + ge::graphStatus GetPlatformInfo() override; + ge::graphStatus GetShapeAttrsInfo() override; + ge::graphStatus DoOpTiling() override; + ge::graphStatus DoLibApiTiling() override; + uint64_t GetTilingKey() const override; + ge::graphStatus GetWorkspaceSize() override; + ge::graphStatus PostTiling() override; + void Reset(); + +private: + ge::graphStatus CheckAttr(); + ge::graphStatus CheckOutShape(); + ge::graphStatus CheckInputShape(); + ge::graphStatus CheckDtype(); + void Tiling4GatherOutCompute(); + void Tiling4SortOutCompute(); + void Tiling4VMSMiddleCompute(); + void Tiling4VBSCompute(); + void Tiling4ExpertTokensCountCompute(); + void ShowTilingData(); + void Tinlig4VBSMultiCoreCompute(MoeCustomVBSComputeTilingData *tilingData); + void Tinlig4VBSOneCoreCompute(MoeCustomVBSComputeTilingData *tilingData); + bool IsPerformanceMode_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168() const; + bool IsFullLoad(); + int64_t IsGatherFirstFullLoad(); + void SetGatherTilingData(MoeCustomSrcToDstCapacityComputeTilingData *tilingData, int64_t perCoreRows, + int64_t lastCoreRows, int64_t cols); + void SetGatherTilingDataCols(MoeCustomSrcToDstCapacityComputeTilingData *tilingData, int64_t baseMaxCols, int64_t cols); + void SetGatherTilingDataRows(MoeCustomSrcToDstCapacityComputeTilingData *tilingData, int64_t perCoreRows, + int64_t lastCoreRows, int64_t basePerLoopMaxRows); + void Tiling4SrcToDstDropPadCompute(); + void Tiling4SrcToDstDropPadDynamicCompute(); + void Tiling4SrcToDstCompute(); + PerformanceMode GetPerformanceMode() const; + + int64_t aivNum; + int64_t sortLoopMaxElement = 0; + int64_t mrgSortListMaxElement = 1504; + int64_t totalLength_ = 0; + int64_t n_ = 0; + int64_t k_ = 0; + int64_t cols_ = 0; + int64_t inuptXDtypeSize_; + + int64_t expertStart_ = 0; + int64_t expertEnd_ = 0; + int64_t isInputScale_ = 0; + int64_t isInputOffset_ = 0; + + int64_t sortMode_ = 0; + int64_t rowIdxTytpe_ = 0; + int64_t activeNum_ = -1L; + int64_t expertCapacity_ = -1L; + int64_t expertNum_ = -1L; + int64_t dropPadMode_ = -1L; + int64_t expertTokensNumType_ = -1L; + bool expertTokensNumFlag_ = false; + int64_t quantMode_ = 0; + int64_t rowIdxType_ = -1L; + + bool isFullload_ = false; + int64_t gatherFirstFullload_ = 0; + int64_t ep_ = 0; + int64_t smoothType_ = 0; + + const gert::StorageShape *xShapePtr_ = nullptr; + const gert::StorageShape *expertIdxShapePtr_ = nullptr; + const gert::StorageShape *scaleShapePtr_ = nullptr; + const gert::StorageShape *offsetShapePtr_ = nullptr; + + const int64_t *activeNumPtr_ = nullptr; + const int64_t *expertCapacityPtr_ = nullptr; + const int64_t *expertNumPtr_ = nullptr; + const int64_t *dropPadModePtr_ = nullptr; + const int64_t *expertTokensNumTypePtr_ = nullptr; + const bool *expertTokensNumFlagPtr_ = nullptr; + const int64_t *quantModePtr_ = nullptr; + const gert::ContinuousVector *activeExpertRangeListPtr_; + const int64_t *rowIdxTypePtr_ = nullptr; + + const gert::StorageShape *expandedXShapePtr_ = nullptr; + const gert::StorageShape *expandedRowIdxShapePtr_ = nullptr; + const gert::StorageShape *expertTokensCountOrCumsumShapePtr_ = nullptr; + const gert::StorageShape *expandedScaleShapePtr_ = nullptr; + + const gert::Shape performXShape = gert::Shape({1, 7168}); + const gert::Shape performExpertIdxShape = gert::Shape({1, 8}); + const gert::Shape performScaleShape = gert::Shape({256, 7168}); + + const char *opName = ""; + MoeInitRoutingCustomTilingData moeInitRoutingCustomTilingData; +}; + +void MoeInitRountingCustomTilingBase::Reset() +{ + opName = nullptr; + return; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::GetPlatformInfo() +{ + auto compileInfoPtr = reinterpret_cast(context_->GetCompileInfo()); + CHECK_FAIL(context_, compileInfoPtr == nullptr, "fail to get platform info"); + aivNum = compileInfoPtr->aivNum; + aicoreParams_.blockDim = aivNum; + aicoreParams_.ubSize = compileInfoPtr->ubSize; + moeInitRoutingCustomTilingData.set_coreNum(aivNum); + OPS_LOG_I(context_->GetNodeName(), "---PlatformInfo--- aivNum is: %ld, ubSizePlatForm is: %ld ", aivNum, aicoreParams_.ubSize); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::CheckAttr() +{ + quantMode_ = *quantModePtr_; + moeInitRoutingCustomTilingData.set_quantMode(quantMode_); + OPS_LOG_I(context_->GetNodeName(), "quant_mode is: %ld ", quantMode_); + + dropPadMode_ = *dropPadModePtr_; + moeInitRoutingCustomTilingData.set_dropPadMode(dropPadMode_); + CHECK_FAIL(context_, (dropPadMode_ != DROP_LESS) && (dropPadMode_ != DROP_PAD), + "drop_pad_mode should be %ld or %ld", DROP_LESS, DROP_PAD); + + rowIdxTytpe_ = *rowIdxTypePtr_; + moeInitRoutingCustomTilingData.set_rowIdxType(rowIdxTytpe_); + OPS_LOG_I(context_->GetNodeName(), "row_idx_type is: %ld ", rowIdxTytpe_); + + activeNum_ = *activeNumPtr_; + if (dropPadMode_ == DROP_LESS) { + CHECK_FAIL(context_, activeNum_ < ACTIVE_NUM_MIN_VALUE, + "active_num should be greater than or equal to 0"); + } + + expertNum_ = *expertNumPtr_; + moeInitRoutingCustomTilingData.set_expertNum(expertNum_); + if (expertNum_ <= 0) { + OPS_LOG_E(context_->GetNodeName(), "expert_num should be greater than 0"); + return ge::GRAPH_FAILED; + } + if (activeExpertRangeListPtr_->GetSize() != ATTR_EXPERT_RANGE_DIM && activeExpertRangeListPtr_->GetSize() != 0) { + OPS_LOG_E(context_, "The dim number of expert_range should be %ld or 0(no input)", ATTR_EXPERT_RANGE_DIM); + return ge::GRAPH_FAILED; + } + if (activeExpertRangeListPtr_->GetSize() == 0) { + expertStart_ = 0; + expertEnd_ = expertNum_; + } else { + const int64_t *expertRangeList = reinterpret_cast(activeExpertRangeListPtr_->GetData()); + expertStart_ = expertRangeList[0]; + expertEnd_ = expertRangeList[1]; + } + moeInitRoutingCustomTilingData.set_expertStart(expertStart_); + moeInitRoutingCustomTilingData.set_expertEnd(expertEnd_); + moeInitRoutingCustomTilingData.set_actualExpertNum(expertEnd_ - expertStart_); + OPS_LOG_I(context_, "expert_start is: %ld, expert_end is: %ld, actualExpertNum is: %ld ", expertStart_, expertEnd_, + expertEnd_ - expertStart_); + + n_ = xShapePtr_->GetStorageShape().GetDim(0); + expertCapacity_ = *expertCapacityPtr_; + moeInitRoutingCustomTilingData.set_expertCapacity(expertCapacity_); + if (dropPadMode_ == DROP_PAD) { + CHECK_FAIL(context_, expertCapacity_ <= EXPERT_CAPACITY_MIN_VALUE || expertCapacity_ > n_, + "expert_Capacity should be greater than 0 and less than %ld", n_); + CHECK_FAIL(context_, rowIdxTytpe_ == SCATTER, "rowIdxTytpe should be 0 when droppadmode is 1"); + CHECK_FAIL(context_, expertStart_ != 0 || expertEnd_ != expertNum_, + "expert_range should be [0, %ld] when droppadmode is 1", expertNum_); + } + + expertTokensNumType_ = *expertTokensNumTypePtr_; + moeInitRoutingCustomTilingData.set_expertTokensNumType(expertTokensNumType_); + CHECK_FAIL(context_, (expertTokensNumType_ != COUNT) && (expertTokensNumType_ != KEY_VALUE) && (expertTokensNumType_ != CUMSUM), + "expert_tokens_num_type currently not support %ld", expertTokensNumType_); + + expertTokensNumFlag_ = *expertTokensNumFlagPtr_; + if (dropPadMode_ == DROP_PAD && expertTokensNumFlag_) { + CHECK_FAIL(context_, expertTokensNumType_ != COUNT, "In DROP_PAD mode and expert_tokens_num_flag is true, expert_tokens_num_type only supports COUNT, but got %ld", expertTokensNumType_);} + if (expertTokensNumFlag_) { + moeInitRoutingCustomTilingData.set_expertTokensNumFlag(1); + } else { + moeInitRoutingCustomTilingData.set_expertTokensNumFlag(0); + } + + CHECK_FAIL(context_, expertStart_ < 0, "expert_start should be greater than or equal to 0"); + CHECK_FAIL(context_, expertStart_ >= expertEnd_, "expert_start should be less than expert_end"); + CHECK_FAIL(context_, expertEnd_ > expertNum_, "expert_end should be less than or equal to %ld", expertNum_); + if (expertTokensNumType_ == KEY_VALUE) { + CHECK_FAIL(context_, expertEnd_ > KV_MODE_EXPERT_IDX_MAX, "expert_end should be less than or equal to %ld in KEY_VALUE mode", + KV_MODE_EXPERT_IDX_MAX); + } else { + CHECK_FAIL(context_, expertEnd_ > EXPERT_IDX_MAX, "expert_end should be less than or equal to %ld", EXPERT_IDX_MAX); + } + CHECK_FAIL(context_, quantMode_ != UN_QUANT && quantMode_ != DYNAMIC_QUANT && quantMode_ != STATIC_QUANT, "quant_mode currently support %ld, %ld or %ld", UN_QUANT, DYNAMIC_QUANT, STATIC_QUANT); + CHECK_FAIL(context_, rowIdxTytpe_ != SCATTER && rowIdxTytpe_ != GATHER, "row_idx_type currently support %ld or %ld", SCATTER, GATHER); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::CheckInputShape() +{ + const gert::Shape xShape = xShapePtr_->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "input x shape: %s ", ops::Shape2String(xShape).c_str()); + const gert::Shape expertIdxShape = expertIdxShapePtr_->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "input expert_idx shape: %s.", ops::Shape2String(expertIdxShape).c_str()); + + // 参数校验 + CHECK_FAIL(context_, xShape.GetDimNum() != DIM_TWO, "The dim number of x should be %lu.", DIM_TWO); + CHECK_FAIL(context_, expertIdxShape.GetDimNum() != DIM_TWO, "The dim number of expert_idx should be %lu.", DIM_TWO); + CHECK_FAIL(context_, xShape.GetDim(0) != expertIdxShape.GetDim(0), context_->GetNodeName(), "Input rows should be same."); + + n_ = expertIdxShape.GetDim(0); + k_ = expertIdxShape.GetDim(1); + cols_ = xShape.GetDim(1); + moeInitRoutingCustomTilingData.set_n(n_); + moeInitRoutingCustomTilingData.set_k(k_); + moeInitRoutingCustomTilingData.set_cols(cols_); + totalLength_ = n_ * k_; + if (activeNum_ == 0 || activeNum_ == ACTIVE_NUM_MIN_VALUE) { + activeNum_ = totalLength_; + } else { + activeNum_ = std::min(activeNum_, totalLength_); + } + moeInitRoutingCustomTilingData.set_activeNum(activeNum_); + + inuptXDtypeSize_ = + static_cast(ge::GetSizeByDataType(context_->GetInputDesc(INPUT_X_INDEX)->GetDataType())); + OPS_LOG_I(context_->GetNodeName(), "Input x dtype size is: %ld. ", inuptXDtypeSize_); + + if (quantMode_ == UN_QUANT && scaleShapePtr_ != nullptr) { + auto scaleShape = scaleShapePtr_->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "input scale shape: %s", ops::Shape2String(scaleShape).c_str()); + auto scaleDimNum = static_cast(scaleShape.GetDimNum()); + CHECK_FAIL(context_, + scaleDimNum != 1, + context_->GetNodeName(), "The dim number of scale should be 1, current is %ld", scaleDimNum); + auto scaleDim0 = static_cast(scaleShape.GetDim(0)); + CHECK_FAIL(context_, + scaleDim0 != n_, + "The first dim of scale should be n_, current is %ld", scaleDim0); + } + + if (quantMode_ == STATIC_QUANT) { + CHECK_FAIL(context_, scaleShapePtr_ == nullptr, "scale is null"); + CHECK_FAIL(context_, offsetShapePtr_ == nullptr, "offset is null"); + auto scaleShape = scaleShapePtr_->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "input scale shape: %s", ops::Shape2String(scaleShape).c_str()); + auto scaleDimNum = static_cast(scaleShape.GetDimNum()); + CHECK_FAIL(context_, + scaleDimNum != 1, + "The dim number of scale should be 1, current is %ld", scaleDimNum); + auto scaleDim0 = static_cast(scaleShape.GetDim(0)); + CHECK_FAIL(context_, + scaleDim0 != 1, + "The first dim of scale should be 1, current is %ld", scaleDim0); + auto offsetShape = offsetShapePtr_->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "input offset shape: %s", ops::Shape2String(offsetShape).c_str()); + auto offsetDimNum = static_cast(offsetShape.GetDimNum()); + CHECK_FAIL(context_, + offsetDimNum != 1, + "The dim number of offset should be 1, current is %ld", offsetDimNum); + auto offsetDim0 = static_cast(offsetShape.GetDim(0)); + CHECK_FAIL(context_, + offsetDim0 != 1, + "The first dim of offset should be 1, current is %ld", offsetDim0); + } + + if (quantMode_ == DYNAMIC_QUANT && scaleShapePtr_ != nullptr) { + auto scaleShape = scaleShapePtr_->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "input scale shape: %s", ops::Shape2String(scaleShape).c_str()); + auto scaleDimNum = static_cast(scaleShape.GetDimNum()); + CHECK_FAIL(context_, + scaleDimNum != NUM_TWO, + "The dim number of scale should be 2, current is %ld", scaleDimNum); + auto scaleDim0 = static_cast(scaleShape.GetDim(0)); + CHECK_FAIL(context_, + scaleDim0 != (expertEnd_ - expertStart_) && scaleDim0 != 1, + "The first dim of scale should be %ld or 1, current is %ld", (expertEnd_ - expertStart_), scaleDim0); + auto scaleDim1 = static_cast(scaleShape.GetDim(1)); + CHECK_FAIL(context_, + scaleDim1 != cols_, + "The second dim of scale should be %ld, current is %ld", cols_, scaleDim0); + if (scaleDim0 == 1) { + smoothType_ = scale1H; + } else { + smoothType_ = scaleEH; + } + moeInitRoutingCustomTilingData.set_smoothType(smoothType_); + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::CheckOutShape() +{ + const gert::Shape expandedXShape = context_->GetOutputShape(0)->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "expanded_x shape: %s.", ops::Shape2String(expandedXShape).c_str()); + const gert::Shape expandedRowIdxShape = context_->GetOutputShape(1)->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "expanded_row_idx shape: %s.", ops::Shape2String(expandedRowIdxShape).c_str()); + const gert::Shape expertTokensCountOrCumsumShape = context_->GetOutputShape(NUM_TWO)->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "expert_tokens_count_or_cumsum shape: %s.", ops::Shape2String(expertTokensCountOrCumsumShape).c_str()); + + size_t expandedXDimNum = expandedXShape.GetDimNum(); + if (dropPadMode_ > 0) { + CHECK_FAIL(context_, expandedXDimNum != NUM_THREE, "The dim number of expandedX should be 3."); + CHECK_FAIL(context_, expandedXShape.GetDim(0) != expertNum_, "The first dim of expandedX should be %ld.", expertNum_); + CHECK_FAIL(context_, expandedXShape.GetDim(1) != expertCapacity_, "The second dim of expandedX should be %ld.", + expertCapacity_); + CHECK_FAIL(context_, + expandedXShape.GetDim(NUM_TWO) != cols_, + "The third dim of expandedX should be %ld.", cols_); + } else { + CHECK_FAIL(context_,expandedXDimNum != DIM_TWO, "The dim number of expandedX should be 2."); + int64_t firstDim = totalLength_; + firstDim = activeNum_ == 0 ? firstDim : std::min(firstDim, activeNum_); + CHECK_FAIL(context_, expandedXShape.GetDim(0) != firstDim, "The first dim of expandedX should be %ld.", firstDim); + CHECK_FAIL(context_, expandedXShape.GetDim(1) != cols_, + "The second dim of expandedX should be %ld.", cols_); + } + + CHECK_FAIL(context_, expandedRowIdxShape.GetDimNum() != DIM_ONE, + "The dim number of expanded_row_idx should be 1."); + CHECK_FAIL(context_, + expandedRowIdxShape.GetDim(0) != totalLength_, + "The first dim of expanded_row_idx and expanded_expert_idx should be %ld.", totalLength_); + + if(expertTokensNumFlag_){ + if (expertTokensNumType_ == KEY_VALUE) { + CHECK_FAIL(context_, + expertTokensCountOrCumsumShape.GetDimNum() != DIM_TWO, + "The dim number of expert_tokens_count_or_cumsum should be 2 when in KEY_VALUE mode."); + CHECK_FAIL(context_, expertTokensCountOrCumsumShape.GetDim(0) != expertNum_, + "The first dim of expert_tokens_count_or_cumsum should be %ld.", expertNum_); + CHECK_FAIL(context_, expertTokensCountOrCumsumShape.GetDim(1) != KEY_VALUE_MODE_DIM0_NUM, + "The second dim of expert_tokens_count_or_cumsum should be %ld.", + KEY_VALUE_MODE_DIM0_NUM); + } else { + CHECK_FAIL(context_, expertTokensCountOrCumsumShape.GetDimNum() != DIM_ONE, + "The dim number of expert_tokens_count_or_cumsum should be 1 when not in KEY_VALUE mode."); + CHECK_FAIL(context_, expertTokensCountOrCumsumShape.GetDim(0) != (expertEnd_ - expertStart_), + "The first dim of expert_tokens_count_or_cumsum should be %ld.", (expertEnd_ - expertStart_)); + } + } + + if (quantMode_ != STATIC_QUANT && scaleShapePtr_ != nullptr) { + const gert::Shape expandedScaleShape = context_->GetOutputShape(3)->GetStorageShape(); + OPS_LOG_I(context_->GetNodeName(), "expanded_scale shape: %s.", ops::Shape2String(expandedScaleShape).c_str()); + size_t expandedScaleDimNum = expandedScaleShape.GetDimNum(); + CHECK_FAIL(context_, expandedScaleDimNum != DIM_ONE, "The dim number of expanded_scale should be 1."); + if (dropPadMode_ > 0) { + CHECK_FAIL(context_, expandedScaleShape.GetDim(0) != expertNum_ * expertCapacity_, + "The first dim of expanded_scale should be %ld.", expertNum_ * expertCapacity_); + } else { + int64_t firstDim = totalLength_; + firstDim = activeNum_ == 0 ? firstDim : std::min(firstDim, activeNum_); + CHECK_FAIL(context_, expandedScaleShape.GetDim(0) != firstDim, + "The first dim of expanded_scale should be %ld.", firstDim); + } + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::GetShapeAttrsInfo() +{ + OPS_LOG_I(context_->GetNodeName(), "TilingContext: %s.", context_->GetNodeName()); + + xShapePtr_ = context_->GetInputShape(INPUT_X_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr_); + + expertIdxShapePtr_ = context_->GetInputShape(INPUT_EXPERT_IDX_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxShapePtr_); + + scaleShapePtr_ = context_->GetOptionalInputShape(INPUT_SCALE_INDEX); + if (scaleShapePtr_ == nullptr) { + OPS_LOG_I(context_->GetNodeName(), "optional input scale is null"); + } else { + isInputScale_ = 1; + } + moeInitRoutingCustomTilingData.set_isInputScale(isInputScale_); + + offsetShapePtr_ = context_->GetOptionalInputShape(INPUT_OFFSET_INDEX); + if (offsetShapePtr_ == nullptr) { + OPS_LOG_I(context_->GetNodeName(), "optional input offset is null"); + } else { + isInputOffset_ = 1; + } + moeInitRoutingCustomTilingData.set_isInputOffset(isInputOffset_); + + expandedXShapePtr_ = context_->GetOutputShape(OUTPUT_EXPANDED_X_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expandedXShapePtr_); + expandedRowIdxShapePtr_ = context_->GetOutputShape(OUTPUT_EXPANDED_ROW_IDX_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expandedRowIdxShapePtr_); + expertTokensCountOrCumsumShapePtr_ = context_->GetOutputShape(OUTPUT_EXPERT_TOKENS_COUNT_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertTokensCountOrCumsumShapePtr_); + expandedScaleShapePtr_ = context_->GetOutputShape(OUTPUT_EXPANDED_SCALE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expandedScaleShapePtr_); + + auto attrs = context_->GetAttrs(); + OP_CHECK_NULL_WITH_CONTEXT(context_, attrs); + activeNumPtr_ = attrs->GetAttrPointer(ATTR_ACTIVE_NUM_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, activeNumPtr_); + expertCapacityPtr_ = attrs->GetAttrPointer(ATTR_EXPERT_CAPACITY_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertCapacityPtr_); + expertNumPtr_ = attrs->GetAttrPointer(ATTR_EXPERT_NUM_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertNumPtr_); + dropPadModePtr_ = attrs->GetAttrPointer(ATTR_DROP_PAD_MODE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, dropPadModePtr_); + expertTokensNumTypePtr_ = attrs->GetAttrPointer(ATTR_EXPERT_TOKEN_NUM_TYPE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertTokensNumTypePtr_); + expertTokensNumFlagPtr_ = attrs->GetAttrPointer(ATTR_EXPERT_TOKEN_NUM_FLAG_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertTokensNumFlagPtr_); + quantModePtr_ = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, quantModePtr_); + activeExpertRangeListPtr_ = attrs->GetAttrPointer(ATTR_EXPERT_RANGE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, activeExpertRangeListPtr_); + rowIdxTypePtr_ = attrs->GetAttrPointer(ATTR_ROW_IDX_TYPE_INDEX); + OP_CHECK_NULL_WITH_CONTEXT(context_, rowIdxTypePtr_); + return ge::GRAPH_SUCCESS; +} + +void MoeInitRountingCustomTilingBase::ShowTilingData() +{ + int64_t isFullloadInt = 1 ? isFullload_ == true : 0; + OPS_LOG_I(context_->GetNodeName(), "isFullload: %ld, gatherFirstFullload: %ld, ep: %ld", isFullloadInt, gatherFirstFullload_, ep_); +} + +int64_t MoeInitRountingCustomTilingBase::IsGatherFirstFullLoad() { + if (ep_ == 0) { + return 0; + } else if (n_ >= gatherFirstN && (expertEnd_-expertStart_) * gatherFirstScale <= expertNum_) { + return 1; + } + return 0; +} + +bool MoeInitRountingCustomTilingBase::IsFullLoad() { + int64_t perCoreTokens = 1; + if (expertStart_ == 0 && expertEnd_ == expertNum_) { + ep_ = 0; + if (quantMode_ != 1) { + perCoreTokens = n_ / aivNum; + int64_t remainder = n_ % aivNum; + // NUM_TWO is Max xRows need add 2 becauseof the left and right row may be another row. + perCoreTokens = remainder <= 1 ? perCoreTokens + 1 : perCoreTokens + NUM_TWO; + } + } else { + ep_ = 1; + perCoreTokens = 1; + } + moeInitRoutingCustomTilingData.set_ep(ep_); + + if (totalLength_ > sortLoopMaxElement || this->dropPadMode_ == 1) { + return false; + } + + gatherFirstFullload_ = IsGatherFirstFullLoad(); + moeInitRoutingCustomTilingData.set_gatherFirstFullload(gatherFirstFullload_); + int64_t tileLength = Align(this->totalLength_, int64_t(sizeof(int32_t))); + int64_t sortNum = CeilDiv(tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + + int64_t sortSpace = sortNum * sizeof(int32_t) * ONE_CORE_SORT_BUFFER; + int64_t rowIdxSpace = sortNum * sizeof(int32_t) * NUM_THREE; + int64_t expertSpace = CeilDiv(this->expertNum_ * int64_t(sizeof(int64_t)), ONE_BLOCK_BYTE) * ONE_BLOCK_BYTE * NUM_TWO; + int64_t gatherSpace = CeilDiv(cols_ * inuptXDtypeSize_, ONE_BLOCK_BYTE) * ONE_BLOCK_BYTE * perCoreTokens; + int64_t remainUb = aicoreParams_.ubSize - sortSpace - rowIdxSpace - expertSpace - LENGTH_1024; + + if (quantMode_ == -1) { + remainUb -= (gatherSpace + ONE_BLOCK_BYTE); + } else if (quantMode_ == 0) { + int64_t quantSpace = 0; + int64_t xAlignedCount = Align(this->cols_, int64_t(sizeof(int8_t))); + quantSpace = xAlignedCount * STATIC_QUANT_FULLLOAD_COLS_BUFFER * perCoreTokens; + remainUb -= (gatherSpace + quantSpace); + } else { + int64_t quantSpace = CeilDiv(cols_, ONE_BLOCK_BYTE) * ONE_BLOCK_BYTE * DYNAMIC_QUANT_FULLLOAD_COLS_BUFFER; + int64_t scaleOutSpace = ONE_BLOCK_BYTE * NUM_TWO; + remainUb -= (quantSpace + scaleOutSpace); + } + return remainUb > 0; +} + +bool MoeInitRountingCustomTilingBase::IsPerformanceMode_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168() const +{ + OPS_LOG_I(context_->GetNodeName(), "Begin IsPerformanceMode_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168() ..."); + bool result = false; + + // expert_range [0,256), quant_mode=DYNAMIC_QUANT + const gert::Shape performXShape_X_1_7168 = gert::Shape({1, 7168}); + const gert::Shape performExpertIdxShape_X_1_7168 = gert::Shape({1, 8}); + const gert::Shape performScaleShape_X_1_7168 = gert::Shape({256, 7168}); + + OP_CHECK_NULL_WITH_CONTEXT(context_, xShapePtr_); + OP_CHECK_NULL_WITH_CONTEXT(context_, expertIdxShapePtr_); + if (nullptr == scaleShapePtr_) { + result = false; + } else if (xShapePtr_->GetStorageShape() == performXShape_X_1_7168 && + expertIdxShapePtr_->GetStorageShape() == performExpertIdxShape_X_1_7168 && + scaleShapePtr_->GetStorageShape() == performScaleShape_X_1_7168 && offsetShapePtr_ == nullptr && + context_->GetInputDesc(INPUT_X_INDEX)->GetDataType() == ge::DT_BF16 && expertStart_ == 0 && + expertEnd_ == ASSIST_NUM && quantMode_ == DYNAMIC_QUANT && expertTokensNumType_ == KEY_VALUE) { + result = true; + } + OPS_LOG_I(context_->GetNodeName(), "End IsPerformanceMode_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168() ..."); + return result; +} + +PerformanceMode MoeInitRountingCustomTilingBase::GetPerformanceMode() const +{ + PerformanceMode result = PerformanceMode::COMMON; + if (expertNum_ != ASSIST_NUM || (expertEnd_ - expertStart_) > PERFORMANCE_MODE_RANGE_MAX || + n_ < PERFORMANCE_MODE_BS_MIN || n_ > PERFORMANCE_MODE_BS_MAX || k_ != PERFORMANCE_MODE_TOP_K) { + return result; + } + + // Judge performance mode according to totalLength_ + if (totalLength_ < PERFORMANCE_MODE_MAX_ONE_CORE_GATHER) { + OPS_LOG_I(context_->GetNodeName(), "totalLength_: %ld, PerformanceMode::ONE_CORE_GATHER_SORT", totalLength_); + result = PerformanceMode::ONE_CORE_GATHER_SORT; + } else if (totalLength_ <= PERFORMANCE_MODE_MAX_BATCH_SIZE_TOP_K) { + OPS_LOG_I(context_->GetNodeName(), "totalLength_: %ld, PerformanceMode::MULTI_CORE_GATHER_SORT", totalLength_); + result = PerformanceMode::MULTI_CORE_GATHER_SORT; + } + return result; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::CheckDtype() +{ + auto inputXDtype_ = context_->GetInputDesc(INPUT_X_INDEX)->GetDataType(); + CHECK_FAIL(context_, inputXDtype_ != ge::DT_INT8 && inputXDtype_ != ge::DT_FLOAT16 && inputXDtype_ != ge::DT_BF16 && inputXDtype_ != ge::DT_FLOAT, + "The data type of input_X should be INT8, FLOAT16, BF16, FLOAT."); + CHECK_FAIL(context_, inputXDtype_ == ge::DT_INT8 && quantMode_ != UN_QUANT, + "When input_X is INT8, quantization is not supported."); + + auto expertIdxDtype_ = context_->GetInputDesc(INPUT_EXPERT_IDX_INDEX)->GetDataType(); + CHECK_FAIL(context_, expertIdxDtype_ != ge::DT_INT32, + "The data type of input_expertIdx should be INT32."); + + if (quantMode_ == STATIC_QUANT) { + auto scaleDtype_ = context_->GetOptionalInputDesc(INPUT_SCALE_INDEX)->GetDataType(); + CHECK_FAIL(context_, scaleDtype_ != ge::DT_FLOAT, + "The data type of input_scale should be FLOAT."); + + auto offsetDtype_ = context_->GetOptionalInputDesc(INPUT_OFFSET_INDEX)->GetDataType(); + CHECK_FAIL(context_, offsetDtype_ != ge::DT_FLOAT, + "The data type of input_offset should be FLOAT."); + } else { + if (scaleShapePtr_ != nullptr) { + auto scaleDtype_ = context_->GetOptionalInputDesc(INPUT_SCALE_INDEX)->GetDataType(); + CHECK_FAIL(context_, scaleDtype_ != ge::DT_FLOAT, + "The data type of input_scale should be FLOAT."); + } + } + + auto expandedXDtype_ = context_->GetOutputDesc(OUTPUT_EXPANDED_X_INDEX)->GetDataType(); + CHECK_FAIL(context_,expandedXDtype_ != ge::DT_INT8 && expandedXDtype_ != ge::DT_FLOAT16 && expandedXDtype_ != ge::DT_BF16 && expandedXDtype_ != ge::DT_FLOAT, + "The data type of output_expanded_X should be INT8, FLOAT16, BF16, FLOAT."); + + auto expandedRowIdxDtype_ = context_->GetOutputDesc(OUTPUT_EXPANDED_ROW_IDX_INDEX)->GetDataType(); + CHECK_FAIL(context_,expandedRowIdxDtype_ != ge::DT_INT32, + "The data type of output_expanded_row_idx should be INT32."); + + auto expertTokensCountOrCusumDtype_ = context_->GetOutputDesc(OUTPUT_EXPERT_TOKENS_COUNT_INDEX)->GetDataType(); + CHECK_FAIL(context_,expertTokensCountOrCusumDtype_ != ge::DT_INT64, + "The data type of output_expert_tokens_count_or_cumsum should be INT64."); + + if (quantMode_ == DYNAMIC_QUANT || (quantMode_ == UN_QUANT && scaleShapePtr_ != nullptr)) { + auto expandedScaleDtype_ = context_->GetOutputDesc(OUTPUT_EXPANDED_SCALE_INDEX)->GetDataType(); + CHECK_FAIL(context_,expandedScaleDtype_ != ge::DT_FLOAT, + "The data type of input_expanded_scale should be FLOAT."); + } + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::DoOpTiling() +{ + auto ret = CheckAttr(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckInputShape(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckOutShape(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + ret = CheckDtype(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + if (IsPerformanceMode_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168()) { + aivNum = totalLength_; + } + + sortLoopMaxElement = (aicoreParams_.ubSize - aivNum * ONE_BLOCK_BYTE) / (NUM_FOUR * NUM_TWO * NUM_FOUR) / + SORT32_ALIGN_ELEMENT * SORT32_ALIGN_ELEMENT; + + Tiling4VBSCompute(); + Tiling4VMSMiddleCompute(); + Tiling4SortOutCompute(); + Tiling4ExpertTokensCountCompute(); + Tiling4SrcToDstCompute(); + Tiling4SrcToDstDropPadCompute(); + Tiling4GatherOutCompute(); + isFullload_ = IsFullLoad(); + ShowTilingData(); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::DoLibApiTiling() +{ + return ge::GRAPH_SUCCESS; +} + +uint64_t MoeInitRountingCustomTilingBase::GetTilingKey() const +{ + if (isFullload_) { + if (quantMode_ == UN_QUANT) { + return UNQUANTIZED_FULLLOAD_TILINGKEY; + } else if (quantMode_ == STATIC_QUANT) { + return STATIC_QUANT_FULLLOAD_TILINGKEY; + } else { + return (DYNAMIC_QUANT_FULLLOAD_TILINGKEY + ep_ * DYNAMIC_QUANT_EPFULLLOAD_TILINGKEY + + smoothType_ * DYNAMIC_QUANT_SMOOTHTYPE_FULLLOAD_TILINGKEY); + } + } + else if (IsPerformanceMode_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168()) { + return PERFORMANCE_TILINGKEY_X_1_7168_EXPERT_IDX_1_8_SCALE_256_7168; + } else if (PerformanceMode::ONE_CORE_GATHER_SORT == GetPerformanceMode() && quantMode_ == UN_QUANT && + rowIdxTytpe_ == SCATTER && expertTokensNumType_ == COUNT) { + uint64_t sortMode = NUM_TWO; + return static_cast(TILINGKEY_BASE + sortMode * SORT_CORE_TILINGKEY_BASE + + static_cast(quantMode_ + 1) * QUANT_MODE_TILINGKEY_BASE + + static_cast(rowIdxTytpe_) * ROWIDX_TYPE_TILINGKEY_BASE + + static_cast(dropPadMode_) * DROP_MODE_TILINGKEY_BASE); + } else if (PerformanceMode::MULTI_CORE_GATHER_SORT == GetPerformanceMode() && quantMode_ == UN_QUANT && + rowIdxTytpe_ == SCATTER && expertTokensNumType_ == COUNT) { + uint64_t sortMode = 3; + return static_cast(TILINGKEY_BASE + sortMode * SORT_CORE_TILINGKEY_BASE + + static_cast(quantMode_ + 1) * QUANT_MODE_TILINGKEY_BASE + + static_cast(rowIdxTytpe_) * ROWIDX_TYPE_TILINGKEY_BASE + + static_cast(dropPadMode_) * DROP_MODE_TILINGKEY_BASE); + } + return static_cast(TILINGKEY_BASE + static_cast(sortMode_) * SORT_CORE_TILINGKEY_BASE + + static_cast(quantMode_ + 1) * QUANT_MODE_TILINGKEY_BASE + + static_cast(rowIdxTytpe_) * ROWIDX_TYPE_TILINGKEY_BASE + + static_cast(dropPadMode_) * DROP_MODE_TILINGKEY_BASE); +} + +ge::graphStatus MoeInitRountingCustomTilingBase::GetWorkspaceSize() +{ + size_t sortWorkspaceSize = + sizeof(float) * static_cast(totalLength_ * NUM_TWO * NUM_THREE); + size_t coreSyncWorkspaceSize = + moeInitRoutingCustomTilingData.get_coreNum() * SORT32_ALIGN_ELEMENT * NUM_TWO; + size_t scatterWorkspaceSize = sizeof(int32_t) * static_cast(totalLength_); + size_t expertIdxValueWorkspaceSize = sizeof(int32_t) * static_cast(aivNum) * 2U; + size_t expertTokensCountWorkspaceSize = sizeof(int32_t) * static_cast((expertEnd_ - expertStart_)); + int64_t expertTokenTotalCountWorkspace = AlignBytes(1, static_cast(sizeof(int32_t))); + int64_t quantTempWorkspaceSize = aivNum * cols_ * static_cast(sizeof(float)); + workspaceSize_ = sortWorkspaceSize + coreSyncWorkspaceSize + scatterWorkspaceSize + expertTokensCountWorkspaceSize + + expertTokenTotalCountWorkspace + SIZE_16 * LENGTH_1024 * LENGTH_1024; + if (quantMode_ == DYNAMIC_QUANT) { + workspaceSize_ += quantTempWorkspaceSize; + } + if (dropPadMode_ == DROP_PAD) { + workspaceSize_ += expertIdxValueWorkspaceSize; + } + OPS_LOG_I(context_->GetNodeName(), "Allocate workspaceSize is: %ld.", workspaceSize_); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus MoeInitRountingCustomTilingBase::PostTiling() +{ + context_->SetBlockDim(aivNum); + size_t *currentWorkspace = context_->GetWorkspaceSizes(1); + currentWorkspace[0] = workspaceSize_; + moeInitRoutingCustomTilingData.SaveToBuffer(context_->GetRawTilingData()->GetData(), + context_->GetRawTilingData()->GetCapacity()); + context_->GetRawTilingData()->SetDataSize(moeInitRoutingCustomTilingData.GetDataSize()); + return ge::GRAPH_SUCCESS; +} +void MoeInitRountingCustomTilingBase::Tinlig4VBSOneCoreCompute(MoeCustomVBSComputeTilingData *tilingData) +{ + tilingData->set_needCoreNum(1); + tilingData->set_perCoreElements(totalLength_); + tilingData->set_perCoreLoops(1); + tilingData->set_perCorePerLoopElements(tilingData->get_perCoreElements()); + tilingData->set_perCoreLastLoopElements(tilingData->get_perCoreElements()); + tilingData->set_lastCoreElements(tilingData->get_perCoreElements()); + tilingData->set_lastCoreLoops(1); + tilingData->set_lastCorePerLoopElements(tilingData->get_perCoreElements()); + tilingData->set_lastCoreLastLoopElements(tilingData->get_perCoreElements()); +} + +void MoeInitRountingCustomTilingBase::Tinlig4VBSMultiCoreCompute(MoeCustomVBSComputeTilingData *tilingData) +{ + int64_t needCoreNum = CeilDiv(totalLength_, sortLoopMaxElement); + needCoreNum = static_cast(std::pow(NUM_FOUR, CeilLog4(needCoreNum))); + needCoreNum = std::min(needCoreNum, aivNum); + + if (needCoreNum == 0) { + OPS_LOG_E(context_->GetNodeName(), "Variale needCoreNum cannot be 0."); + return; + } + int64_t perCoreElements = (needCoreNum == 0) ? 0 : (totalLength_ / needCoreNum); + int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT; + int64_t lastCoreElement = totalLength_ - (needCoreNum - 1) * alineFloorPerCoreElements; + int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT; + if (lastCoreElement > alineCeilPerCoreElements) { + perCoreElements = alineCeilPerCoreElements; + needCoreNum = CeilDiv(totalLength_, perCoreElements); + } else { + perCoreElements = alineFloorPerCoreElements; + } + + tilingData->set_needCoreNum(needCoreNum); + do { + tilingData->set_perCoreElements(perCoreElements); + tilingData->set_perCoreLoops( + CeilDiv(tilingData->get_perCoreElements(), sortLoopMaxElement)); + tilingData->set_perCorePerLoopElements(std::min(tilingData->get_perCoreElements(), sortLoopMaxElement)); + + tilingData->set_perCoreLastLoopElements(tilingData->get_perCoreElements() - + (tilingData->get_perCoreLoops() - 1) * + tilingData->get_perCorePerLoopElements()); + + tilingData->set_lastCoreElements(totalLength_ - + (tilingData->get_needCoreNum() - 1) * tilingData->get_perCoreElements()); + tilingData->set_lastCoreLoops(tilingData->get_perCoreLoops()); + int64_t lastCorePerLoopElements = + CeilDiv(CeilDiv(tilingData->get_lastCoreElements(), tilingData->get_lastCoreLoops()), + SORT32_ALIGN_ELEMENT) * + SORT32_ALIGN_ELEMENT; + tilingData->set_lastCorePerLoopElements(lastCorePerLoopElements); + tilingData->set_lastCoreLastLoopElements(tilingData->get_lastCoreElements() - + (tilingData->get_lastCoreLoops() - 1) * + tilingData->get_lastCorePerLoopElements()); + perCoreElements -= SORT32_ALIGN_ELEMENT; + } while (tilingData->get_lastCoreLastLoopElements() <= 0 && perCoreElements > 0); + if (tilingData->get_lastCoreLastLoopElements() <= 0) { + OPS_LOG_E(context_->GetNodeName(), "vbs tiling failed"); + return; + } +} + +void MoeInitRountingCustomTilingBase::Tiling4VBSCompute() +{ + if (totalLength_ <= sortLoopMaxElement) { + sortMode_ = 0; + } else { + sortMode_ = 1; + } + + auto tilingData = &moeInitRoutingCustomTilingData.vbsComputeParamsOp; + tilingData->set_oneLoopMaxElements(sortLoopMaxElement); + if (sortMode_ == 0UL) { + Tinlig4VBSOneCoreCompute(tilingData); + return; + } + Tinlig4VBSMultiCoreCompute(tilingData); +} + +void MoeInitRountingCustomTilingBase::Tiling4VMSMiddleCompute() +{ + auto vbsComputeTilingData = &moeInitRoutingCustomTilingData.vbsComputeParamsOp; + auto tilingData = &moeInitRoutingCustomTilingData.vmsMiddleComputeParamsOp; + if (vbsComputeTilingData->get_needCoreNum() <= MRG_LIST_NUM) { + tilingData->set_needCoreNum(0); + return; + } + int64_t needCoreNum = CeilDiv(vbsComputeTilingData->get_needCoreNum(), MRG_LIST_NUM); + tilingData->set_needCoreNum(needCoreNum); +} + +void MoeInitRountingCustomTilingBase::Tiling4SortOutCompute() +{ + auto tilingData = &moeInitRoutingCustomTilingData.sortOutComputeParamsOp; + tilingData->set_oneLoopMaxElements(mrgSortListMaxElement); +} + +void MoeInitRountingCustomTilingBase::Tiling4ExpertTokensCountCompute() +{ + auto tilingData = &moeInitRoutingCustomTilingData.expertTokensCountTilingDataOp; + int64_t totalElements = moeInitRoutingCustomTilingData.get_n() * moeInitRoutingCustomTilingData.get_k(); + int64_t perCoreElements = CeilDiv(totalElements, aivNum); + int64_t needCoreNum = CeilDiv(totalElements, perCoreElements); + int64_t lastCoreElements = totalElements - (needCoreNum - 1) * perCoreElements; + tilingData->set_needCoreNum(needCoreNum); + tilingData->set_perCoreElements(perCoreElements); + tilingData->set_lastCoreElements(lastCoreElements); + + int64_t expertNumElement = (moeInitRoutingCustomTilingData.get_expertTokensNumType() != KEY_VALUE) ? + moeInitRoutingCustomTilingData.get_actualExpertNum() : + (moeInitRoutingCustomTilingData.get_actualExpertNum() + 1) * DIM_TWO; + + int64_t maxElementsPerLoop = + (static_cast(aicoreParams_.ubSize) - + CeilAlign(expertNumElement, ONE_BLOCK_BYTE) * + (static_cast(sizeof(int32_t)) * NUM_TWO + static_cast(sizeof(int64_t))) - + ONE_BLOCK_BYTE) / static_cast(sizeof(int32_t)); + int64_t perCoreLoops = CeilDiv(perCoreElements, maxElementsPerLoop); + int64_t perCorePerLoopElements = CeilDiv(perCoreElements, perCoreLoops); + int64_t perCoreLastLoopElements = perCoreElements - (perCoreLoops - 1) * perCorePerLoopElements; + + tilingData->set_perCoreLoops(perCoreLoops); + tilingData->set_perCorePerLoopElements(perCorePerLoopElements); + tilingData->set_perCoreLastLoopElements(perCoreLastLoopElements); + + int64_t lastCoreLoops = CeilDiv(lastCoreElements, maxElementsPerLoop); + int64_t lastCorePerLoopElements = CeilDiv(lastCoreElements, lastCoreLoops); + int64_t lastCoreLastLoopElements = lastCoreElements - (lastCoreLoops - 1) * lastCorePerLoopElements; + + tilingData->set_lastCoreLoops(lastCoreLoops); + tilingData->set_lastCorePerLoopElements(lastCorePerLoopElements); + tilingData->set_lastCoreLastLoopElements(lastCoreLastLoopElements); + + OPS_LOG_I(context_->GetNodeName(), + "ExpertTokensCountCompute Tilingdata, needCoreNum is: %ld, perCoreElements is: %ld, lastCoreElements is: " + "%ld, maxElementsPerLoop is: %ld, perCoreLoops is: %ld, perCorePerLoopElements is: %ld, " + "perCoreLastLoopElements " + "is: %ld, lastCoreLoops is: %ld, lastCorePerLoopElements is: %ld, lastCoreLastLoopElements is: %ld.", + needCoreNum, perCoreElements, lastCoreElements, maxElementsPerLoop, perCoreLoops, perCorePerLoopElements, + perCoreLastLoopElements, lastCoreLoops, lastCorePerLoopElements, lastCoreLastLoopElements); +} + +void MoeInitRountingCustomTilingBase::Tiling4SrcToDstDropPadCompute() +{ + if (quantMode_ == DYNAMIC_QUANT && dropPadMode_ == DROP_PAD) { + MoeInitRountingCustomTilingBase::Tiling4SrcToDstDropPadDynamicCompute(); + return; + } + + auto tilingData = &moeInitRoutingCustomTilingData.srcToDstDropPadParamsOp; + + int64_t perCoreRows = CeilDiv(totalLength_, aivNum); + if (perCoreRows <= 0) { + tilingData->set_needCoreNum(0); + return; + } + int64_t needCoreNum = CeilDiv(totalLength_, perCoreRows); + tilingData->set_needCoreNum(needCoreNum); + int64_t cols = moeInitRoutingCustomTilingData.get_cols(); + tilingData->set_perCoreRows(perCoreRows); + int64_t lastCoreRows = totalLength_ - perCoreRows * (needCoreNum - 1); + tilingData->set_lastCoreRows(lastCoreRows); + bool needScaleCopy = (isInputScale_ != 0 && quantMode_ == -1); + int64_t inuptXDtypeSize = inuptXDtypeSize_ == SIZE_INT8 ? SIZE_INT16 : inuptXDtypeSize_; + + int64_t rowSize = + (perCoreRows * sizeof(int32_t) * NUM_TWO + ONE_BLOCK_BYTE + ONE_BLOCK_BYTE * needScaleCopy + ONE_BLOCK_BYTE - 1) / + ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t colSize = (cols * inuptXDtypeSize + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + + if (rowSize + colSize < static_cast(aicoreParams_.ubSize)) { + SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols); + } else { + int64_t baseMaxCols = MAX_COLS_ONE_LOOP; + int64_t baseMaxColsSize = + (baseMaxCols * inuptXDtypeSize + ONE_BLOCK_BYTE - 1) / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + int64_t basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) - baseMaxColsSize - ONE_BLOCK_BYTE - + ONE_BLOCK_BYTE * needScaleCopy) /static_cast(sizeof(int32_t)) + / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + if (cols < MAX_COLS_ONE_LOOP) { + basePerLoopMaxRows = (static_cast(aicoreParams_.ubSize) - colSize - ONE_BLOCK_BYTE - + ONE_BLOCK_BYTE * needScaleCopy) / static_cast(sizeof(int32_t)) + / NUM_TWO / ONE_BLOCK_BYTE * ONE_BLOCK_BYTE; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = (static_cast(aicoreParams_.ubSize) - rowSize) / inuptXDtypeSize / ONE_BLOCK_BYTE * + ONE_BLOCK_BYTE; + } + tilingData->set_perLoopCols(std::min(baseMaxCols, cols)); + tilingData->set_lastLoopCols(GetPerOrLastValue(cols, baseMaxCols)); + tilingData->set_colLoops((cols + baseMaxCols - 1) / baseMaxCols); + + tilingData->set_perCorePerLoopRows(std::min(perCoreRows, basePerLoopMaxRows)); + tilingData->set_perCoreLastLoopRows(GetPerOrLastValue(perCoreRows, basePerLoopMaxRows)); + tilingData->set_perCoreLoops((perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + + tilingData->set_lastCorePerLoopRows(std::min(lastCoreRows, basePerLoopMaxRows)); + tilingData->set_lastCoreLastLoopRows(GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows)); + tilingData->set_lastCoreLoops((lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + } +} + +void MoeInitRountingCustomTilingBase::SetGatherTilingData( + MoeCustomSrcToDstCapacityComputeTilingData* tilingData, int64_t perCoreRows, int64_t lastCoreRows, int64_t cols) +{ + tilingData->set_perCorePerLoopRows(perCoreRows); + tilingData->set_perCoreLastLoopRows(perCoreRows); + tilingData->set_lastCorePerLoopRows(lastCoreRows); + tilingData->set_lastCoreLastLoopRows(lastCoreRows); + tilingData->set_perCoreLoops(1); + tilingData->set_lastCoreLoops(1); + tilingData->set_perLoopCols(cols); + tilingData->set_lastLoopCols(cols); + tilingData->set_colLoops(1); +} + +void MoeInitRountingCustomTilingBase::SetGatherTilingDataCols( + MoeCustomSrcToDstCapacityComputeTilingData* tilingData, int64_t baseMaxCols, int64_t cols) +{ + tilingData->set_perLoopCols(std::min(baseMaxCols, cols)); + tilingData->set_lastLoopCols(GetPerOrLastValue(cols, baseMaxCols)); + tilingData->set_colLoops(baseMaxCols == 0 ? 0 : (cols + baseMaxCols - 1) / baseMaxCols); +} + +void MoeInitRountingCustomTilingBase::SetGatherTilingDataRows( + MoeCustomSrcToDstCapacityComputeTilingData* tilingData, int64_t perCoreRows, int64_t lastCoreRows, + int64_t basePerLoopMaxRows) +{ + tilingData->set_perCorePerLoopRows(std::min(perCoreRows, basePerLoopMaxRows)); + tilingData->set_perCoreLastLoopRows(GetPerOrLastValue(perCoreRows, basePerLoopMaxRows)); + tilingData->set_perCoreLoops( + basePerLoopMaxRows == 0 ? 0 : (perCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); + + tilingData->set_lastCorePerLoopRows(std::min(lastCoreRows, basePerLoopMaxRows)); + tilingData->set_lastCoreLastLoopRows(GetPerOrLastValue(lastCoreRows, basePerLoopMaxRows)); + tilingData->set_lastCoreLoops( + basePerLoopMaxRows == 0 ? 0 : (lastCoreRows + basePerLoopMaxRows - 1) / basePerLoopMaxRows); +} + +void MoeInitRountingCustomTilingBase::Tiling4SrcToDstDropPadDynamicCompute() +{ + auto tilingData = &moeInitRoutingCustomTilingData.srcToDstDropPadDynamicParamsOp; + + int64_t perCoreRows = CeilDiv(totalLength_, aivNum); + if (perCoreRows <= 0) { + tilingData->set_needCoreNum(0); + return; + } + tilingData->set_needCoreNum(CeilDiv(totalLength_, perCoreRows)); + int64_t cols = moeInitRoutingCustomTilingData.get_cols(); + tilingData->set_perCoreRows(perCoreRows); + int64_t lastCoreRows = totalLength_ - perCoreRows * (tilingData->get_needCoreNum() - 1); + tilingData->set_lastCoreRows(lastCoreRows); + + int64_t rowSize = AlignBytes(perCoreRows, static_cast(sizeof(int32_t))) * NUM_FOUR; + int64_t colSize = AlignBytes(cols, static_cast(sizeof(int8_t))) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER; + int64_t scaleSize = DYNAMIC_QUANT_SCALE_SIZE_64; + if (rowSize + colSize + scaleSize < static_cast(aicoreParams_.ubSize)) { + SetGatherTilingData(tilingData, perCoreRows, lastCoreRows, cols); + } else { + int64_t baseMaxCols = MAX_COLS_DYNAMIC_QUANT; + int64_t totalColSize = AlignBytes(baseMaxCols, static_cast(sizeof(int8_t))) * DYNAMIC_QUANT_SRC_TO_DST_BUFFER; + int64_t ubSize = static_cast(aicoreParams_.ubSize); + int64_t basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - totalColSize - scaleSize) / SIZE_INT32) / NUM_FOUR; + if (cols < MAX_COLS_DYNAMIC_QUANT) { + basePerLoopMaxRows = AlignOneBlockByteCeil((ubSize - colSize - scaleSize) / SIZE_INT32) / NUM_FOUR; + } else if (perCoreRows < basePerLoopMaxRows) { + baseMaxCols = AlignOneBlockByteCeil(ubSize - rowSize - scaleSize) / DYNAMIC_QUANT_SRC_TO_DST_BUFFER; + } + SetGatherTilingDataCols(tilingData, baseMaxCols, cols); + SetGatherTilingDataRows(tilingData, perCoreRows, lastCoreRows, basePerLoopMaxRows); + } +} + +void MoeInitRountingCustomTilingBase::Tiling4SrcToDstCompute() +{ + auto tilingData = &moeInitRoutingCustomTilingData.srcToDstComputeParamsOp; + + int64_t useCore = aivNum; + int64_t remainUbSize = aicoreParams_.ubSize - ASSIST_NUM * sizeof(int32_t) - ONE_BLOCK_BYTE * (ASSIST_NUM + 1); + int64_t perLoopMaxElements = remainUbSize / (ONE_BLOCK_BYTE + SIZE_INT32); + int64_t perCoreElements = CeilDiv(totalLength_, useCore); + if (perCoreElements <= 0) { + tilingData->set_needCoreNum(0); + return; + } + int64_t needCoreNum = CeilDiv(totalLength_, perCoreElements); + tilingData->set_needCoreNum(needCoreNum); + int64_t lastCoreElements = totalLength_ - perCoreElements * (needCoreNum - 1); + + tilingData->set_perCoreElements(perCoreElements); + tilingData->set_lastCoreElements(lastCoreElements); + int64_t perCoreLoops = CeilDiv(perCoreElements, perLoopMaxElements); + int64_t perCorePerLoopElements = CeilDiv(perCoreElements, perCoreLoops); + int64_t perCoreLastLoopElements = perCoreElements - (perCoreLoops - 1) * perCorePerLoopElements; + + int64_t lastCoreLoops = CeilDiv(lastCoreElements, perLoopMaxElements); + int64_t lastCorePerLoopElements = CeilDiv(lastCoreElements, lastCoreLoops); + int64_t lastCoreLastLoopElements = lastCoreElements - (lastCoreLoops - 1) * lastCorePerLoopElements; + + tilingData->set_perCoreLoops(perCoreLoops); + tilingData->set_perCorePerLoopElements(perCorePerLoopElements); + tilingData->set_perCoreLastLoopElements(perCoreLastLoopElements); + tilingData->set_lastCoreLoops(lastCoreLoops); + tilingData->set_lastCorePerLoopElements(lastCorePerLoopElements); + tilingData->set_lastCoreLastLoopElements(lastCoreLastLoopElements); +} + +void MoeInitRountingCustomTilingBase::Tiling4GatherOutCompute() +{ + auto tilingData = &moeInitRoutingCustomTilingData.gatherOutComputeParamsOp; + int64_t perCoreIndicesElements = CeilDiv(totalLength_, aivNum); + if (perCoreIndicesElements <= 0) { + tilingData->set_needCoreNum(0); + return; + } + int64_t needCoreNum = CeilDiv(totalLength_, perCoreIndicesElements); + int64_t lastCoreIndicesElements = totalLength_ - (needCoreNum - 1) * perCoreIndicesElements; + + int64_t perLoopCols = moeInitRoutingCustomTilingData.get_cols(); + int64_t colMultiple = NUM_TWO * inuptXDtypeSize_; + int64_t rowMultiple = NUM_TWO; + if (quantMode_ == DYNAMIC_QUANT) { + colMultiple = DYNAMIC_QUANT_COLS_BUFFER; + rowMultiple = NUM_FOUR; + } + if (quantMode_ == STATIC_QUANT) { + colMultiple = SIZE_INT8 * NUM_TWO + SIZE_FP32 + SIZE_INT16 + inuptXDtypeSize_ * NUM_TWO; + rowMultiple = NUM_TWO; + } + int64_t perLoopMaxIndicesElements = + (static_cast(aicoreParams_.ubSize) - Align(perLoopCols, inuptXDtypeSize_) * colMultiple - + ONE_BLOCK_BYTE * NUM_TWO) / + rowMultiple / static_cast(sizeof(int32_t)); + while (perLoopMaxIndicesElements <= 0) { + perLoopCols = CeilDiv(perLoopCols, NUM_TWO); + perLoopMaxIndicesElements = (static_cast(aicoreParams_.ubSize) - + Align(perLoopCols, inuptXDtypeSize_) * colMultiple - ONE_BLOCK_BYTE * NUM_TWO) / + rowMultiple / static_cast(sizeof(int32_t)); + OPS_LOG_I(context_->GetNodeName(), "perLoopCols is: %ld, perLoopMaxIndicesElements is: %ld", perLoopCols, + perLoopMaxIndicesElements); + } + int64_t colsLoops = CeilDiv(moeInitRoutingCustomTilingData.get_cols(), perLoopCols); + int64_t lastLoopCols = moeInitRoutingCustomTilingData.get_cols() - (colsLoops - 1) * perLoopCols; + tilingData->set_needCoreNum(needCoreNum); + tilingData->set_perCoreIndicesElements(perCoreIndicesElements); + tilingData->set_lastCoreIndicesElements(lastCoreIndicesElements); + tilingData->set_colsLoops(colsLoops); + tilingData->set_perLoopCols(perLoopCols); + tilingData->set_lastLoopCols(lastLoopCols); + + int64_t perCorePerLoopIndicesElements = std::min(perLoopMaxIndicesElements, perCoreIndicesElements); + int64_t perCoreIndicesLoops = CeilDiv(perCoreIndicesElements, perCorePerLoopIndicesElements); + int64_t perCoreLastLoopIndicesElements = + perCoreIndicesElements - (perCoreIndicesLoops - 1) * perCorePerLoopIndicesElements; + tilingData->set_perCoreIndicesLoops(perCoreIndicesLoops); + tilingData->set_perCorePerLoopIndicesElements(perCorePerLoopIndicesElements); + tilingData->set_perCoreLastLoopIndicesElements(perCoreLastLoopIndicesElements); + + int64_t lastCorePerLoopIndicesElements = std::min(perLoopMaxIndicesElements, lastCoreIndicesElements); + int64_t lastCoreIndicesLoops = CeilDiv(lastCoreIndicesElements, lastCorePerLoopIndicesElements); + int64_t lastCoreLastLoopIndicesElements = + lastCoreIndicesElements - (lastCoreIndicesLoops - 1) * lastCorePerLoopIndicesElements; + tilingData->set_lastCoreIndicesLoops(lastCoreIndicesLoops); + tilingData->set_lastCorePerLoopIndicesElements(lastCorePerLoopIndicesElements); + tilingData->set_lastCoreLastLoopIndicesElements(lastCoreLastLoopIndicesElements); + + OPS_LOG_I( + context_->GetNodeName(), + "GatherOut Tilingdata, needCoreNum is: %ld, perCoreIndicesElements is: %ld, lastCoreIndicesElements is: %ld, " + "colsLoops is: %ld, perLoopCols is: %ld, lastLoopCols is: %ld, perCoreIndicesLoops is: %ld, " + "perCorePerLoopIndicesElements is: %ld, perCoreLastLoopIndicesElements is: %ld, lastCoreIndicesLoops is: " + "%ld, lastCorePerLoopIndicesElements is: " + "%ld, lastCoreLastLoopIndicesElements is: %ld.", + needCoreNum, perCoreIndicesElements, lastCoreIndicesElements, colsLoops, perLoopCols, lastLoopCols, + perCoreIndicesLoops, perCorePerLoopIndicesElements, perCoreLastLoopIndicesElements, lastCoreIndicesLoops, + lastCorePerLoopIndicesElements, lastCoreLastLoopIndicesElements); +} + +REGISTER_TILING_TEMPLATE("MoeInitRoutingCustom", MoeInitRountingCustomTilingBase, 10000); // If not 910_95, fallback to this. +} // namespace optiling \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling.h b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling.h new file mode 100644 index 00000000000..64a72d94732 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling.h @@ -0,0 +1,143 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_init_routing_custom_tiling.h + * \brief + */ +#ifndef AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_INIT_ROUTING_CUSTOM_H +#define AIR_CXX_RUNTIME_V2_OP_IMPL_MOE_INIT_ROUTING_CUSTOM_H +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" + + +namespace optiling { +BEGIN_TILING_DATA_DEF(MoeCustomVBSComputeTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, perCoreElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, perCorePerLoopElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLastLoopElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, lastCorePerLoopElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLastLoopElements); +TILING_DATA_FIELD_DEF(int64_t, oneLoopMaxElements); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomVBSComputeTilingDataOp, MoeCustomVBSComputeTilingData) + +BEGIN_TILING_DATA_DEF(MoeCustomVMSMiddleComputeTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomVMSMiddleComputeTilingDataOp, MoeCustomVMSMiddleComputeTilingData) + +BEGIN_TILING_DATA_DEF(MoeCustomSortOutComputeTilingData) +TILING_DATA_FIELD_DEF(int64_t, oneLoopMaxElements); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomSortOutComputeTilingDataOp, MoeCustomSortOutComputeTilingData) + +BEGIN_TILING_DATA_DEF(MoeCustomExpertTokensCountTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, perCoreElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, perCorePerLoopElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLastLoopElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, lastCorePerLoopElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLastLoopElements); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomExpertTokensCountTilingDataOp, MoeCustomExpertTokensCountTilingData) + +BEGIN_TILING_DATA_DEF(MoeCustomGatherOutComputeTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, perCoreIndicesElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreIndicesElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreIndicesLoops); +TILING_DATA_FIELD_DEF(int64_t, perCorePerLoopIndicesElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLastLoopIndicesElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreIndicesLoops); +TILING_DATA_FIELD_DEF(int64_t, lastCorePerLoopIndicesElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLastLoopIndicesElements); +TILING_DATA_FIELD_DEF(int64_t, colsLoops); +TILING_DATA_FIELD_DEF(int64_t, perLoopCols); +TILING_DATA_FIELD_DEF(int64_t, lastLoopCols); +TILING_DATA_FIELD_DEF(int64_t, activeNum); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomGatherOutComputeTilingDataOp, MoeCustomGatherOutComputeTilingData) + +BEGIN_TILING_DATA_DEF(MoeCustomSrcToDstCapacityComputeTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, perCoreRows); +TILING_DATA_FIELD_DEF(int64_t, perCorePerLoopRows); +TILING_DATA_FIELD_DEF(int64_t, perCoreLastLoopRows); +TILING_DATA_FIELD_DEF(int64_t, lastCoreRows); +TILING_DATA_FIELD_DEF(int64_t, lastCorePerLoopRows); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLastLoopRows); +TILING_DATA_FIELD_DEF(int64_t, perCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, perLoopCols); +TILING_DATA_FIELD_DEF(int64_t, lastLoopCols); +TILING_DATA_FIELD_DEF(int64_t, colLoops); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomSrcToDstCapacityComputeTilingDataOp, MoeCustomSrcToDstCapacityComputeTilingData) + +BEGIN_TILING_DATA_DEF(MoeCustomSrcToDstComputeTilingData) +TILING_DATA_FIELD_DEF(int64_t, needCoreNum); +TILING_DATA_FIELD_DEF(int64_t, perCoreElements); +TILING_DATA_FIELD_DEF(int64_t, perCorePerLoopElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLastLoopElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreElements); +TILING_DATA_FIELD_DEF(int64_t, lastCorePerLoopElements); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLastLoopElements); +TILING_DATA_FIELD_DEF(int64_t, perCoreLoops); +TILING_DATA_FIELD_DEF(int64_t, lastCoreLoops) +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeCustomSrcToDstComputeTilingDataOp, MoeCustomSrcToDstComputeTilingData) + +BEGIN_TILING_DATA_DEF(MoeInitRoutingCustomTilingData) +TILING_DATA_FIELD_DEF(int64_t, coreNum); +TILING_DATA_FIELD_DEF(int64_t, n); +TILING_DATA_FIELD_DEF(int64_t, cols); +TILING_DATA_FIELD_DEF(int64_t, k); +TILING_DATA_FIELD_DEF(int64_t, expertStart); +TILING_DATA_FIELD_DEF(int64_t, expertEnd); +TILING_DATA_FIELD_DEF(int64_t, actualExpertNum); +TILING_DATA_FIELD_DEF(int64_t, quantMode); +TILING_DATA_FIELD_DEF(int64_t, rowIdxType); +TILING_DATA_FIELD_DEF(int64_t, isInputScale); +TILING_DATA_FIELD_DEF(int64_t, isInputOffset); +TILING_DATA_FIELD_DEF(int64_t, expertNum); +TILING_DATA_FIELD_DEF(int64_t, expertTokensNumType); +TILING_DATA_FIELD_DEF(int64_t, expertTokensNumFlag); +TILING_DATA_FIELD_DEF(int64_t, gatherFirstFullload); +TILING_DATA_FIELD_DEF(int64_t, ep); +TILING_DATA_FIELD_DEF(int64_t, activeNum); +TILING_DATA_FIELD_DEF(int64_t, dropPadMode); +TILING_DATA_FIELD_DEF(int64_t, smoothType); +TILING_DATA_FIELD_DEF(int64_t, expertCountElements); +TILING_DATA_FIELD_DEF(int64_t, expertCapacity); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomVBSComputeTilingData, vbsComputeParamsOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomVMSMiddleComputeTilingData, vmsMiddleComputeParamsOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomSortOutComputeTilingData, sortOutComputeParamsOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomExpertTokensCountTilingData, expertTokensCountTilingDataOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomGatherOutComputeTilingData, gatherOutComputeParamsOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomSrcToDstCapacityComputeTilingData, srcToDstDropPadParamsOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomSrcToDstCapacityComputeTilingData, srcToDstDropPadDynamicParamsOp); +TILING_DATA_FIELD_DEF_STRUCT(MoeCustomSrcToDstComputeTilingData, srcToDstComputeParamsOp); +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(MoeInitRoutingCustom, MoeInitRoutingCustomTilingData) +struct MoeInitRoutingCustomCompileInfo { + int32_t aivNum = 0; + uint64_t ubSize = 0; + platform_ascendc::SocVersion socVersion = platform_ascendc::SocVersion::ASCEND910B; + }; +} // namespace optiling +#endif \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling_base.cpp b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling_base.cpp new file mode 100644 index 00000000000..65e3442ae89 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_host/moe_init_routing_custom_tiling_base.cpp @@ -0,0 +1,68 @@ +/** + * This program is free software, you can redistribute it and/or modify. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_init_routing_custom_tiling_base.cpp + * \brief + */ +#include "moe_init_routing_custom_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_templates_registry.h" + +#define unlikely(x) __builtin_expect((x), 0) + +#define OP_CHECK_NULL_WITH_CONTEXT(context, ptr) \ + do { \ + if (unlikely((ptr) == nullptr)) { \ + const char* name = (unlikely(((context) == nullptr) || (context)->GetNodeName() == nullptr)) ? \ + "nil" : \ + (context)->GetNodeName(); \ + OPS_LOG_E(name, "%s is nullptr!", #ptr); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +namespace optiling { +static ge::graphStatus TilingForMoeInitRoutingCustom(gert::TilingContext *context) +{ + return TilingRegistry::GetInstance().DoTilingImpl(context); +} + +static ge::graphStatus TilingPrepareForMoeInitRountingCustom(gert::TilingParseContext* context) +{ + OPS_LOG_D(context, "TilingPrepareForMoeInitRountingCustom enter."); + + auto compileInfo = context->GetCompiledInfo(); + OP_CHECK_NULL_WITH_CONTEXT(context, compileInfo); + auto platformInfo = context->GetPlatformInfo(); + OP_CHECK_NULL_WITH_CONTEXT(context, platformInfo); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + compileInfo->aivNum = ascendcPlatform.GetCoreNumAiv(); + if (compileInfo->aivNum <= 0) { + OPS_LOG_E(context, "TilingPrepareForMoeInitRountingCustom fail to get core num."); + return ge::GRAPH_FAILED; + } + + uint64_t ubSize; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + compileInfo->ubSize = static_cast(ubSize); + compileInfo->socVersion = ascendcPlatform.GetSocVersion(); + if (compileInfo->ubSize <= 0) { + OPS_LOG_E(context, "TilingPrepareForMoeInitRountingCustom fail to get ub size."); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(MoeInitRoutingCustom) + .Tiling(TilingForMoeInitRoutingCustom) + .TilingParse(TilingPrepareForMoeInitRountingCustom); +} // namespace optiling \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_common.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_common.h new file mode 100644 index 00000000000..5afeebaf98f --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_common.h @@ -0,0 +1,110 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_common.h + * \brief + */ +#ifndef MOE_CUSTOM_COMMON_H +#define MOE_CUSTOM_COMMON_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; +constexpr int64_t SPLIT_N = 0; +constexpr int64_t SPLIT_K = 1; +constexpr float MIN_FP32 = -3.4e38f; +constexpr int64_t FP32_ONE_REPEAT_NUM = 64; +constexpr int64_t ONE_REPEAT_SORT_NUM = 32; +constexpr int64_t ONE_REPEAT_COMPARE_NUM = 64; +constexpr int64_t BLOCK_BYTES = 32; +constexpr int64_t INT32_ONE_BLOCK_NUM = 8; +constexpr int64_t FP32_ONE_BLOCK_NUM = 8; +constexpr int64_t DROPLESS_MODE = 0; +constexpr int64_t DROP_PAD_MODE = 1; +constexpr int64_t ASSIST_NUM = 256; +constexpr int64_t ASSIST_INDEX_NUM = 32; +constexpr int64_t MRGSORT_LIST_MAX_ELEMENT = 2040; +constexpr float MAX_INT8 = 127.0f; +constexpr uint32_t INF = 0xFF7FFFFF; + +constexpr int64_t MERGE_LIST_TWO = 2; +constexpr int64_t MERGE_LIST_THREE = 3; +constexpr int64_t MERGE_LIST_FOUR = 4; + +constexpr int64_t MERGE_LIST_IDX_TWO = 2; +constexpr int64_t MERGE_LIST_IDX_THREE = 3; + +constexpr int64_t GATHER = 0; +constexpr int64_t SCATTER = 1; + +static constexpr int64_t NO_SCALE = 0; +static constexpr int64_t SCALE_1H = 1; +static constexpr int64_t SCALE_EH = 2; + +constexpr int64_t EXERPT_TOKENS_CUMSUM = 0; +constexpr int64_t EXERPT_TOKENS_COUNT = 1; +constexpr int64_t EXERPT_TOKENS_KEY_VALUE = 2; +constexpr int64_t EXERPT_TOKENS_NONE = 0; + +const __gm__ int32_t assist[256] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, + 4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, + 12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, + 16, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, + 20, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, 0, 0, 0, + 24, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 26, 0, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0, 0, + 28, 0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0}; + +__aicore__ inline int64_t Ceil(int64_t a, int64_t b) +{ + if (b == 0) { + return 0; + } + return (a + b - 1) / b; +} + +__aicore__ inline int64_t Align(int64_t elementNum, int64_t bytes) +{ + if (bytes == 0) { + return 0; + } + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES / bytes; +} + +__aicore__ inline int64_t AlignBytes(int64_t elementNum, int64_t bytes) +{ + return (elementNum * bytes + BLOCK_BYTES - 1) / BLOCK_BYTES * BLOCK_BYTES; +} + +template +__aicore__ inline T Min(T a, T b) +{ + return a > b ? b : a; +} + +template +__aicore__ inline T Max(T a, T b) +{ + return a < b ? b : a; +} + +template +__aicore__ inline void SetWaitFlag(HardEvent evt) +{ + event_t eventId = static_cast(GetTPipePtr()->FetchEventID(evt)); + SetFlag(eventId); + WaitFlag(eventId); +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_COMMON_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_expert_tokens_count.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_expert_tokens_count.h new file mode 100644 index 00000000000..c4fb6ba2f86 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_expert_tokens_count.h @@ -0,0 +1,371 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_expert_tokens_count.h + * \brief + */ +#ifndef MOE_CUSTOM_EXPERT_TOKENS_COUNT_H +#define MOE_CUSTOM_EXPERT_TOKENS_COUNT_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t EXPERT_ID_VALUE_NUM = 2; +constexpr int64_t CUMSUM_MODE = 0; +constexpr int64_t COUNT_MODE = 1; +constexpr int64_t KEY_VALUE_MODE = 2; +constexpr int64_t KEY_VALUE_MODE_DIM_NUM = 2; +constexpr int64_t GATHER_SORT_CORE_NUM = 16; +constexpr int64_t DROP_LESS = 0; +constexpr int64_t DROP_PAD = 1; + +template +class ExpertTokensCount { +public: + __aicore__ inline ExpertTokensCount(){}; + template + __aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expertTokensCount, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t loop, int64_t curLoopElements); + __aicore__ inline void Compute(int64_t curLoopElements); + __aicore__ inline void CopyOut(); + __aicore__ inline void CopyOutExpertTotalCount(); + + __aicore__ inline void expertCountCopyIn(); + __aicore__ inline void expertCountCompute(); + __aicore__ inline void expertCountCopyOut(); + +private: + GlobalTensor sortedexpertIdxGm_; + GlobalTensor expertCountTempGm_; + GlobalTensor expertTokensCountGm_; + GlobalTensor expertTotalCountGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expertIdxValueGm_; + TPipe *pipe_; + + TQue sortedExpertIdxInQueue_; + TQue expertCountOutToTempQueue_; + TQue expertCountTempInQueue_; + TQue expertIdxCountOutQueue_; + TQue expertTotalCountQueue_; + + const MoeCustomExpertTokensCountTilingData *expertTokensCountTilingData_; + int64_t coreNum_; + int64_t blockIdx_; + int64_t needCoreNum_; + int64_t perCoreElements_; + int64_t curCoreElements_ = 0; + int64_t expertStart_ = 0; + int64_t expertEnd_ = 0; + int64_t actualExpertNum_ = 0; + int64_t coreLoopsNum_ = 0; + int64_t perCorePerLoopElements_ = 0; + int64_t perCoreLastLoopElements_ = 0; + int64_t actualExpertTotalNum_ = 0; + int64_t expertNum_ = 0; + int64_t expertCountElements_ = 0; + bool expertTokensNumFlag_ = false; + int64_t dropPadMode_ = 0; + int32_t finalExpertId = -1; + int32_t expertTokenValue = 0; + int64_t ep_ = 0; + int64_t rowIdxType_ = 0; +}; + +template +template +__aicore__ inline void +ExpertTokensCount::Init(GM_ADDR expandedRowIdx, GM_ADDR expertTokensCount, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + coreNum_ = tilingData->coreNum; + pipe_ = tPipe; + expertTokensCountTilingData_ = &(tilingData->expertTokensCountTilingDataOp); + blockIdx_ = GetBlockIdx(); + needCoreNum_ = expertTokensCountTilingData_->needCoreNum; + perCoreElements_ = expertTokensCountTilingData_->perCoreElements; + expertStart_ = tilingData->expertStart; + expertEnd_ = tilingData->expertEnd; + actualExpertNum_ = tilingData->actualExpertNum; + expertNum_ = tilingData->expertNum; + expertTokensNumFlag_ = tilingData->expertTokensNumFlag; + dropPadMode_ = tilingData->dropPadMode; + ep_ = tilingData->ep; + rowIdxType_ = tilingData->rowIdxType; + + if (blockIdx_ == needCoreNum_ - 1) { + curCoreElements_ = expertTokensCountTilingData_->lastCoreElements; + coreLoopsNum_ = expertTokensCountTilingData_->lastCoreLoops; + perCorePerLoopElements_ = expertTokensCountTilingData_->lastCorePerLoopElements; + perCoreLastLoopElements_ = expertTokensCountTilingData_->lastCoreLastLoopElements; + } else { + curCoreElements_ = expertTokensCountTilingData_->perCoreElements; + coreLoopsNum_ = expertTokensCountTilingData_->perCoreLoops; + perCorePerLoopElements_ = expertTokensCountTilingData_->perCorePerLoopElements; + perCoreLastLoopElements_ = expertTokensCountTilingData_->perCoreLastLoopElements; + } + + if (CALC_ACTUAL_EXPERT_NUM) { + // key and value + int64_t kvFactor = 2; + GlobalTensor sortedNumGm; + sortedNumGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * kvFactor * kvFactor); + int32_t totalSortedNum = 0; + for (int32_t i = 0; i < 16; i++) { + totalSortedNum += sortedNumGm.GetValue(i); + } + perCoreElements_ = Ceil(totalSortedNum, GetBlockNum()); + needCoreNum_ = Ceil(totalSortedNum, perCoreElements_); + int64_t lastCoreElements = totalSortedNum - (needCoreNum_ - 1) * perCoreElements_; + if (blockIdx_ == needCoreNum_ - 1) { + curCoreElements_ = lastCoreElements; + } else { + curCoreElements_ = perCoreElements_; + } + coreLoopsNum_ = Ceil(curCoreElements_, expertTokensCountTilingData_->perCorePerLoopElements); + perCorePerLoopElements_ = Ceil(curCoreElements_, coreLoopsNum_); + perCoreLastLoopElements_ = curCoreElements_ - (coreLoopsNum_ - 1) * perCorePerLoopElements_; + } + + if constexpr (HISTOGRAMTYPE == KEY_VALUE_MODE) { + expertCountElements_ = ((actualExpertNum_ + 1) < expertNum_) ? (actualExpertNum_ + 1) * KEY_VALUE_MODE_DIM_NUM : + expertNum_ * KEY_VALUE_MODE_DIM_NUM; + } else { + expertCountElements_ = actualExpertNum_; + } + sortedexpertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + blockIdx_ * perCoreElements_, curCoreElements_); + expertTokensCountGm_.SetGlobalBuffer((__gm__ int64_t *)expertTokensCount, expertCountElements_); + expertCountTempGm_.SetGlobalBuffer( + (__gm__ int32_t *)workspace + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * 2, actualExpertNum_); + expertTotalCountGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * 2 + + Align(actualExpertNum_, sizeof(int32_t)), + actualExpertNum_); + expertIdxValueGm_.SetGlobalBuffer( + (__gm__ int32_t *)workspace + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * 2 + + Align((actualExpertNum_), sizeof(int32_t)) + Align((actualExpertNum_), sizeof(int32_t)), + coreNum_ * 2); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreElements_, + curCoreElements_); + + if ((tilingData->rowIdxType == GATHER) && (blockIdx_ < needCoreNum_)) { + InitGlobalMemory(expandedRowIdxGm_, curCoreElements_, -1); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + int64_t sortedExpertIdxInLen = Max(perCorePerLoopElements_, perCoreLastLoopElements_); + + pipe_->InitBuffer(sortedExpertIdxInQueue_, 1, AlignBytes(sortedExpertIdxInLen, sizeof(int32_t))); + pipe_->InitBuffer(expertCountOutToTempQueue_, 1, AlignBytes(actualExpertNum_, sizeof(int32_t))); + pipe_->InitBuffer(expertCountTempInQueue_, 1, AlignBytes(actualExpertNum_, sizeof(int32_t))); + + pipe_->InitBuffer(expertIdxCountOutQueue_, 1, AlignBytes(expertCountElements_, sizeof(int64_t))); + pipe_->InitBuffer(expertTotalCountQueue_, 1, AlignBytes(1, sizeof(int32_t))); + + if (blockIdx_ == 0) { + InitGlobalMemory(expertTotalCountGm_, 1, 0); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + SyncAll(); +} + +template +__aicore__ inline void ExpertTokensCount::Process() +{ + if (blockIdx_ < needCoreNum_) { + for (int64_t i = 0; i < coreLoopsNum_; i++) { + int64_t perLoopElements = (i == (coreLoopsNum_ - 1)) ? perCoreLastLoopElements_ : perCorePerLoopElements_; + CopyIn(i, perLoopElements); + Compute(perLoopElements); + CopyOut(); + } + if (ep_ == 1) { + CopyOutExpertTotalCount(); + } + } + if (ep_ == 1 || expertTokensNumFlag_ || dropPadMode_ == 1) { + SyncAll(); + } + /* copy expert tokens count result from worksapce to output GM. */ + if (blockIdx_ == 0 && expertTokensNumFlag_) { + expertCountCopyIn(); + expertCountCompute(); + expertCountCopyOut(); + } +} + +template +__aicore__ inline void ExpertTokensCount::CopyIn(int64_t loop, int64_t curLoopElements) +{ + LocalTensor sortedExpertIdxInLocal = sortedExpertIdxInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(curLoopElements * sizeof(int32_t)), + 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + int64_t sortedexpertIdxOffset = loop * perCorePerLoopElements_; + DataCopyPad(sortedExpertIdxInLocal, sortedexpertIdxGm_[sortedexpertIdxOffset], dataCopyParams, dataCopyPadParams); + sortedExpertIdxInQueue_.EnQue(sortedExpertIdxInLocal); +} + +template +__aicore__ inline void ExpertTokensCount::Compute(int64_t curLoopElements) +{ + LocalTensor sortedExpertIdxInLocal = sortedExpertIdxInQueue_.DeQue(); + LocalTensor expertCountOutLocal = expertCountOutToTempQueue_.AllocTensor(); + Duplicate(expertCountOutLocal.ReinterpretCast(), static_cast(0), + static_cast(actualExpertNum_)); + SetWaitFlag(HardEvent::V_S); + int64_t i = 0; + int32_t lastExpertId = sortedExpertIdxInLocal.GetValue(0); + int32_t lastIndex = 0; + int64_t loopTokenCount = 0; + int32_t lastlastExpertId = lastExpertId; + for (i = 1; i < curLoopElements; i++) { + if ((lastExpertId >= expertEnd_) || (lastExpertId < expertStart_)) { + break; + } + int32_t curExpertId = sortedExpertIdxInLocal.GetValue(i); + if (curExpertId != lastExpertId || curExpertId >= expertEnd_) { + if constexpr (HISTOGRAMTYPE == COUNT_MODE || HISTOGRAMTYPE == KEY_VALUE_MODE) { + expertCountOutLocal.SetValue(lastExpertId - expertStart_, i - lastIndex); + loopTokenCount += i - lastIndex; + } else { + for (int64_t j = lastlastExpertId; j < lastExpertId; j++) { + expertCountOutLocal.SetValue(j - expertStart_, loopTokenCount); + } + loopTokenCount += i - lastIndex; + expertCountOutLocal.SetValue(lastExpertId - expertStart_, loopTokenCount); + } + lastIndex = i; + lastlastExpertId = lastExpertId; + lastExpertId = curExpertId; + } + } + if ((i == curLoopElements) && ((lastExpertId >= expertStart_) && (lastExpertId < expertEnd_))) { + if constexpr (HISTOGRAMTYPE == COUNT_MODE || HISTOGRAMTYPE == KEY_VALUE_MODE) { + expertCountOutLocal.SetValue(lastExpertId - expertStart_, i - lastIndex); + loopTokenCount += i - lastIndex; + } else { + for (int64_t j = lastlastExpertId; j < lastExpertId; j++) { + expertCountOutLocal.SetValue(j - expertStart_, loopTokenCount); + } + loopTokenCount += i - lastIndex; + expertCountOutLocal.SetValue(lastExpertId - expertStart_, loopTokenCount); + for (int64_t j = lastExpertId; j < expertEnd_; j++) { + expertCountOutLocal.SetValue(j - expertStart_, loopTokenCount); + } + } + } else { + if constexpr (HISTOGRAMTYPE == EXERPT_TOKENS_CUMSUM) { + for (int64_t j = lastlastExpertId; j < expertEnd_; j++) { + expertCountOutLocal.SetValue(j - expertStart_, loopTokenCount); + } + } + } + actualExpertTotalNum_ += loopTokenCount; + finalExpertId = lastExpertId; + expertTokenValue = (i - lastIndex); + + expertCountOutToTempQueue_.EnQue(expertCountOutLocal); + sortedExpertIdxInQueue_.FreeTensor(sortedExpertIdxInLocal); +} + +template +__aicore__ inline void ExpertTokensCount::CopyOutExpertTotalCount() +{ + LocalTensor expertTotalCountLocal = expertTotalCountQueue_.AllocTensor(); + DataCopyExtParams copyTotalCountParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + expertTotalCountLocal.SetValue(0, static_cast(actualExpertTotalNum_)); + SetWaitFlag(HardEvent::S_MTE3); + SetAtomicAdd(); + DataCopyPad(expertTotalCountGm_, expertTotalCountLocal, copyTotalCountParams); + SetAtomicNone(); + expertTotalCountQueue_.FreeTensor(expertTotalCountLocal); +} + +template +__aicore__ inline void ExpertTokensCount::CopyOut() +{ + LocalTensor expertCountOutLocal = expertCountOutToTempQueue_.DeQue(); + DataCopyExtParams copyParams{static_cast(1), static_cast((actualExpertNum_) * sizeof(int32_t)), + 0, 0, 0}; + SetWaitFlag(HardEvent::S_MTE3); + SetAtomicAdd(); + DataCopyPad(expertCountTempGm_, expertCountOutLocal, copyParams); + SetAtomicNone(); + + if (dropPadMode_ == DROP_PAD) { + expertCountOutLocal.SetValue(0, finalExpertId); + expertCountOutLocal.SetValue(1, expertTokenValue); + DataCopyExtParams copyParams{static_cast(1), + static_cast(EXPERT_ID_VALUE_NUM * sizeof(int32_t)), 0, 0, 0}; + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expertIdxValueGm_[blockIdx_ * EXPERT_ID_VALUE_NUM], expertCountOutLocal, copyParams); + } + expertCountOutToTempQueue_.FreeTensor(expertCountOutLocal); +} + +template +__aicore__ inline void ExpertTokensCount::expertCountCopyIn() +{ + LocalTensor expertCountTempInLocal = expertCountTempInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast((actualExpertNum_) * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(expertCountTempInLocal, expertCountTempGm_, dataCopyParams, dataCopyPadParams); + expertCountTempInQueue_.EnQue(expertCountTempInLocal); +} + +template +__aicore__ inline void ExpertTokensCount::expertCountCompute() +{ + LocalTensor expertCountTempInLocal = expertCountTempInQueue_.DeQue(); + LocalTensor expertCountOutLocal = expertIdxCountOutQueue_.AllocTensor(); + if constexpr (HISTOGRAMTYPE == KEY_VALUE_MODE) { + int64_t expertOffset = 0; + Duplicate(expertCountOutLocal.ReinterpretCast(), static_cast(0), + static_cast(expertCountElements_ * KEY_VALUE_MODE)); + SetWaitFlag(HardEvent::V_S); + for (int64_t i = 0; i < actualExpertNum_; i++) { + int64_t expertCount = static_cast(expertCountTempInLocal.GetValue(i)); + if (expertCount != 0) { + expertCountOutLocal.SetValue(expertOffset * KEY_VALUE_MODE_DIM_NUM, i + expertStart_); + expertCountOutLocal.SetValue(expertOffset * KEY_VALUE_MODE_DIM_NUM + 1, expertCount); + expertOffset++; + } + } + } else { + Cast(expertCountOutLocal, expertCountTempInLocal, RoundMode::CAST_NONE, actualExpertNum_); + } + + expertIdxCountOutQueue_.EnQue(expertCountOutLocal); + expertCountTempInQueue_.FreeTensor(expertCountTempInLocal); +} + +template +__aicore__ inline void ExpertTokensCount::expertCountCopyOut() +{ + LocalTensor expertCountOutLocal = expertIdxCountOutQueue_.DeQue(); + DataCopyExtParams copyParams{static_cast(1), + static_cast(expertCountElements_ * sizeof(int64_t)), 0, 0, 0}; + DataCopyPad(expertTokensCountGm_, expertCountOutLocal, copyParams); + copyParams.blockLen = sizeof(int32_t); + expertIdxCountOutQueue_.FreeTensor(expertCountOutLocal); +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_EXPERT_TOKENS_COUNT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load.h new file mode 100644 index 00000000000..6b985ec080d --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load.h @@ -0,0 +1,280 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_full_load.h + * \brief + */ +#ifndef MOE_CUSTOM_FULL_LOAD_H +#define MOE_CUSTOM_FULL_LOAD_H + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class MoeCustomFullLoad { +public: + __aicore__ inline MoeCustomFullLoad(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedScale, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void ExpertCountCompute(); + __aicore__ inline void CopyOutDynamicQuant(); + +private: + int64_t sortNum; + + TPipe *pipe; + TQue sortDataCopyInQueue; + TQue sortDataCopyOutQueue; + TQue expertTokensCountOrCumsumOutQueue; + TQue smoothInQueue; + TQue inputXInQueue; + TQue inputXOutQueue; + TQue scaleOutQueue; + TQue rowIdxOutQueue; + + TBuf tempBuffer; + TBuf sortedBuffer; + TBuf quantTempBuffer; + + GlobalTensor inputXGm; + GlobalTensor smoothGm; + GlobalTensor expandedXGm; + GlobalTensor expandedScaleGm; + GlobalTensor expertIdxGm; + GlobalTensor expendedRowIdxGm; + GlobalTensor sortedExpertForSourceRowGm; + GlobalTensor expandDstToSrcRowGm; + GlobalTensor sortedexpertIdxGm; + GlobalTensor expertCountTempGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expertTokensCountOrCumsumGm; + + int64_t blockIdx = 0; + int64_t tileLength; + int64_t bufferNum = 1; + int64_t totalLength; + int64_t n; + int64_t k; + int64_t cols_; + int64_t expertNum_ = 256; + int64_t rowIdxType_; + int64_t kvFactor = 2; + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; +}; + +__aicore__ inline void MoeCustomFullLoad::CopyIn() +{ + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast(this->totalLength * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams); + LocalTensor rowIdxLocal = inLocal[this->sortNum]; + ArithProgression(rowIdxLocal, 0, 1, this->sortNum); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeCustomFullLoad::SortCompute() +{ + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertIdx = inLocal[0]; + LocalTensor expertIdxFp32 = expertIdx.ReinterpretCast(); + Cast(expertIdxFp32, expertIdx, RoundMode::CAST_ROUND, this->tileLength); + Muls(expertIdxFp32, expertIdxFp32, (float)-1, this->tileLength); + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + } + + LocalTensor concatLocal; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum)); + Concat(concatLocal, expertIdxFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum)); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[this->sortNum].ReinterpretCast(); + Sort(sortedLocal, concatLocal, sourceRowLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + LocalTensor sortedExpertForSourceRowLocal = outLocal[0]; + LocalTensor expandDstToSrcRowLocal; + expandDstToSrcRowLocal = outLocal[this->sortNum].ReinterpretCast(); + Extract(sortedExpertForSourceRowLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM); + Muls(sortedExpertForSourceRowLocal, sortedExpertForSourceRowLocal, (float)-1, this->tileLength); + + LocalTensor expertForSourceRowLocalInt32; + expertForSourceRowLocalInt32 = sortedExpertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalInt32, sortedExpertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength); + sortDataCopyOutQueue.EnQue(outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeCustomFullLoad::ExpertCountCompute() +{ + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + LocalTensor sortedExpertId = outLocal; + LocalTensor expertTokensLocalTensor = expertTokensCountOrCumsumOutQueue.AllocTensor(); + + int64_t i = 0; + int32_t lastExpertId = sortedExpertId.GetValue(0); + int32_t lastIndex = 0; + int64_t index = 0; + for (i = 1; i < this->totalLength; i++) { + int32_t curExpertId = sortedExpertId.GetValue(i); + if (curExpertId != lastExpertId) { + expertTokensLocalTensor.SetValue(index * kvFactor, lastExpertId); + expertTokensLocalTensor.SetValue(index * kvFactor + 1, i - lastIndex); + index++; + lastIndex = i; + lastExpertId = curExpertId; + } + } + if (i == this->totalLength) { + expertTokensLocalTensor.SetValue(index * kvFactor, lastExpertId); + expertTokensLocalTensor.SetValue(index * kvFactor + 1, i - lastIndex); + index++; + } + // totalLength < 256 + expertTokensLocalTensor.SetValue(index * kvFactor, 0); + expertTokensLocalTensor.SetValue(index * kvFactor + 1, 0); + SetWaitFlag(HardEvent::S_MTE3); + + expertTokensCountOrCumsumOutQueue.EnQue(expertTokensLocalTensor); + sortDataCopyOutQueue.EnQue(outLocal); +} + +__aicore__ inline void MoeCustomFullLoad::CopyOutDynamicQuant() +{ + LocalTensor expertTokensLocalTensor = expertTokensCountOrCumsumOutQueue.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = expertNum_ * sizeof(int64_t); + DataCopyPad(expertTokensCountOrCumsumGm, expertTokensLocalTensor, intriParams); + expertTokensCountOrCumsumOutQueue.FreeTensor(expertTokensLocalTensor); + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + + int64_t expertIdx = outLocal.GetValue(blockIdx); + LocalTensor xInLocal = inputXInQueue.AllocTensor(); + LocalTensor xOutLocal = inputXOutQueue.AllocTensor(); + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + LocalTensor scaleLocal = scaleOutQueue.AllocTensor(); + LocalTensor tempLocal = quantTempBuffer.Get(); + DataCopyExtParams copyInParams{1, static_cast(cols_ * sizeof(bfloat16_t)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(cols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(cols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyPad(xInLocal, inputXGm, copyInParams, {false, 0, 0, 0}); + DataCopyPad(smoothLocal, smoothGm[expertIdx * cols_], smoothParams, {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + Cast(tempLocal, xInLocal, RoundMode::CAST_NONE, cols_); + Mul(smoothLocal, tempLocal, smoothLocal, cols_); + // compute scale + Abs(tempLocal, smoothLocal, cols_); + ReduceMax(scaleLocal, tempLocal, tempLocal, cols_); + float scaleValue = scaleLocal.GetValue(0) / 127.0f; + Duplicate(scaleLocal, scaleValue, DST_REP_STRIDE); + Duplicate(tempLocal, scaleValue, cols_); + // compute quant + Div(tempLocal, smoothLocal, tempLocal, cols_); + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_ODD, cols_); // fp32->fp16 + Cast(xOutLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_RINT, cols_); // fp16->int8 + inputXOutQueue.EnQue(xOutLocal); + xOutLocal = inputXOutQueue.DeQue(); + scaleOutQueue.EnQue(scaleLocal); + scaleLocal = scaleOutQueue.DeQue(); + DataCopyPad(expandedXGm[blockIdx * cols_], xOutLocal, copyOutParams); + DataCopyPad(expandedScaleGm[blockIdx], scaleLocal, {1, 4, 0, 0, 0}); + smoothInQueue.FreeTensor(smoothLocal); + inputXInQueue.FreeTensor(xInLocal); + inputXOutQueue.FreeTensor(xOutLocal); + scaleOutQueue.FreeTensor(scaleLocal); + + if (blockIdx == 0) { + intriParams.blockLen = this->totalLength * sizeof(int32_t); + if (rowIdxType_ == 1) { + DataCopyPad(expandedRowIdxGm, outLocal[this->sortNum], intriParams); + } else if (rowIdxType_ == 0) { + LocalTensor rowIdxLocalTensor = rowIdxOutQueue.AllocTensor(); + for (int i = 0; i < this->totalLength; i++) { + int32_t dstIdx = outLocal[this->sortNum].GetValue(i); + rowIdxLocalTensor.SetValue(dstIdx, i); + } + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm, rowIdxLocalTensor, intriParams); + rowIdxOutQueue.FreeTensor(rowIdxLocalTensor); + } + } + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeCustomFullLoad::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, + GM_ADDR expandedX, GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expandedScale, const MoeInitRoutingCustomTilingData *tilingData, + TPipe *tPipe) +{ + this->pipe = tPipe; + this->blockIdx = GetBlockIdx(); + this->n = tilingData->n; + this->k = tilingData->k; + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + cols_ = tilingData->cols; + rowIdxType_ = tilingData->rowIdxType; + + expertIdxGm.SetGlobalBuffer((__gm__ int32_t *)expertIdx, this->tileLength); + + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, this->tileLength); + expertTokensCountOrCumsumGm.SetGlobalBuffer((__gm__ int64_t *)expertTokensCountOrCumsum, this->tileLength); + + inputXGm.SetGlobalBuffer((__gm__ bfloat16_t *)x, this->n * cols_); + smoothGm.SetGlobalBuffer((__gm__ float *)scale, expertNum_ * cols_); + expandedXGm.SetGlobalBuffer((__gm__ int8_t *)expandedX, this->n * cols_ * this->k); + expandedScaleGm.SetGlobalBuffer((__gm__ float *)expandedScale, this->n * this->k); + + // key and value + int64_t buffSize = this->sortNum * sizeof(int32_t) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, buffSize); + pipe->InitBuffer(tempBuffer, buffSize); + pipe->InitBuffer(sortedBuffer, buffSize); + pipe->InitBuffer(expertTokensCountOrCumsumOutQueue, bufferNum, Align(expertNum_ * kvFactor, sizeof(int32_t))); + + pipe->InitBuffer(smoothInQueue, bufferNum, AlignBytes(cols_, sizeof(float))); + pipe->InitBuffer(inputXInQueue, bufferNum, AlignBytes(cols_, sizeof(bfloat16_t))); + pipe->InitBuffer(inputXOutQueue, bufferNum, AlignBytes(cols_, sizeof(int8_t))); + pipe->InitBuffer(quantTempBuffer, AlignBytes(cols_, sizeof(float))); + pipe->InitBuffer(scaleOutQueue, bufferNum, AlignBytes(1, sizeof(float))); + pipe->InitBuffer(rowIdxOutQueue, bufferNum, AlignBytes(this->totalLength, sizeof(int32_t))); +} + +__aicore__ inline void MoeCustomFullLoad::Process() +{ + if (this->blockIdx < GetBlockNum()) { + CopyIn(); + SortCompute(); + ExpertCountCompute(); + CopyOutDynamicQuant(); + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_FULL_LOAD_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_base.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_base.h new file mode 100644 index 00000000000..897c33226cd --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_base.h @@ -0,0 +1,512 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_base_full_load.h + * \brief + */ +#ifndef MOE_CUSTOM_FULL_LOAD_BASE_H +#define MOE_CUSTOM_FULL_LOAD_BASE_H + +#include "moe_custom_common.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +template +class MoeCustomFullLoadBase { +public: + __aicore__ inline MoeCustomFullLoadBase(){}; + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR workspace, const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + +protected: + __aicore__ inline void CopyIn(); + __aicore__ inline void Compute(); + __aicore__ inline void TilingInKernel(); + __aicore__ inline void SortComputeWithRange(); + __aicore__ inline void SortCompute(); + __aicore__ inline void CopyOutIdx(); + __aicore__ inline void CopyOutDefaultGatherIdx(); + __aicore__ inline void CopyOutDefaultTokenCountOrCumsum(); + __aicore__ inline void ComputeExpertTokenCountOrCumsum(); + +protected: + int64_t sortNum_; + const MoeCustomGatherOutComputeTilingData *gatherOutTilingData_; + int64_t blockIdx_; + int64_t needCoreNum_; + int64_t coreIndicesElements_; + int64_t perCoreIndicesElements_; + int64_t k_; + int64_t n_; + int64_t cols_; + int64_t dropPadMode_; + int64_t activeNum_; + int64_t expertNum_; + int64_t expertStart_ = 0; + int64_t expertEnd_ = 0; + int64_t bufferNum_ = 1; + int64_t kvFactor_ = 2; + int64_t totalLength_; + int64_t tileLength_; + int64_t expertTokensNumType_ = 0; + int64_t expertTokensNumFlag_ = 0; + uint64_t actual_idx_num_ = 0; + int64_t ep_ = 0; + int64_t gatherFirstFullload_ = 0; + int64_t isInputScale_ = 0; + int64_t rowIdxType_ = 0; + int64_t actualExpertNum_ = 0; + int64_t expertCountElements_ = 0; + int64_t curIndexStart_; + int64_t startXRow_; + int64_t endXRow_; + int64_t quantMode_ = -1; + + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; + static constexpr int64_t MASK_STRIDE = 64; + + TQue expandedRowIdxCopyOutQueue_; + TQue expandedExpertIdxCopyOutQueue_; + TQue expandDstToSrcRowQueue_; + TQue expertTokensCopyOutQueue_; + TQue sortDataCopyInQueue_; + + TBuf tempBuffer_; + TBuf sortedBuffer_; + + GlobalTensor expertIdxGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expertTokensCountOrCumsumGm_; + + TPipe *pipe_; +}; + +template +__aicore__ inline void MoeCustomFullLoadBase::Init(GM_ADDR expertIdx, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + this->gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + this->blockIdx_ = GetBlockIdx(); + this->n_ = tilingData->n; + this->k_ = tilingData->k; + this->cols_ = tilingData->cols; + this->expertStart_ = tilingData->expertStart; + this->expertEnd_ = tilingData->expertEnd; + this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; + + this->perCoreIndicesElements_ = this->gatherOutTilingData_->perCoreIndicesElements; + this->dropPadMode_ = tilingData->dropPadMode; + this->activeNum_ = tilingData->activeNum; + this->quantMode_ = tilingData->quantMode; + if (this->blockIdx_ == this->gatherOutTilingData_->needCoreNum - 1) { + this->coreIndicesElements_ = this->gatherOutTilingData_->lastCoreIndicesElements; + } else { + this->coreIndicesElements_ = this->gatherOutTilingData_->perCoreIndicesElements; + } + this->expertTokensNumType_ = tilingData->expertTokensNumType; + this->expertTokensNumFlag_ = tilingData->expertTokensNumFlag; + this->expertNum_ = tilingData->expertNum; + this->totalLength_ = tilingData->n * tilingData->k; + this->ep_ = tilingData->ep; + this->gatherFirstFullload_ = tilingData->gatherFirstFullload; + this->isInputScale_ = tilingData->isInputScale; + this->tileLength_ = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum_ = Ceil(this->tileLength_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->actual_idx_num_ = this->totalLength_; + this->rowIdxType_ = tilingData->rowIdxType; + this->actualExpertNum_ = tilingData->actualExpertNum; + this->pipe_ = tPipe; + + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx, this->tileLength_); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, this->tileLength_); + if (this->expertTokensNumFlag_ > 0) { + expertTokensCountOrCumsumGm_.SetGlobalBuffer((__gm__ int64_t *)expertTokensCountOrCumsum); + } + + if (expertTokensNumType_ == EXERPT_TOKENS_KEY_VALUE) { + expertCountElements_ = expertNum_ * EXERPT_TOKENS_KEY_VALUE; + } else { + expertCountElements_ = actualExpertNum_; + } + int64_t buffSize = this->sortNum_ * sizeof(int32_t); + + curIndexStart_ = this->blockIdx_ * this->perCoreIndicesElements_; + startXRow_ = curIndexStart_ / this->k_; + endXRow_ = (curIndexStart_ + this->coreIndicesElements_ - 1) / this->k_; + + pipe_->InitBuffer(expandedExpertIdxCopyOutQueue_, bufferNum_, buffSize); + pipe_->InitBuffer(expertTokensCopyOutQueue_, bufferNum_, AlignBytes(expertCountElements_, sizeof(int64_t))); + pipe_->InitBuffer(expandDstToSrcRowQueue_, bufferNum_, buffSize); + pipe_->InitBuffer(expandedRowIdxCopyOutQueue_, bufferNum_, buffSize); + pipe_->InitBuffer(sortDataCopyInQueue_, bufferNum_, buffSize * kvFactor_); + pipe_->InitBuffer(tempBuffer_, buffSize * kvFactor_); + pipe_->InitBuffer(sortedBuffer_, buffSize * kvFactor_); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::CopyIn() +{ + LocalTensor inLocal = sortDataCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(totalLength_ * sizeof(int32_t)), 0, + 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm_, dataCopyParams, dataCopyPadParams); + ArithProgression(inLocal[this->sortNum_], 0, 1, totalLength_); + sortDataCopyInQueue_.EnQue(inLocal); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::Compute() +{ + if (ep_) { + SortComputeWithRange(); + } else { + SortCompute(); + } +} + +template +__aicore__ inline void MoeCustomFullLoadBase::SortComputeWithRange() +{ + LocalTensor inLocal = sortDataCopyInQueue_.DeQue(); + LocalTensor expertIdxLocal = inLocal[0]; + LocalTensor expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast(); + LocalTensor rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast(); + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, totalLength_); + PipeBarrier(); + Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, totalLength_); + PipeBarrier(); + if (gatherFirstFullload_) { + int64_t maskOffset = AlignBytes(Ceil(totalLength_, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE, sizeof(int8_t)); + LocalTensor compareScalarMaskLocalTensor0 = tempBuffer_.Get()[maskOffset]; + LocalTensor compareScalarMaskLocalTensor1 = tempBuffer_.Get()[maskOffset * kvFactor_]; + LocalTensor gatherMaskLocalTensor = tempBuffer_.Get(); + + // Find elements >= expertStart_, which means -elements <= -expertStart_ + AscendC::CompareScalar( + compareScalarMaskLocalTensor0, expertIdxLocalFp32, static_cast(-expertStart_), AscendC::CMPMODE::LE, + (totalLength_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + PipeBarrier(); + + // Find elements < expertEnd_, which means -elements > -expertEnd_ + AscendC::CompareScalar( + compareScalarMaskLocalTensor1, expertIdxLocalFp32, static_cast(-expertEnd_), AscendC::CMPMODE::GT, + (totalLength_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + PipeBarrier(); + + And(gatherMaskLocalTensor.ReinterpretCast(), + compareScalarMaskLocalTensor0.ReinterpretCast(), + compareScalarMaskLocalTensor1.ReinterpretCast(), + Ceil(totalLength_, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE / kvFactor_); + PipeBarrier(); + + uint64_t rsvdCnt = 0; + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = 1; + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = DST_REP_STRIDE; + gatherMaskParams.src1RepeatStride = DST_REP_STRIDE; + GatherMask(expertIdxLocalFp32, expertIdxLocalFp32, gatherMaskLocalTensor.ReinterpretCast(), true, + static_cast(totalLength_), gatherMaskParams, rsvdCnt); + PipeBarrier(); + actual_idx_num_ = rsvdCnt; + sortNum_ = Ceil(actual_idx_num_, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + + GatherMask(rowIdxLocal, rowIdxLocal, gatherMaskLocalTensor.ReinterpretCast(), true, + static_cast(totalLength_), gatherMaskParams, actual_idx_num_); + PipeBarrier(); + TilingInKernel(); + } else { + LocalTensor maskLocalTensor = tempBuffer_.Get(); + AscendC::CompareScalar( + maskLocalTensor, expertIdxLocalFp32, static_cast(-expertStart_), AscendC::CMPMODE::GT, + (totalLength_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + LocalTensor floatMinLocalTensor = sortedBuffer_.Get(); + Duplicate(floatMinLocalTensor, MIN_FP32, totalLength_); + PipeBarrier(); + Select(expertIdxLocalFp32, maskLocalTensor, floatMinLocalTensor, expertIdxLocalFp32, + SELMODE::VSEL_TENSOR_TENSOR_MODE, totalLength_); + PipeBarrier(); + } + // handle actual_idx_num_ == 0 + if (actual_idx_num_ < 1) { + sortDataCopyInQueue_.FreeTensor(inLocal); + return; + } + int64_t duplicateNum = actual_idx_num_ % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = actual_idx_num_ - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } + + LocalTensor concatLocal = expertIdxLocalFp32; + LocalTensor tempTensor = tempBuffer_.Get(GetSortLen(this->sortNum_)); + Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + LocalTensor sortedLocal = sortedBuffer_.Get(GetSortLen(this->sortNum_)); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + LocalTensor expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor(); + Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, actual_idx_num_); + PipeBarrier(); + LocalTensor expandedExpertIdxLocalInt32; + expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast(); + Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, actual_idx_num_); + PipeBarrier(); + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdxLocalInt32); + expandDstToSrcRowQueue_.EnQue(expandDstToSrcRowLocal); + sortDataCopyInQueue_.FreeTensor(inLocal); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::SortCompute() +{ + LocalTensor inLocal = sortDataCopyInQueue_.DeQue(); + LocalTensor expertIdxLocal = inLocal[0]; + LocalTensor expertIdxLocalFp32 = expertIdxLocal.ReinterpretCast(); + Cast(expertIdxLocalFp32, expertIdxLocal, RoundMode::CAST_ROUND, totalLength_); + PipeBarrier(); + Muls(expertIdxLocalFp32, expertIdxLocalFp32, (float)-1, totalLength_); + PipeBarrier(); + int64_t duplicateNum = totalLength_ % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = totalLength_ - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } + LocalTensor concatLocal = expertIdxLocalFp32; + LocalTensor tempTensor = tempBuffer_.Get(GetSortLen(this->sortNum_)); + Concat(concatLocal, expertIdxLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + LocalTensor rowIdxLocal = inLocal[this->sortNum_].template ReinterpretCast(); + LocalTensor sortedLocal = sortedBuffer_.Get(GetSortLen(this->sortNum_)); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + LocalTensor expandedExpertIdxLocal = expandedExpertIdxCopyOutQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocal = expandDstToSrcRowQueue_.AllocTensor(); + LocalTensor expandDstToSrcRowLocalFp32 = expandDstToSrcRowLocal.ReinterpretCast(); + Extract(expandedExpertIdxLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor(); + Muls(expandedExpertIdxLocal, expandedExpertIdxLocal, (float)-1, totalLength_); + PipeBarrier(); + LocalTensor expandedExpertIdxLocalInt32; + expandedExpertIdxLocalInt32 = expandedExpertIdxLocal.ReinterpretCast(); + Cast(expandedExpertIdxLocalInt32, expandedExpertIdxLocal, RoundMode::CAST_ROUND, totalLength_); + PipeBarrier(); + + Cast(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, + totalLength_); + PipeBarrier(); + Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, totalLength_); + PipeBarrier(); + ArithProgression(inLocal[this->sortNum_], 0, 1, totalLength_); + PipeBarrier(); + if (duplicateNum > 0) { + int duplicateIndex = totalLength_ - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> (FP32_ONE_REPEAT_NUM - ONE_REPEAT_SORT_NUM)); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expandDstToSrcRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + PipeBarrier(); + } + Concat(concatLocal, expandDstToSrcRowLocalFp32, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + Sort(sortedLocal, concatLocal, rowIdxLocal, tempTensor, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + Extract(tempTensor, expandedRowIdx, sortedLocal, this->sortNum_ / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + if (rowIdxType_ == SCATTER or quantMode_ == 1) { + Muls(expandDstToSrcRowLocalFp32, expandDstToSrcRowLocalFp32, (float)-1, totalLength_); + PipeBarrier(); + Cast(expandDstToSrcRowLocal.ReinterpretCast(), expandDstToSrcRowLocalFp32, RoundMode::CAST_RINT, + totalLength_); + } + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdxLocalInt32); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); + expandDstToSrcRowQueue_.EnQue(expandDstToSrcRowLocal); + sortDataCopyInQueue_.FreeTensor(inLocal); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::CopyOutDefaultGatherIdx() +{ + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor(); + Duplicate(expandedRowIdx, static_cast(-1), static_cast(totalLength_)); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyExtParams copyParams{static_cast(1), static_cast(totalLength_ * sizeof(int32_t)), 0, 0, + 0}; + DataCopyPad(expandedRowIdxGm_, expandedRowIdx, copyParams); + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::CopyOutDefaultTokenCountOrCumsum() +{ + LocalTensor expertTokensOut = expertTokensCopyOutQueue_.AllocTensor(); + Duplicate(expertTokensOut.ReinterpretCast(), static_cast(0), + static_cast(expertCountElements_ * EXERPT_TOKENS_KEY_VALUE)); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyExtParams copyParams{static_cast(1), + static_cast(expertCountElements_ * sizeof(int64_t)), 0, 0, 0}; + DataCopyPad(expertTokensCountOrCumsumGm_, expertTokensOut, copyParams); + expertTokensCopyOutQueue_.FreeTensor(expertTokensOut); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::CopyOutIdx() +{ + LocalTensor expandedExpertIdx = expandedExpertIdxCopyOutQueue_.DeQue(); + LocalTensor expandDstToSrcRowLocal = expandDstToSrcRowQueue_.DeQue(); + if (rowIdxType_ == SCATTER) { + DataCopyExtParams copyParams{static_cast(1), static_cast(actual_idx_num_ * sizeof(int32_t)), + 0, 0, 0}; + DataCopyPad(expandedRowIdxGm_, expandDstToSrcRowLocal, copyParams); + } else if (ep_) { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.AllocTensor(); + Duplicate(expandedRowIdx, static_cast(-1), static_cast(totalLength_)); + SetWaitFlag(HardEvent::V_S); + for (int64_t i = 0; i < actual_idx_num_; i++) { + int32_t curExpertId = expandedExpertIdx.GetValue(i); + if (curExpertId < expertStart_ || curExpertId >= expertEnd_) { + break; + } + int64_t outIndices = expandDstToSrcRowLocal.GetValue(i); + expandedRowIdx.SetValue(outIndices, i); + } + SetWaitFlag(HardEvent::S_MTE3); + DataCopyExtParams copyParams{static_cast(1), static_cast(totalLength_ * sizeof(int32_t)), 0, + 0, 0}; + DataCopyPad(expandedRowIdxGm_, expandedRowIdx, copyParams); + expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + } else { + LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); + DataCopyExtParams copyParams{static_cast(1), static_cast(totalLength_ * sizeof(int32_t)), 0, + 0, 0}; + DataCopyPad(expandedRowIdxGm_, expandedRowIdx, copyParams); + expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); + } + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdx); + expandDstToSrcRowQueue_.EnQue(expandDstToSrcRowLocal); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::ComputeExpertTokenCountOrCumsum() +{ + // compute + LocalTensor expandedExpertIdx = expandedExpertIdxCopyOutQueue_.DeQue(); + LocalTensor expertTokensOut = expertTokensCopyOutQueue_.AllocTensor(); + Duplicate(expertTokensOut.ReinterpretCast(), static_cast(0), + static_cast(expertCountElements_ * EXERPT_TOKENS_KEY_VALUE)); + SetWaitFlag(HardEvent::V_S); + int64_t i = 0; + int32_t lastExpertId = expandedExpertIdx.GetValue(0); + int32_t lastLastId = lastExpertId; + int64_t tokenCount = 0; + int64_t lastIndex = 0; + int64_t Offset = 0; + for (i = 1; i < actual_idx_num_; i++) { + if ((lastExpertId >= expertEnd_) || (lastExpertId < expertStart_)) { + break; + } + int32_t curExpertId = expandedExpertIdx.GetValue(i); + if (curExpertId != lastExpertId || curExpertId >= expertEnd_) { + int64_t expertOffset = lastExpertId - expertStart_; + if (expertTokensNumType_ == EXERPT_TOKENS_KEY_VALUE) { + expertTokensOut.SetValue(Offset * EXERPT_TOKENS_KEY_VALUE, lastExpertId); + expertTokensOut.SetValue(Offset * EXERPT_TOKENS_KEY_VALUE + 1, i - lastIndex); + Offset += 1; + } else if (expertTokensNumType_ == EXERPT_TOKENS_COUNT) { + expertTokensOut.SetValue(expertOffset, i - lastIndex); + } else { + for (int64_t j = lastLastId; j < lastExpertId; j++) { + expertTokensOut.SetValue(j - expertStart_, tokenCount); + } + tokenCount += i - lastIndex; + expertTokensOut.SetValue(expertOffset, tokenCount); + } + lastIndex = i; + lastLastId = lastExpertId; + lastExpertId = curExpertId; + } + } + if ((i == actual_idx_num_) && ((lastExpertId >= expertStart_) && (lastExpertId < expertEnd_))) { + int64_t expertOffset = lastExpertId - expertStart_; + if (expertTokensNumType_ == EXERPT_TOKENS_KEY_VALUE) { + expertTokensOut.SetValue(Offset * EXERPT_TOKENS_KEY_VALUE, lastExpertId); + expertTokensOut.SetValue(Offset * EXERPT_TOKENS_KEY_VALUE + 1, i - lastIndex); + } else if (expertTokensNumType_ == EXERPT_TOKENS_COUNT) { + expertTokensOut.SetValue(expertOffset, i - lastIndex); + } else { + for (int64_t j = lastLastId; j < lastExpertId; j++) { + expertTokensOut.SetValue(j - expertStart_, tokenCount); + } + tokenCount += i - lastIndex; + expertTokensOut.SetValue(expertOffset, tokenCount); + for (int64_t j = lastExpertId; j < expertEnd_; j++) { + expertTokensOut.SetValue(j - expertStart_, tokenCount); + } + } + } else { + if (expertTokensNumType_ == EXERPT_TOKENS_CUMSUM) { + for (int64_t j = lastLastId; j < expertEnd_; j++) { + expertTokensOut.SetValue(j - expertStart_, tokenCount); + } + } + } + expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdx); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyExtParams copyParams{static_cast(1), + static_cast(expertCountElements_ * sizeof(int64_t)), 0, 0, 0}; + DataCopyPad(expertTokensCountOrCumsumGm_, expertTokensOut, copyParams); + SetWaitFlag(HardEvent::MTE3_V); + expertTokensCopyOutQueue_.FreeTensor(expertTokensOut); +} + +template +__aicore__ inline void MoeCustomFullLoadBase::TilingInKernel() +{ + int64_t coreNum = needCoreNum_; + perCoreIndicesElements_ = Ceil(actual_idx_num_, coreNum); + needCoreNum_ = Ceil(actual_idx_num_, perCoreIndicesElements_); + int64_t lastCoreIndicesElements = actual_idx_num_ - (needCoreNum_ - 1) * perCoreIndicesElements_; + if (blockIdx_ == needCoreNum_ - 1) { + coreIndicesElements_ = lastCoreIndicesElements; + } else { + coreIndicesElements_ = perCoreIndicesElements_; + } + curIndexStart_ = this->blockIdx_ * this->perCoreIndicesElements_; + startXRow_ = curIndexStart_ / this->k_; + endXRow_ = (curIndexStart_ + this->coreIndicesElements_ - 1) / this->k_; +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_FULL_LOAD_BASE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_dynamic_quant.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_dynamic_quant.h new file mode 100644 index 00000000000..5d7010f32c7 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_dynamic_quant.h @@ -0,0 +1,300 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_full_load_dynamic_quant.h + * \brief + */ +#ifndef MOE_CUSTOM_FULL_LOAD_DYNAMIC_QUANT_H +#define MOE_CUSTOM_FULL_LOAD_DYNAMIC_QUANT_H + +#include "moe_custom_full_load_base.h" +#include "moe_custom_common.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +template +class MoeCustomFullLoadDynamicQuant : public MoeCustomFullLoadBase { +public: + __aicore__ inline MoeCustomFullLoadDynamicQuant(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyOutXDynamicQuantFromGather(); + __aicore__ inline void CopyOutXDynamicQuantFromScatter(); + __aicore__ inline void FreeLocalTensor(); + __aicore__ inline void ComputeQuant(LocalTensor &smoothLocal); + +private: + TQue xCopyInQueue_; + TQue smoothInQueue_; + TBuf tmpBuff_; + TQue inputXOutQueue_; + TQue scaleOutQueue_; + + GlobalTensor xGm_; + GlobalTensor expandedXGm_; + GlobalTensor quantSmoothGm_; + GlobalTensor expandedScaleGm_; + + int64_t colsAlign_ = 0; +}; + +template +__aicore__ inline void MoeCustomFullLoadDynamicQuant::Init( + GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + MoeCustomFullLoadBase::Init(expertIdx, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, tPipe); + + xGm_.SetGlobalBuffer((__gm__ T *)x); + expandedXGm_.SetGlobalBuffer((__gm__ int8_t *)expandedX); + quantSmoothGm_.SetGlobalBuffer((__gm__ float *)scale); + expandedScaleGm_.SetGlobalBuffer((__gm__ float *)expandedScale); + this->colsAlign_ = Align(this->cols_, sizeof(T)); + if constexpr (IsSameType::value) { + this->pipe_->InitBuffer(xCopyInQueue_, 1, AlignBytes(this->cols_, sizeof(float))); + } else { + this->pipe_->InitBuffer(xCopyInQueue_, 1, 2 * AlignBytes(this->cols_, sizeof(T))); + } + this->pipe_->InitBuffer(inputXOutQueue_, 1, AlignBytes(this->cols_, sizeof(int8_t))); + this->pipe_->InitBuffer(smoothInQueue_, 1, AlignBytes(this->cols_, sizeof(float))); + this->pipe_->InitBuffer(tmpBuff_, AlignBytes(this->cols_, sizeof(float))); + this->pipe_->InitBuffer(scaleOutQueue_, 1, BLOCK_BYTES + BLOCK_BYTES); +} + +template +__aicore__ inline void MoeCustomFullLoadDynamicQuant::Process() +{ + if (this->blockIdx_ < this->needCoreNum_) { + this->CopyIn(); + this->Compute(); + + // vaild expert equal zero + if (this->needCoreNum_ < 1) { + if (this->blockIdx_ == 0) { + if (this->rowIdxType_ == GATHER) { + this->CopyOutDefaultGatherIdx(); + } + if (this->expertTokensNumFlag_ == 1) { + this->CopyOutDefaultTokenCountOrCumsum(); + } + } + return; + } + + if (this->blockIdx_ == 0) { + this->CopyOutIdx(); + } + + if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensNumFlag_ == 1) { + this->ComputeExpertTokenCountOrCumsum(); + } + + if (this->blockIdx_ < this->needCoreNum_) { + if constexpr (!COPYOUTTYPE && SMOOTHTYPE != SCALE_EH) { + CopyOutXDynamicQuantFromGather(); + } else { + CopyOutXDynamicQuantFromScatter(); + } + } + + FreeLocalTensor(); + } +} + +template +__aicore__ inline void +MoeCustomFullLoadDynamicQuant::ComputeQuant(LocalTensor &smoothLocal) +{ + LocalTensor tempLocal = tmpBuff_.Get(); + LocalTensor outLocal = inputXOutQueue_.AllocTensor(); + LocalTensor dynamicQuantLocal = scaleOutQueue_.AllocTensor(); + LocalTensor inLocal = xCopyInQueue_.DeQue(); + + if constexpr (!IsSameType::value && !IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[colsAlign_], RoundMode::CAST_NONE, this->cols_); + PipeBarrier(); + } + + if constexpr (SMOOTHTYPE != NO_SCALE) { + Mul(inLocal, inLocal, smoothLocal, this->cols_); + PipeBarrier(); + } + + Abs(tempLocal, inLocal, this->cols_); + PipeBarrier(); + + ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols_); + PipeBarrier(); + + float maxValue = dynamicQuantLocal.GetValue(0) / MAX_INT8; + + Duplicate(dynamicQuantLocal, maxValue, INT32_ONE_BLOCK_NUM); + PipeBarrier(); + Duplicate(tempLocal, maxValue, this->cols_); + PipeBarrier(); + + Div(tempLocal, inLocal, tempLocal, this->cols_); + PipeBarrier(); + + LocalTensor intLocal = tempLocal.ReinterpretCast(); + Cast(intLocal, tempLocal, RoundMode::CAST_RINT, this->cols_); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + Cast(intLocal.ReinterpretCast(), intLocal, RoundMode::CAST_ROUND, this->cols_); + PipeBarrier(); + Cast(outLocal, intLocal.ReinterpretCast(), RoundMode::CAST_TRUNC, this->cols_); + + inputXOutQueue_.EnQue(outLocal); + scaleOutQueue_.EnQue(dynamicQuantLocal); +} + +template +__aicore__ inline void MoeCustomFullLoadDynamicQuant::CopyOutXDynamicQuantFromScatter() +{ + LocalTensor sortedRowIdx = this->expandDstToSrcRowQueue_.template DeQue(); + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + + DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + LocalTensor smoothLocal = smoothInQueue_.AllocTensor(); + ; + + if constexpr (SMOOTHTYPE == SCALE_1H) { + DataCopyPad(smoothLocal, quantSmoothGm_, smoothCopyParams, {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + } + + int64_t dstIndexStart = this->curIndexStart_; + int64_t dstIndexEnd = dstIndexStart + this->coreIndicesElements_ - 1; + int32_t lastExpertIdx = -1; + + for (int64_t dstIndex = dstIndexStart; dstIndex <= dstIndexEnd; dstIndex++) { + if (this->dropPadMode_ == DROPLESS_MODE && dstIndex >= this->activeNum_) { + break; + } + int32_t srcIdx = sortedRowIdx.GetValue(dstIndex); + int32_t expertIdx = expandedExpertIdx.GetValue(dstIndex); + if (expertIdx < this->expertStart_ || expertIdx >= this->expertEnd_) { + break; + } + expertIdx = expertIdx - this->expertStart_; + LocalTensor xLocal = this->xCopyInQueue_.template AllocTensor(); + // copy in single x + if constexpr (IsSameType::value) { + DataCopyPad(xLocal, this->xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } else { + DataCopyPad(xLocal[colsAlign_], this->xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, + {false, 0, 0, 0}); + } + xCopyInQueue_.EnQue(xLocal); + + // copyin dynamic scale + if constexpr (SMOOTHTYPE == SCALE_EH) { + if (expertIdx != lastExpertIdx) { + DataCopyPad(smoothLocal, quantSmoothGm_[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + lastExpertIdx = expertIdx; + } + } + + ComputeQuant(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue_.DeQue(); + DataCopyPad(expandedScaleGm_[dstIndex], quantScaleLocal, quantScaleParams); + + LocalTensor outLocal = inputXOutQueue_.DeQue(); + DataCopyPad(this->expandedXGm_[dstIndex * this->cols_], outLocal, intriParams); + + inputXOutQueue_.FreeTensor(outLocal); + scaleOutQueue_.FreeTensor(quantScaleLocal); + this->xCopyInQueue_.FreeTensor(xLocal); + } + smoothInQueue_.FreeTensor(smoothLocal); + this->expandDstToSrcRowQueue_.EnQue(sortedRowIdx); + this->expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdx); +} + +template +__aicore__ inline void MoeCustomFullLoadDynamicQuant::CopyOutXDynamicQuantFromGather() +{ + DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + LocalTensor smoothLocal = smoothInQueue_.AllocTensor(); + int64_t curIndex = this->blockIdx_ * this->perCoreIndicesElements_; + int64_t curIndexEnd = curIndex + this->coreIndicesElements_ - 1; + + if constexpr (SMOOTHTYPE == SCALE_1H) { + DataCopyPad(smoothLocal, quantSmoothGm_, smoothCopyParams, {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + } + + for (int64_t row = this->startXRow_; row <= this->endXRow_; row++) { + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + if constexpr (IsSameType::value) { + DataCopyPad(xLocal, this->xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } else { + DataCopyPad(xLocal[colsAlign_], this->xGm_[row * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + } + xCopyInQueue_.EnQue(xLocal); + ComputeQuant(smoothLocal); + + LocalTensor quantScaleLocal = scaleOutQueue_.DeQue(); + LocalTensor outLocal = inputXOutQueue_.DeQue(); + while (curIndex <= curIndexEnd && curIndex / this->k_ == row) { + int32_t outIndex = expandedRowIdx.GetValue(curIndex); + curIndex++; + if (outIndex == -1 || this->dropPadMode_ == DROPLESS_MODE && outIndex >= this->activeNum_) { + continue; + } + DataCopyPad(expandedXGm_[outIndex * this->cols_], outLocal, intriParams); + DataCopyPad(expandedScaleGm_[outIndex], quantScaleLocal, quantScaleParams); + } + + xCopyInQueue_.FreeTensor(xLocal); + inputXOutQueue_.FreeTensor(outLocal); + scaleOutQueue_.FreeTensor(quantScaleLocal); + } + + smoothInQueue_.FreeTensor(smoothLocal); + this->expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); +} + +template +__aicore__ inline void MoeCustomFullLoadDynamicQuant::FreeLocalTensor() +{ + if constexpr (!COPYOUTTYPE) { + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + this->expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + } + LocalTensor sortedRowIdx = this->expandDstToSrcRowQueue_.template DeQue(); + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + this->expandDstToSrcRowQueue_.FreeTensor(sortedRowIdx); + this->expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx); +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_FULL_LOAD_DYNAMIC_QUANT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_static_quant.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_static_quant.h new file mode 100644 index 00000000000..e2c074d3772 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_static_quant.h @@ -0,0 +1,229 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_static_quant_full_load.h + * \brief + */ +#ifndef MOE_CUSTOM_FULL_LOAD_STATIC_QUANT_H +#define MOE_CUSTOM_FULL_LOAD_STATIC_QUANT_H + +#include "moe_custom_full_load_base.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +template +class MoeCustomFullLoadStaticQuant : public MoeCustomFullLoadBase { +public: + __aicore__ inline MoeCustomFullLoadStaticQuant(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyOutXStaticQuant(); + __aicore__ inline void FreeLocalTensor(); + __aicore__ inline void ComputeQuant(int64_t xLocalLength); + +private: + TQue xCopyInQueue_; + TQue floatQueue_; + TQue halfQueue_; + TQue inputXOutQueue_; + + GlobalTensor xGm_; + GlobalTensor expandedXGm_; + GlobalTensor scaleGm_; + GlobalTensor offsetGm_; + + float scale_; + float offset_; +}; + +template +__aicore__ inline void MoeCustomFullLoadStaticQuant::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, + GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + MoeCustomFullLoadBase::Init(expertIdx, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, tPipe); + + xGm_.SetGlobalBuffer((__gm__ T *)x); + expandedXGm_.SetGlobalBuffer((__gm__ int8_t *)expandedX); + scaleGm_.SetGlobalBuffer((__gm__ float *)scale, 1); + offsetGm_.SetGlobalBuffer((__gm__ float *)offset, 1); + this->scale_ = scaleGm_.GetValue(0); + this->offset_ = offsetGm_.GetValue(0); + SetWaitFlag(HardEvent::S_V); + int64_t curIndexStart = this->blockIdx_ * this->perCoreIndicesElements_; + int64_t rowLength = 0; + if (this->ep_) { + rowLength = 1; + } else { + rowLength = (curIndexStart + this->coreIndicesElements_ - 1) / this->k_ - curIndexStart / this->k_ + 1; + } + int64_t xAlignedCount = Align(this->cols_, sizeof(int8_t)); + this->pipe_->InitBuffer(xCopyInQueue_, this->bufferNum_, xAlignedCount * sizeof(T) * rowLength); + this->pipe_->InitBuffer(inputXOutQueue_, 1, xAlignedCount * sizeof(int8_t) * rowLength); + this->pipe_->InitBuffer(floatQueue_, 1, xAlignedCount * sizeof(float) * rowLength); + this->pipe_->InitBuffer(halfQueue_, 1, xAlignedCount * sizeof(half) * rowLength); +} + +template +__aicore__ inline void MoeCustomFullLoadStaticQuant::Process() +{ + if (this->blockIdx_ < this->needCoreNum_) { + this->CopyIn(); + this->Compute(); + + // vaild expert equal zero + if (this->needCoreNum_ < 1) { + if (this->blockIdx_ == 0) { + if (this->rowIdxType_ == GATHER) { + this->CopyOutDefaultGatherIdx(); + } + if (this->expertTokensNumFlag_ == 1) { + this->CopyOutDefaultTokenCountOrCumsum(); + } + } + return; + } + + if (this->blockIdx_ == 0) { + this->CopyOutIdx(); + } + if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensNumFlag_ == 1) { + this->ComputeExpertTokenCountOrCumsum(); + } + if (this->blockIdx_ < this->needCoreNum_) { + CopyOutXStaticQuant(); + } + FreeLocalTensor(); + } +} + +template +__aicore__ inline void MoeCustomFullLoadStaticQuant::ComputeQuant(int64_t xLocalLength) +{ + LocalTensor floatLocal; + LocalTensor inLocal; + LocalTensor outLocal = inputXOutQueue_.AllocTensor(); + LocalTensor halfLocal = halfQueue_.AllocTensor(); + uint64_t elements = Align(this->cols_, sizeof(int8_t)) * xLocalLength; + if constexpr (IsSameType::value) { + floatLocal = this->xCopyInQueue_.template DeQue(); + } else { + inLocal = this->xCopyInQueue_.template DeQue(); + floatLocal = floatQueue_.AllocTensor(); + Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements); + PipeBarrier(); + } + Muls(floatLocal, floatLocal, this->scale_, elements); + PipeBarrier(); + Adds(floatLocal, floatLocal, this->offset_, elements); + PipeBarrier(); + LocalTensor intLocal = floatLocal.ReinterpretCast(); + Cast(intLocal, floatLocal, RoundMode::CAST_RINT, elements); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + Cast(halfLocal, intLocal, RoundMode::CAST_ROUND, elements); + PipeBarrier(); + Cast(outLocal, halfLocal, RoundMode::CAST_TRUNC, elements); + inputXOutQueue_.EnQue(outLocal); + if constexpr (IsSameType::value) { + this->xCopyInQueue_.FreeTensor(floatLocal); + } else { + this->xCopyInQueue_.FreeTensor(inLocal); + floatQueue_.FreeTensor(floatLocal); + } + + halfQueue_.FreeTensor(halfLocal); +} + +template +__aicore__ inline void MoeCustomFullLoadStaticQuant::CopyOutXStaticQuant() +{ + int64_t curIndex = this->curIndexStart_; + int64_t curIndexEnd = curIndex + this->coreIndicesElements_ - 1; + + if (this->ep_) { + LocalTensor sortedRowIdx = this->expandDstToSrcRowQueue_.template DeQue(); + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + + DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + + for (int64_t dstIndex = curIndex; dstIndex <= curIndexEnd; dstIndex++) { + if (this->dropPadMode_ == DROPLESS_MODE && dstIndex >= this->activeNum_) { + break; + } + int32_t srcIdx = sortedRowIdx.GetValue(dstIndex); + int32_t expertIdx = expandedExpertIdx.GetValue(dstIndex); + if (expertIdx < this->expertStart_ || expertIdx >= this->expertEnd_) { + break; + } + LocalTensor inLocal = this->xCopyInQueue_.template AllocTensor(); + // copyinx + DataCopyPad(inLocal, this->xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + this->xCopyInQueue_.template EnQue(inLocal); + ComputeQuant(1); + + LocalTensor outLocal = inputXOutQueue_.DeQue(); + DataCopyPad(this->expandedXGm_[dstIndex * this->cols_], outLocal, intriParams); + inputXOutQueue_.FreeTensor(outLocal); + } + this->expandDstToSrcRowQueue_.EnQue(sortedRowIdx); + this->expandedExpertIdxCopyOutQueue_.EnQue(expandedExpertIdx); + } else { + LocalTensor xLocal = this->xCopyInQueue_.template AllocTensor(); + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + int64_t inFactor = Align(this->cols_, sizeof(int8_t)); + uint32_t dstStride = (inFactor * sizeof(T) - AlignBytes(this->cols_, sizeof(T))) / BLOCK_BYTES; + DataCopyExtParams dataXCopyParams{static_cast(this->endXRow_ - this->startXRow_ + 1), + static_cast(this->cols_ * sizeof(T)), 0, dstStride, 0}; + DataCopyPad(xLocal, this->xGm_[this->startXRow_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); + this->xCopyInQueue_.EnQue(xLocal); + SetWaitFlag(HardEvent::MTE2_V); + ComputeQuant(this->endXRow_ - this->startXRow_ + 1); + + LocalTensor outLocal = inputXOutQueue_.DeQue(); + int64_t k = 0; + DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + for (int64_t i = this->startXRow_; i <= this->endXRow_; i++) { + for (; k < this->coreIndicesElements_ && curIndex / this->k_ == i; curIndex++, k++) { + int32_t outIndex = expandedRowIdx.GetValue(curIndex); + if (outIndex < this->activeNum_) { + DataCopyPad(this->expandedXGm_[outIndex * this->cols_], outLocal[(i - this->startXRow_) * inFactor], + intriParams); + } + } + } + inputXOutQueue_.FreeTensor(outLocal); + this->expandedRowIdxCopyOutQueue_.EnQue(expandedRowIdx); + } +} + +template +__aicore__ inline void MoeCustomFullLoadStaticQuant::FreeLocalTensor() +{ + if (!this->ep_) { + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + this->expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + } + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + this->expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx); + LocalTensor sortedRowIdx = this->expandDstToSrcRowQueue_.template DeQue(); + this->expandDstToSrcRowQueue_.FreeTensor(sortedRowIdx); +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_FULL_LOAD_STATIC_QUANT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_unquantized.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_unquantized.h new file mode 100644 index 00000000000..2fbced98c80 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_full_load_unquantized.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_full_load_unquantized.h + * \brief + */ +#ifndef MOE_CUSTOM_FULL_LOAD_UNQUANTIZED_H +#define MOE_CUSTOM_FULL_LOAD_UNQUANTIZED_H + +#include "moe_custom_full_load_base.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +template +class MoeCustomFullLoadUnquantized : public MoeCustomFullLoadBase { +public: + __aicore__ inline MoeCustomFullLoadUnquantized(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +protected: + __aicore__ inline void FreeLocalTensor(); + __aicore__ inline void GatherOutX(); + __aicore__ inline void CopyOutScale(); + +protected: + TQue xCopyInQueue_; + TQue scaleCopyInQueue_; + + GlobalTensor xGm_; + GlobalTensor scaleGm_; + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expandedScaleGm_; +}; + +template +__aicore__ inline void MoeCustomFullLoadUnquantized::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR expandedX, + GM_ADDR expandedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + MoeCustomFullLoadBase::Init(expertIdx, expandedRowIdx, expertTokensCountOrCumsum, workspace, tilingData, tPipe); + xGm_.SetGlobalBuffer((__gm__ T *)x); + if (this->isInputScale_) { + scaleGm_.SetGlobalBuffer((__gm__ float *)scale); + expandedScaleGm_.SetGlobalBuffer((__gm__ float *)expandedScale); + } + + expandedXGm_.SetGlobalBuffer((__gm__ T *)expandedX); + int64_t buffSize = this->sortNum_ * sizeof(int32_t); + int64_t row_length = + (this->curIndexStart_ + this->coreIndicesElements_ - 1) / this->k_ - this->curIndexStart_ / this->k_ + 1; + + if (this->ep_) { + this->pipe_->InitBuffer(xCopyInQueue_, this->bufferNum_, AlignBytes(this->cols_, sizeof(T))); + } else { + this->pipe_->InitBuffer(xCopyInQueue_, this->bufferNum_, AlignBytes(this->cols_, sizeof(T)) * row_length); + } + this->pipe_->InitBuffer(scaleCopyInQueue_, 1, AlignBytes(1, sizeof(float))); +} + +template +__aicore__ inline void MoeCustomFullLoadUnquantized::Process() +{ + if (this->blockIdx_ < this->needCoreNum_) { + this->CopyIn(); + this->Compute(); + + // vaild expert equal zero + if (this->needCoreNum_ < 1) { + if (this->blockIdx_ == 0) { + if (this->rowIdxType_ == GATHER) { + this->CopyOutDefaultGatherIdx(); + } + if (this->expertTokensNumFlag_ == 1) { + this->CopyOutDefaultTokenCountOrCumsum(); + } + } + return; + } + + if (this->blockIdx_ == 0) { + this->CopyOutIdx(); + } + + if (this->blockIdx_ == this->needCoreNum_ - 1 && this->expertTokensNumFlag_ == 1) { + this->ComputeExpertTokenCountOrCumsum(); + } + + if (this->blockIdx_ < this->needCoreNum_) { + this->GatherOutX(); + if (this->isInputScale_) { + this->CopyOutScale(); + } + } + + this->FreeLocalTensor(); + } +} + +template +__aicore__ inline void MoeCustomFullLoadUnquantized::GatherOutX() +{ + if (this->ep_) { + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + LocalTensor expandDstToSrcRowLocal = this->expandDstToSrcRowQueue_.template DeQue(); + int64_t startRowIdx = this->blockIdx_ * this->perCoreIndicesElements_; + int64_t endRowIdx = startRowIdx + this->coreIndicesElements_; + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + DataCopyExtParams copyParams{static_cast(1), static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + for (int64_t i = startRowIdx; i < endRowIdx && i < this->activeNum_; i++) { + int32_t curExpertId = expandedExpertIdx.GetValue(i); + if (curExpertId < this->expertStart_ || curExpertId >= this->expertEnd_) { + break; + } + int64_t rowIdx = expandDstToSrcRowLocal.GetValue(i); + int64_t srcOffset = rowIdx / this->k_ * this->cols_; + int64_t dstOffset = i * this->cols_; + SetWaitFlag(HardEvent::MTE3_MTE2); + DataCopyPad(xLocal, xGm_[srcOffset], copyParams, padParams); + SetWaitFlag(HardEvent::MTE2_MTE3); + DataCopyPad(expandedXGm_[dstOffset], xLocal, copyParams); + } + xCopyInQueue_.FreeTensor(xLocal); + this->expandedExpertIdxCopyOutQueue_.template EnQue(expandedExpertIdx); + this->expandDstToSrcRowQueue_.template EnQue(expandDstToSrcRowLocal); + } else { + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + DataCopyExtParams dataXCopyParams{static_cast(this->endXRow_ - this->startXRow_ + 1), + static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataXCopyPadParams{false, 0, 0, 0}; + DataCopyPad(xLocal, xGm_[this->startXRow_ * this->cols_], dataXCopyParams, dataXCopyPadParams); + SetWaitFlag(HardEvent::MTE2_MTE3); + int64_t inFactor = Align(this->cols_, sizeof(T)); + DataCopyExtParams copyParams{static_cast(1), static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + int64_t curIndexStart = this->curIndexStart_; + int64_t k = 0; + for (int64_t i = this->startXRow_; i <= this->endXRow_; i++) { + for (; k < this->coreIndicesElements_ && curIndexStart / this->k_ == i; curIndexStart++, k++) { + int32_t outIndex = expandedRowIdx.GetValue(curIndexStart); + if (outIndex < this->activeNum_) { + DataCopyPad(expandedXGm_[outIndex * this->cols_], xLocal[(i - this->startXRow_) * inFactor], + copyParams); + } + } + } + xCopyInQueue_.FreeTensor(xLocal); + this->expandedRowIdxCopyOutQueue_.template EnQue(expandedRowIdx); + } +} + +template +__aicore__ inline void MoeCustomFullLoadUnquantized::FreeLocalTensor() +{ + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + LocalTensor expandDstToSrcRowLocal = this->expandDstToSrcRowQueue_.template DeQue(); + this->expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdx); + this->expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); + if (!this->ep_) { + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + this->expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); + } +} + +template +__aicore__ inline void MoeCustomFullLoadUnquantized::CopyOutScale() +{ + LocalTensor scaleLocal = scaleCopyInQueue_.AllocTensor(); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + if (this->ep_) { + LocalTensor expandedExpertIdx = this->expandedExpertIdxCopyOutQueue_.template DeQue(); + LocalTensor expandDstToSrcRowLocal = this->expandDstToSrcRowQueue_.template DeQue(); + int64_t startRowIdx = this->blockIdx_ * this->perCoreIndicesElements_; + int64_t endRowIdx = startRowIdx + this->coreIndicesElements_; + for (int64_t i = startRowIdx; i < endRowIdx && i < this->activeNum_; i++) { + int32_t curExpertId = expandedExpertIdx.GetValue(i); + if (curExpertId < this->expertStart_ || curExpertId >= this->expertEnd_) { + break; + } + int64_t rowIdx = expandDstToSrcRowLocal.GetValue(i); + SetWaitFlag(HardEvent::MTE3_MTE2); + DataCopyPad(scaleLocal, scaleGm_[rowIdx / this->k_], copyParams, padParams); + SetWaitFlag(HardEvent::MTE2_MTE3); + DataCopyPad(expandedScaleGm_[i], scaleLocal, copyParams); + } + this->expandedExpertIdxCopyOutQueue_.template EnQue(expandedExpertIdx); + this->expandDstToSrcRowQueue_.template EnQue(expandDstToSrcRowLocal); + } else { + LocalTensor expandedRowIdx = this->expandedRowIdxCopyOutQueue_.template DeQue(); + int64_t curIndexStart = this->curIndexStart_; + int64_t k = 0; + for (int64_t i = this->startXRow_; i <= this->endXRow_; i++) { + SetWaitFlag(HardEvent::MTE3_MTE2); + DataCopyPad(scaleLocal, scaleGm_[i], copyParams, padParams); + SetWaitFlag(HardEvent::MTE2_MTE3); + for (; k < this->coreIndicesElements_ && curIndexStart / this->k_ == i; curIndexStart++, k++) { + int32_t outIndex = expandedRowIdx.GetValue(curIndexStart); + if (outIndex < this->activeNum_) { + DataCopyPad(expandedScaleGm_[outIndex], scaleLocal, copyParams); + } + } + } + this->expandedRowIdxCopyOutQueue_.template EnQue(expandedRowIdx); + } + scaleCopyInQueue_.FreeTensor(scaleLocal); +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_FULL_LOAD_UNQUANTIZED_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_droppad_static_quant.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_droppad_static_quant.h new file mode 100644 index 00000000000..e0c5f00b890 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_droppad_static_quant.h @@ -0,0 +1,238 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_gather_droppad_static_quant.h + * \brief + */ +#ifndef MOE_CUSTOM_GATHER_DROPPAD_STATIC_QUANT_H +#define MOE_CUSTOM_GATHER_DROPPAD_STATIC_QUANT_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t GATHER_OUT_DROPPAD_QUANT_BUFFER_NUM = 2; + +template +class MoeGatherDroppadQuant { +public: + __aicore__ inline MoeGatherDroppadQuant(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR workspace, const MoeInitRoutingCustomTilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyExpertIn(int64_t progress); + __aicore__ inline void Compute(); + __aicore__ inline void CopyXIn(int64_t xSrcOffset, int64_t curLoopCols); + __aicore__ inline void CopyOut(int64_t progress); + +private: + TPipe *pipe_; + TQue inputXCopyInQueue_; + TQue expandRowIdxCopyInQueue_; + TQue inputXCopyOutQueue_; + TQue floatQueue_; + TQue halfQueue_; + + GlobalTensor inputXGm_; + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor scaleGm_; + GlobalTensor offsetGm_; + + const MoeCustomGatherOutComputeTilingData *gatherOutTilingData_; + + int64_t needCoreNum_; + int64_t blockIdx_; + int64_t cols_; + int64_t n_; + int64_t k_; + int64_t currentLoopRows_; + int64_t coreRows_; + int64_t perLoopRows_; + int64_t lastLoopRows_; + int64_t rowLoops_; + int64_t colsTileLength_; + int64_t perLoopCols_; + int64_t lastLoopCols_; + int64_t colLoops_; + float scale_; + float offset_; + + int64_t indicesOffset_; + int64_t inputOffset_; + int64_t outOffset_; +}; + +template +__aicore__ inline void MoeGatherDroppadQuant::CopyExpertIn(int64_t progress) +{ + indicesOffset_ = progress * perLoopRows_; + LocalTensor indicesLocal = expandRowIdxCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(currentLoopRows_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm_[indicesOffset_], dataCopyParams, dataCopyPadParams); + expandRowIdxCopyInQueue_.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeGatherDroppadQuant::CopyXIn(int64_t xSrcOffset, int64_t curLoopCols) +{ + LocalTensor inLocal = inputXCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(curLoopCols * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal, inputXGm_[xSrcOffset], dataCopyParams, dataCopyPadParams); + inputXCopyInQueue_.EnQue(inLocal); +} + +template +__aicore__ inline void MoeGatherDroppadQuant::Compute() +{ + LocalTensor floatLocal; + LocalTensor inLocal; + LocalTensor outLocal = inputXCopyOutQueue_.AllocTensor(); + LocalTensor halfLocal = halfQueue_.AllocTensor(); + uint32_t elements = Align(colsTileLength_, sizeof(T)); + if constexpr (IsSameType::value) { + floatLocal = inputXCopyInQueue_.DeQue(); + } else { + inLocal = inputXCopyInQueue_.DeQue(); + floatLocal = floatQueue_.AllocTensor(); + Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements); + PipeBarrier(); + } + Muls(floatLocal, floatLocal, scale_, elements); + PipeBarrier(); + Adds(floatLocal, floatLocal, offset_, elements); + PipeBarrier(); + LocalTensor intLocal = floatLocal.ReinterpretCast(); + Cast(intLocal, floatLocal, RoundMode::CAST_RINT, elements); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + PipeBarrier(); + Cast(halfLocal, intLocal, RoundMode::CAST_ROUND, elements); + PipeBarrier(); + Cast(outLocal, halfLocal, RoundMode::CAST_TRUNC, elements); + inputXCopyOutQueue_.EnQue(outLocal); + if constexpr (IsSameType::value) { + inputXCopyInQueue_.FreeTensor(floatLocal); + } else { + inputXCopyInQueue_.FreeTensor(inLocal); + floatQueue_.FreeTensor(floatLocal); + } + halfQueue_.FreeTensor(halfLocal); +} + +template +__aicore__ inline void MoeGatherDroppadQuant::CopyOut(int64_t progress) +{ + LocalTensor indicesLocal = expandRowIdxCopyInQueue_.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + colsTileLength_ = perLoopCols_; + for (int64_t colsLoop = 0; colsLoop < colLoops_; colsLoop++) { + int64_t initialRow = gatherOutTilingData_->perCoreIndicesElements * blockIdx_ + perLoopRows_ * progress; + int64_t curLoopRow = 0; + if (colsLoop == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + int64_t currentLoopStartRow = initialRow / k_; + int64_t currentLoopLastRow = (initialRow + currentLoopRows_ - 1) / k_; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + inputOffset_ = row * cols_ + colsLoop * perLoopCols_; + // input row position + CopyXIn(inputOffset_, colsTileLength_); + Compute(); + LocalTensor outLocal = inputXCopyOutQueue_.DeQue(); + DataCopyExtParams intriParams{1, static_cast(colsTileLength_ * sizeof(int8_t)), 0, 0, 0}; + while (curLoopRow < currentLoopRows_ && initialRow / k_ == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1) { + continue; + } + outOffset_ = outIndex * cols_ + colsLoop * perLoopCols_; + DataCopyPad(expandedXGm_[outOffset_], outLocal, intriParams); + } + inputXCopyOutQueue_.FreeTensor(outLocal); + } + } + expandRowIdxCopyInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeGatherDroppadQuant::Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, + GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + + needCoreNum_ = gatherOutTilingData_->needCoreNum; + cols_ = tilingData->cols; + n_ = tilingData->n; + k_ = tilingData->k; + + if (blockIdx_ == needCoreNum_ - 1) { + coreRows_ = gatherOutTilingData_->lastCoreIndicesElements; + perLoopRows_ = gatherOutTilingData_->lastCorePerLoopIndicesElements; + lastLoopRows_ = gatherOutTilingData_->lastCoreLastLoopIndicesElements; + rowLoops_ = gatherOutTilingData_->lastCoreIndicesLoops; + } else { + coreRows_ = gatherOutTilingData_->perCoreIndicesElements; + perLoopRows_ = gatherOutTilingData_->perCorePerLoopIndicesElements; + lastLoopRows_ = gatherOutTilingData_->perCoreLastLoopIndicesElements; + rowLoops_ = gatherOutTilingData_->perCoreIndicesLoops; + } + perLoopCols_ = gatherOutTilingData_->perLoopCols; + lastLoopCols_ = gatherOutTilingData_->lastLoopCols; + colLoops_ = gatherOutTilingData_->colsLoops; + + inputXGm_.SetGlobalBuffer((__gm__ T *)inputX); + expandedXGm_.SetGlobalBuffer((__gm__ int8_t *)expandedX); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + + blockIdx_ * gatherOutTilingData_->perCoreIndicesElements, + Align(coreRows_, sizeof(int32_t))); + scaleGm_.SetGlobalBuffer((__gm__ float *)scale, 1); + offsetGm_.SetGlobalBuffer((__gm__ float *)offset, 1); + scale_ = scaleGm_.GetValue(0); + offset_ = offsetGm_.GetValue(0); + + pipe_->InitBuffer(inputXCopyInQueue_, GATHER_OUT_DROPPAD_QUANT_BUFFER_NUM, AlignBytes(perLoopCols_, sizeof(T))); + pipe_->InitBuffer(inputXCopyOutQueue_, GATHER_OUT_DROPPAD_QUANT_BUFFER_NUM, + AlignBytes(perLoopCols_, sizeof(int8_t))); + pipe_->InitBuffer(expandRowIdxCopyInQueue_, GATHER_OUT_DROPPAD_QUANT_BUFFER_NUM, + AlignBytes(perLoopRows_, sizeof(int32_t))); + pipe_->InitBuffer(floatQueue_, 1, AlignBytes(perLoopCols_, sizeof(float))); + pipe_->InitBuffer(halfQueue_, 1, AlignBytes(perLoopCols_, sizeof(half))); +} + +template +__aicore__ inline void MoeGatherDroppadQuant::Process() +{ + if (blockIdx_ < needCoreNum_) { + currentLoopRows_ = perLoopRows_; + for (int64_t loop = 0; loop < rowLoops_; loop++) { + if (loop == rowLoops_ - 1) { + currentLoopRows_ = lastLoopRows_; + } + CopyExpertIn(loop); + CopyOut(loop); + } + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_GATHER_DROPPAD_STATIC_QUANT_H diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_dynamic_quant.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_dynamic_quant.h new file mode 100644 index 00000000000..be6abc8abd7 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_dynamic_quant.h @@ -0,0 +1,602 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_gather_dynamic_quant.h + * \brief + */ +#ifndef MOE_CUSTOM_GATHER_DYNAMIC_QUANT_H +#define MOE_CUSTOM_GATHER_DYNAMIC_QUANT_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; +constexpr int64_t GATHER_OUT_DYNAMIC_QUANT_BUFFER_NUM = 2; + +template +class MoeGatherOutDynamicQuant { +public: + __aicore__ inline MoeGatherOutDynamicQuant(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR expandedScale, GM_ADDR sortedExpertIdx, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyOutXDynamicQuantFromGather(int64_t progress); + __aicore__ inline void CopyOutXDynamicQuantFromScatter(int64_t progress); + __aicore__ inline void CopyOutXPartialDynamicQuantFromGather(int64_t progress); + __aicore__ inline void CopyOutXPartialDynamicQuantFromScatter(int64_t progress); + __aicore__ inline void CopyInExpandedExpertIdx(int64_t progress); + __aicore__ inline void Compute(LocalTensor &smoothLocal); + __aicore__ inline float ComputeMax(LocalTensor &inLocal, LocalTensor &tempLocal, + LocalTensor &scaleLocal, int32_t srcIdx, int32_t expertIdx, int64_t j); + __aicore__ inline void ComputeScale(LocalTensor &inLocal, LocalTensor &tempLocal, float scaleTemp, + int64_t dstIndex, int64_t j); + +private: + TPipe *pipe_; + TQue inputXInQueue_; + TQue smoothInQueue_; + TQue expandRowIdxInQueue_; + TQue calcQueue_; + TQue inputXOutQueue_; + TQue scaleOutQueue_; + + GlobalTensor inputXGm_; + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor quantSmoothGm_; + GlobalTensor expandedScaleGm_; + GlobalTensor quantTempGm_; + GlobalTensor expandedExpertIdxGm_; + GlobalTensor expertTotalCountGm_; + + const MoeCustomGatherOutComputeTilingData *gatherOutTilingData_; + + int64_t needCoreNum_; + int64_t blockIdx_; + int64_t cols_; + int64_t n_; + int64_t k_; + int64_t totalLength_; + int64_t perCoreRow_; + int64_t currentLoopRows_; + int64_t currentLoopRowsAlign_; + int64_t coreRows_; + int64_t perLoopRows_; + int64_t lastLoopRows_; + int64_t rowLoops_; + int64_t colsTileLength_; + int64_t perLoopCols_; + int64_t perLoopColsAlign_; + int64_t lastLoopCols_; + int64_t colLoops_; + int64_t isInputScale_; + int64_t expertStart_; + + int64_t indicesOffset_; + int64_t rowIdxType_ = 0; + int64_t dropPadMode_; + int64_t activeNum_; + int64_t ep_; + int64_t smoothType_; + int64_t coreNum_; + int64_t expertTotalCount_ = 0; +}; + +template +__aicore__ inline void MoeGatherOutDynamicQuant::CopyInExpandedExpertIdx(int64_t progress) +{ + indicesOffset_ = progress * perLoopRows_; + LocalTensor indicesLocal = expandRowIdxInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(currentLoopRows_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm_[indicesOffset_], dataCopyParams, dataCopyPadParams); + DataCopyPad(indicesLocal[currentLoopRowsAlign_], expandedExpertIdxGm_[indicesOffset_], dataCopyParams, + dataCopyPadParams); + expandRowIdxInQueue_.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutDynamicQuant::Compute(LocalTensor &smoothLocal) +{ + LocalTensor inLocal = inputXInQueue_.DeQue(); + + LocalTensor tempLocal = calcQueue_.AllocTensor(); + LocalTensor outLocal = inputXOutQueue_.AllocTensor(); + LocalTensor scaleLocal = scaleOutQueue_.AllocTensor(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign_], RoundMode::CAST_NONE, cols_); + PipeBarrier(); + } + + if (isInputScale_) { + Mul(inLocal, inLocal, smoothLocal, cols_); + PipeBarrier(); + } + + Abs(tempLocal, inLocal, cols_); + PipeBarrier(); + + ReduceMax(scaleLocal, tempLocal, tempLocal, cols_); // get max value and index [0,1] + + float scaleValue = scaleLocal.GetValue(0) / MAX_INT8; + + Duplicate(scaleLocal, scaleValue, INT32_ONE_BLOCK_NUM); + PipeBarrier(); + Duplicate(tempLocal, scaleValue, cols_); + PipeBarrier(); + + Div(tempLocal, inLocal, tempLocal, cols_); + PipeBarrier(); + + LocalTensor intLocal = tempLocal.ReinterpretCast(); + Cast(intLocal, tempLocal, RoundMode::CAST_RINT, cols_); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + Cast(intLocal.ReinterpretCast(), intLocal, RoundMode::CAST_ROUND, cols_); + PipeBarrier(); + Cast(outLocal, intLocal.ReinterpretCast(), RoundMode::CAST_TRUNC, cols_); + + calcQueue_.FreeTensor(tempLocal); + inputXOutQueue_.EnQue(outLocal); + scaleOutQueue_.EnQue(scaleLocal); +} + +template +__aicore__ inline void MoeGatherOutDynamicQuant::CopyOutXDynamicQuantFromScatter(int64_t progress) +{ + DataCopyExtParams copyInParams{1, static_cast(perLoopCols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(perLoopCols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(perLoopCols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + LocalTensor indicesLocal = expandRowIdxInQueue_.DeQue(); + LocalTensor smoothLocal = smoothInQueue_.AllocTensor(); + + // copyin [1,H] scale + if (smoothType_ == SCALE_1H) { + DataCopyPad(smoothLocal, quantSmoothGm_, smoothParams, {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + } + + int32_t lastExpertIdx = -1; + for (int64_t i = 0; i < currentLoopRows_; i++) { + int64_t rowOffset = perCoreRow_ * blockIdx_ + perLoopRows_ * progress; + if (dropPadMode_ == DROPLESS_MODE && (rowOffset + i) >= activeNum_) { + break; + } + LocalTensor inLocal = inputXInQueue_.AllocTensor(); + int32_t srcIdx = indicesLocal.GetValue(i); + + int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign_ + i) - expertStart_; + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, inputXGm_[srcIdx / k_ * cols_], copyInParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal[perLoopColsAlign_], inputXGm_[srcIdx / k_ * cols_], copyInParams, {false, 0, 0, 0}); + } + inputXInQueue_.EnQue(inLocal); + + // copyin dynamic scale + if (smoothType_ == SCALE_EH && expertIdx != lastExpertIdx) { + DataCopyPad(smoothLocal, quantSmoothGm_[expertIdx * this->cols_], smoothParams, {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + lastExpertIdx = expertIdx; + } + Compute(smoothLocal); + inputXInQueue_.FreeTensor(inLocal); + LocalTensor scaleLocal = scaleOutQueue_.DeQue(); + DataCopyPad(expandedScaleGm_[(rowOffset + i)], scaleLocal, quantScaleParams); + LocalTensor outLocal = inputXOutQueue_.DeQue(); + DataCopyPad(expandedXGm_[(rowOffset + i) * cols_], outLocal, copyOutParams); + + inputXOutQueue_.FreeTensor(outLocal); + scaleOutQueue_.FreeTensor(scaleLocal); + } + + smoothInQueue_.FreeTensor(smoothLocal); + expandRowIdxInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutDynamicQuant::CopyOutXDynamicQuantFromGather(int64_t progress) +{ + DataCopyExtParams copyInParams{1, static_cast(perLoopCols_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(perLoopCols_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(perLoopCols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + LocalTensor indicesLocal = expandRowIdxInQueue_.DeQue(); + LocalTensor smoothLocal = smoothInQueue_.AllocTensor(); + + int64_t rowOffset = blockIdx_ * perCoreRow_ + progress * perLoopRows_; + int64_t startXRow = rowOffset / k_; + int64_t endXRow = (rowOffset + currentLoopRows_ - 1) / k_; + int64_t curIndex = 0; + + if (smoothType_ == SCALE_1H) { + DataCopyPad(smoothLocal, quantSmoothGm_, smoothParams, {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + } + + for (int64_t row = startXRow; row <= endXRow; row++) { + LocalTensor inLocal = inputXInQueue_.AllocTensor(); + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, inputXGm_[row * cols_], copyInParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal[perLoopColsAlign_], inputXGm_[row * cols_], copyInParams, {false, 0, 0, 0}); + } + inputXInQueue_.EnQue(inLocal); + Compute(smoothLocal); + LocalTensor scaleLocal = scaleOutQueue_.DeQue(); + LocalTensor outLocal = inputXOutQueue_.DeQue(); + + while (curIndex < currentLoopRows_ && (rowOffset + curIndex) / this->k_ == row) { + int32_t outIndex = indicesLocal.GetValue(curIndex); + curIndex++; + if (outIndex == -1 || dropPadMode_ == DROPLESS_MODE && outIndex >= this->activeNum_) { + continue; + } + DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, copyOutParams); + DataCopyPad(expandedScaleGm_[outIndex], scaleLocal, quantScaleParams); + } + + inputXInQueue_.FreeTensor(inLocal); + inputXOutQueue_.FreeTensor(outLocal); + scaleOutQueue_.FreeTensor(scaleLocal); + } + + smoothInQueue_.FreeTensor(smoothLocal); + expandRowIdxInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline float +MoeGatherOutDynamicQuant::ComputeMax(LocalTensor &inLocal, LocalTensor &tempLocal, + LocalTensor &scaleLocal, int32_t srcIdx, int32_t expertIdx, + int64_t j) +{ + LocalTensor smoothLocal = smoothInQueue_.AllocTensor(); + + DataCopyExtParams intriParamsT{1, static_cast(colsTileLength_ * sizeof(T)), 0, 0, 0}; + DataCopyExtParams intriParamsFp32{1, static_cast(colsTileLength_ * sizeof(float)), 0, 0, 0}; + + if constexpr (!IsSameType::value) { + DataCopyPad(inLocal.ReinterpretCast()[perLoopColsAlign_], inputXGm_[srcIdx * cols_ + j * perLoopCols_], + intriParamsT, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal, inputXGm_[srcIdx * cols_ + j * perLoopCols_], intriParamsT, {false, 0, 0, 0}); + } + + inputXInQueue_.EnQue(inLocal); + inLocal = inputXInQueue_.DeQue(); + + if (isInputScale_) { + DataCopyPad(smoothLocal, quantSmoothGm_[expertIdx * cols_ + j * perLoopCols_], intriParamsFp32, + {false, 0, 0, 0}); + smoothInQueue_.EnQue(smoothLocal); + smoothLocal = smoothInQueue_.DeQue(); + } + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign_], RoundMode::CAST_NONE, colsTileLength_); + PipeBarrier(); + } + + if (isInputScale_) { + Mul(inLocal, inLocal, smoothLocal, colsTileLength_); + PipeBarrier(); + } + + Abs(tempLocal, inLocal, colsTileLength_); + PipeBarrier(); + + ReduceMax(scaleLocal[INT32_ONE_BLOCK_NUM], tempLocal, tempLocal, colsTileLength_); + + DataCopyPad(quantTempGm_[j * perLoopCols_], inLocal, intriParamsFp32); + smoothInQueue_.FreeTensor(smoothLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); + return scaleLocal.GetValue(INT32_ONE_BLOCK_NUM); +} + +template +__aicore__ inline void +MoeGatherOutDynamicQuant::ComputeScale(LocalTensor &inLocal, LocalTensor &tempLocal, + float scaleTemp, int64_t dstIndex, int64_t j) +{ + DataCopyExtParams copyInParams{1, static_cast(colsTileLength_ * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(colsTileLength_ * sizeof(int8_t)), 0, 0, 0}; + + LocalTensor outLocal = inputXOutQueue_.AllocTensor(); + + DataCopyPad(inLocal, quantTempGm_[j * perLoopCols_], copyInParams, {false, 0, 0, 0}); + inputXInQueue_.EnQue(inLocal); + inLocal = inputXInQueue_.DeQue(); + + Duplicate(tempLocal, scaleTemp, colsTileLength_); + PipeBarrier(); + + Div(tempLocal, inLocal, tempLocal, colsTileLength_); + PipeBarrier(); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_TRUNC, colsTileLength_); + PipeBarrier(); + + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, colsTileLength_); + + inputXOutQueue_.EnQue(outLocal); + outLocal = inputXOutQueue_.DeQue(); + DataCopyPad(expandedXGm_[dstIndex * cols_ + j * perLoopCols_], outLocal, copyOutParams); + + inputXOutQueue_.FreeTensor(outLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); +} + +template +__aicore__ inline void +MoeGatherOutDynamicQuant::CopyOutXPartialDynamicQuantFromScatter(int64_t progress) +{ + LocalTensor indicesLocal = expandRowIdxInQueue_.DeQue(); + for (int64_t i = 0; i < currentLoopRows_; i++) { + int64_t rowOffset = perCoreRow_ * blockIdx_ + perLoopRows_ * progress; + if (dropPadMode_ == DROPLESS_MODE && (rowOffset + i) >= activeNum_) { + break; + } + int32_t srcIdx = indicesLocal.GetValue(i); + int32_t expertIdx = indicesLocal.GetValue(currentLoopRowsAlign_ + i) - expertStart_; + LocalTensor inLocal = inputXInQueue_.AllocTensor(); + LocalTensor tempLocal = calcQueue_.AllocTensor(); + LocalTensor scaleLocal = scaleOutQueue_.AllocTensor(); + + float tileMax; + float reduceMax = *((float *)&INF); + for (int64_t j = 0; j < colLoops_; j++) { + colsTileLength_ = perLoopCols_; + if (j == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + + if (smoothType_ == SCALE_1H) { + // 1H + tileMax = ComputeMax(inLocal, tempLocal, scaleLocal, srcIdx / k_, 0, j); + } else { + // EH + tileMax = ComputeMax(inLocal, tempLocal, scaleLocal, srcIdx / k_, expertIdx, j); + } + reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax; + } + + float scaleTemp = reduceMax / MAX_INT8; + Duplicate(scaleLocal, scaleTemp, INT32_ONE_BLOCK_NUM); + scaleOutQueue_.EnQue(scaleLocal); + scaleLocal = scaleOutQueue_.DeQue(); + + DataCopyPad(expandedScaleGm_[(rowOffset + i)], scaleLocal, {1, 4, 0, 0, 0}); + + for (int64_t j = 0; j < colLoops_; j++) { + colsTileLength_ = perLoopCols_; + if (j == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + ComputeScale(inLocal, tempLocal, scaleTemp, rowOffset + i, j); + } + inputXInQueue_.FreeTensor(inLocal); + calcQueue_.FreeTensor(tempLocal); + scaleOutQueue_.FreeTensor(scaleLocal); + } + expandRowIdxInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutDynamicQuant::CopyOutXPartialDynamicQuantFromGather(int64_t progress) +{ + LocalTensor indicesLocal = expandRowIdxInQueue_.DeQue(); + int64_t rowOffset = blockIdx_ * perCoreRow_ + progress * perLoopRows_; + int64_t startXRow = rowOffset / k_; + int64_t endXRow = (rowOffset + currentLoopRows_ - 1) / k_; + int64_t curIndex = 0; + + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + for (int64_t row = startXRow; row <= endXRow; row++) { + LocalTensor inLocal = inputXInQueue_.AllocTensor(); + LocalTensor tempLocal = calcQueue_.AllocTensor(); + LocalTensor quantScaleLocal = scaleOutQueue_.AllocTensor(); + + float reduceMax = *((float *)&INF); + for (int64_t j = 0; j < colLoops_; j++) { + colsTileLength_ = perLoopCols_; + if (j == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + + float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, row, 0, j); + reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax; + } + + float scaleTemp = reduceMax / MAX_INT8; + Duplicate(quantScaleLocal, scaleTemp, INT32_ONE_BLOCK_NUM); + scaleOutQueue_.EnQue(quantScaleLocal); + quantScaleLocal = scaleOutQueue_.DeQue(); + + while (curIndex < currentLoopRows_ && (curIndex + rowOffset) / k_ == row) { + int32_t outIndex = indicesLocal.GetValue(curIndex); + curIndex++; + if (outIndex == -1 || (dropPadMode_ == DROPLESS_MODE && outIndex >= activeNum_)) { + continue; + } + DataCopyPad(expandedScaleGm_[outIndex], quantScaleLocal, quantScaleParams); + for (int64_t j = 0; j < colLoops_; j++) { + colsTileLength_ = perLoopCols_; + if (j == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + ComputeScale(inLocal, tempLocal, scaleTemp, outIndex, j); + } + } + inputXInQueue_.FreeTensor(inLocal); + calcQueue_.FreeTensor(tempLocal); + scaleOutQueue_.FreeTensor(quantScaleLocal); + } + expandRowIdxInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void +MoeGatherOutDynamicQuant::Init(GM_ADDR inputX, GM_ADDR quantSmooth, GM_ADDR sortedExpertIdx, + GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR expandedScale, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + cols_ = tilingData->cols; + n_ = tilingData->n; + k_ = tilingData->k; + totalLength_ = n_ * k_; + isInputScale_ = tilingData->isInputScale; + expertStart_ = tilingData->expertStart; + rowIdxType_ = tilingData->rowIdxType; + dropPadMode_ = tilingData->dropPadMode; + activeNum_ = tilingData->activeNum; + ep_ = tilingData->ep; + smoothType_ = tilingData->smoothType; + coreNum_ = tilingData->coreNum; + + // core split + int64_t actualExpertNum_ = tilingData->actualExpertNum; + if (ep_) { + expertTotalCountGm_.SetGlobalBuffer((__gm__ int32_t *)sortedExpertIdx + Align(n_ * k_, sizeof(int32_t)) * 2 + + Align(actualExpertNum_, sizeof(int32_t)), + 1); + AscendC::DataCacheCleanAndInvalid(expertTotalCountGm_); + expertTotalCount_ = expertTotalCountGm_.GetValue(0); + } else { + expertTotalCount_ = totalLength_; + } + + perCoreRow_ = Ceil(expertTotalCount_, tilingData->coreNum); + needCoreNum_ = Ceil(expertTotalCount_, perCoreRow_); + int64_t lastCoreIndicesElements = expertTotalCount_ - (needCoreNum_ - 1) * perCoreRow_; + + // inner core split + int64_t originPerLoopElements; + if (blockIdx_ == needCoreNum_ - 1) { + coreRows_ = lastCoreIndicesElements; + originPerLoopElements = gatherOutTilingData_->lastCorePerLoopIndicesElements; + } else { + coreRows_ = perCoreRow_; + originPerLoopElements = gatherOutTilingData_->perCorePerLoopIndicesElements; + } + perLoopRows_ = Min(coreRows_, originPerLoopElements); + rowLoops_ = Ceil(coreRows_, perLoopRows_); + lastLoopRows_ = coreRows_ - (rowLoops_ - 1) * perLoopRows_; + + // cols split + perLoopCols_ = gatherOutTilingData_->perLoopCols; + lastLoopCols_ = gatherOutTilingData_->lastLoopCols; + colLoops_ = gatherOutTilingData_->colsLoops; + + perLoopColsAlign_ = Align(perLoopCols_, sizeof(T)); + + inputXGm_.SetGlobalBuffer((__gm__ T *)inputX); + expandedXGm_.SetGlobalBuffer((__gm__ int8_t *)expandedX); + + expandedExpertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)sortedExpertIdx + blockIdx_ * perCoreRow_, + Align(coreRows_, sizeof(int32_t))); + + if constexpr (COPYOUTTYPE == SCATTER) { + if (rowIdxType_ == SCATTER) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreRow_, + Align(perCoreRow_, sizeof(int32_t))); + } else { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)sortedExpertIdx + Align(n_ * k_, sizeof(int32_t)) + + blockIdx_ * perCoreRow_, + Align(perCoreRow_, sizeof(int32_t))); + } + } else { + if (rowIdxType_ == GATHER) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreRow_, + Align(perCoreRow_, sizeof(int32_t))); + } else { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)sortedExpertIdx + Align(n_ * k_, sizeof(int32_t)) + + blockIdx_ * perCoreRow_, + Align(perCoreRow_, sizeof(int32_t))); + } + } + + if (isInputScale_) { + quantSmoothGm_.SetGlobalBuffer((__gm__ float *)quantSmooth); + } + expandedScaleGm_.SetGlobalBuffer((__gm__ float *)expandedScale); + + if (colLoops_ > 1) { + quantTempGm_.SetGlobalBuffer((__gm__ float *)sortedExpertIdx + Align(totalLength_, sizeof(int32_t)) * 2 + + Align(actualExpertNum_, sizeof(int32_t)) * 2 + + Align(totalLength_, sizeof(int32_t)) + blockIdx_ * cols_, + cols_ * sizeof(float)); + } + + currentLoopRowsAlign_ = Align(perLoopRows_, sizeof(int32_t)); + + int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols_, sizeof(T)); + perLoopColsAlignBytes = + Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES)); + pipe_->InitBuffer(expandRowIdxInQueue_, GATHER_OUT_DYNAMIC_QUANT_BUFFER_NUM, + 2 * AlignBytes(perLoopRows_, sizeof(int32_t))); + pipe_->InitBuffer(inputXInQueue_, GATHER_OUT_DYNAMIC_QUANT_BUFFER_NUM, perLoopColsAlignBytes); // percols * 2 * 4 + pipe_->InitBuffer(smoothInQueue_, GATHER_OUT_DYNAMIC_QUANT_BUFFER_NUM, + AlignBytes(perLoopCols_, sizeof(float))); // percols * 2 * 4 + pipe_->InitBuffer(calcQueue_, 1, AlignBytes(perLoopCols_, sizeof(float))); // percols * 1 * 4 + pipe_->InitBuffer(inputXOutQueue_, 1, AlignBytes(perLoopCols_, sizeof(int8_t))); // percols * 1 + pipe_->InitBuffer(scaleOutQueue_, 1, BLOCK_BYTES + BLOCK_BYTES); // 32 + 32 +} + +template +__aicore__ inline void MoeGatherOutDynamicQuant::Process() +{ + if (blockIdx_ < needCoreNum_) { + currentLoopRows_ = perLoopRows_; + if (colLoops_ > 1) { + for (int64_t loop = 0; loop < rowLoops_; loop++) { + if (loop == rowLoops_ - 1) { + currentLoopRows_ = lastLoopRows_; + } + CopyInExpandedExpertIdx(loop); + if constexpr (COPYOUTTYPE == GATHER) { + CopyOutXPartialDynamicQuantFromGather(loop); + } else { + CopyOutXPartialDynamicQuantFromScatter(loop); + } + } + } else { + for (int64_t loop = 0; loop < rowLoops_; loop++) { + if (loop == rowLoops_ - 1) { + currentLoopRows_ = lastLoopRows_; + } + CopyInExpandedExpertIdx(loop); + if constexpr (COPYOUTTYPE == GATHER) { + CopyOutXDynamicQuantFromGather(loop); + } else { + CopyOutXDynamicQuantFromScatter(loop); + } + } + } + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_GATHER_DYNAMIC_QUANT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_out.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_out.h new file mode 100644 index 00000000000..0ba44f76fa0 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_out.h @@ -0,0 +1,321 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_gather_out.h + * \brief + */ +#ifndef MOE_CUSTOM_GATHER_OUT_H +#define MOE_CUSTOM_GATHER_OUT_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t GATHER_OUT_BUFFER_NUM = 2; + +template +class MoeGatherOut { +public: + __aicore__ inline MoeGatherOut(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR scale, GM_ADDR workspace, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR expandedScale, const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + __aicore__ inline void CopyExpertIn(int64_t progress); + __aicore__ inline void CopyXIn(int64_t xSrcOffset, int64_t curLoopCols); + __aicore__ inline void CopyXOut(int64_t xDstOffset, int64_t curLoopCols); + __aicore__ inline void CopyScaleIn(int64_t scaleSrcOffset); + __aicore__ inline void CopyScaleOut(int64_t scaleDstOffset); + __aicore__ inline void GatherCopyOut(int64_t progress); + __aicore__ inline void ScatterCopyOut(int64_t progress); + +private: + TPipe *pipe_; + TQueBind xCopyInQueue_; + TQueBind scaleCopyInQueue_; + TQue expandedRowIdxCopyInQueue_; + + GlobalTensor xGm_; + GlobalTensor xGscaleGm_; + GlobalTensor sortedExpertIdxGm_; + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expandedScaleGm_; + GlobalTensor expertTotalCountGm_; + + int64_t blockIdx_; + int64_t cols_; + int64_t n_; + int64_t k_; + int64_t activeNum_; + int64_t dropPadMode_; + + int64_t colsLoops_; + int64_t perLoopCols_; + int64_t lastLoopCols_; + + int64_t indicesLoops_; + int64_t curLoopElements_; + + int64_t perCoreIndicesElements_; + int64_t lastCoreIndicesElements_; + int64_t perCorePerLoopIndicesElements_; + int64_t lastCorePerLoopIndicesElements_; + int64_t curCorePerLoopIndicesElements_; + int64_t curCoreLastLoopIndicesElements_; + int64_t needCoreNum_; + int64_t curCoreIndicesElements_; + + int64_t actualExpertNum_; + int64_t expertTotalCount_; + + int64_t rowIdxType_; + int64_t isInputScale_; + int64_t coreNum_; +}; + +template +__aicore__ inline void MoeGatherOut::Init(GM_ADDR x, GM_ADDR scale, GM_ADDR workspace, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR expandedScale, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + + cols_ = tilingData->cols; + n_ = tilingData->n; + k_ = tilingData->k; + coreNum_ = tilingData->coreNum; + dropPadMode_ = tilingData->dropPadMode; + activeNum_ = tilingData->activeNum; + + isInputScale_ = tilingData->isInputScale; + rowIdxType_ = tilingData->rowIdxType; + + colsLoops_ = tilingData->gatherOutComputeParamsOp.colsLoops; + perLoopCols_ = tilingData->gatherOutComputeParamsOp.perLoopCols; + lastLoopCols_ = tilingData->gatherOutComputeParamsOp.lastLoopCols; + + actualExpertNum_ = tilingData->actualExpertNum; + + if constexpr (EP) { + expertTotalCountGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(n_ * k_, sizeof(int32_t)) * 2 + + Align(actualExpertNum_, sizeof(int32_t)), + 1); + AscendC::DataCacheCleanAndInvalid(expertTotalCountGm_); + expertTotalCount_ = expertTotalCountGm_.GetValue(0); + } else { + expertTotalCount_ = n_ * k_; + } + + perCorePerLoopIndicesElements_ = tilingData->gatherOutComputeParamsOp.perCorePerLoopIndicesElements; + lastCorePerLoopIndicesElements_ = tilingData->gatherOutComputeParamsOp.lastCorePerLoopIndicesElements; + perCoreIndicesElements_ = Ceil(expertTotalCount_, tilingData->coreNum); + needCoreNum_ = Ceil(expertTotalCount_, perCoreIndicesElements_); + lastCoreIndicesElements_ = expertTotalCount_ - (needCoreNum_ - 1) * perCoreIndicesElements_; + + if (blockIdx_ == needCoreNum_ - 1) { + curCoreIndicesElements_ = lastCoreIndicesElements_; + curCorePerLoopIndicesElements_ = Min(lastCorePerLoopIndicesElements_, curCoreIndicesElements_); + } else { + curCoreIndicesElements_ = perCoreIndicesElements_; + curCorePerLoopIndicesElements_ = Min(perCorePerLoopIndicesElements_, curCoreIndicesElements_); + } + indicesLoops_ = Ceil(curCoreIndicesElements_, curCorePerLoopIndicesElements_); + curCoreLastLoopIndicesElements_ = curCoreIndicesElements_ - (indicesLoops_ - 1) * curCorePerLoopIndicesElements_; + + xGm_.SetGlobalBuffer((__gm__ T *)x, n_ * cols_); + xGscaleGm_.SetGlobalBuffer((__gm__ float *)scale, n_); + + expandedXGm_.SetGlobalBuffer((__gm__ T *)expandedX); + expandedScaleGm_.SetGlobalBuffer((__gm__ float *)expandedScale); + + pipe_->InitBuffer(expandedRowIdxCopyInQueue_, GATHER_OUT_BUFFER_NUM, + AlignBytes(curCorePerLoopIndicesElements_, sizeof(int32_t))); + pipe_->InitBuffer(xCopyInQueue_, GATHER_OUT_BUFFER_NUM, AlignBytes(perLoopCols_, sizeof(T))); + pipe_->InitBuffer(scaleCopyInQueue_, GATHER_OUT_BUFFER_NUM, AlignBytes(1, sizeof(float))); + + sortedExpertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + blockIdx_ * perCoreIndicesElements_, + Align(curCoreIndicesElements_, sizeof(int32_t))); + + if constexpr (EP) { + if (rowIdxType_ == SCATTER) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreIndicesElements_, + Align(curCoreIndicesElements_, sizeof(int32_t))); + } else { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(n_ * k_, sizeof(int32_t)) + + blockIdx_ * perCoreIndicesElements_, + Align(curCoreIndicesElements_, sizeof(int32_t))); + } + } else { + if (rowIdxType_ == GATHER) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreIndicesElements_, + Align(curCoreIndicesElements_, sizeof(int32_t))); + } else { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(n_ * k_, sizeof(int32_t)) + + blockIdx_ * perCoreIndicesElements_, + Align(curCoreIndicesElements_, sizeof(int32_t))); + } + } +} + +template +__aicore__ inline void MoeGatherOut::CopyExpertIn(int64_t progress) +{ + LocalTensor subRowIdxLocal = expandedRowIdxCopyInQueue_.AllocTensor(); + DataCopyExtParams copyParams{1, static_cast(curLoopElements_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + DataCopyPad(subRowIdxLocal, expandedRowIdxGm_[progress * curCorePerLoopIndicesElements_], copyParams, padParams); + expandedRowIdxCopyInQueue_.EnQue(subRowIdxLocal); +} + +template +__aicore__ inline void MoeGatherOut::CopyXIn(int64_t xSrcOffset, int64_t curLoopCols) +{ + LocalTensor xLocal = xCopyInQueue_.AllocTensor(); + DataCopyExtParams copyParams0{static_cast(1), static_cast(curLoopCols * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams padParams0{false, 0, 0, 0}; + DataCopyPad(xLocal, xGm_[xSrcOffset], copyParams0, padParams0); + xCopyInQueue_.EnQue(xLocal); +} + +template +__aicore__ inline void MoeGatherOut::CopyXOut(int64_t xDstOffset, int64_t curLoopCols) +{ + LocalTensor xLocal = xCopyInQueue_.DeQue(); + DataCopyExtParams copyParams2{1, static_cast(curLoopCols * sizeof(T)), 0, 0, 0}; + DataCopyPad(expandedXGm_[xDstOffset], xLocal, copyParams2); + xCopyInQueue_.FreeTensor(xLocal); +} + +template +__aicore__ inline void MoeGatherOut::CopyScaleIn(int64_t scaleSrcOffset) +{ + LocalTensor scaleLocal = scaleCopyInQueue_.AllocTensor(); + DataCopyExtParams copyParams1{static_cast(1), static_cast(1 * sizeof(float)), 0, 0, 0}; + DataCopyPadExtParams padParams1{false, 0, 0, 0}; + DataCopyPad(scaleLocal, xGscaleGm_[scaleSrcOffset], copyParams1, padParams1); + scaleCopyInQueue_.EnQue(scaleLocal); +} + +template +__aicore__ inline void MoeGatherOut::CopyScaleOut(int64_t scaleDstOffset) +{ + LocalTensor scaleLocal = scaleCopyInQueue_.DeQue(); + DataCopyExtParams copyParams3{1, static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPad(expandedScaleGm_[scaleDstOffset], scaleLocal, copyParams3); + scaleCopyInQueue_.FreeTensor(scaleLocal); +} + +template +__aicore__ inline void MoeGatherOut::GatherCopyOut(int64_t progress) +{ + LocalTensor subRowIdxLocal = expandedRowIdxCopyInQueue_.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + int64_t curLoopCols = perLoopCols_; + for (int64_t colsLoop = 0; colsLoop < colsLoops_; colsLoop++) { + int64_t initialRow = blockIdx_ * perCoreIndicesElements_ + curCorePerLoopIndicesElements_ * progress; + int64_t curLoopRow = 0; + if (colsLoop == colsLoops_ - 1) { + curLoopCols = lastLoopCols_; + } + int64_t currentLoopStartRow = initialRow / k_; + int64_t currentLoopLastRow = (initialRow + this->curLoopElements_ - 1) / k_; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor inLocal = xCopyInQueue_.AllocTensor(); + int64_t inputOffset = row * cols_ + colsLoop * perLoopCols_; + DataCopyExtParams xCopyParams{1, static_cast(curLoopCols * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal, xGm_[inputOffset], xCopyParams, dataCopyPadParams); + // copy in scale + LocalTensor scaleLocal = scaleCopyInQueue_.AllocTensor(); + DataCopyExtParams scaleCopyParams{1, static_cast(sizeof(float)), 0, 0, 0}; + if (isInputScale_ == 1 && colsLoop == 0) { + DataCopyPadExtParams scalePadParams{false, 0, 0, 0}; + DataCopyPad(scaleLocal, xGscaleGm_[row], scaleCopyParams, scalePadParams); + } + SetWaitFlag(HardEvent::MTE2_MTE3); + DataCopyExtParams intriParams{1, static_cast(curLoopCols * sizeof(T)), 0, 0, 0}; + while (curLoopRow < this->curLoopElements_ && initialRow / k_ == row) { + int32_t outIndex = subRowIdxLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (dropPadMode_ == DROPLESS_MODE && outIndex >= activeNum_)) { + continue; + } + int64_t outOffset = outIndex * this->cols_ + colsLoop * this->perLoopCols_; + DataCopyPad(expandedXGm_[outOffset], inLocal, intriParams); + // copy out scale + if (isInputScale_ == 1 && colsLoop == 0) { + DataCopyPad(expandedScaleGm_[outIndex], scaleLocal, scaleCopyParams); + } + } + scaleCopyInQueue_.FreeTensor(scaleLocal); + xCopyInQueue_.FreeTensor(inLocal); + } + } + expandedRowIdxCopyInQueue_.FreeTensor(subRowIdxLocal); +} + +template +__aicore__ inline void MoeGatherOut::ScatterCopyOut(int64_t progress) +{ + int64_t curExpertLoopOffset = progress * curCorePerLoopIndicesElements_; + LocalTensor subRowIdxLocal = expandedRowIdxCopyInQueue_.DeQue(); + for (int64_t indicesIndex = 0; indicesIndex < curLoopElements_; indicesIndex++) { + int64_t rowIdx = subRowIdxLocal.GetValue(indicesIndex); + int64_t rowOffset = curExpertLoopOffset + indicesIndex + blockIdx_ * perCoreIndicesElements_; + if (activeNum_ > 0 && dropPadMode_ == DROPLESS_MODE && rowOffset >= activeNum_) { + break; + } + SetWaitFlag(HardEvent::S_MTE2); + if (isInputScale_ == 1) { + int64_t scaleSrcOffset = rowIdx / k_; + CopyScaleIn(scaleSrcOffset); + CopyScaleOut(indicesIndex + curExpertLoopOffset + blockIdx_ * perCoreIndicesElements_); + } + int64_t curLoopCols = perLoopCols_; + for (int64_t colsLoop = 0; colsLoop < colsLoops_; colsLoop++) { + if (colsLoop == colsLoops_ - 1) { + curLoopCols = lastLoopCols_; + } + int64_t xSrcOffset = rowIdx / k_ * cols_; + int64_t xDstOffset = (blockIdx_ * perCoreIndicesElements_ + curExpertLoopOffset + indicesIndex) * cols_; + int64_t colsLoopOffset = colsLoop * perLoopCols_; + CopyXIn(xSrcOffset + colsLoopOffset, curLoopCols); + CopyXOut(xDstOffset + colsLoopOffset, curLoopCols); + } + } + expandedRowIdxCopyInQueue_.FreeTensor(subRowIdxLocal); +} + +template +__aicore__ inline void MoeGatherOut::Process() +{ + if (blockIdx_ < needCoreNum_) { + curLoopElements_ = curCorePerLoopIndicesElements_; + for (int64_t loop = 0; loop < indicesLoops_; loop++) { + if (loop == indicesLoops_ - 1) { + curLoopElements_ = curCoreLastLoopIndicesElements_; + } + CopyExpertIn(loop); + if constexpr (!EP) { + GatherCopyOut(loop); + } else { + ScatterCopyOut(loop); + } + } + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_GATHER_OUT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_out_droppad.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_out_droppad.h new file mode 100644 index 00000000000..d5229ba591c --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_out_droppad.h @@ -0,0 +1,210 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_gather_out_droppad.h + * \brief + */ +#ifndef MOE_CUSTOM_GATHER_OUT_DROPPAD_H +#define MOE_CUSTOM_GATHER_OUT_DROPPAD_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t GATHER_OUT_DROPPAD_BUFFER_NUM = 2; + +template +class MoeGatherOutDroppad { +public: + __aicore__ inline MoeGatherOutDroppad(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR expandedScale, GM_ADDR workspace, const MoeInitRoutingCustomTilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyInIndices(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + __aicore__ inline void CopyScaleIn(int64_t scaleSrcOffset, LocalTensor scaleLocal); + __aicore__ inline void CopyScaleOut(int64_t scaleDstOffset, LocalTensor scaleLocal); + +private: + TPipe *pipe_; + TQueBind xCopyInQueue_; + TQueBind scaleCopyInQueue_; + TQue expandedRowIdxCopyInQueue_; + + GlobalTensor inputXGm_; + GlobalTensor xGscaleGm_; + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor expandedScaleGm_; + + const MoeCustomGatherOutComputeTilingData *gatherOutTilingData_; + + int64_t needCoreNum_; + int64_t blockIdx_; + int64_t cols_; + int64_t n_; + int64_t k_; + int64_t currentLoopRows_; + int64_t coreRows_; + int64_t perLoopRows_; + int64_t lastLoopRows_; + int64_t rowLoops_; + int64_t colsTileLength_; + int64_t perLoopCols_; + int64_t lastLoopCols_; + int64_t colLoops_; + int64_t isInputScale_; + + int64_t indicesOffset_; + int64_t inputOffset_; + int64_t outOffset_; +}; + +template +__aicore__ inline void MoeGatherOutDroppad::CopyInIndices(int64_t progress) +{ + indicesOffset_ = progress * perLoopRows_; + LocalTensor indicesLocal = expandedRowIdxCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(currentLoopRows_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm_[indicesOffset_], dataCopyParams, dataCopyPadParams); + expandedRowIdxCopyInQueue_.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutDroppad::CopyScaleIn(int64_t scaleSrcOffset, LocalTensor scaleLocal) +{ + DataCopyExtParams copyParams1{static_cast(1), static_cast(1 * sizeof(float)), 0, 0, 0}; + DataCopyPadExtParams padParams1{false, 0, 0, 0}; + DataCopyPad(scaleLocal, xGscaleGm_[scaleSrcOffset], copyParams1, padParams1); + scaleCopyInQueue_.EnQue(scaleLocal); +} + +template +__aicore__ inline void MoeGatherOutDroppad::CopyScaleOut(int64_t scaleDstOffset, LocalTensor scaleLocal) +{ + DataCopyExtParams copyParams3{1, static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPad(expandedScaleGm_[scaleDstOffset], scaleLocal, copyParams3); +} + +template +__aicore__ inline void MoeGatherOutDroppad::CopyOut(int64_t progress) +{ + LocalTensor indicesLocal = expandedRowIdxCopyInQueue_.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + colsTileLength_ = perLoopCols_; + for (int64_t colsLoop = 0; colsLoop < colLoops_; colsLoop++) { + int64_t initialRow = gatherOutTilingData_->perCoreIndicesElements * blockIdx_ + perLoopRows_ * progress; + int64_t curLoopRow = 0; + if (colsLoop == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + int64_t currentLoopStartRow = initialRow / k_; + int64_t currentLoopLastRow = (initialRow + currentLoopRows_ - 1) / k_; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + LocalTensor scaleLocal = scaleCopyInQueue_.AllocTensor(); + if (isInputScale_ == 1) { + CopyScaleIn(row, scaleLocal); + LocalTensor scaleLocal = scaleCopyInQueue_.DeQue(); + } + inputOffset_ = row * cols_ + colsLoop * perLoopCols_; + // input row position + LocalTensor inLocal = xCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(colsTileLength_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal, inputXGm_[inputOffset_], dataCopyParams, dataCopyPadParams); + SetWaitFlag(HardEvent::MTE2_MTE3); + DataCopyExtParams intriParams{1, static_cast(colsTileLength_ * sizeof(T)), 0, 0, 0}; + while (curLoopRow < currentLoopRows_ && initialRow / k_ == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1) { + continue; + } + outOffset_ = outIndex * cols_ + colsLoop * perLoopCols_; + DataCopyPad(expandedXGm_[outOffset_], inLocal, intriParams); + if (isInputScale_ == 1) { + CopyScaleOut(outIndex, scaleLocal); + } + } + xCopyInQueue_.FreeTensor(inLocal); + scaleCopyInQueue_.FreeTensor(scaleLocal); + } + } + expandedRowIdxCopyInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutDroppad::Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + + needCoreNum_ = gatherOutTilingData_->needCoreNum; + cols_ = tilingData->cols; + n_ = tilingData->n; + k_ = tilingData->k; + isInputScale_ = tilingData->isInputScale; + + if (blockIdx_ == needCoreNum_ - 1) { + coreRows_ = gatherOutTilingData_->lastCoreIndicesElements; + perLoopRows_ = gatherOutTilingData_->lastCorePerLoopIndicesElements; + lastLoopRows_ = gatherOutTilingData_->lastCoreLastLoopIndicesElements; + rowLoops_ = gatherOutTilingData_->lastCoreIndicesLoops; + } else { + coreRows_ = gatherOutTilingData_->perCoreIndicesElements; + perLoopRows_ = gatherOutTilingData_->perCorePerLoopIndicesElements; + lastLoopRows_ = gatherOutTilingData_->perCoreLastLoopIndicesElements; + rowLoops_ = gatherOutTilingData_->perCoreIndicesLoops; + } + perLoopCols_ = gatherOutTilingData_->perLoopCols; + lastLoopCols_ = gatherOutTilingData_->lastLoopCols; + colLoops_ = gatherOutTilingData_->colsLoops; + + inputXGm_.SetGlobalBuffer((__gm__ T *)inputX, coreRows_ * cols_); + xGscaleGm_.SetGlobalBuffer((__gm__ float *)scale, n_); + expandedXGm_.SetGlobalBuffer((__gm__ T *)expandedX, n_ * k_ * cols_); + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + + blockIdx_ * gatherOutTilingData_->perCoreIndicesElements, + Align(coreRows_, sizeof(int32_t))); + expandedScaleGm_.SetGlobalBuffer((__gm__ float *)expandedScale); + + pipe_->InitBuffer(xCopyInQueue_, GATHER_OUT_DROPPAD_BUFFER_NUM, AlignBytes(perLoopCols_, sizeof(T))); + pipe_->InitBuffer(expandedRowIdxCopyInQueue_, GATHER_OUT_DROPPAD_BUFFER_NUM, + AlignBytes(perLoopRows_, sizeof(int32_t))); + pipe_->InitBuffer(scaleCopyInQueue_, GATHER_OUT_DROPPAD_BUFFER_NUM, AlignBytes(1, sizeof(float))); +} + +template +__aicore__ inline void MoeGatherOutDroppad::Process() +{ + if (blockIdx_ < needCoreNum_) { + currentLoopRows_ = perLoopRows_; + for (int64_t loop = 0; loop < rowLoops_; loop++) { + if (loop == rowLoops_ - 1) { + currentLoopRows_ = lastLoopRows_; + } + CopyInIndices(loop); + CopyOut(loop); + } + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_GATHER_OUT_DROPPAD_H diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_sort_multi_core.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_sort_multi_core.h new file mode 100644 index 00000000000..6ed3f2d723c --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_sort_multi_core.h @@ -0,0 +1,242 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_gather_sort_multi_core.h + * \brief + */ +#ifndef MOE_CUSTOM_GATHER_SORT_MULTI_CORE_H +#define MOE_CUSTOM_GATHER_SORT_MULTI_CORE_H + +#include "moe_custom_common.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t SORT32_ALIGN_ELEMENT = 32; +constexpr int64_t PARALLEL_GATHERED_SORT_NEED_CORE_NUM = 16; +constexpr int64_t MULTI_GATHERED_MAX_NUM = 4096; // 8192 * 8 / 16 + +class MoeGatherSortMultiCore { +public: + __aicore__ inline MoeGatherSortMultiCore(){}; + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void Compute(); + __aicore__ inline void CopyOut(); + +private: + TPipe *pipe_; + TBuf buffer_; + GlobalTensor workspaceGm_; + GlobalTensor expendedRowIdxGm_; + GlobalTensor expertIdxGm_; + GlobalTensor sortedExpertIdxGm_; + GlobalTensor sortedExpertIndexGm_; + GlobalTensor sortedNumGm_; + + TQue sortedNumCopyOutQueue_; + + int64_t expertIdxOffset_ = 0; + int64_t expertIndexOffset_ = 0; + int64_t compareScalarMask0Offset_ = 0; + int64_t compareScalarMask1Offset_ = 0; + int64_t gatherMaskOffset_ = 0; + + int64_t totalLength_; + int64_t expertStart_ = 0; + int64_t expertEnd_ = 0; + int64_t actual_expert_num_ = 0; + int64_t needCoreNum_ = 0; + int64_t perCoreElements_ = 0; + int64_t blockIdx_; + int64_t currentCoreElements_ = 0; + int64_t needSortNum_ = 0; + int64_t kvFactor = 2; + + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; + static constexpr int64_t MASK_STRIDE = 64; +}; + +__aicore__ inline void MoeGatherSortMultiCore::CopyIn() +{ + LocalTensor expertIdx = buffer_.Get()[expertIdxOffset_ / sizeof(int32_t)]; + + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast(currentCoreElements_ * sizeof(int32_t)), 0, 0, 0}; + + DataCopyPad(expertIdx, expertIdxGm_[blockIdx_ * perCoreElements_], dataCopyParams, dataCopyPadParams); + SetWaitFlag(HardEvent::MTE2_V); +} + +__aicore__ inline void MoeGatherSortMultiCore::Compute() +{ + LocalTensor expertIdx = buffer_.Get()[expertIdxOffset_ / sizeof(int32_t)]; + LocalTensor expertIdxFp32 = expertIdx.ReinterpretCast(); + LocalTensor gatheredExpertIdx = buffer_.Get(); + LocalTensor gatheredExpertIdxFp32 = gatheredExpertIdx.ReinterpretCast(); + + Cast(expertIdxFp32, expertIdx, RoundMode::CAST_ROUND, currentCoreElements_); + PipeBarrier(); + Muls(expertIdxFp32, expertIdxFp32, (float)-1, currentCoreElements_); + PipeBarrier(); + + LocalTensor compareScalarMaskLocalTensor0 = buffer_.Get()[compareScalarMask0Offset_]; + LocalTensor compareScalarMaskLocalTensor1 = buffer_.Get()[compareScalarMask1Offset_]; + LocalTensor gatherMaskLocalTensor = buffer_.Get()[gatherMaskOffset_]; + + // Find elements >= expertStart_, which means -elements <= -expertStart_ + AscendC::CompareScalar( + compareScalarMaskLocalTensor0, expertIdxFp32, static_cast(-expertStart_), AscendC::CMPMODE::LE, + (currentCoreElements_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + PipeBarrier(); + + // Find elements < expertEnd_, which means -elements > -expertEnd_ + AscendC::CompareScalar( + compareScalarMaskLocalTensor1, expertIdxFp32, static_cast(-expertEnd_), AscendC::CMPMODE::GT, + (currentCoreElements_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + PipeBarrier(); + + // Get experts between [expert_start, expert_end) + And(gatherMaskLocalTensor.ReinterpretCast(), compareScalarMaskLocalTensor0.ReinterpretCast(), + compareScalarMaskLocalTensor1.ReinterpretCast(), + Ceil(currentCoreElements_, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE / kvFactor); + PipeBarrier(); + + uint64_t sortedNum = 0; + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = 1; + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = DST_REP_STRIDE; + gatherMaskParams.src1RepeatStride = DST_REP_STRIDE; + GatherMask(gatheredExpertIdxFp32, expertIdxFp32, gatherMaskLocalTensor.ReinterpretCast(), true, + static_cast(currentCoreElements_), gatherMaskParams, sortedNum); + PipeBarrier(); + actual_expert_num_ = sortedNum; + int64_t needSortNum = Ceil(static_cast(sortedNum), ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + needSortNum_ = needSortNum; + + // Handle actual_expert_num_ == 0 + if (actual_expert_num_ < 1) { + return; + } + + LocalTensor expertIndex = buffer_.Get()[expertIdxOffset_ / sizeof(int32_t)]; + LocalTensor gatheredExpertIndex = buffer_.Get()[needSortNum]; + ArithProgression(expertIndex, blockIdx_ * perCoreElements_, 1, currentCoreElements_); + GatherMask(gatheredExpertIndex, expertIndex, gatherMaskLocalTensor.ReinterpretCast(), true, + static_cast(currentCoreElements_), gatherMaskParams, sortedNum); + PipeBarrier(); + int64_t duplicateNum = sortedNum % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = sortedNum - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(gatheredExpertIdxFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + } + PipeBarrier(); + + LocalTensor concatLocal; + LocalTensor sortTempTensor = buffer_.Get()[needSortNum * kvFactor]; + Concat(concatLocal, gatheredExpertIdxFp32, sortTempTensor, needSortNum / ONE_REPEAT_SORT_NUM); + LocalTensor sortedLocal = buffer_.Get()[needSortNum * kvFactor + needSortNum * kvFactor * kvFactor]; + Sort(sortedLocal, concatLocal, gatheredExpertIndex.ReinterpretCast(), sortTempTensor, + needSortNum / ONE_REPEAT_SORT_NUM); + SetWaitFlag(HardEvent::V_MTE3); +} + +__aicore__ inline void MoeGatherSortMultiCore::CopyOut() +{ + // Copy out sortedLocal for MergeSort + if (actual_expert_num_ > 0) { + LocalTensor sortedLocal = + buffer_.Get()[needSortNum_ * kvFactor + needSortNum_ * kvFactor * kvFactor]; + DataCopyExtParams extParams{static_cast(1), + static_cast(2 * actual_expert_num_ * sizeof(float)), 0, 0, 0}; + int64_t curCoreStartIndex = 2 * GetBlockIdx() * perCoreElements_; + DataCopyPad(sortedExpertIdxGm_[curCoreStartIndex], sortedLocal, extParams); + } + + // Copyout actual_expert_num_ + LocalTensor sortedNumOutLocal = sortedNumCopyOutQueue_.AllocTensor(); + sortedNumOutLocal.SetValue(0, actual_expert_num_); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyExtParams copyParams3{static_cast(1), static_cast(sizeof(uint32_t)), 0, 0, 0}; + DataCopyPad(sortedNumGm_[GetBlockIdx()], sortedNumOutLocal, copyParams3); + + sortedNumCopyOutQueue_.FreeTensor(sortedNumOutLocal); +} + +__aicore__ inline void MoeGatherSortMultiCore::Init(GM_ADDR expertIdx, GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + totalLength_ = tilingData->n * tilingData->k; + + expertStart_ = tilingData->expertStart; + expertEnd_ = tilingData->expertEnd; + + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx); + + expendedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expendedRowIdx); + + workspaceGm_.SetGlobalBuffer((__gm__ int32_t *)workspace); + + sortedExpertIdxGm_.SetGlobalBuffer((__gm__ float *)workspace); + sortedExpertIndexGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(totalLength_, sizeof(int32_t))); + + // key and value + sortedNumGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(totalLength_, sizeof(int32_t)) * kvFactor * kvFactor); + + needCoreNum_ = PARALLEL_GATHERED_SORT_NEED_CORE_NUM; + perCoreElements_ = Ceil(totalLength_, needCoreNum_); + + int32_t lastCoreElements = totalLength_ - (needCoreNum_ - 1) * perCoreElements_; + if (blockIdx_ == (needCoreNum_ - 1)) { + currentCoreElements_ = lastCoreElements; + } else { + currentCoreElements_ = perCoreElements_; + } + + // expertIdxOffset_ + expertIdxOffset_ = AlignBytes(currentCoreElements_, sizeof(int32_t)); + expertIndexOffset_ = expertIdxOffset_; + + gatherMaskOffset_ = expertIdxOffset_ * kvFactor; + int64_t maskOffset = + AlignBytes(Ceil(currentCoreElements_, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE, sizeof(int8_t)); + compareScalarMask0Offset_ = gatherMaskOffset_ + maskOffset; + compareScalarMask1Offset_ = compareScalarMask0Offset_ + maskOffset; + int64_t bufferSize = MULTI_GATHERED_MAX_NUM * kvFactor * kvFactor * kvFactor * sizeof(int32_t); + pipe_->InitBuffer(sortedNumCopyOutQueue_, 1, AlignBytes(1, sizeof(int32_t))); + pipe_->InitBuffer(buffer_, bufferSize); // 73728 Bytes +} + +__aicore__ inline void MoeGatherSortMultiCore::Process() +{ + if (blockIdx_ < PARALLEL_GATHERED_SORT_NEED_CORE_NUM) { + CopyIn(); + Compute(); + CopyOut(); + } + SyncAll(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_GATHER_SORT_MULTI_CORE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_static_quant.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_static_quant.h new file mode 100644 index 00000000000..78bd9e1c9bf --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_gather_static_quant.h @@ -0,0 +1,329 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_gather_quant.h + * \brief + */ +#ifndef MOE_CUSTOM_GATHER_STATIC_QUANT_H +#define MOE_CUSTOM_GATHER_STATIC_QUANT_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t GATHER_OUT_QUANT_BUFFER_NUM = 2; + +template +class MoeGatherOutQuant { +public: + __aicore__ inline MoeGatherOutQuant(){}; + __aicore__ inline void Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR workspace, const MoeInitRoutingCustomTilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyExpertIn(int64_t progress); + __aicore__ inline void Compute(int64_t curLoopCols); + __aicore__ inline void CopyXIn(int64_t xSrcOffset, int64_t curLoopCols); + __aicore__ inline void CopyXOut(int64_t xDstOffset, int64_t curLoopCols); + __aicore__ inline void ScatterCopyOut(int64_t progress); + __aicore__ inline void GatherCopyOut(int64_t progress); + +private: + TPipe *pipe_; + TQue inputXCopyInQueue_; + TQue expandRowIdxCopyInQueue_; + TQue inputXCopyOutQueue_; + TQue floatQueue_; + TQue halfQueue_; + + GlobalTensor inputXGm_; + GlobalTensor expandedXGm_; + GlobalTensor expandedRowIdxGm_; + GlobalTensor scaleGm_; + GlobalTensor offsetGm_; + GlobalTensor expertTotalCountGm_; + + const MoeCustomGatherOutComputeTilingData *gatherOutTilingData_; + + int64_t needCoreNum_; + int64_t blockIdx_; + int64_t cols_; + int64_t n_; + int64_t k_; + int64_t perCoreRow_; + int64_t currentLoopRows_; + int64_t coreRows_; + int64_t perLoopRows_; + int64_t lastLoopRows_; + int64_t rowLoops_; + int64_t colsTileLength_; + int64_t perLoopCols_; + int64_t lastLoopCols_; + int64_t colLoops_; + float scale_; + float offset_; + int64_t rowIdxType_; + int64_t dropPadMode_; + int64_t activeNum_; + int64_t indicesOffset_; + int64_t coreNum_; + int64_t inputOffset_; + int64_t outOffset_; + int64_t expertTotalCount_; +}; + +template +__aicore__ inline void MoeGatherOutQuant::Init(GM_ADDR inputX, GM_ADDR scale, GM_ADDR offset, + GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + blockIdx_ = GetBlockIdx(); + + gatherOutTilingData_ = &(tilingData->gatherOutComputeParamsOp); + cols_ = tilingData->cols; + n_ = tilingData->n; + k_ = tilingData->k; + rowIdxType_ = tilingData->rowIdxType; + dropPadMode_ = tilingData->dropPadMode; + activeNum_ = tilingData->activeNum; + coreNum_ = tilingData->coreNum; + + // core split + int64_t actualExpertNum_ = tilingData->actualExpertNum; + + if constexpr (EP) { + expertTotalCountGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(n_ * k_, sizeof(int32_t)) * 2 + + Align(actualExpertNum_, sizeof(int32_t)), + 1); + AscendC::DataCacheCleanAndInvalid(expertTotalCountGm_); + expertTotalCount_ = expertTotalCountGm_.GetValue(0); + } else { + expertTotalCount_ = n_ * k_; + } + + perCoreRow_ = Ceil(expertTotalCount_, tilingData->coreNum); + needCoreNum_ = Ceil(expertTotalCount_, perCoreRow_); + int64_t lastCoreIndicesElements_ = expertTotalCount_ - (needCoreNum_ - 1) * perCoreRow_; + + // inner core split + int64_t originPerLoopElements; + if (blockIdx_ == needCoreNum_ - 1) { + coreRows_ = lastCoreIndicesElements_; + originPerLoopElements = gatherOutTilingData_->lastCorePerLoopIndicesElements; + } else { + coreRows_ = perCoreRow_; + originPerLoopElements = gatherOutTilingData_->perCorePerLoopIndicesElements; + } + perLoopRows_ = Min(coreRows_, originPerLoopElements); + rowLoops_ = Ceil(coreRows_, perLoopRows_); + lastLoopRows_ = coreRows_ - (rowLoops_ - 1) * perLoopRows_; + + // cols split + perLoopCols_ = gatherOutTilingData_->perLoopCols; + lastLoopCols_ = gatherOutTilingData_->lastLoopCols; + colLoops_ = gatherOutTilingData_->colsLoops; + + inputXGm_.SetGlobalBuffer((__gm__ T *)inputX); + expandedXGm_.SetGlobalBuffer((__gm__ int8_t *)expandedX); + + if constexpr (EP) { + if (rowIdxType_ == SCATTER) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreRow_, + Align(coreRows_, sizeof(int32_t))); + } else { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(n_ * k_, sizeof(int32_t)) + + blockIdx_ * perCoreRow_, + Align(coreRows_, sizeof(int32_t))); + } + } else { + if (rowIdxType_ == GATHER) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreRow_, + Align(coreRows_, sizeof(int32_t))); + } else { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(n_ * k_, sizeof(int32_t)) + + blockIdx_ * perCoreRow_, + Align(coreRows_, sizeof(int32_t))); + } + } + + + scaleGm_.SetGlobalBuffer((__gm__ float *)scale, 1); + offsetGm_.SetGlobalBuffer((__gm__ float *)offset, 1); + scale_ = scaleGm_.GetValue(0); + offset_ = offsetGm_.GetValue(0); + + pipe_->InitBuffer(inputXCopyInQueue_, GATHER_OUT_QUANT_BUFFER_NUM, AlignBytes(perLoopCols_, sizeof(T))); + pipe_->InitBuffer(inputXCopyOutQueue_, GATHER_OUT_QUANT_BUFFER_NUM, AlignBytes(perLoopCols_, sizeof(int8_t))); + pipe_->InitBuffer(expandRowIdxCopyInQueue_, GATHER_OUT_QUANT_BUFFER_NUM, AlignBytes(perLoopRows_, sizeof(int32_t))); + pipe_->InitBuffer(floatQueue_, 1, AlignBytes(perLoopCols_, sizeof(float))); + pipe_->InitBuffer(halfQueue_, 1, AlignBytes(perLoopCols_, sizeof(half))); +} + +template +__aicore__ inline void MoeGatherOutQuant::CopyExpertIn(int64_t progress) +{ + indicesOffset_ = progress * perLoopRows_; + LocalTensor indicesLocal = expandRowIdxCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{1, static_cast(currentLoopRows_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(indicesLocal, expandedRowIdxGm_[indicesOffset_], dataCopyParams, dataCopyPadParams); + expandRowIdxCopyInQueue_.EnQue(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutQuant::CopyXIn(int64_t xSrcOffset, int64_t curLoopCols) +{ + LocalTensor inLocal = inputXCopyInQueue_.AllocTensor(); + DataCopyExtParams copyParams0{static_cast(1), static_cast(curLoopCols * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams padParams0{false, 0, 0, 0}; + DataCopyPad(inLocal, inputXGm_[xSrcOffset], copyParams0, padParams0); + inputXCopyInQueue_.EnQue(inLocal); +} + +template +__aicore__ inline void MoeGatherOutQuant::CopyXOut(int64_t xDstOffset, int64_t curLoopCols) +{ + LocalTensor outLocal = inputXCopyOutQueue_.DeQue(); + DataCopyExtParams copyParams2{1, static_cast(curLoopCols * sizeof(int8_t)), 0, 0, 0}; + DataCopyPad(expandedXGm_[xDstOffset], outLocal, copyParams2); + inputXCopyOutQueue_.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeGatherOutQuant::Compute(int64_t curLoopCols) +{ + LocalTensor floatLocal; + LocalTensor inLocal; + LocalTensor outLocal = inputXCopyOutQueue_.AllocTensor(); + LocalTensor halfLocal = halfQueue_.AllocTensor(); + uint32_t elements = Align(curLoopCols, sizeof(T)); + if constexpr (IsSameType::value) { + floatLocal = inputXCopyInQueue_.DeQue(); + } else { + inLocal = inputXCopyInQueue_.DeQue(); + floatLocal = floatQueue_.AllocTensor(); + Cast(floatLocal, inLocal, RoundMode::CAST_NONE, elements); + PipeBarrier(); + } + Muls(floatLocal, floatLocal, scale_, elements); + PipeBarrier(); + Adds(floatLocal, floatLocal, offset_, elements); + PipeBarrier(); + LocalTensor intLocal = floatLocal.ReinterpretCast(); + Cast(intLocal, floatLocal, RoundMode::CAST_RINT, elements); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + PipeBarrier(); + Cast(halfLocal, intLocal, RoundMode::CAST_ROUND, elements); + PipeBarrier(); + Cast(outLocal, halfLocal, RoundMode::CAST_TRUNC, elements); + inputXCopyOutQueue_.EnQue(outLocal); + if constexpr (IsSameType::value) { + inputXCopyInQueue_.FreeTensor(floatLocal); + } else { + inputXCopyInQueue_.FreeTensor(inLocal); + floatQueue_.FreeTensor(floatLocal); + } + halfQueue_.FreeTensor(halfLocal); +} + +template +__aicore__ inline void MoeGatherOutQuant::ScatterCopyOut(int64_t progress) +{ + LocalTensor indicesLocal = expandRowIdxCopyInQueue_.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + for (int64_t indicesIndex = 0; indicesIndex < currentLoopRows_; indicesIndex++) { + int64_t rowOffset = perCoreRow_ * blockIdx_ + perLoopRows_ * progress; + int64_t rowIdx = indicesLocal.GetValue(indicesIndex); + int64_t xSrcOffset = rowIdx / k_ * cols_; + int64_t xDstOffset = (rowOffset + indicesIndex) * cols_; + int64_t curLoopCols = perLoopCols_; + if (activeNum_ > 0 && dropPadMode_ == DROPLESS_MODE && (rowOffset + indicesIndex) >= activeNum_) { + break; + } + SetWaitFlag(HardEvent::S_MTE2); + for (int64_t colsLoop = 0; colsLoop < colLoops_; colsLoop++) { + if (colsLoop == colLoops_ - 1) { + curLoopCols = lastLoopCols_; + } + int64_t colsLoopOffset = colsLoop * perLoopCols_; + CopyXIn(xSrcOffset + colsLoopOffset, curLoopCols); + Compute(curLoopCols); + CopyXOut(xDstOffset + colsLoopOffset, curLoopCols); + } + } + expandRowIdxCopyInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutQuant::GatherCopyOut(int64_t progress) +{ + LocalTensor indicesLocal = expandRowIdxCopyInQueue_.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + colsTileLength_ = perLoopCols_; + for (int64_t colsLoop = 0; colsLoop < colLoops_; colsLoop++) { + int64_t initialRow = perCoreRow_ * blockIdx_ + perLoopRows_ * progress; + int64_t curLoopRow = 0; + if (colsLoop == colLoops_ - 1) { + colsTileLength_ = lastLoopCols_; + } + int64_t currentLoopStartRow = initialRow / k_; + int64_t currentLoopLastRow = (initialRow + currentLoopRows_ - 1) / k_; + for (int64_t row = currentLoopStartRow; row <= currentLoopLastRow; row++) { + inputOffset_ = row * cols_ + colsLoop * perLoopCols_; + // input row position + CopyXIn(inputOffset_, colsTileLength_); + Compute(colsTileLength_); + LocalTensor outLocal = inputXCopyOutQueue_.DeQue(); + DataCopyExtParams intriParams{1, static_cast(colsTileLength_ * sizeof(int8_t)), 0, 0, 0}; + SetWaitFlag(HardEvent::MTE2_MTE3); + while (curLoopRow < currentLoopRows_ && initialRow / k_ == row) { + int32_t outIndex = indicesLocal.GetValue(curLoopRow); + curLoopRow++; + initialRow++; + if (outIndex == -1 || (dropPadMode_ == DROPLESS_MODE && outIndex >= activeNum_)) { + continue; + } + outOffset_ = outIndex * cols_ + colsLoop * perLoopCols_; + DataCopyPad(expandedXGm_[outOffset_], outLocal, intriParams); + } + inputXCopyOutQueue_.FreeTensor(outLocal); + } + } + expandRowIdxCopyInQueue_.FreeTensor(indicesLocal); +} + +template +__aicore__ inline void MoeGatherOutQuant::Process() +{ + if (blockIdx_ < needCoreNum_) { + currentLoopRows_ = perLoopRows_; + for (int64_t loop = 0; loop < rowLoops_; loop++) { + if (loop == rowLoops_ - 1) { + currentLoopRows_ = lastLoopRows_; + } + CopyExpertIn(loop); + if constexpr (EP) { + ScatterCopyOut(loop); + } else { + GatherCopyOut(loop); + } + } + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_GATHER_STATIC_QUANT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort.h new file mode 100644 index 00000000000..4ae95fa2b00 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort.h @@ -0,0 +1,207 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_mrgsort.h + * \brief + */ +#ifndef MOE_CUSTOM_MRGSORT_H +#define MOE_CUSTOM_MRGSORT_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +struct MoeMrgsortParam { + int64_t perListElements; + int64_t lastListElements; + int64_t oneLoopMaxElements; +}; + +class MoeMrgsort { +public: + __aicore__ inline MoeMrgsort(){}; + __aicore__ inline void Init(MoeMrgsortParam *param); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor &gmInput, LocalTensor &ubInput); + __aicore__ inline void SetOutput(GlobalTensor &gmOutput, LocalTensor &ubOutput); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + +private: + MoeMrgsortParam *param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput; + + LocalTensor ubInputs[4]; + LocalTensor ubOutput; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail{0}; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeMrgsort::ClearCache() +{ + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeMrgsort::SetInput(GlobalTensor &gmInput, LocalTensor &ubInput) +{ + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeMrgsort::SetOutput(GlobalTensor &gmOutput, LocalTensor &ubOutput) +{ + this->gmOutput = gmOutput; + this->ubOutput = ubOutput; +} + +__aicore__ inline void MoeMrgsort::UpdateMrgParam() +{ + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeMrgsort::CopyIn() +{ + this->remainListNum = 0; + event_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], + Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeMrgsort::MrgsortCompute() +{ + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->ubOutput, this->tmpUbInputs[0], + Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeMrgsort::UpdateSortInfo() +{ + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeMrgsort::CopyOut() +{ + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = GetSortLen(curLoopSortedNum) * sizeof(float); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopyPad(this->gmOutput[outOffset], this->ubOutput, intriParams); + outOffset += GetSortLen(curLoopSortedNum); +} + +__aicore__ inline void MoeMrgsort::Init(MoeMrgsortParam *param) +{ + this->param = param; + this->remainListNum = listNum; + + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + if (i == listNum - 1) { + listRemainElements[i] = param->lastListElements; + } else { + listRemainElements[i] = param->perListElements; + } + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeMrgsort::Process() +{ + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + CopyOut(); + } + + ClearCache(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_MRGSORT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_out.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_out.h new file mode 100644 index 00000000000..84fb4b6726c --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_out.h @@ -0,0 +1,232 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_mrgsort_out.h + * \brief + */ +#ifndef MOE_CUSTOM_MRGSORT_OUT_H +#define MOE_CUSTOM_MRGSORT_OUT_H + +#include "moe_custom_mrgsort.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class MoeMrgsortOut { +public: + __aicore__ inline MoeMrgsortOut(){}; + __aicore__ inline void Init(MoeMrgsortParam *param, TPipe *tPipe); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor &gmInput, LocalTensor &ubInput); + __aicore__ inline void SetOutput(GlobalTensor &gmOutput1, GlobalTensor &gmOutput2, + LocalTensor &ubOutput1, LocalTensor &ubOutput2); + __aicore__ inline void SetBuffer(LocalTensor &tempBuffer); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void Extract(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + +private: + MoeMrgsortParam *param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput1; + GlobalTensor gmOutput2; + + LocalTensor ubInputs[4]; + LocalTensor tempBuffer; + + // for extract + LocalTensor ubOutput1; + LocalTensor ubOutput2; + + // for copy out + LocalTensor ubOutputInt1; + LocalTensor ubOutputInt2; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeMrgsortOut::ClearCache() +{ + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeMrgsortOut::SetInput(GlobalTensor &gmInput, LocalTensor &ubInput) +{ + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeMrgsortOut::SetOutput(GlobalTensor &gmOutput1, GlobalTensor &gmOutput2, + LocalTensor &ubOutput1, LocalTensor &ubOutput2) +{ + this->gmOutput1 = gmOutput1; + this->ubOutput1 = ubOutput1; + this->ubOutputInt1 = ubOutput1.ReinterpretCast(); + + this->gmOutput2 = gmOutput2; + this->ubOutput2 = ubOutput2.ReinterpretCast(); + this->ubOutputInt2 = ubOutput2.ReinterpretCast(); +} + +__aicore__ inline void MoeMrgsortOut::SetBuffer(LocalTensor &tempBuffer) +{ + this->tempBuffer = tempBuffer; +} + +__aicore__ inline void MoeMrgsortOut::UpdateMrgParam() +{ + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeMrgsortOut::CopyIn() +{ + this->remainListNum = 0; + event_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], + Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeMrgsortOut::MrgsortCompute() +{ + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->tempBuffer, this->tmpUbInputs[0], + Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeMrgsortOut::UpdateSortInfo() +{ + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeMrgsortOut::Extract() +{ + AscendC::Extract(this->ubOutput1, this->ubOutput2, this->tempBuffer, Ceil(curLoopSortedNum, ONE_REPEAT_SORT_NUM)); + Muls(this->ubOutput1, this->ubOutput1, (float)-1, Align(curLoopSortedNum, sizeof(float))); + Cast(this->ubOutputInt1, this->ubOutput1, RoundMode::CAST_ROUND, Align(curLoopSortedNum, sizeof(float))); +} + +__aicore__ inline void MoeMrgsortOut::CopyOut() +{ + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = curLoopSortedNum * sizeof(int32_t); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopyPad(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams); + DataCopyPad(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams); + + outOffset += curLoopSortedNum; +} + +__aicore__ inline void MoeMrgsortOut::Init(MoeMrgsortParam *param, TPipe *tPipe) +{ + this->param = param; + this->allRemainElements = 0; + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + if (i == listNum - 1) { + listRemainElements[i] = param->lastListElements; + } else { + listRemainElements[i] = param->perListElements; + } + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeMrgsortOut::Process() +{ + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + Extract(); + CopyOut(); + } + ClearCache(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_MRGSORT_OUT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_out_performance.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_out_performance.h new file mode 100644 index 00000000000..650c90bcbcc --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_out_performance.h @@ -0,0 +1,239 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_mrgsort_out_performance.h + * \brief + */ +#ifndef MOE_CUSTOM_MRGSORT_OUT_PERFORMANCE_H +#define MOE_CUSTOM_MRGSORT_OUT_PERFORMANCE_H + +#include "moe_custom_mrgsort_performance.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +constexpr int64_t MAX_MRGSORT_LIST = 4; +constexpr int64_t MAX_MRGSORT_LIST_TOTAL = 16; + +class MoeMrgsortOutPerformance { +public: + __aicore__ inline MoeMrgsortOutPerformance(){}; + __aicore__ inline void Init(MoeMrgsortPerformanceParam *param, TPipe *tPipe); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor &gmInput, LocalTensor &ubInput, + GlobalTensor &gmActualSortNum); + __aicore__ inline void SetOutput(GlobalTensor &gmOutput1, GlobalTensor &gmOutput2, + LocalTensor &ubOutput1, LocalTensor &ubOutput2); + __aicore__ inline void SetBuffer(LocalTensor &tempBuffer); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void Extract(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + +private: + MoeMrgsortPerformanceParam *param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput1; + GlobalTensor gmOutput2; + GlobalTensor gmActualSortNum; + + LocalTensor ubInputs[4]; + LocalTensor tempBuffer; + + // for extract + LocalTensor ubOutput1; + LocalTensor ubOutput2; + + // for copy out + LocalTensor ubOutputInt1; + LocalTensor ubOutputInt2; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4] = {0}; + int64_t listRemainElements[4] = {0}; + int64_t lengths[4] = {0}; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail; + uint16_t elementCountListTail[4] = {0}; + uint32_t listSortedNums[4] = {0}; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeMrgsortOutPerformance::ClearCache() +{ + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeMrgsortOutPerformance::SetInput(GlobalTensor &gmInput, LocalTensor &ubInput, + GlobalTensor &gmActualSortNum) +{ + if (this->listNum == 0) { + this->gmActualSortNum = gmActualSortNum; + } + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeMrgsortOutPerformance::SetOutput(GlobalTensor &gmOutput1, + GlobalTensor &gmOutput2, + LocalTensor &ubOutput1, LocalTensor &ubOutput2) +{ + this->gmOutput1 = gmOutput1; + this->ubOutput1 = ubOutput1; + this->ubOutputInt1 = ubOutput1.ReinterpretCast(); + + this->gmOutput2 = gmOutput2; + this->ubOutput2 = ubOutput2.ReinterpretCast(); + this->ubOutputInt2 = ubOutput2.ReinterpretCast(); +} + +__aicore__ inline void MoeMrgsortOutPerformance::SetBuffer(LocalTensor &tempBuffer) +{ + this->tempBuffer = tempBuffer; +} + +__aicore__ inline void MoeMrgsortOutPerformance::UpdateMrgParam() +{ + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeMrgsortOutPerformance::CopyIn() +{ + this->remainListNum = 0; + event_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], + Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeMrgsortOutPerformance::MrgsortCompute() +{ + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->tempBuffer, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->tempBuffer, this->tmpUbInputs[0], + Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeMrgsortOutPerformance::UpdateSortInfo() +{ + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeMrgsortOutPerformance::Extract() +{ + AscendC::Extract(this->ubOutput1, this->ubOutput2, this->tempBuffer, Ceil(curLoopSortedNum, ONE_REPEAT_SORT_NUM)); + Muls(this->ubOutput1, this->ubOutput1, (float)-1, Align(curLoopSortedNum, sizeof(float))); + Cast(this->ubOutputInt1, this->ubOutput1, RoundMode::CAST_ROUND, Align(curLoopSortedNum, sizeof(float))); +} + +__aicore__ inline void MoeMrgsortOutPerformance::CopyOut() +{ + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = curLoopSortedNum * sizeof(int32_t); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopyPad(this->gmOutput1[outOffset], this->ubOutputInt1, intriParams); + DataCopyPad(this->gmOutput2[outOffset], this->ubOutputInt2, intriParams); + + outOffset += curLoopSortedNum; +} + +__aicore__ inline void MoeMrgsortOutPerformance::Init(MoeMrgsortPerformanceParam *param, TPipe *tPipe) +{ + this->param = param; + for (int64_t i = 0; i < MAX_MRGSORT_LIST_TOTAL; i++) { + listRemainElements[i / MAX_MRGSORT_LIST] += static_cast(gmActualSortNum.GetValue(i)); + } + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i * MAX_MRGSORT_LIST); + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeMrgsortOutPerformance::Process() +{ + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + Extract(); + CopyOut(); + } + ClearCache(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_MRGSORT_OUT_PERFORMANCE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_performance.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_performance.h new file mode 100644 index 00000000000..5f5cc97a0a1 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_mrgsort_performance.h @@ -0,0 +1,206 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_mrgsort_performance.h + * \brief + */ +#ifndef MOE_CUSTOM_MRGSORT_PERFORMANCE_H +#define MOE_CUSTOM_MRGSORT_PERFORMANCE_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +struct MoeMrgsortPerformanceParam { + int64_t perListElements; + int64_t oneLoopMaxElements; +}; + +class MoeMrgsortPerformance { +public: + __aicore__ inline MoeMrgsortPerformance(){}; + __aicore__ inline void Init(MoeMrgsortPerformanceParam *param); + __aicore__ inline void Process(); + __aicore__ inline void SetInput(GlobalTensor &gmInput, LocalTensor &ubInput, + GlobalTensor &gmActualSortNum); + __aicore__ inline void SetOutput(GlobalTensor &gmOutput, LocalTensor &ubOutput); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void UpdateMrgParam(); + __aicore__ inline void MrgsortCompute(); + __aicore__ inline void UpdateSortInfo(); + __aicore__ inline void CopyOut(); + __aicore__ inline void ClearCache(); + +private: + MoeMrgsortPerformanceParam *param = nullptr; + + GlobalTensor gmInputs[4]; + GlobalTensor gmOutput; + GlobalTensor gmActualSortNum; + + LocalTensor ubInputs[4]; + LocalTensor ubOutput; + + int64_t listNum{0}; + int64_t remainListNum{0}; + int64_t outOffset{0}; + int64_t offsets[4]; + int64_t listRemainElements[4]; + int64_t lengths[4]; + int64_t allRemainElements{0}; + int64_t curLoopSortedNum{0}; + + // for MrgSort + uint16_t validBitTail{0}; + uint16_t elementCountListTail[4]; + uint32_t listSortedNums[4]; + LocalTensor tmpUbInputs[4]; +}; + +__aicore__ inline void MoeMrgsortPerformance::ClearCache() +{ + this->listNum = 0; + this->allRemainElements = 0; + this->outOffset = 0; +} + +__aicore__ inline void MoeMrgsortPerformance::SetInput(GlobalTensor &gmInput, LocalTensor &ubInput, + GlobalTensor &gmActualSortNum) +{ + if (this->listNum == 0) { + this->gmActualSortNum = gmActualSortNum; + } + this->gmInputs[listNum] = gmInput; + this->ubInputs[listNum] = ubInput; + this->listNum += 1; +} + +__aicore__ inline void MoeMrgsortPerformance::SetOutput(GlobalTensor &gmOutput, LocalTensor &ubOutput) +{ + this->gmOutput = gmOutput; + this->ubOutput = ubOutput; +} + +__aicore__ inline void MoeMrgsortPerformance::UpdateMrgParam() +{ + if (this->remainListNum == MERGE_LIST_TWO) { + elementCountListTail[MERGE_LIST_IDX_TWO] = 0; + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0011; + } else if (this->remainListNum == MERGE_LIST_THREE) { + elementCountListTail[MERGE_LIST_IDX_THREE] = 0; + validBitTail = 0b0111; + } else if (this->remainListNum == MERGE_LIST_FOUR) { + validBitTail = 0b1111; + } else { + validBitTail = 0b0001; + } +} + +__aicore__ inline void MoeMrgsortPerformance::CopyIn() +{ + this->remainListNum = 0; + event_t eventIdMte3ToMte2 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2)); + SetFlag(eventIdMte3ToMte2); + WaitFlag(eventIdMte3ToMte2); + for (int64_t i = 0, j = 0; i < listNum; i++) { + lengths[i] = Min(param->oneLoopMaxElements, listRemainElements[i]); + if (lengths[i] > 0) { + DataCopy(this->ubInputs[i], this->gmInputs[i][offsets[i]], + Align(GetSortLen(lengths[i]), sizeof(float))); + tmpUbInputs[j] = this->ubInputs[i]; + elementCountListTail[j] = lengths[i]; + this->remainListNum += 1; + j++; + } + } +} + +__aicore__ inline void MoeMrgsortPerformance::MrgsortCompute() +{ + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + if (this->remainListNum == MERGE_LIST_TWO) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[0], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_THREE) { + MrgSortSrcList sortListTail = + MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], tmpUbInputs[0]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else if (this->remainListNum == MERGE_LIST_FOUR) { + MrgSortSrcList sortListTail = MrgSortSrcList(tmpUbInputs[0], tmpUbInputs[1], tmpUbInputs[MERGE_LIST_IDX_TWO], + tmpUbInputs[MERGE_LIST_IDX_THREE]); + MrgSort(this->ubOutput, sortListTail, elementCountListTail, listSortedNums, validBitTail, 1); + } else { + DataCopy(this->ubOutput, this->tmpUbInputs[0], + Align(GetSortLen(elementCountListTail[0]), sizeof(float))); + listSortedNums[0] = elementCountListTail[0]; + } +} + +__aicore__ inline void MoeMrgsortPerformance::UpdateSortInfo() +{ + curLoopSortedNum = 0; + for (int64_t i = 0, j = 0; i < listNum; i++) { + if (lengths[i] > 0) { + // update remain size + listRemainElements[i] -= listSortedNums[j]; + allRemainElements -= listSortedNums[j]; + // update offset + offsets[i] += GetSortOffset(listSortedNums[j]); + // update current loop sorted nums + curLoopSortedNum += listSortedNums[j]; + j += 1; + } + } +} + +__aicore__ inline void MoeMrgsortPerformance::CopyOut() +{ + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = GetSortLen(curLoopSortedNum) * sizeof(float); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopyPad(this->gmOutput[outOffset], this->ubOutput, intriParams); + outOffset += GetSortLen(curLoopSortedNum); +} + +__aicore__ inline void MoeMrgsortPerformance::Init(MoeMrgsortPerformanceParam *param) +{ + this->param = param; + for (int64_t i = 0; i < listNum; i++) { + offsets[i] = GetSortOffset(param->perListElements * i); + listRemainElements[i] = static_cast(gmActualSortNum.GetValue(i)); + allRemainElements += listRemainElements[i]; + } +} + +__aicore__ inline void MoeMrgsortPerformance::Process() +{ + for (; allRemainElements > 0;) { + CopyIn(); + UpdateMrgParam(); + MrgsortCompute(); + UpdateSortInfo(); + CopyOut(); + } + + ClearCache(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_MRGSORT_PERFORMANCE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather.h new file mode 100644 index 00000000000..03f35a2e72f --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather.h @@ -0,0 +1,204 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_row_idx_gather.h + * \brief + */ +#ifndef MOE_CUSTOM_ROW_IDX_GATHER_H +#define MOE_CUSTOM_ROW_IDX_GATHER_H + +#include "moe_custom_common.h" +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class RowIdxGather { +public: + __aicore__ inline RowIdxGather(){}; + __aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR workspace, const MoeInitRoutingCustomTilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t loop, int64_t elements); + __aicore__ inline void Compute(int64_t loop, int64_t elements); + __aicore__ inline void CopyOut(int64_t loop, int64_t elements, GlobalTensor &RowIdxDstGm_); + __aicore__ inline void AssistInit(); + +private: + GlobalTensor expandedRowIdxGm_; + GlobalTensor sortedExpertIndicesGm_; + GlobalTensor expertTokensCountGm_; + GlobalTensor expertTotalCountGm_; + GlobalTensor assistGm_; + GlobalTensor gatherIndicesGm_; + + TPipe *pipe_; + + TQue sortedExpertIndicesInQueue_; + TQue copyOutQueue_; + TBuf assistBuffer_; + + const MoeCustomSrcToDstComputeTilingData *srcToDstComputeTilingData_; + int64_t blockIdx_; + int64_t needCoreNum_; + int64_t perCoreElements_; + int64_t actualExpertNum_ = 0; + int64_t ep_ = 0; + int64_t rowIdxType_ = 0; + int64_t expertTotalCount_ = 0; + + int64_t loops_ = 0; + int64_t perLoopElements_ = 0; + int64_t lastLoopElements_ = 0; +}; + +__aicore__ inline void RowIdxGather::AssistInit() +{ + LocalTensor assistTensor = assistBuffer_.Get(ASSIST_NUM); + DataCopy(assistTensor, assistGm_, ASSIST_NUM); + SetWaitFlag(HardEvent::MTE2_V); + Adds(assistTensor, assistTensor, (int32_t)(blockIdx_ * perCoreElements_), ASSIST_NUM); +} + +__aicore__ inline void RowIdxGather::Init(GM_ADDR expandedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + pipe_ = tPipe; + srcToDstComputeTilingData_ = &(tilingData->srcToDstComputeParamsOp); + blockIdx_ = GetBlockIdx(); + actualExpertNum_ = tilingData->actualExpertNum; + ep_ = tilingData->ep; + rowIdxType_ = tilingData->rowIdxType; + + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, actualExpertNum_); + + if (ep_) { + expertTotalCountGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * 2 + + Align(actualExpertNum_, sizeof(int32_t)), + actualExpertNum_); + AscendC::DataCacheCleanAndInvalid(expertTotalCountGm_); + expertTotalCount_ = expertTotalCountGm_.GetValue(0); + } else { + expertTotalCount_ = tilingData->n * tilingData->k; + } + assistGm_.SetGlobalBuffer((__gm__ int32_t *)assist, ASSIST_NUM); + perCoreElements_ = Ceil(expertTotalCount_, srcToDstComputeTilingData_->needCoreNum); + needCoreNum_ = Ceil(expertTotalCount_, perCoreElements_); + + int64_t lastCoreElements = expertTotalCount_ - (needCoreNum_ - 1) * perCoreElements_; + int64_t perCoreLoops = Ceil(perCoreElements_, srcToDstComputeTilingData_->perCorePerLoopElements); + int64_t perCorePerLoopElements = Ceil(perCoreElements_, perCoreLoops); + int64_t perCoreLastLoopElements = perCoreElements_ - (perCoreLoops - 1) * perCorePerLoopElements; + + int64_t lastCoreLoops = Ceil(lastCoreElements, srcToDstComputeTilingData_->perCorePerLoopElements); + int64_t lastCorePerLoopElements = Ceil(lastCoreElements, lastCoreLoops); + int64_t lastCoreLastLoopELements = lastCoreElements - (lastCoreLoops - 1) * lastCorePerLoopElements; + + loops_ = perCoreLoops; + if (blockIdx_ == needCoreNum_ - 1) { + loops_ = lastCoreLoops; + perLoopElements_ = lastCorePerLoopElements; + lastLoopElements_ = lastCoreLastLoopELements; + } else { + loops_ = perCoreLoops; + perLoopElements_ = perCorePerLoopElements; + lastLoopElements_ = perCoreLastLoopElements; + } + + if (rowIdxType_ == SCATTER) { + sortedExpertIndicesGm_.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx + blockIdx_ * perCoreElements_, + actualExpertNum_); + } else { + sortedExpertIndicesGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t)) + + blockIdx_ * perCoreElements_, + actualExpertNum_); + } + + if ((ep_ == 0 && rowIdxType_ == SCATTER) && (blockIdx_ < needCoreNum_)) { + expandedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t))); + } + pipe_->InitBuffer(sortedExpertIndicesInQueue_, 1, AlignBytes(perLoopElements_, sizeof(int32_t))); + pipe_->InitBuffer(copyOutQueue_, 1, Ceil(perLoopElements_, ASSIST_NUM) * ASSIST_NUM * BLOCK_BYTES); + pipe_->InitBuffer(assistBuffer_, ASSIST_NUM * sizeof(int32_t)); +} + +__aicore__ inline void RowIdxGather::Process() +{ + if (ep_ == 1 && rowIdxType_ == SCATTER) { + return; + } else { + if (blockIdx_ < needCoreNum_) { + AssistInit(); + for (int64_t loop = 0; loop < loops_; loop++) { + int64_t elements = perLoopElements_; + if (loop == loops_ - 1) { + elements = lastLoopElements_; + } + CopyIn(loop, elements); + Compute(loop, elements); + CopyOut(loop, elements, expandedRowIdxGm_); + } + } + } + AscendC::SyncAll(); +} + +__aicore__ inline void RowIdxGather::CopyIn(int64_t loop, int64_t elements) +{ + LocalTensor sortedExpertIndicesInLocal = sortedExpertIndicesInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(elements * sizeof(int32_t)), 0, 0, + 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(sortedExpertIndicesInLocal, sortedExpertIndicesGm_[loop * perLoopElements_], dataCopyParams, + dataCopyPadParams); + sortedExpertIndicesInQueue_.EnQue(sortedExpertIndicesInLocal); +} + +__aicore__ inline void RowIdxGather::Compute(int64_t loop, int64_t elements) +{ + LocalTensor outLocal = copyOutQueue_.AllocTensor(); + LocalTensor assistTensor = assistBuffer_.Get(ASSIST_NUM); + PipeBarrier(); + int64_t loops = Ceil(elements, ASSIST_INDEX_NUM); + for (int64_t i = 0; i < loops; i++) { + Adds(outLocal[i * ASSIST_NUM], assistTensor, + static_cast(perLoopElements_ * loop + i * ASSIST_INDEX_NUM), ASSIST_NUM); + } + PipeBarrier(); + copyOutQueue_.EnQue(outLocal); +} + +__aicore__ inline void RowIdxGather::CopyOut(int64_t loop, int64_t elements, GlobalTensor &RowIdxDstGm_) +{ + LocalTensor inLocal = sortedExpertIndicesInQueue_.DeQue(); + LocalTensor outLocal = copyOutQueue_.DeQue(); + SetWaitFlag(HardEvent::MTE2_S); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = sizeof(int32_t); + uint32_t outOffset; + for (int64_t idx = 0; idx < elements; idx++) { + outOffset = inLocal.GetValue(idx); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(RowIdxDstGm_[outOffset], outLocal[idx * INT32_ONE_BLOCK_NUM], intriParams); + } + + sortedExpertIndicesInQueue_.FreeTensor(inLocal); + copyOutQueue_.FreeTensor(outLocal); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_ROW_IDX_GATHER_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather_droppad.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather_droppad.h new file mode 100644 index 00000000000..b8a94e5364b --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather_droppad.h @@ -0,0 +1,306 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_row_idx_gather_droppad.h + * \brief + */ +#ifndef MOE_CUSTOM_ROW_IDX_GATHER_DROPPAD_H +#define MOE_CUSTOM_ROW_IDX_GATHER_DROPPAD_H + +#include "moe_custom_common.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +template +class MoeCustomSrcToDstWithCapacity { +public: + __aicore__ inline MoeCustomSrcToDstWithCapacity(){}; + __aicore__ inline void Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, GM_ADDR expandedScale, GM_ADDR workspace, + const TilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + __aicore__ inline void CopyOutRemain(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + +private: + TPipe *pipe; + TQue copyInQueue; + TQue copyOutQueue; + TQue copyOutZeroQueue; + TQue scaleOutZeroQueue; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expandedXGm; + GlobalTensor expandedScaleGm; + + LocalTensor outTmpLocal; + LocalTensor scaleLocal; + + const MoeCustomSrcToDstCapacityComputeTilingData *srcToDstTilingData; + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t expertCapacity; + int64_t expertNum; + int64_t cols; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + int64_t isInputScale_; + int64_t quantMode_; + + int64_t tokenCount = 0; + int32_t lastExpertId = -1; + int32_t lastCoreExpertId = 0; + int32_t lastCoreExpertIdNum = 0; + bool needScaleCopy = false; +}; + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::AssistInit() +{ + if constexpr (IsSameType::value) { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + } else { + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + } + if (this->needScaleCopy) { + LocalTensor scaleOutLocal = scaleOutZeroQueue.AllocTensor(); + Duplicate(scaleOutLocal, 0.0f, FP32_ONE_BLOCK_NUM); + scaleOutZeroQueue.EnQue(scaleOutLocal); + } + + if (this->blockIdx != 0) { + this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2); + this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * 2 + 1); + for (int64_t i = this->blockIdx - 2; i >= 0; i--) { + int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * 2); + if (lastExpertIdx < this->lastCoreExpertId) { + break; + } + int32_t lastExpertNum = expertIdxValueGm.GetValue(i * 2 + 1); + this->lastCoreExpertIdNum += lastExpertNum; + } + } +} + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::CopyIn(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length); + DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length); + copyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::CopyOut(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + DataCopyExtParams ScaleParams{1, static_cast(sizeof(float)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + if (this->needScaleCopy) { + DataCopyPad(expandedScaleGm[index], this->scaleLocal, ScaleParams); + } + int64_t col = this->perLoopCols; + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams1{static_cast(1), static_cast(col * sizeof(T)), 0, + 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, + copyParams1); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::CopyOutRemain() +{ + if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) { + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + if (this->needScaleCopy) { + scaleOutZeroQueue.FreeTensor(this->scaleLocal); + } + return; + } + DataCopyExtParams ScaleParams{1, static_cast(sizeof(float)), 0, 0, 0}; + while (this->lastExpertId < this->expertNum) { + while (this->tokenCount < this->expertCapacity) { + int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount; + if (this->needScaleCopy) { + DataCopyPad(expandedScaleGm[index], this->scaleLocal, ScaleParams); + } + int64_t col = this->perLoopCols; + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(col * sizeof(T)), 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + if (this->needScaleCopy) { + scaleOutZeroQueue.FreeTensor(this->scaleLocal); + } +} + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::SyncAll() +{ + if (coreNum == 1) { + return; + } +#ifndef __CCE_KT_TEST__ + AscendC::SyncAll(); +#endif +} + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::Init(GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR expandedScale, GM_ADDR workspace, + const TilingData *tilingData, + TPipe *tPipe) +{ + int64_t blockNum = GetBlockNum(); + this->pipe = tPipe; + this->blockIdx = GetBlockIdx(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstDropPadParamsOp); + this->expertNum = tilingData->expertNum; + this->expertCapacity = tilingData->expertCapacity; + this->cols = tilingData->cols; + this->isInputScale_ = tilingData->isInputScale; + this->quantMode_ = tilingData->quantMode; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->lastCoreLoops; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->perCoreLoops; + } + this->perLoopCols = this->srcToDstTilingData->perLoopCols; + this->lastLoopCols = this->srcToDstTilingData->lastLoopCols; + this->colLoops = this->srcToDstTilingData->colLoops; + this->needScaleCopy = (this->isInputScale_ != 0 && this->quantMode_ == -1); + + expandedScaleGm.SetGlobalBuffer((__gm__ float *)expandedScale); + + int64_t length = Align(this->totalLength, sizeof(int32_t)); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, length); + expandedXGm.SetGlobalBuffer((__gm__ T *)expandedX, this->expertNum * this->expertCapacity * this->cols); + + expandedExpertIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t *)workspace + length + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expertIdxValueGm.SetGlobalBuffer( + (__gm__ int32_t *)workspace + length * 2 + Align(this->expertNum, sizeof(int32_t)) * 2, this->coreNum * 2); + + pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2); + pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t))); + if constexpr (IsSameType::value) { + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t))); + } else { + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(T))); + } + if (this->needScaleCopy) { + pipe->InitBuffer(scaleOutZeroQueue, 1, BLOCK_BYTES); + } +} + +template +__aicore__ inline void MoeCustomSrcToDstWithCapacity::Process() +{ + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + AssistInit(); + this->outTmpLocal = copyOutZeroQueue.DeQue(); + if (this->needScaleCopy) { + this->scaleLocal = scaleOutZeroQueue.DeQue(); + } + currentLoopRows = perLoopRows; + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOut(loop); + } + CopyOutRemain(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_ROW_IDX_GATHER_DROPPAD_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather_droppad_dynamic.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather_droppad_dynamic.h new file mode 100644 index 00000000000..3a1800dc7b0 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_row_idx_gather_droppad_dynamic.h @@ -0,0 +1,582 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_row_idx_gather_droppad_dynamic.h + * \brief + */ +#ifndef MOE_CUSTOM_ROW_IDX_GATHER_DROPPAD_DYNAMIC_H +#define MOE_CUSTOM_ROW_IDX_GATHER_DROPPAD_DYNAMIC_H + +#include "moe_custom_common.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +template +class MoeCustomSrcToDstAndGather { +public: + __aicore__ inline MoeCustomSrcToDstAndGather(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx, GM_ADDR expandedX, + GM_ADDR dynamicQuantScale, GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(int64_t progress); + __aicore__ inline void CopyOut(int64_t progress); + __aicore__ inline void CopyOutLoops(int64_t progress); + __aicore__ inline void Compute(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx); + __aicore__ inline float ComputeMax(LocalTensor &inLocal, LocalTensor &tempLocal, + LocalTensor &dynamicQuantLocal, int32_t srcIdx, int32_t expertIdx, + int64_t j); + __aicore__ inline void ComputeScale(LocalTensor &inLocal, LocalTensor &tempLocal, float scaleTemp, + int64_t dstIndex, int64_t j); + __aicore__ inline void ComputeLoops(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx); + + __aicore__ inline void CopyOutRemain(); + __aicore__ inline void SyncAll(); + __aicore__ inline void AssistInit(); + +private: + TPipe *pipe; + TQue copyInQueue; + TQue copyOutQueue; + TQue copyOutZeroQueue; + + TQue inputXInQueue; + TQue smoothInQueue; + TQue calcQueue; + TQue inputXOutQueue; + TQue scaleOutQueue; + TQue scaleOutZeroQueue; + + GlobalTensor expandDstToSrcRowGm; + GlobalTensor expandedRowIdxGm; + GlobalTensor expertIdxValueGm; + GlobalTensor expandedExpertIdxGm; + GlobalTensor expandedXGm; + + GlobalTensor inputXGm; + GlobalTensor quantSmoothGm; + GlobalTensor dynamicQuantScaleGm; + GlobalTensor quantSrcGm; + + LocalTensor outTmpLocal; + LocalTensor scaleOutTmpLocal; + LocalTensor smoothLocal; + + const MoeCustomSrcToDstCapacityComputeTilingData *srcToDstTilingData; + + int64_t coreNum; + int64_t blockIdx; + int64_t totalLength; + int64_t currentLoopRows; + int64_t coreRows; + int64_t perLoopRows; + int64_t lastLoopRows; + int64_t rowLoops; + int64_t expertCapacity; + int64_t expertNum; + int64_t cols; + int64_t perLoopCols; + int64_t lastLoopCols; + int64_t colLoops; + int64_t perLoopColsAlign; + int64_t k; + int64_t colsTileLength; + int64_t smoothType; + + int64_t tokenCount = 0; + int32_t lastExpertId = -1; + int32_t lastCoreExpertId = 0; + int32_t lastCoreExpertIdNum = 0; +}; + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::AssistInit() +{ + LocalTensor outLocal = copyOutZeroQueue.AllocTensor(); + Duplicate(outLocal, static_cast(0), this->perLoopCols); + copyOutZeroQueue.EnQue(outLocal); + LocalTensor scaleOutLocal = scaleOutZeroQueue.AllocTensor(); + Duplicate(scaleOutLocal, 0.0f, 8); + scaleOutZeroQueue.EnQue(scaleOutLocal); + + if (this->blockIdx != 0) { + this->lastCoreExpertId = expertIdxValueGm.GetValue((this->blockIdx - 1) * EXPERT_ID_VALUE_NUM); + this->lastCoreExpertIdNum = expertIdxValueGm.GetValue((this->blockIdx - 1) * EXPERT_ID_VALUE_NUM + 1); + for (int64_t i = this->blockIdx - 2; i >= 0; i--) { + int32_t lastExpertIdx = expertIdxValueGm.GetValue(i * EXPERT_ID_VALUE_NUM); + if (lastExpertIdx < this->lastCoreExpertId) { + break; + } + int32_t lastExpertNum = expertIdxValueGm.GetValue(i * EXPERT_ID_VALUE_NUM + 1); + this->lastCoreExpertIdNum += lastExpertNum; + } + } +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::CopyIn(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopy(inLocal, expandDstToSrcRowGm[progress * perLoopRows], length); + DataCopy(inLocal[length], expandedExpertIdxGm[progress * perLoopRows], length); + + copyInQueue.EnQue(inLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::Compute(int32_t srcIdx, int32_t dstIdx, int32_t expertIdx) +{ + DataCopyExtParams copyInParams{1, static_cast(this->cols * sizeof(T)), 0, 0, 0}; + DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + LocalTensor inLocal = inputXInQueue.AllocTensor(); + + if constexpr (IsSameType::value) { + DataCopyPad(inLocal, inputXGm[srcIdx / this->k * this->cols], copyInParams, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal.template ReinterpretCast()[perLoopColsAlign], inputXGm[srcIdx / this->k * this->cols], + copyInParams, {false, 0, 0, 0}); + } + + if (smoothType == SCALE_EH) { + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols], smoothParams, {false, 0, 0, 0}); + } + + inputXInQueue.EnQue(inLocal); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + + inLocal = inputXInQueue.DeQue(); + + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.template ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); + PipeBarrier(); + } + + if (smoothType != NO_SCALE) { + Mul(inLocal, inLocal, smoothLocal, this->cols); + PipeBarrier(); + } + + Abs(tempLocal, inLocal, this->cols); + PipeBarrier(); + + ReduceMax(dynamicQuantLocal, tempLocal, tempLocal, this->cols); + PipeBarrier(); + + float maxValue = dynamicQuantLocal.GetValue(0) / MAX_INT8; + + Duplicate(dynamicQuantLocal, maxValue, FP32_ONE_BLOCK_NUM); + Duplicate(tempLocal, maxValue, this->cols); + PipeBarrier(); + + Div(tempLocal, inLocal, tempLocal, this->cols); + PipeBarrier(); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_RINT, this->cols); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + Cast(tempLocal.ReinterpretCast(), tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->cols); + PipeBarrier(); + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_TRUNC, this->cols); + + calcQueue.FreeTensor(tempLocal); + inputXOutQueue.EnQue(outLocal); + scaleOutQueue.EnQue(dynamicQuantLocal); + + LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); + DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, quantScaleParams); + + outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm[dstIdx * this->cols], outLocal, copyOutParams); + + inputXInQueue.FreeTensor(inLocal); + inputXOutQueue.FreeTensor(outLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::CopyOut(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + DataCopyExtParams copyParams1{static_cast(1), static_cast(this->cols * sizeof(int8_t)), 0, 0, + 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + DataCopyPad(expandedXGm[index * this->cols], this->outTmpLocal, copyParams1); + DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, quantScaleParams); + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + Compute(outOffset, index, expertIdx); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline float MoeCustomSrcToDstAndGather::ComputeMax(LocalTensor &inLocal, + LocalTensor &tempLocal, + LocalTensor &dynamicQuantLocal, + int32_t srcIdx, int32_t expertIdx, int64_t j) +{ + LocalTensor smoothLocal = smoothInQueue.AllocTensor(); + + DataCopyExtParams intriParamsT{1, static_cast(colsTileLength * sizeof(T)), 0, 0, 0}; + DataCopyExtParams intriParamsFp32{1, static_cast(colsTileLength * sizeof(float)), 0, 0, 0}; + + if constexpr (!IsSameType::value) { + DataCopyPad(inLocal.ReinterpretCast()[perLoopColsAlign], + inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0}); + } else { + DataCopyPad(inLocal, inputXGm[srcIdx * this->cols + j * this->perLoopCols], intriParamsT, {false, 0, 0, 0}); + } + + inputXInQueue.EnQue(inLocal); + inLocal = inputXInQueue.DeQue(); + + if constexpr (!IsSameType::value) { + Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, colsTileLength); + PipeBarrier(); + } + + if (smoothType != NO_SCALE) { + DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols + j * this->perLoopCols], intriParamsFp32, + {false, 0, 0, 0}); + smoothInQueue.EnQue(smoothLocal); + smoothLocal = smoothInQueue.DeQue(); + + Mul(inLocal, inLocal, smoothLocal, colsTileLength); + PipeBarrier(); + } + + Abs(tempLocal, inLocal, colsTileLength); + PipeBarrier(); + + ReduceMax(dynamicQuantLocal[FP32_ONE_BLOCK_NUM], tempLocal, tempLocal, colsTileLength); + + DataCopyPad(quantSrcGm[j * this->perLoopCols], inLocal, intriParamsFp32); + smoothInQueue.FreeTensor(smoothLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); + + return dynamicQuantLocal.GetValue(FP32_ONE_BLOCK_NUM); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::ComputeScale(LocalTensor &inLocal, + LocalTensor &tempLocal, + float scaleTemp, int64_t dstIndex, int64_t j) +{ + DataCopyExtParams copyInParams{1, static_cast(colsTileLength * sizeof(float)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast(colsTileLength * sizeof(int8_t)), 0, 0, 0}; + + LocalTensor outLocal = inputXOutQueue.AllocTensor(); + + DataCopyPad(inLocal, quantSrcGm[j * this->perLoopCols], copyInParams, {false, 0, 0, 0}); + inputXInQueue.EnQue(inLocal); + inLocal = inputXInQueue.DeQue(); + + Duplicate(tempLocal, scaleTemp, colsTileLength); + PipeBarrier(); + + Div(tempLocal, inLocal, tempLocal, colsTileLength); + PipeBarrier(); + + Cast(tempLocal.ReinterpretCast(), tempLocal, RoundMode::CAST_RINT, colsTileLength); + PipeBarrier(); + SetDeqScale((half)1.000000e+00f); + Cast(tempLocal.ReinterpretCast(), tempLocal.ReinterpretCast(), RoundMode::CAST_ROUND, + colsTileLength); + PipeBarrier(); + Cast(outLocal, tempLocal.ReinterpretCast(), RoundMode::CAST_TRUNC, colsTileLength); + + inputXOutQueue.EnQue(outLocal); + outLocal = inputXOutQueue.DeQue(); + DataCopyPad(expandedXGm[dstIndex * this->cols + j * this->perLoopCols], outLocal, copyOutParams); + + inputXOutQueue.FreeTensor(outLocal); + SetWaitFlag(HardEvent::MTE3_MTE2); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::ComputeLoops(int32_t srcIdx, int32_t dstIdx, + int32_t expertIdx) +{ + LocalTensor inLocal = inputXInQueue.AllocTensor(); + LocalTensor tempLocal = calcQueue.AllocTensor(); + LocalTensor quantScaleLocal = scaleOutQueue.AllocTensor(); + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + uint32_t tmp = 0xFF7FFFFF; + float reduceMax = *((float *)&tmp); + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + float tileMax = ComputeMax(inLocal, tempLocal, quantScaleLocal, srcIdx / this->k, expertIdx, j); + reduceMax = (reduceMax > tileMax) ? reduceMax : tileMax; + } + + float scaleTemp = reduceMax / 127.0f; + Duplicate(quantScaleLocal, scaleTemp, 8); + scaleOutQueue.EnQue(quantScaleLocal); + quantScaleLocal = scaleOutQueue.DeQue(); + + DataCopyPad(dynamicQuantScaleGm[dstIdx], quantScaleLocal, quantScaleParams); + + for (int64_t j = 0; j < this->colLoops; j++) { + colsTileLength = this->perLoopCols; + if (j == this->colLoops - 1) { + colsTileLength = this->lastLoopCols; + } + ComputeScale(inLocal, tempLocal, scaleTemp, dstIdx, j); + } + + inputXInQueue.FreeTensor(inLocal); + calcQueue.FreeTensor(tempLocal); + scaleOutQueue.FreeTensor(quantScaleLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::CopyOutLoops(int64_t progress) +{ + LocalTensor inLocal = copyInQueue.DeQue(); + LocalTensor outLocal = copyOutQueue.AllocTensor(); + int64_t length = Align(currentLoopRows, sizeof(int32_t)); + DataCopyExtParams copyParams{static_cast(1), static_cast(sizeof(int32_t)), 0, 0, 0}; + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + + SetWaitFlag(HardEvent::MTE2_S); + if (this->lastExpertId == -1) { + this->lastExpertId = this->lastCoreExpertId; + this->tokenCount = this->lastCoreExpertIdNum; + } + for (int64_t idx = 0; idx < currentLoopRows; idx++) { + int32_t expertIdx = inLocal[length].GetValue(idx); + SetWaitFlag(HardEvent::S_MTE3); + int32_t index = 0; + while (this->lastExpertId < expertIdx) { + while (this->tokenCount < this->expertCapacity) { + index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, quantScaleParams); + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams1{static_cast(1), static_cast(col * sizeof(int8_t)), + 0, 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, + copyParams1); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + + if (this->tokenCount < this->expertCapacity) { + int32_t outOffset = inLocal.GetValue(idx); + index = expertIdx * this->expertCapacity + this->tokenCount; + outLocal.SetValue(0, index); + SetWaitFlag(HardEvent::S_MTE3); + DataCopyPad(expandedRowIdxGm[outOffset], outLocal, copyParams); + if (smoothType == SCALE_EH) { + ComputeLoops(outOffset, index, expertIdx); + } else { + ComputeLoops(outOffset, index, 0); + } + SetWaitFlag(HardEvent::MTE3_S); + this->tokenCount++; + } + } + copyInQueue.FreeTensor(inLocal); + copyOutQueue.FreeTensor(outLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::CopyOutRemain() +{ + DataCopyExtParams quantScaleParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + if (this->blockIdx != this->srcToDstTilingData->needCoreNum - 1) { + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal); + return; + } + while (this->lastExpertId < this->expertNum) { + while (this->tokenCount < this->expertCapacity) { + int32_t index = this->lastExpertId * this->expertCapacity + this->tokenCount; + int64_t col = this->perLoopCols; + DataCopyPad(dynamicQuantScaleGm[index], this->scaleOutTmpLocal, quantScaleParams); + for (int64_t i = 0; i < this->colLoops; i++) { + if (i == this->colLoops - 1) { + col = this->lastLoopCols; + } + DataCopyExtParams copyParams{static_cast(1), static_cast(col * sizeof(int8_t)), 0, + 0, 0}; + DataCopyPad(expandedXGm[index * this->cols + i * this->perLoopCols], this->outTmpLocal, copyParams); + SetWaitFlag(HardEvent::MTE3_S); + } + this->tokenCount++; + } + this->tokenCount = 0; + this->lastExpertId++; + } + copyOutZeroQueue.FreeTensor(this->outTmpLocal); + scaleOutZeroQueue.FreeTensor(this->scaleOutTmpLocal); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::Init(GM_ADDR x, GM_ADDR scale, GM_ADDR expandedRowIdx, + GM_ADDR expandedX, GM_ADDR dynamicQuantScale, + GM_ADDR workspace, const TilingData *tilingData, + TPipe *tPipe) +{ + int64_t blockNum = GetBlockNum(); + this->pipe = tPipe; + this->blockIdx = GetBlockIdx(); + + this->coreNum = tilingData->coreNum; + this->totalLength = tilingData->n * tilingData->k; + this->srcToDstTilingData = &(tilingData->srcToDstDropPadDynamicParamsOp); + this->expertNum = tilingData->expertNum; + this->expertCapacity = tilingData->expertCapacity; + this->cols = tilingData->cols; + this->k = tilingData->k; + this->smoothType = tilingData->smoothType; + + if (this->blockIdx == this->srcToDstTilingData->needCoreNum - 1) { + this->coreRows = this->srcToDstTilingData->lastCoreRows; + this->perLoopRows = this->srcToDstTilingData->lastCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->lastCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->lastCoreLoops; + } else { + this->coreRows = this->srcToDstTilingData->perCoreRows; + this->perLoopRows = this->srcToDstTilingData->perCorePerLoopRows; + this->lastLoopRows = this->srcToDstTilingData->perCoreLastLoopRows; + this->rowLoops = this->srcToDstTilingData->perCoreLoops; + } + this->perLoopCols = this->srcToDstTilingData->perLoopCols; + this->lastLoopCols = this->srcToDstTilingData->lastLoopCols; + this->colLoops = this->srcToDstTilingData->colLoops; + this->perLoopColsAlign = Align(this->perLoopCols, sizeof(T)); + + inputXGm.SetGlobalBuffer((__gm__ T *)x); + quantSmoothGm.SetGlobalBuffer((__gm__ float *)scale); + dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float *)dynamicQuantScale); + + int64_t length = Align(this->totalLength, sizeof(int32_t)); + expandedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expandedRowIdx, length); + expandedXGm.SetGlobalBuffer((__gm__ int8_t *)expandedX, this->expertNum * this->expertCapacity * this->cols); + + expandedExpertIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expandDstToSrcRowGm.SetGlobalBuffer((__gm__ int32_t *)workspace + length + + this->blockIdx * this->srcToDstTilingData->perCoreRows, + Align(this->coreRows, sizeof(int32_t))); + expertIdxValueGm.SetGlobalBuffer( + (__gm__ int32_t *)workspace + length * 2 + Align(this->expertNum, sizeof(int32_t)) * 2, this->coreNum * 2); + if (this->colLoops > 1) { + quantSrcGm.SetGlobalBuffer((__gm__ float *)workspace + length * 2 + + Align(this->expertNum, sizeof(int32_t)) * 2 + this->coreNum * 2 + + this->blockIdx * this->cols, + this->cols * sizeof(float)); + } + + pipe->InitBuffer(copyInQueue, 1, AlignBytes(this->perLoopRows, sizeof(int32_t)) * 2); + pipe->InitBuffer(copyOutQueue, 1, AlignBytes(INT32_ONE_BLOCK_NUM, sizeof(int32_t))); + pipe->InitBuffer(copyOutZeroQueue, 1, AlignBytes(this->perLoopCols, sizeof(int16_t))); + + int64_t perLoopColsAlignBytes = AlignBytes(this->perLoopCols, sizeof(T)); + perLoopColsAlignBytes = + Max(int64_t(perLoopColsAlignBytes * sizeof(float) / sizeof(T)), int64_t(BLOCK_BYTES + BLOCK_BYTES)); + + pipe->InitBuffer(inputXInQueue, 1, perLoopColsAlignBytes); + pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); + pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); + pipe->InitBuffer(scaleOutZeroQueue, 1, BLOCK_BYTES); +} + +template +__aicore__ inline void MoeCustomSrcToDstAndGather::Process() +{ + if (this->blockIdx < this->srcToDstTilingData->needCoreNum) { + AssistInit(); + this->outTmpLocal = copyOutZeroQueue.DeQue(); + this->scaleOutTmpLocal = scaleOutZeroQueue.DeQue(); + currentLoopRows = perLoopRows; + if (colLoops > 1) { + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOutLoops(loop); + } + } else { + smoothLocal = smoothInQueue.AllocTensor(); + if (smoothType == SCALE_1H) { + DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; + DataCopyPad(smoothLocal, quantSmoothGm, smoothParams, {false, 0, 0, 0}); + } + for (int64_t loop = 0; loop < this->rowLoops; loop++) { + if (loop == this->rowLoops - 1) { + currentLoopRows = lastLoopRows; + } + CopyIn(loop); + CopyOut(loop); + } + smoothInQueue.FreeTensor(smoothLocal); + } + CopyOutRemain(); + } +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_ROW_IDX_GATHER_DROPPAD_DYNAMIC_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_actual_expert.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_actual_expert.h new file mode 100644 index 00000000000..b8c7355aed1 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_actual_expert.h @@ -0,0 +1,430 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_sort_actual_expert.h + * \brief + */ +#ifndef MOE_CUSTOM_SORT_ACTUAL_EXPERT_H +#define MOE_CUSTOM_SORT_ACTUAL_EXPERT_H + +namespace MoeInitRoutingCustom { +using namespace AscendC; +constexpr int64_t MULTI_GATHERED_SORT_CORE_NUM = 16; +constexpr int64_t MULTI_GATHERED_SORT_THRSHOLD = 5632; +constexpr int64_t SINGLE_GATHERED_BUFFER_NUM = 2; +constexpr int64_t SINGLE_GATHERED_MAX_NUM = 21845; + +template +class MoeSortActualExpert { +public: + __aicore__ inline MoeSortActualExpert(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR expandedX, GM_ADDR expendedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline bool Process(); + __aicore__ inline void multiCoreGatheredSort(); + __aicore__ inline void CopyOutExpandRowIdx(); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void TilingInKernel(); + __aicore__ inline void ExpertCountCompute(); + __aicore__ inline void CopyOut(); + __aicore__ inline void CopyOutExpertCount(); + +private: + TPipe *pipe; + TBuf buffer_; + TQueBind scaleCopyInQueue_; + TQue sortedNumCopyOutQueue_; + + GlobalTensor xGm_; + GlobalTensor scaleGm_; + GlobalTensor expandedXGm_; + GlobalTensor expertTokensCountOrCumsumGm_; + GlobalTensor expandedScaleGm_; + GlobalTensor expendedRowIdxGm_; + GlobalTensor expertIdxGm_; + GlobalTensor workspaceGm_; + GlobalTensor workspaceExpertIdxGm_; + GlobalTensor workspaceGatheredSortNumGm_; + GlobalTensor workspaceGatheredExpertIdxGm_; + GlobalTensor workspaceGatheredExpertIndexGm_; + + int64_t expertIdxOffset_ = 0; + int64_t expertIndexOffset_ = 0; + int64_t compareScalarMaskOffset_ = 0; + int64_t compareScalarMask0Offset_ = 0; + int64_t compareScalarMask1Offset_ = 0; + int64_t gatherMaskOffset_ = 0; + + int64_t totalLength_; + int64_t expertStart_ = 0; + int64_t expertEnd_ = 0; + int64_t actual_expert_num_ = 0; + int64_t cols_ = 0; + int64_t rowIdxType_ = 0; + int64_t isInputScale_ = 0; + int64_t k_ = 0; + + int64_t needSortNum_ = 0; + + int64_t needCoreNum_ = 0; + int64_t perCoreElements_ = 0; + int64_t lastCoreElements_ = 0; + int64_t curCoreElements_ = 0; + int64_t curCoreStartIndex_ = 0; + + bool needMultiSort = false; + + int64_t kvFactor = 2; + + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; + static constexpr int64_t MASK_STRIDE = 64; +}; + +template +__aicore__ inline void MoeSortActualExpert::CopyIn() +{ + LocalTensor expertIdx = buffer_.Get()[expertIdxOffset_ / sizeof(int32_t)]; + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast(this->totalLength_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(expertIdx, expertIdxGm_, dataCopyParams, dataCopyPadParams); + SetWaitFlag(HardEvent::MTE2_V); +} + +template +__aicore__ inline void MoeSortActualExpert::SortCompute() +{ + LocalTensor expertIdx = buffer_.Get()[expertIdxOffset_ / sizeof(int32_t)]; + LocalTensor expertIdxFp32 = expertIdx.ReinterpretCast(); + LocalTensor gatheredExpertIdx = buffer_.Get(); + LocalTensor gatheredExpertIdxFp32 = gatheredExpertIdx.ReinterpretCast(); + + Cast(expertIdxFp32, expertIdx, RoundMode::CAST_ROUND, this->totalLength_); + PipeBarrier(); + Muls(expertIdxFp32, expertIdxFp32, (float)-1, this->totalLength_); + PipeBarrier(); + + LocalTensor compareScalarMaskLocalTensor0 = buffer_.Get()[compareScalarMask0Offset_]; + LocalTensor compareScalarMaskLocalTensor1 = buffer_.Get()[compareScalarMask1Offset_]; + LocalTensor gatherMaskLocalTensor = buffer_.Get()[gatherMaskOffset_]; + + AscendC::CompareScalar( + compareScalarMaskLocalTensor0, expertIdxFp32, static_cast(-expertStart_), AscendC::CMPMODE::LE, + (this->totalLength_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + PipeBarrier(); + + AscendC::CompareScalar( + compareScalarMaskLocalTensor1, expertIdxFp32, static_cast(-expertEnd_), AscendC::CMPMODE::GT, + (this->totalLength_ + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + PipeBarrier(); + And(gatherMaskLocalTensor.ReinterpretCast(), compareScalarMaskLocalTensor0.ReinterpretCast(), + compareScalarMaskLocalTensor1.ReinterpretCast(), + Ceil(this->totalLength_, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE / kvFactor); + PipeBarrier(); + + uint64_t rsvdCnt = 0; + GatherMaskParams gatherMaskParams; + gatherMaskParams.repeatTimes = 1; + gatherMaskParams.src0BlockStride = 1; + gatherMaskParams.src0RepeatStride = 8; + gatherMaskParams.src1RepeatStride = 8; + GatherMask(gatheredExpertIdxFp32, expertIdxFp32, gatherMaskLocalTensor.ReinterpretCast(), true, + static_cast(this->totalLength_), gatherMaskParams, rsvdCnt); + PipeBarrier(); + actual_expert_num_ = rsvdCnt; + // Handle actual_expert_num_ == 0 + if (actual_expert_num_ < 1) { + return; + } + int64_t needSortNum = Ceil(static_cast(rsvdCnt), ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + needSortNum_ = needSortNum; + + LocalTensor expertIndex = buffer_.Get()[expertIdxOffset_ / sizeof(int32_t)]; + LocalTensor gatheredExpertIndex = buffer_.Get()[needSortNum]; + ArithProgression(expertIndex, 0, 1, this->totalLength_); + GatherMask(gatheredExpertIndex, expertIndex, gatherMaskLocalTensor.ReinterpretCast(), true, + static_cast(this->totalLength_), gatherMaskParams, rsvdCnt); + PipeBarrier(); + if (rsvdCnt > MULTI_GATHERED_SORT_THRSHOLD) { + if (GetBlockIdx() == 0) { + SetWaitFlag(HardEvent::V_MTE3); + DataCopyExtParams copyParams{1, static_cast(rsvdCnt * sizeof(int32_t)), 0, 0, 0}; + DataCopyPad(workspaceGatheredExpertIdxGm_, gatheredExpertIdxFp32, copyParams); + DataCopyPad(workspaceGatheredExpertIndexGm_, gatheredExpertIndex, copyParams); + } + needMultiSort = true; + return; + } + int64_t duplicateNum = rsvdCnt % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = rsvdCnt - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(gatheredExpertIdxFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + } + + PipeBarrier(); + LocalTensor concatLocal; + LocalTensor sortTempTensor = buffer_.Get()[needSortNum * kvFactor]; + Concat(concatLocal, gatheredExpertIdxFp32, sortTempTensor, needSortNum / ONE_REPEAT_SORT_NUM); + LocalTensor sortedLocal = buffer_.Get()[needSortNum * kvFactor + needSortNum * kvFactor * kvFactor]; + Sort(sortedLocal, concatLocal, gatheredExpertIndex.ReinterpretCast(), sortTempTensor, + needSortNum / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + LocalTensor sortedExpertIdx = gatheredExpertIdxFp32; + LocalTensor sortedExpertIndex = gatheredExpertIndex.ReinterpretCast(); + + Extract(sortedExpertIdx, sortedExpertIndex.ReinterpretCast(), sortedLocal, + needSortNum / ONE_REPEAT_SORT_NUM); + PipeBarrier(); + + LocalTensor sortedExpertIdxInt32 = sortedExpertIdx.ReinterpretCast(); + + Muls(sortedExpertIdx, sortedExpertIdx, (float)-1, rsvdCnt); + Cast(sortedExpertIdxInt32, sortedExpertIdx, RoundMode::CAST_ROUND, rsvdCnt); +} + +template +__aicore__ inline void MoeSortActualExpert::TilingInKernel() +{ + int64_t coreNum = needMultiSort ? MULTI_GATHERED_SORT_CORE_NUM : GetBlockNum(); + perCoreElements_ = Ceil(actual_expert_num_, coreNum); + needCoreNum_ = Ceil(actual_expert_num_, perCoreElements_); + lastCoreElements_ = actual_expert_num_ - (needCoreNum_ - 1) * perCoreElements_; + if (GetBlockIdx() == needCoreNum_ - 1) { + curCoreElements_ = lastCoreElements_; + } else { + curCoreElements_ = perCoreElements_; + } + curCoreStartIndex_ = GetBlockIdx() * perCoreElements_; +} + +template +__aicore__ inline void MoeSortActualExpert::multiCoreGatheredSort() +{ + needSortNum_ = Ceil(static_cast(curCoreElements_), ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + perCoreElements_ = Ceil(this->totalLength_, MULTI_GATHERED_SORT_CORE_NUM); + + LocalTensor sortedNumOutLocal = sortedNumCopyOutQueue_.AllocTensor(); + LocalTensor gatheredExpertIdxFp32 = buffer_.Get(); + LocalTensor gatheredExpertIndex = buffer_.Get()[needSortNum_]; + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(curCoreElements_ * sizeof(float)), + 0, 0, 0}; + DataCopyPadExtParams expertIdxPadParams{false, 0, 0, 0}; + DataCopyPad(gatheredExpertIdxFp32, workspaceGatheredExpertIdxGm_[curCoreStartIndex_], dataCopyParams, + expertIdxPadParams); + DataCopyPadExtParams expertIndexPadParams{false, 0, 0, 0}; + DataCopyPad(gatheredExpertIndex, workspaceGatheredExpertIndexGm_[curCoreStartIndex_], dataCopyParams, + expertIndexPadParams); + SetWaitFlag(HardEvent::MTE2_V); + + LocalTensor concatLocal; + LocalTensor sortTempTensor = buffer_.Get()[needSortNum_ * kvFactor]; + // Duplicate MIN_FP32 + int64_t duplicateNum = curCoreElements_ % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = curCoreElements_ - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(gatheredExpertIdxFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + } + Concat(concatLocal, gatheredExpertIdxFp32, sortTempTensor, needSortNum_ / ONE_REPEAT_SORT_NUM); + LocalTensor sortedLocal = buffer_.Get()[needSortNum_ * kvFactor + needSortNum_ * kvFactor * kvFactor]; + Sort(sortedLocal, concatLocal, gatheredExpertIndex.ReinterpretCast(), sortTempTensor, + needSortNum_ / ONE_REPEAT_SORT_NUM); + + // Copy out sortedLocal for MergeSort + SetWaitFlag(HardEvent::V_MTE3); + int64_t curCoreSortedStartIndex = kvFactor * GetBlockIdx() * perCoreElements_; + dataCopyParams.blockLen = static_cast(kvFactor * curCoreElements_ * sizeof(float)); + DataCopyPad(workspaceExpertIdxGm_[curCoreSortedStartIndex], sortedLocal, dataCopyParams); + // Copyout sortedNum + sortedNumOutLocal.SetValue(0, curCoreElements_); + SetWaitFlag(HardEvent::S_MTE3); + dataCopyParams.blockLen = static_cast(sizeof(int32_t)); + DataCopyPad(workspaceGatheredSortNumGm_[GetBlockIdx()], sortedNumOutLocal, dataCopyParams); + sortedNumCopyOutQueue_.FreeTensor(sortedNumOutLocal); +} + +template +__aicore__ inline void MoeSortActualExpert::CopyOutExpandRowIdx() +{ + LocalTensor sortedExpertIndex = buffer_.Get()[needSortNum_]; + SetWaitFlag(HardEvent::V_MTE3); + if (GetBlockIdx() == 0) { + DataCopyExtParams copyParams{1, static_cast(actual_expert_num_ * sizeof(int32_t)), 0, 0, 0}; + DataCopyPad(expendedRowIdxGm_, sortedExpertIndex, copyParams); + } +} + +template +__aicore__ inline void MoeSortActualExpert::ExpertCountCompute() +{ + LocalTensor sortedExpertIdx = buffer_.Get()[curCoreStartIndex_]; + LocalTensor expertCountLocalTensor = buffer_.Get()[needSortNum_ * kvFactor]; + Duplicate(expertCountLocalTensor, 0, expertEnd_ - expertStart_); + + for (int64_t i = 0; i < curCoreElements_; i++) { + int64_t expertIdx = sortedExpertIdx.GetValue(i) - expertStart_; + int32_t curExpertCount = expertCountLocalTensor.GetValue(expertIdx); + expertCountLocalTensor.SetValue(expertIdx, curExpertCount + 1); + } + SetWaitFlag(HardEvent::S_MTE3); + DataCopyExtParams copyOutParams1{1, static_cast((expertEnd_ - expertStart_) * sizeof(int32_t)), 0, 0, 0}; + SetAtomicAdd(); + DataCopyPad(workspaceGm_, expertCountLocalTensor, copyOutParams1); + SetAtomicNone(); +} + +template +__aicore__ inline void MoeSortActualExpert::CopyOut() +{ + LocalTensor sortedExpertIndex = buffer_.Get()[needSortNum_ + curCoreStartIndex_]; + int64_t xLocalOffset = (needSortNum_ * kvFactor + ASSIST_NUM) * sizeof(int32_t) / sizeof(T); + LocalTensor xLocalTensor = buffer_.Get()[xLocalOffset]; + + for (int64_t i = 0; i < curCoreElements_; i++) { + int64_t srcRow = sortedExpertIndex.GetValue(i) / k_; + int64_t dstRow = i + curCoreStartIndex_; + SetWaitFlag(HardEvent::S_MTE2); + + LocalTensor scaleLocalTensor; + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(cols_ * sizeof(T)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(xLocalTensor, xGm_[srcRow * cols_], dataCopyParams, dataCopyPadParams); + if (isInputScale_ == 1) { + scaleLocalTensor = scaleCopyInQueue_.AllocTensor(); + DataCopyExtParams dataCopyParams2{static_cast(1), static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams2{false, 0, 0, 0}; + DataCopyPad(scaleLocalTensor, scaleGm_[srcRow], dataCopyParams2, dataCopyPadParams2); + scaleCopyInQueue_.EnQue(scaleLocalTensor); + } + SetWaitFlag(HardEvent::MTE2_MTE3); + DataCopyExtParams copyOutParams1{1, static_cast(cols_ * sizeof(T)), 0, 0, 0}; + DataCopyPad(expandedXGm_[dstRow * cols_], xLocalTensor, copyOutParams1); + if (isInputScale_ == 1) { + scaleLocalTensor = scaleCopyInQueue_.DeQue(); + DataCopyExtParams copyOutParams2{1, static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPad(expandedScaleGm_[dstRow], scaleLocalTensor, copyOutParams2); + scaleCopyInQueue_.FreeTensor(scaleLocalTensor); + } + } +} + +template +__aicore__ inline void MoeSortActualExpert::CopyOutExpertCount() +{ + LocalTensor expertCountLocalTensor = buffer_.Get()[needSortNum_ * kvFactor]; + LocalTensor expertCountLocalTensorInt64 = + buffer_.Get()[needSortNum_ * kvFactor + ASSIST_NUM].ReinterpretCast(); + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast((expertEnd_ - expertStart_) * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(expertCountLocalTensor, workspaceGm_, dataCopyParams, dataCopyPadParams); + SetWaitFlag(HardEvent::MTE2_V); + Cast(expertCountLocalTensorInt64, expertCountLocalTensor, RoundMode::CAST_NONE, (expertEnd_ - expertStart_)); + SetWaitFlag(HardEvent::V_MTE3); + DataCopyExtParams copyOutParams1{1, static_cast((expertEnd_ - expertStart_) * sizeof(int64_t)), 0, 0, 0}; + DataCopyPad(expertTokensCountOrCumsumGm_, expertCountLocalTensorInt64, copyOutParams1); +} + +template +__aicore__ inline void MoeSortActualExpert::Init(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR expandedX, + GM_ADDR expendedRowIdx, GM_ADDR expertTokensCountOrCumsum, + GM_ADDR expandedScale, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + this->pipe = tPipe; + this->totalLength_ = tilingData->n * tilingData->k; + cols_ = tilingData->cols; + expertStart_ = tilingData->expertStart; + expertEnd_ = tilingData->expertEnd; + rowIdxType_ = tilingData->rowIdxType; + isInputScale_ = tilingData->isInputScale; + k_ = tilingData->k; + + expertIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expertIdx); + + expendedRowIdxGm_.SetGlobalBuffer((__gm__ int32_t *)expendedRowIdx); + + xGm_.SetGlobalBuffer((__gm__ T *)x); + scaleGm_.SetGlobalBuffer((__gm__ float *)scale); + expandedXGm_.SetGlobalBuffer((__gm__ T *)expandedX); + expertTokensCountOrCumsumGm_.SetGlobalBuffer((__gm__ int64_t *)expertTokensCountOrCumsum); + expandedScaleGm_.SetGlobalBuffer((__gm__ float *)expandedScale); + workspaceGm_.SetGlobalBuffer((__gm__ int32_t *)workspace, ASSIST_NUM); + if (GetBlockIdx() == 0) { + InitGlobalMemory(workspaceGm_, ASSIST_NUM, 0); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + workspaceExpertIdxGm_.SetGlobalBuffer((__gm__ float *)workspace); + int64_t offset = kvFactor * Align(this->totalLength_, sizeof(int32_t)); + workspaceGatheredExpertIdxGm_.SetGlobalBuffer((__gm__ float *)workspace + offset); + offset += Align(this->totalLength_, sizeof(float)); + workspaceGatheredExpertIndexGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + offset); + offset += Align(this->totalLength_, sizeof(float)); + workspaceGatheredSortNumGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + offset); + + expertIdxOffset_ = AlignBytes(this->totalLength_, sizeof(int32_t)); + expertIndexOffset_ = expertIdxOffset_; + + gatherMaskOffset_ = expertIdxOffset_ * kvFactor; + int64_t maskOffset = + AlignBytes(Ceil(this->totalLength_, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE, sizeof(int8_t)); + compareScalarMask0Offset_ = gatherMaskOffset_ + maskOffset; + compareScalarMask1Offset_ = compareScalarMask0Offset_ + maskOffset; + int64_t maskOffsetMax = Ceil(SINGLE_GATHERED_MAX_NUM, MASK_STRIDE) * MASK_STRIDE / DST_REP_STRIDE; + int64_t bufferSize = + AlignBytes(SINGLE_GATHERED_MAX_NUM, sizeof(int32_t)) * kvFactor + maskOffsetMax + maskOffsetMax + maskOffsetMax; + pipe->InitBuffer(scaleCopyInQueue_, SINGLE_GATHERED_BUFFER_NUM, 32); + pipe->InitBuffer(sortedNumCopyOutQueue_, SINGLE_GATHERED_BUFFER_NUM, 32); + pipe->InitBuffer(buffer_, bufferSize); // 182992 Bytes +} + +template +__aicore__ inline bool MoeSortActualExpert::Process() +{ + CopyIn(); + SortCompute(); + TilingInKernel(); + if (needMultiSort) { + SyncAll(); + if (GetBlockIdx() < needCoreNum_) { + multiCoreGatheredSort(); + } + SyncAll(); + return false; + } + + if (GetBlockIdx() < needCoreNum_) { + CopyOutExpandRowIdx(); + } + if (GetBlockIdx() < needCoreNum_) { + ExpertCountCompute(); + CopyOut(); + } + SyncAll(); + if (GetBlockIdx() == GetBlockNum() - 1) { + CopyOutExpertCount(); + } + return true; +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_SORT_ACTUAL_EXPERT_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_base.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_base.h new file mode 100644 index 00000000000..98db4b120be --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_base.h @@ -0,0 +1,71 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_sort_base.h + * \brief + */ +#ifndef MOE_CUSTOM_SORT_BASE_H +#define MOE_CUSTOM_SORT_BASE_H + +#include "kernel_operator.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class MoeSortBase { +public: + __aicore__ inline MoeSortBase(){}; + __aicore__ inline int64_t GetSyncRound(); + +protected: + __aicore__ inline void CleanWSCache(); + __aicore__ inline void SyncAll(); + +protected: + TPipe *pipe; + TQue sortDataCopyInQueue; + TQue sortDataCopyOutQueue; + TBuf tempBuffer; + TBuf sortedBuffer; + + GlobalTensor expertIdxGm; + GlobalTensor expendedRowIdxGm; + GlobalTensor sortedExpertForSourceRowGm; + GlobalTensor expandDstToSrcRowGm; + GlobalTensor sortedexpertIdxGm; + GlobalTensor expertCountTempGm; + + int64_t tileLength; + int64_t bufferNum = 1; + int64_t totalLength; + int64_t coreNum; + + int64_t expertStart_ = 0; + int64_t expertEnd_ = 0; + int64_t n; + int64_t k; + int64_t ep_ = 0; + int64_t oneLoopMaxElements_; + int64_t rowIdxType_ = 0; + + static constexpr int64_t SYNC_GM_NUM = 2; + static constexpr int64_t WORK_GM_NUM = 2; + static constexpr int64_t DST_BLK_STRIDE = 1; + static constexpr int64_t DST_REP_STRIDE = 8; +}; + +__aicore__ inline void MoeSortBase::SyncAll() +{ + AscendC::SyncAll(); +} + +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_SORT_BASE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_multi_core.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_multi_core.h new file mode 100644 index 00000000000..a3985f84130 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_multi_core.h @@ -0,0 +1,377 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_sort_multi_core.h + * \brief + */ +#ifndef MOE_CUSTOM_VBS_ONE_CORE_H +#define MOE_CUSTOM_VBS_ONE_CORE_H + +#include "moe_custom_sort_base.h" +#include "moe_custom_mrgsort.h" +#include "moe_custom_mrgsort_out.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class MoeSortMultiCore : public MoeSortBase { +public: + __aicore__ inline MoeSortMultiCore(){}; + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void VBSProcess(); + __aicore__ inline void UBSortProcess(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void OneCoreVMSProcess(int64_t listNum, int64_t perListElements, int64_t lastListElements); + __aicore__ inline void VMSProcess(); + __aicore__ inline void SortOutProcess(); + __aicore__ inline void VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void UBSortCompute(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum); + __aicore__ inline void InitMoeMrgSort(MoeMrgsort *sorter, int64_t listNum, int64_t coreOffset, int64_t loopOffset); + __aicore__ inline void InitMoeMrgSortOut(MoeMrgsortOut *sorter, int64_t listNum, int64_t coreOffset); + +private: + GlobalTensor workspaceGms[2]; + // GlobalTensor expertTokensCountGm_; + + const MoeCustomVBSComputeTilingData *vbsTilingData; + const MoeCustomVMSMiddleComputeTilingData *vmsTilingData; + const MoeCustomSortOutComputeTilingData *sortOutTilingData; + + // for MoeMrgsort + MoeMrgsort mrgsorter; + MoeMrgsortParam mrgsortParam; + + int64_t coreNum; + int64_t blockIdx; + int64_t srcWsIndex = 0; + + int64_t listNum; + int64_t perListElements; + int64_t lastListElements; + + int64_t sortTotalLength; + int64_t sortCoreLoops; + int64_t sortCoreLoopElements; + int64_t sortCoreLastLoopElements; + + int64_t perCoreExpert; + int64_t needInitExpertCore; + int64_t currentCoreExpert; + + static constexpr int64_t MAX_MRGSORT_LIST = 4; +}; + +__aicore__ inline void MoeSortMultiCore::VBSCopyIn(int64_t progress, int64_t size, int64_t sortNum) +{ + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + int64_t inOffset = progress * sortCoreLoopElements; + DataCopyExtParams dataCopyParams{static_cast(1), static_cast(size * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm[inOffset], dataCopyParams, dataCopyPadParams); + + LocalTensor rowIdxLocal = inLocal[sortNum]; + int64_t startValue = this->blockIdx * this->vbsTilingData->perCoreElements + inOffset; + SetWaitFlag(HardEvent::MTE3_S); + ArithProgression(rowIdxLocal, startValue, 1, size); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeSortMultiCore::UBSortCompute(int64_t progress, int64_t size, int64_t sortNum) +{ + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertForSourceRowLocal = inLocal[0]; + LocalTensor expertForSourceRowLocalFp32; + + expertForSourceRowLocalFp32 = expertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalFp32, expertForSourceRowLocal, RoundMode::CAST_ROUND, sortNum); + + Muls(expertForSourceRowLocalFp32, expertForSourceRowLocalFp32, (float)-1, sortNum); + + if (ep_) { + LocalTensor maskLocalTensor = sortedBuffer.Get(); + AscendC::CompareScalar( + maskLocalTensor, expertForSourceRowLocalFp32, static_cast(-expertStart_), AscendC::CMPMODE::GT, + (sortNum + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * ONE_REPEAT_COMPARE_NUM); + LocalTensor floatMinLocalTensor = tempBuffer.Get(); + Duplicate(floatMinLocalTensor, MIN_FP32, sortNum); + Select(expertForSourceRowLocalFp32, maskLocalTensor, floatMinLocalTensor, expertForSourceRowLocalFp32, + SELMODE::VSEL_TENSOR_TENSOR_MODE, sortNum); + } + + int64_t duplicateNum = size % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = size - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertForSourceRowLocalFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + } + + LocalTensor concatLocal = expertForSourceRowLocalFp32; + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(sortNum)); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[sortNum].ReinterpretCast(); + Sort(outLocal, concatLocal, sourceRowLocal, sortedLocal, sortNum / ONE_REPEAT_SORT_NUM); + + sortDataCopyOutQueue.EnQue(outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeSortMultiCore::VBSCopyOut(int64_t progress, int64_t size, int64_t sortNum) +{ + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + DataCopy(workspaceGms[0][this->blockIdx * GetSortLen(this->vbsTilingData->perCoreElements) + + GetSortLen(progress * sortCoreLoopElements)], + outLocal, Align(GetSortLen(size), sizeof(float))); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeSortMultiCore::InitMoeMrgSort(MoeMrgsort *sorter, int64_t listNum, int64_t coreOffset, + int64_t loopOffset) +{ + GlobalTensor srcWsGm = workspaceGms[srcWsIndex][blockIdx * coreOffset + loopOffset]; + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + for (int64_t i = 0; i < listNum; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(oneLoopMaxElements_) * i]; + sorter->SetInput(srcWsGm, inLocalT); + } + GlobalTensor dstWsGm = workspaceGms[1 - srcWsIndex][blockIdx * coreOffset + loopOffset]; + sorter->SetOutput(dstWsGm, outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeSortMultiCore::InitMoeMrgSortOut(MoeMrgsortOut *sorter, int64_t listNum, int64_t coreOffset) +{ + GlobalTensor srcWsGm = workspaceGms[srcWsIndex]; + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + + for (int64_t i = 0; i < listNum; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(oneLoopMaxElements_) * i]; + sorter->SetInput(srcWsGm, inLocalT); + } + + LocalTensor outLocalV = outLocal[oneLoopMaxElements_ * MAX_MRGSORT_LIST]; + sorter->SetOutput(this->sortedexpertIdxGm, this->expendedRowIdxGm, outLocal, outLocalV); + + LocalTensor tempBuffer = sortedBuffer.Get(GetSortLen(oneLoopMaxElements_) * MAX_MRGSORT_LIST); + sorter->SetBuffer(tempBuffer); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeSortMultiCore::OneCoreVMSProcess(int64_t listNum, int64_t perListElements, + int64_t lastListElements) +{ + int64_t coreOffset = GetSortLen(this->vbsTilingData->perCoreElements); + mrgsortParam.oneLoopMaxElements = oneLoopMaxElements_; + + for (int64_t i = 0; listNum >= 1; i++) { + int64_t loops = (listNum + MAX_MRGSORT_LIST - 1) / MAX_MRGSORT_LIST; + int64_t remainListNum = listNum - (loops - 1) * MAX_MRGSORT_LIST; + + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = perListElements; + + int64_t loopOffset = GetSortLen(mrgsortParam.perListElements * MAX_MRGSORT_LIST); + for (int64_t loop = 0; loop < loops - 1; loop++) { + InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, loop * loopOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, (loops - 1) * loopOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + + listNum = loops; + lastListElements = perListElements * (remainListNum - 1) + lastListElements; + perListElements = perListElements * MAX_MRGSORT_LIST; + srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM; + if (loops == 1) { + break; + } + } +} + +__aicore__ inline void MoeSortMultiCore::UBSortProcess(int64_t progress, int64_t size, int64_t sortNum) +{ + VBSCopyIn(progress, size, sortNum); + UBSortCompute(progress, size, sortNum); + VBSCopyOut(progress, size, sortNum); +} + +__aicore__ inline void MoeSortMultiCore::VBSProcess() +{ + if (this->blockIdx < this->vbsTilingData->needCoreNum) { + int64_t sortNum = Ceil(sortCoreLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + for (int64_t loop = 0; loop < sortCoreLoops - 1; loop++) { + UBSortProcess(loop, sortCoreLoopElements, sortNum); + } + + sortNum = Ceil(sortCoreLastLoopElements, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + UBSortProcess(sortCoreLoops - 1, sortCoreLastLoopElements, sortNum); + + if (sortCoreLoops > 1) { + OneCoreVMSProcess(sortCoreLoops, sortCoreLoopElements, sortCoreLastLoopElements); + } + } + SyncAll(); +} + +__aicore__ inline void MoeSortMultiCore::VMSProcess() +{ + int64_t currentStageNeedCoreNum = this->vmsTilingData->needCoreNum; + perListElements = this->vbsTilingData->perCoreElements; + lastListElements = this->vbsTilingData->lastCoreElements; + listNum = this->vbsTilingData->needCoreNum; + + for (; listNum > MAX_MRGSORT_LIST;) { + currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST); + int64_t coreOffset = GetSortLen(perListElements * MAX_MRGSORT_LIST); + int64_t remainListNum = listNum - (currentStageNeedCoreNum - 1) * MAX_MRGSORT_LIST; + + if (this->blockIdx < currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = perListElements; + mrgsortParam.oneLoopMaxElements = oneLoopMaxElements_; + InitMoeMrgSort(&mrgsorter, MAX_MRGSORT_LIST, coreOffset, 0); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } else if (this->blockIdx == currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + mrgsortParam.oneLoopMaxElements = oneLoopMaxElements_; + InitMoeMrgSort(&mrgsorter, remainListNum, coreOffset, 0); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + listNum = currentStageNeedCoreNum; + currentStageNeedCoreNum = Ceil(listNum, MAX_MRGSORT_LIST); + srcWsIndex = (srcWsIndex + 1) % WORK_GM_NUM; + + lastListElements = perListElements * (remainListNum - 1) + lastListElements; + perListElements = perListElements * MAX_MRGSORT_LIST; + + SyncAll(); + } +} + +__aicore__ inline void MoeSortMultiCore::SortOutProcess() +{ + if (this->blockIdx < 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.lastListElements = lastListElements; + mrgsortParam.oneLoopMaxElements = oneLoopMaxElements_; + + MoeMrgsortOut sorter; + InitMoeMrgSortOut(&sorter, listNum, GetSortLen(perListElements)); + sorter.Init(&mrgsortParam, pipe); + sorter.Process(); + } + SyncAll(); +} + +__aicore__ inline void MoeSortMultiCore::Init(GM_ADDR expertIdx, GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + this->totalLength = tilingData->n * tilingData->k; + this->coreNum = tilingData->coreNum; + this->vbsTilingData = &(tilingData->vbsComputeParamsOp); + this->vmsTilingData = &(tilingData->vmsMiddleComputeParamsOp); + this->sortOutTilingData = &(tilingData->sortOutComputeParamsOp); + + this->blockIdx = GetBlockIdx(); + this->tileLength = this->vbsTilingData->perCorePerLoopElements; + this->sortTotalLength = this->vbsTilingData->perCoreElements; + if (this->blockIdx == tilingData->vbsComputeParamsOp.needCoreNum - 1) { + this->tileLength = this->vbsTilingData->lastCorePerLoopElements; + this->sortTotalLength = this->vbsTilingData->lastCoreElements; + } + this->n = tilingData->n; + this->k = tilingData->k; + this->ep_ = tilingData->ep; + this->oneLoopMaxElements_ = ep_ ? this->sortOutTilingData->oneLoopMaxElements : MRGSORT_LIST_MAX_ELEMENT; + + expertStart_ = tilingData->expertStart; + expertEnd_ = tilingData->expertEnd; + rowIdxType_ = tilingData->rowIdxType; + + // VBS param init + if (this->blockIdx == this->vbsTilingData->needCoreNum - 1) { + sortCoreLoops = this->vbsTilingData->lastCoreLoops; + sortCoreLoopElements = this->vbsTilingData->lastCorePerLoopElements; + sortCoreLastLoopElements = this->vbsTilingData->lastCoreLastLoopElements; + } else { + sortCoreLoops = this->vbsTilingData->perCoreLoops; + sortCoreLoopElements = this->vbsTilingData->perCorePerLoopElements; + sortCoreLastLoopElements = this->vbsTilingData->perCoreLastLoopElements; + } + + this->pipe = tPipe; + expertIdxGm.SetGlobalBuffer((__gm__ int32_t *)expertIdx + + this->blockIdx * tilingData->vbsComputeParamsOp.perCoreElements, + this->sortTotalLength); + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace), + Align(this->totalLength, sizeof(int32_t))); + if (rowIdxType_ == SCATTER) { + expendedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expendedRowIdx, Align(this->totalLength, sizeof(int32_t))); + } else { + expendedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->totalLength, sizeof(int32_t)), + Align(this->totalLength, sizeof(int32_t))); + } + + if (GetBlockIdx() == 0) { + expertCountTempGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * 2, + tilingData->actualExpertNum); + InitGlobalMemory(expertCountTempGm, tilingData->actualExpertNum, 0); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + + // key and value + int64_t kvFactor = 2; + workspaceGms[0].SetGlobalBuffer((__gm__ float *)workspace + Align(this->totalLength, sizeof(int32_t)) * 2 + + tilingData->actualExpertNum, + Align(this->totalLength, sizeof(int32_t)) * kvFactor); + workspaceGms[1].SetGlobalBuffer((__gm__ float *)workspace + + Align(this->totalLength, sizeof(int32_t)) * (kvFactor + 2) + + tilingData->actualExpertNum, + Align(this->totalLength, sizeof(int32_t)) * kvFactor); + + int64_t bufferSize = Ceil(Max(oneLoopMaxElements_ * MAX_MRGSORT_LIST, sortCoreLoopElements), ONE_REPEAT_SORT_NUM) * + ONE_REPEAT_SORT_NUM * sizeof(int32_t) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortedBuffer, bufferSize); + if (ep_) { + pipe->InitBuffer(tempBuffer, bufferSize); + } +} + +__aicore__ inline void MoeSortMultiCore::Process() +{ + VBSProcess(); + VMSProcess(); + SortOutProcess(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_VBS_ONE_CORE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_multi_core_performance.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_multi_core_performance.h new file mode 100644 index 00000000000..1a678bd8657 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_multi_core_performance.h @@ -0,0 +1,171 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_sort_multi_core_performance.h + * \brief + */ +#ifndef MOE_CUSTOM_VBS_ONE_CORE_PERFORMANCE_H +#define MOE_CUSTOM_VBS_ONE_CORE_PERFORMANCE_H + +#include "moe_custom_sort_base.h" +#include "moe_custom_mrgsort_performance.h" +#include "moe_custom_mrgsort_out_performance.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class MoeSortMultiCorePerformance : public MoeSortBase { +public: + __aicore__ inline MoeSortMultiCorePerformance(){}; + __aicore__ inline void Init(GM_ADDR expendedRowIdx, GM_ADDR workspace, const MoeInitRoutingCustomTilingData *tilingData, + TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void VMSProcess(); + __aicore__ inline void SortOutProcess(); + __aicore__ inline void InitMoeMrgSort(MoeMrgsortPerformance *sorter, int64_t coreOffset); + __aicore__ inline void InitMoeMrgSortOut(MoeMrgsortOutPerformance *sorter); + +private: + GlobalTensor workspaceGms[2]; + GlobalTensor workspaceGatheredSortNumGm_; + + const MoeCustomSortOutComputeTilingData *sortOutTilingData; + const MoeCustomVBSComputeTilingData *vbsTilingData; + + // for MoeMrgsortPerformance + MoeMrgsortPerformance mrgsorter; + MoeMrgsortPerformanceParam mrgsortParam; + + int64_t blockIdx; + + int64_t perListElements; + int64_t maxPerListElements; +}; + +__aicore__ inline void MoeSortMultiCorePerformance::InitMoeMrgSort(MoeMrgsortPerformance *sorter, int64_t coreOffset) +{ + GlobalTensor srcWsGm = workspaceGms[0][this->blockIdx * coreOffset]; // 0-3 + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + GlobalTensor sortNumGm = workspaceGatheredSortNumGm_[this->blockIdx * MAX_MRGSORT_LIST]; + for (int64_t i = 0; i < MAX_MRGSORT_LIST; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(maxPerListElements) * i]; + sorter->SetInput(srcWsGm, inLocalT, sortNumGm); + } + GlobalTensor dstWsGm = workspaceGms[1][this->blockIdx * coreOffset]; + sorter->SetOutput(dstWsGm, outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeSortMultiCorePerformance::InitMoeMrgSortOut(MoeMrgsortOutPerformance *sorter) +{ + GlobalTensor srcWsGm = workspaceGms[1]; + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + GlobalTensor sortNumGm = workspaceGatheredSortNumGm_; + for (int64_t i = 0; i < MAX_MRGSORT_LIST; i++) { + LocalTensor inLocalT = inLocal[GetSortLen(maxPerListElements) * i]; + sorter->SetInput(srcWsGm, inLocalT, sortNumGm); + } + + LocalTensor outLocalV = outLocal[maxPerListElements * MAX_MRGSORT_LIST]; + sorter->SetOutput(this->sortedexpertIdxGm, this->expendedRowIdxGm, outLocal, outLocalV); + + LocalTensor tempBuffer = sortedBuffer.Get(GetSortLen(maxPerListElements) * MAX_MRGSORT_LIST); + sorter->SetBuffer(tempBuffer); + sortDataCopyInQueue.FreeTensor(inLocal); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeSortMultiCorePerformance::VMSProcess() +{ + int64_t currentStageNeedCoreNum = MAX_MRGSORT_LIST; + int64_t coreOffset = GetSortLen(perListElements * MAX_MRGSORT_LIST); + if (this->blockIdx <= currentStageNeedCoreNum - 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.oneLoopMaxElements = maxPerListElements; + InitMoeMrgSort(&mrgsorter, coreOffset); + mrgsorter.Init(&mrgsortParam); + mrgsorter.Process(); + } + SyncAll(); +} + +__aicore__ inline void MoeSortMultiCorePerformance::SortOutProcess() +{ + if (this->blockIdx < 1) { + mrgsortParam.perListElements = perListElements; + mrgsortParam.oneLoopMaxElements = maxPerListElements; + MoeMrgsortOutPerformance sorter; + InitMoeMrgSortOut(&sorter); + sorter.Init(&mrgsortParam, pipe); + sorter.Process(); + InitGlobalMemory(expertCountTempGm, expertEnd_ - expertStart_, 0); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + SyncAll(); +} + +__aicore__ inline void MoeSortMultiCorePerformance::Init(GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + this->totalLength = tilingData->n * tilingData->k; + this->blockIdx = GetBlockIdx(); + this->n = tilingData->n; + this->k = tilingData->k; + this->vbsTilingData = &(tilingData->vbsComputeParamsOp); + this->sortOutTilingData = &(tilingData->sortOutComputeParamsOp); + this->perListElements = Ceil(this->totalLength, MAX_MRGSORT_LIST_TOTAL); + this->maxPerListElements = this->sortOutTilingData->oneLoopMaxElements; + + expertStart_ = tilingData->expertStart; + expertEnd_ = tilingData->expertEnd; + rowIdxType_ = tilingData->rowIdxType; + + this->pipe = tPipe; + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace), + Align(this->totalLength, sizeof(int32_t))); + if (rowIdxType_ == SCATTER) { + expendedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expendedRowIdx, Align(this->totalLength, sizeof(int32_t))); + } else { + expendedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->totalLength, sizeof(int32_t)), + Align(this->totalLength, sizeof(int32_t))); + } + + // key and value + int64_t kvFactor = 2; + workspaceGms[0].SetGlobalBuffer((__gm__ float *)workspace, Align(this->totalLength, sizeof(float)) * kvFactor); + workspaceGms[1].SetGlobalBuffer((__gm__ float *)workspace + Align(this->totalLength, sizeof(float)) * kvFactor, + Align(this->totalLength, sizeof(float)) * kvFactor); + workspaceGatheredSortNumGm_.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(this->totalLength, sizeof(int32_t)) * kvFactor * kvFactor, + MAX_MRGSORT_LIST_TOTAL); + expertCountTempGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->totalLength, sizeof(int32_t)) * 2, + expertEnd_ - expertStart_); + + int64_t bufferSize = Ceil(maxPerListElements * MAX_MRGSORT_LIST, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM * + sizeof(float) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, bufferSize); + pipe->InitBuffer(sortedBuffer, bufferSize); + pipe->InitBuffer(tempBuffer, bufferSize); +} + +__aicore__ inline void MoeSortMultiCorePerformance::Process() +{ + VMSProcess(); + SortOutProcess(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_VBS_ONE_CORE_PERFORMANCE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_one_core.h b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_one_core.h new file mode 100644 index 00000000000..a83ee7b4c62 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_custom_sort_one_core.h @@ -0,0 +1,167 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_custom_sort_one_core.h + * \brief + */ +#ifndef MOE_CUSTOM_SORT_ONE_CORE_H +#define MOE_CUSTOM_SORT_ONE_CORE_H + +#include "moe_custom_sort_base.h" + +namespace MoeInitRoutingCustom { +using namespace AscendC; + +class MoeSortOneCore : public MoeSortBase { +public: + __aicore__ inline MoeSortOneCore(){}; + __aicore__ inline void Init(GM_ADDR expertIdx, GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe); + __aicore__ inline void Process(); + +private: + __aicore__ inline void CopyIn(); + __aicore__ inline void SortCompute(); + __aicore__ inline void ExpertCountCompute(); + __aicore__ inline void CopyOut(); + +private: + int64_t sortNum; +}; + +__aicore__ inline void MoeSortOneCore::CopyIn() +{ + LocalTensor inLocal = sortDataCopyInQueue.AllocTensor(); + DataCopyExtParams dataCopyParams{static_cast(1), + static_cast(this->totalLength * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams dataCopyPadParams{false, 0, 0, 0}; + DataCopyPad(inLocal[0], expertIdxGm, dataCopyParams, dataCopyPadParams); + LocalTensor rowIdxLocal = inLocal[this->sortNum]; + ArithProgression(rowIdxLocal, 0, 1, this->sortNum); + sortDataCopyInQueue.EnQue(inLocal); +} + +__aicore__ inline void MoeSortOneCore::SortCompute() +{ + LocalTensor inLocal = sortDataCopyInQueue.DeQue(); + LocalTensor expertIdx = inLocal[0]; + LocalTensor expertIdxFp32 = expertIdx.ReinterpretCast(); + Cast(expertIdxFp32, expertIdx, RoundMode::CAST_ROUND, this->tileLength); + Muls(expertIdxFp32, expertIdxFp32, (float)-1, this->tileLength); + + if (ep_) { + LocalTensor maskLocalTensor = sortedBuffer.Get(); + AscendC::CompareScalar(maskLocalTensor, expertIdxFp32, static_cast(-expertStart_), AscendC::CMPMODE::GT, + (this->totalLength + ONE_REPEAT_COMPARE_NUM - 1) / ONE_REPEAT_COMPARE_NUM * + ONE_REPEAT_COMPARE_NUM); + LocalTensor floatMinLocalTensor = tempBuffer.Get(); + Duplicate(floatMinLocalTensor, MIN_FP32, this->tileLength); + Select(expertIdxFp32, maskLocalTensor, floatMinLocalTensor, expertIdxFp32, SELMODE::VSEL_TENSOR_TENSOR_MODE, + this->totalLength); + } + + int64_t duplicateNum = this->totalLength % ONE_REPEAT_SORT_NUM; + if (duplicateNum > 0) { + int duplicateIndex = this->totalLength - duplicateNum; + uint64_t mask0 = UINT64_MAX; + mask0 = mask0 << duplicateNum; + mask0 = mask0 & (UINT64_MAX >> ONE_REPEAT_SORT_NUM); + uint64_t mask[2] = {mask0, 0}; + Duplicate(expertIdxFp32[duplicateIndex], MIN_FP32, mask, 1, DST_BLK_STRIDE, DST_REP_STRIDE); + } + + LocalTensor concatLocal; + LocalTensor tempTensor = tempBuffer.Get(GetSortLen(this->sortNum)); + Concat(concatLocal, expertIdxFp32, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + + LocalTensor sortedLocal = sortedBuffer.Get(GetSortLen(this->sortNum)); + LocalTensor sourceRowLocal; + sourceRowLocal = inLocal[this->sortNum].ReinterpretCast(); + Sort(sortedLocal, concatLocal, sourceRowLocal, tempTensor, this->sortNum / ONE_REPEAT_SORT_NUM); + + LocalTensor outLocal = sortDataCopyOutQueue.AllocTensor(); + LocalTensor sortedExpertForSourceRowLocal = outLocal[0]; + LocalTensor expandDstToSrcRowLocal; + expandDstToSrcRowLocal = outLocal[this->sortNum].ReinterpretCast(); + Extract(sortedExpertForSourceRowLocal, expandDstToSrcRowLocal, sortedLocal, this->sortNum / ONE_REPEAT_SORT_NUM); + Muls(sortedExpertForSourceRowLocal, sortedExpertForSourceRowLocal, (float)-1, this->tileLength); + + LocalTensor expertForSourceRowLocalInt32; + expertForSourceRowLocalInt32 = sortedExpertForSourceRowLocal.ReinterpretCast(); + Cast(expertForSourceRowLocalInt32, sortedExpertForSourceRowLocal, RoundMode::CAST_ROUND, this->tileLength); + sortDataCopyOutQueue.EnQue(outLocal); + sortDataCopyInQueue.FreeTensor(inLocal); +} + +__aicore__ inline void MoeSortOneCore::CopyOut() +{ + LocalTensor outLocal = sortDataCopyOutQueue.DeQue(); + DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = this->totalLength * sizeof(int32_t); + DataCopyPad(sortedexpertIdxGm, outLocal[0], intriParams); + DataCopyPad(expendedRowIdxGm, outLocal[this->sortNum], intriParams); + sortDataCopyOutQueue.FreeTensor(outLocal); +} + +__aicore__ inline void MoeSortOneCore::Init(GM_ADDR expertIdx, GM_ADDR expendedRowIdx, GM_ADDR workspace, + const MoeInitRoutingCustomTilingData *tilingData, TPipe *tPipe) +{ + this->pipe = tPipe; + this->tileLength = Align(tilingData->vbsComputeParamsOp.lastCorePerLoopElements, sizeof(int32_t)); + this->sortNum = Ceil(this->tileLength, ONE_REPEAT_SORT_NUM) * ONE_REPEAT_SORT_NUM; + this->totalLength = tilingData->n * tilingData->k; + this->coreNum = tilingData->coreNum; + this->ep_ = tilingData->ep; + expertStart_ = tilingData->expertStart; + expertEnd_ = tilingData->expertEnd; + rowIdxType_ = tilingData->rowIdxType; + + expertIdxGm.SetGlobalBuffer((__gm__ int32_t *)expertIdx, this->tileLength); + sortedexpertIdxGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace), + Align(this->totalLength, sizeof(int32_t))); + if (rowIdxType_ == SCATTER) { + expendedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)expendedRowIdx, this->tileLength); + } else { + expendedRowIdxGm.SetGlobalBuffer((__gm__ int32_t *)workspace + Align(this->tileLength, sizeof(int32_t)), + Align(this->tileLength, sizeof(int32_t))); + } + + if (GetBlockIdx() == 0) { + expertCountTempGm.SetGlobalBuffer((__gm__ int32_t *)workspace + + Align(tilingData->n * tilingData->k, sizeof(int32_t)) * 2, + tilingData->actualExpertNum); + InitGlobalMemory(expertCountTempGm, tilingData->actualExpertNum, 0); + SetWaitFlag(HardEvent::MTE3_MTE2); + } + + int64_t coreNum = GetBlockNum(); + + // key and value + int64_t kvFactor = 2; + int64_t buffSize = this->sortNum * sizeof(int32_t) * kvFactor; + pipe->InitBuffer(sortDataCopyInQueue, bufferNum, buffSize); + pipe->InitBuffer(sortDataCopyOutQueue, bufferNum, buffSize); + pipe->InitBuffer(tempBuffer, buffSize); + pipe->InitBuffer(sortedBuffer, buffSize); +} + +__aicore__ inline void MoeSortOneCore::Process() +{ + if (GetBlockIdx() < 1) { + CopyIn(); + SortCompute(); + CopyOut(); + } + this->SyncAll(); +} +} // namespace MoeInitRoutingCustom +#endif // MOE_CUSTOM_SORT_ONE_CORE_H \ No newline at end of file diff --git a/csrc/moe_init_routing_custom/op_kernel/moe_init_routing_custom.cpp b/csrc/moe_init_routing_custom/op_kernel/moe_init_routing_custom.cpp new file mode 100644 index 00000000000..b91983aec51 --- /dev/null +++ b/csrc/moe_init_routing_custom/op_kernel/moe_init_routing_custom.cpp @@ -0,0 +1,412 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_init_routing_custom.cpp + * \brief + */ +#include "moe_custom_mrgsort_out.h" +#include "moe_custom_mrgsort.h" +#include "moe_custom_sort_one_core.h" +#include "moe_custom_sort_multi_core.h" +#include "moe_custom_gather_sort_multi_core.h" +#include "moe_custom_expert_tokens_count.h" +#include "moe_custom_row_idx_gather.h" +#include "moe_custom_gather_out.h" +#include "moe_custom_gather_dynamic_quant.h" +#include "moe_custom_gather_static_quant.h" +#include "moe_custom_full_load.h" +#include "moe_custom_full_load_dynamic_quant.h" +#include "moe_custom_full_load_static_quant.h" +#include "moe_custom_full_load_unquantized.h" +#include "moe_custom_sort_actual_expert.h" +#include "moe_custom_sort_multi_core_performance.h" +#include "moe_custom_row_idx_gather_droppad_dynamic.h" +#include "moe_custom_row_idx_gather_droppad.h" +#include "moe_custom_gather_out_droppad.h" +#include "moe_custom_gather_droppad_static_quant.h" + +#define MOE_INIT_ROUTING_CUSTOM_PERFORMANCE 2000000 +#define UNQUANTIZED_FULLLOAD 2100000 +#define STATIC_QUANT_FULLLOAD 2200000 +#define DYNAMIC_QUANT_GATHER_NO_SCALE_FULLLOAD 2300000 +#define DYNAMIC_QUANT_GATHER_1H_DIM_SCALE_FULLLOAD 2301000 +#define DYNAMIC_QUANT_GATHER_EH_SCALE_FULLLOAD 2302000 +#define DYNAMIC_QUANT_SCATTER_NO_SCALE_FULLLOAD 2310000 +#define DYNAMIC_QUANT_SCATTER_1H_SCALE_FULLLOAD 2311000 +#define DYNAMIC_QUANT_SCATTER_EH_SCALE_FULLLOAD 2312000 + +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_NODROP 1000000 +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_SCATTER_NODROP 1001000 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_NODROP 1100000 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_SCATTER_NODROP 1101000 + +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_GATHER_NODROP 1020000 +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_SCATTER_NODROP 1021000 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_GATHER_NODROP 1120000 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_SCATTER_NODROP 1121000 + +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_NODROP 1010000 +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_SCATTER_NODROP 1011000 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_NODROP 1110000 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_SCATTER_NODROP 1111000 + +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_DROP 1000100 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_DROP 1100100 +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_GATHER_DROP 1020100 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_GATHER_DROP 1120100 +#define MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_DROP 1010100 +#define MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_DROP 1110100 + +#define MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_GATHER 1200000 +#define MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_SCATTER 1201000 +#define MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_GATHER 1300000 +#define MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_SCATTER 1301000 + + +using namespace AscendC; +using namespace MoeInitRoutingCustom; +extern "C" __global__ __aicore__ void moe_init_routing_custom(GM_ADDR x, GM_ADDR expertIdx, GM_ADDR scale, GM_ADDR offset, + GM_ADDR expandedX, GM_ADDR expandedRowIdx, + GM_ADDR expertTokensCountOrCumsum, GM_ADDR expandedScale, + GM_ADDR workspace, GM_ADDR tiling) +{ + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0); + if (g_coreType == AIC) { + return; + } + + GET_TILING_DATA(tilingData, tiling); + if (workspace == nullptr) { + return; + } + + GM_ADDR userWS = GetUserWorkspace(workspace); + if (userWS == nullptr) { + return; + } + + auto t = &tilingData; + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_PERFORMANCE)) { + TPipe fullLoadPipe; + MoeCustomFullLoad op; + op.Init(x, expertIdx, scale, offset, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + return; + } + + if (TILING_KEY_IS(DYNAMIC_QUANT_GATHER_NO_SCALE_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadDynamicQuant op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(DYNAMIC_QUANT_GATHER_1H_DIM_SCALE_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadDynamicQuant op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(DYNAMIC_QUANT_GATHER_EH_SCALE_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadDynamicQuant op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(DYNAMIC_QUANT_SCATTER_NO_SCALE_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadDynamicQuant op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(DYNAMIC_QUANT_SCATTER_1H_SCALE_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadDynamicQuant op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(DYNAMIC_QUANT_SCATTER_EH_SCALE_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadDynamicQuant op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(UNQUANTIZED_FULLLOAD)) { + TPipe fullLoadPipe; + MoeCustomFullLoadUnquantized op; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + return; + } + + if (TILING_KEY_IS(STATIC_QUANT_FULLLOAD)) { + if constexpr (!IsSameType::value) { + TPipe fullLoadPipe; + MoeCustomFullLoadStaticQuant op; + op.Init(x, expertIdx, scale, offset, expandedX, expandedRowIdx, expertTokensCountOrCumsum, userWS, t, + &fullLoadPipe); + op.Process(); + fullLoadPipe.Destroy(); + } + return; + } + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_GATHER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_SCATTER)) { + TPipe sortActualExpertPipe; + MoeSortActualExpert op; + bool isFinished = false; + op.Init(x, expertIdx, scale, expandedX, expandedRowIdx, expertTokensCountOrCumsum, expandedScale, userWS, t, + &sortActualExpertPipe); + isFinished = op.Process(); + sortActualExpertPipe.Destroy(); + if (isFinished) { + return; + } + } + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_GATHER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_SCATTER)) { + TPipe gatherSortMultiCorePipe; + MoeGatherSortMultiCore op; + op.Init(expertIdx, expandedRowIdx, userWS, t, &gatherSortMultiCorePipe); + op.Process(); + gatherSortMultiCorePipe.Destroy(); + } + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_GATHER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_SCATTER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_GATHER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_SCATTER)) { + TPipe mergeSortMultiCorePipe; + MoeSortMultiCorePerformance op; + op.Init(expandedRowIdx, userWS, t, &mergeSortMultiCorePipe); + op.Process(); + mergeSortMultiCorePipe.Destroy(); + } + + TPipe sortPipe; + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_DROP)) { + MoeSortOneCore op; + op.Init(expertIdx, expandedRowIdx, userWS, t, &sortPipe); + op.Process(); + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_DROP)) { + MoeSortMultiCore op; + op.Init(expertIdx, expandedRowIdx, userWS, t, &sortPipe); + op.Process(); + } + sortPipe.Destroy(); + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_GATHER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_SCATTER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_GATHER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_SCATTER)) { + TPipe histogramPipe; + if (t->expertTokensNumType == CUMSUM_MODE) { + ExpertTokensCount countOp; + countOp.Init(expandedRowIdx, expertTokensCountOrCumsum, userWS, t, &histogramPipe); + countOp.Process(); + histogramPipe.Destroy(); + } else if (t->expertTokensNumType == COUNT_MODE) { + ExpertTokensCount countOp; + countOp.Init(expandedRowIdx, expertTokensCountOrCumsum, userWS, t, &histogramPipe); + countOp.Process(); + histogramPipe.Destroy(); + } else { + ExpertTokensCount countOp; + countOp.Init(expandedRowIdx, expertTokensCountOrCumsum, userWS, t, &histogramPipe); + countOp.Process(); + histogramPipe.Destroy(); + } + + } else { + if (t->dropPadMode == 1 || t->ep == 1 || t->expertTokensNumFlag != EXERPT_TOKENS_NONE) { + TPipe histogramPipe; + if (t->expertTokensNumType == CUMSUM_MODE) { + ExpertTokensCount countOp; + countOp.Init(expandedRowIdx, expertTokensCountOrCumsum, userWS, t, &histogramPipe); + countOp.Process(); + histogramPipe.Destroy(); + } else if (t->expertTokensNumType == COUNT_MODE) { + ExpertTokensCount countOp; + countOp.Init(expandedRowIdx, expertTokensCountOrCumsum, userWS, t, &histogramPipe); + countOp.Process(); + histogramPipe.Destroy(); + } else { + ExpertTokensCount countOp; + countOp.Init(expandedRowIdx, expertTokensCountOrCumsum, userWS, t, &histogramPipe); + countOp.Process(); + histogramPipe.Destroy(); + } + } + } + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_DROP)) { + TPipe rowIdxGatherDropPadPipe; + MoeCustomSrcToDstWithCapacity rowIdxGatherDropPadOp; + rowIdxGatherDropPadOp.Init(expandedRowIdx, expandedX, expandedScale, userWS, t, &rowIdxGatherDropPadPipe); + rowIdxGatherDropPadOp.Process(); + rowIdxGatherDropPadPipe.Destroy(); + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_DROP)) { + TPipe rowIdxGatherDropPadPipe; + MoeCustomSrcToDstWithCapacity rowIdxGatherDropPadOp; + rowIdxGatherDropPadOp.Init(expandedRowIdx, expandedX, expandedScale, userWS, t, &rowIdxGatherDropPadPipe); + rowIdxGatherDropPadOp.Process(); + rowIdxGatherDropPadPipe.Destroy(); + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_GATHER_DROP)) { + if constexpr (!IsSameType::value) { + TPipe gatherPipe; + MoeCustomSrcToDstAndGather gatherDroppadDynamicQuantOp; + gatherDroppadDynamicQuantOp.Init(x, scale, expandedRowIdx, expandedX, expandedScale, userWS, t, + &gatherPipe); + gatherDroppadDynamicQuantOp.Process(); + gatherPipe.Destroy(); + } + } else { + TPipe rowIdxPipe; + RowIdxGather rowIdxGatherOp; + rowIdxGatherOp.Init(expandedRowIdx, userWS, t, &rowIdxPipe); + rowIdxGatherOp.Process(); + rowIdxPipe.Destroy(); + } + + if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTONECORE_SCATTER) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_GATHER_SORTMULTICORE_SCATTER)) { + TPipe gatherPipe; + if (t->ep == 1) { + MoeGatherOut gatherOp; + gatherOp.Init(x, scale, userWS, expandedRowIdx, expandedX, expandedScale, t, &gatherPipe); + gatherOp.Process(); + gatherPipe.Destroy(); + } else { + MoeGatherOut gatherOp; + gatherOp.Init(x, scale, userWS, expandedRowIdx, expandedX, expandedScale, t, &gatherPipe); + gatherOp.Process(); + gatherPipe.Destroy(); + } + + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_DYNAMICQUANT_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_DYNAMICQUANT_GATHER_NODROP)) { + if constexpr (!IsSameType::value) { + TPipe gatherPipe; + if (t->ep == 0 and t->smoothType != SCALE_EH) { + MoeGatherOutDynamicQuant gatherDynamicQuantOp; + gatherDynamicQuantOp.Init(x, scale, userWS, expandedRowIdx, expandedX, expandedScale, t, &gatherPipe); + gatherDynamicQuantOp.Process(); + gatherPipe.Destroy(); + } else { + MoeGatherOutDynamicQuant gatherDynamicQuantOp; + gatherDynamicQuantOp.Init(x, scale, userWS, expandedRowIdx, expandedX, expandedScale, t, &gatherPipe); + gatherDynamicQuantOp.Process(); + gatherPipe.Destroy(); + } + } + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_SCATTER_NODROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_NODROP)) { + if constexpr (!IsSameType::value) { + TPipe gatherPipe; + if (t->ep == 1) { + MoeGatherOutQuant gatherStaticQuantOp; + gatherStaticQuantOp.Init(x, scale, offset, expandedRowIdx, expandedX, userWS, t, &gatherPipe); + gatherStaticQuantOp.Process(); + gatherPipe.Destroy(); + } else { + MoeGatherOutQuant gatherStaticQuantOp; + gatherStaticQuantOp.Init(x, scale, offset, expandedRowIdx, expandedX, userWS, t, &gatherPipe); + gatherStaticQuantOp.Process(); + gatherPipe.Destroy(); + } + } + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_GATHER_DROP)) { + TPipe gatherPipe; + MoeGatherOutDroppad gatherDroppadOp; + gatherDroppadOp.Init(x, scale, expandedRowIdx, expandedX, expandedScale, userWS, t, &gatherPipe); + gatherDroppadOp.Process(); + gatherPipe.Destroy(); + } else if (TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTONECORE_QUANT_GATHER_DROP) || + TILING_KEY_IS(MOE_INIT_ROUTING_CUSTOM_SORTMULTICORE_QUANT_GATHER_DROP)) { + if constexpr (!IsSameType::value) { + TPipe gatherPipe; + MoeGatherDroppadQuant gatherDroppadStaticQuantOp; + gatherDroppadStaticQuantOp.Init(x, scale, offset, expandedRowIdx, expandedX, userWS, t, &gatherPipe); + gatherDroppadStaticQuantOp.Process(); + gatherPipe.Destroy(); + } + } +} \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 96b6205e674..a8077b6747b 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -1118,6 +1118,106 @@ at::Tensor combine_prefill(const at::Tensor& x, const at::Tensor& topk_idx, cons return combined_x; } +std::tuple npu_moe_init_routing_custom( + const at::Tensor &x, const at::Tensor &expert_idx, + const c10::optional &scale, const c10::optional &offset, int64_t active_num, + int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type, + bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type) +{ + constexpr int64_t DIM_X = 2; + constexpr int64_t DIM_EXPERT_IDX = 2; + constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2; + constexpr int64_t EXPERT_TOKENS_COUNT = 1; + constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2; + constexpr int64_t QUANT_MODE_UNQUANT = -1; + constexpr int64_t QUANT_MODE_DYNAMIC_QUANT = 1; + constexpr int64_t CUMSUM = 0; + constexpr int64_t COUNT = 1; + constexpr int64_t KEY_VALUE = 2; + + if (active_expert_range.empty()) { + active_expert_range = at::IntArrayRef({0, expert_num}); + } + + int64_t x_dim = x.dim(); + TORCH_CHECK(x_dim == DIM_X, "The x should be ", DIM_X, + "-Dimension, current is ", x_dim, "-Dimension."); + + int64_t expert_idx_dim = expert_idx.dim(); + TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX, "The expert_idx should be ", DIM_EXPERT_IDX, + "-Dimension, current is ", expert_idx_dim, "-Dimension."); + + int64_t active_expert_range_length = active_expert_range.size(); + TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE, "The active_expert_range should be ", LENGTH_ACTIVE_EXPERT_RANGE, + "-Dimension, current is ", expert_idx_dim, "-Dimension."); + + int expert_length = active_expert_range[1] - active_expert_range[0]; + auto x_size = x.sizes(); + auto expert_idx_size = expert_idx.sizes(); + + int bs = x_size[0]; + int h = x_size[1]; + int k = expert_idx_size[1]; + int64_t expanded_scale_len = 0; + at::Tensor expanded_x; + + if (drop_pad_mode == 1) { // Drop/Pad + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({expert_num, expert_capacity, h}, x.options()); + } else { + expanded_x = at::empty({expert_num, expert_capacity, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = expert_num * expert_capacity; + } else { // Dropless / Active + if (active_num > 0) { // Active + int64_t num_out_tokens = std::min((int64_t)bs * k, active_num); + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({num_out_tokens, h}, x.options()); + } else { + expanded_x = at::empty({num_out_tokens, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = num_out_tokens; + } else { // Dropless + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({bs * k, h}, x.options()); + } else { + expanded_x = at::empty({bs * k, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = bs * k; + } + } + + at::Tensor expanded_row_idx = at::empty({bs * k}, expert_idx.options()); + at::Tensor expert_tokens_count_or_cumsum; + if (expert_tokens_num_type >= CUMSUM && expert_tokens_num_type <= COUNT) { + // expert_tokens_count_or_cumsum in [end-start, ] + expert_tokens_count_or_cumsum = at::empty({expert_length}, x.options().dtype(at::kLong)); + } else if (expert_tokens_num_type == KEY_VALUE) { + // key_value in [2, end-start] + expert_tokens_count_or_cumsum = at::empty({expert_num, 2}, x.options().dtype(at::kLong)); + } + at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat)); + EXEC_NPU_CMD(aclnnMoeInitRoutingCustom, + x, + expert_idx, + scale, + offset, + active_num, + expert_capacity, + expert_num, + drop_pad_mode, + expert_tokens_num_type, + expert_tokens_num_flag, + quant_mode, + active_expert_range, + row_idx_type, + expanded_x, + expanded_row_idx, + expert_tokens_count_or_cumsum, + expanded_scale); + return std::tie(expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale); +} + } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -1257,4 +1357,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) "num_ranks) -> Tensor"); ops.impl("combine_prefill", torch::kPrivateUse1, &vllm_ascend::combine_prefill); + ops.def( + "npu_moe_init_routing_custom(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, " + " int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, " + " bool expert_tokens_num_flag=False, int quant_mode=0, int[2] active_expert_range=[], " + " int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)" + ); + ops.impl("npu_moe_init_routing_custom", torch::kPrivateUse1, &vllm_ascend::npu_moe_init_routing_custom); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index e114a981f69..af2c237a224 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -283,6 +283,89 @@ std::tuple matmul_allreduce_add_rmsnorm_meta( return {output, add_out}; } +std::tuple npu_moe_init_routing_custom_meta( + const at::Tensor &x, const at::Tensor &expert_idx, + const c10::optional &scale, const c10::optional &offset, int64_t active_num, + int64_t expert_capacity, int64_t expert_num, int64_t drop_pad_mode, int64_t expert_tokens_num_type, + bool expert_tokens_num_flag, int64_t quant_mode, at::IntArrayRef active_expert_range, int64_t row_idx_type) +{ + constexpr int64_t DIM_X = 2; + constexpr int64_t DIM_EXPERT_IDX = 2; + constexpr int64_t LENGTH_ACTIVE_EXPERT_RANGE = 2; + constexpr int64_t EXPERT_TOKENS_COUNT = 1; + constexpr int64_t EXPERT_TOKENS_KEY_VALUE = 2; + constexpr int64_t QUANT_MODE_UNQUANT = -1; + constexpr int64_t QUANT_MODE_DYNAMIC_QUANT = 1; + constexpr int64_t CUMSUM = 0; + constexpr int64_t COUNT = 1; + constexpr int64_t KEY_VALUE = 2; + + if (active_expert_range.empty()) { + active_expert_range = at::IntArrayRef({0, expert_num}); + } + + int64_t x_dim = x.dim(); + TORCH_CHECK(x_dim == DIM_X, "The x should be ", DIM_X, + "-Dimension, current is ", x_dim, "-Dimension."); + + int64_t expert_idx_dim = expert_idx.dim(); + TORCH_CHECK(expert_idx_dim == DIM_EXPERT_IDX, "The expert_idx should be ", DIM_EXPERT_IDX, + "-Dimension, current is ", expert_idx_dim, "-Dimension."); + + int64_t active_expert_range_length = active_expert_range.size(); + TORCH_CHECK(active_expert_range_length == LENGTH_ACTIVE_EXPERT_RANGE, "The active_expert_range should be ", LENGTH_ACTIVE_EXPERT_RANGE, + "-Dimension, current is ", expert_idx_dim, "-Dimension."); + + int expert_length = active_expert_range[1] - active_expert_range[0]; + auto x_size = x.sizes(); + auto expert_idx_size = expert_idx.sizes(); + + int bs = x_size[0]; + int h = x_size[1]; + int k = expert_idx_size[1]; + int64_t expanded_scale_len = 0; + at::Tensor expanded_x; + + if (drop_pad_mode == 1) { // Drop/Pad + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({expert_num, expert_capacity, h}, x.options()); + } else { + expanded_x = at::empty({expert_num, expert_capacity, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = expert_num * expert_capacity; + } else { // Dropless / Active + if (active_num > 0) { // Active + int64_t num_out_tokens = std::min((int64_t)bs * k, active_num); + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({num_out_tokens, h}, x.options()); + } else { + expanded_x = at::empty({num_out_tokens, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = num_out_tokens; + } else { // Dropless + if (quant_mode == QUANT_MODE_UNQUANT) { + expanded_x = at::empty({bs * k, h}, x.options()); + } else { + expanded_x = at::empty({bs * k, h}, x.options().dtype(at::kChar)); + } + expanded_scale_len = bs * k; + } + } + + at::Tensor expanded_row_idx = at::empty({bs * k}, expert_idx.options()); + at::Tensor expert_tokens_count_or_cumsum; + if (expert_tokens_num_type >= CUMSUM && expert_tokens_num_type <= COUNT) { + // expert_tokens_count_or_cumsum in [end-start, ] + expert_tokens_count_or_cumsum = at::empty({expert_length}, x.options().dtype(at::kLong)); + } else if (expert_tokens_num_type == KEY_VALUE) { + // key_value in [2, end-start] + expert_tokens_count_or_cumsum = at::empty({expert_num, 2}, x.options().dtype(at::kLong)); + } + + at::Tensor expanded_scale = at::empty({expanded_scale_len}, x.options().dtype(at::kFloat)); + return {expanded_x, expanded_row_idx, expert_tokens_count_or_cumsum, expanded_scale}; +} + } // namespace meta } // namespace vllm_ascend @@ -316,5 +399,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("dispatch_ffn_combine", &vllm_ascend::meta::dispatch_ffn_combine_meta); // matmul allreduce add rmsnorm ops.impl("matmul_allreduce_add_rmsnorm", &vllm_ascend::meta::matmul_allreduce_add_rmsnorm_meta); + // moe_init_routing_custom + ops.impl("npu_moe_init_routing_custom", &vllm_ascend::meta::npu_moe_init_routing_custom_meta); } } diff --git a/tests/e2e/nightly/ops/test_moe_init_routing_custom.py b/tests/e2e/nightly/ops/test_moe_init_routing_custom.py new file mode 100644 index 00000000000..05fd5112985 --- /dev/null +++ b/tests/e2e/nightly/ops/test_moe_init_routing_custom.py @@ -0,0 +1,349 @@ +import itertools +import random + +import numpy as np +import torch + +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +def adapter_capacity(sorted_row_idx, sorted_expert_idx, capacity): + count = 0 + last = sorted_expert_idx[0] + for i, val in enumerate(sorted_expert_idx): + if last != val: + count = 1 + last = val + else: + count += 1 + if count > capacity: + sorted_expert_idx[i] = -1 + sorted_row_idx[i] = -1 + + +def moe_init_routing_golden(x, expert_idx, scale, offset, active_num, + expert_capacity, expert_num, drop_pad_mode, + expert_tokens_num_type, expert_tokens_num_flag, + active_expert_range, quant_mode, row_idx_type): + if drop_pad_mode == 1: + if expert_num <= 0: + print("expert num can not be 0") + return + expert_start = active_expert_range[0] if drop_pad_mode == 0 else 0 + expert_end = active_expert_range[1] if drop_pad_mode == 0 else expert_num + num_rows = x.shape[0] + h = x.shape[1] + k = expert_idx.shape[-1] + expert_idx_in = expert_idx.copy().reshape(-1) + actual_expert_total_num: int = np.sum((expert_idx_in >= expert_start) + & (expert_idx_in < expert_end)) + + expert_idx_in[(expert_idx_in + < expert_start)] = np.int32(np.iinfo(np.int32).max) + sorted_expert_indices = np.argsort(expert_idx_in, axis=-1, kind="stable") + sorted_expert_idx = expert_idx_in[sorted_expert_indices] + if row_idx_type == 1: + expanded_row_idx = sorted_expert_indices[:actual_expert_total_num] + else: + expanded_row_idx = np.ones(num_rows * k).astype(np.int32) * -1 + tmp_indices = np.arange(actual_expert_total_num) + expanded_row_idx[ + sorted_expert_indices[:actual_expert_total_num]] = tmp_indices + + if not expert_tokens_num_flag: + expert_tokens_count = torch.tensor([0]) + else: + if drop_pad_mode == 0: + if expert_tokens_num_type == 1: + expert_tokens_count = np.bincount( + sorted_expert_idx[:actual_expert_total_num] - expert_start) + expert_tokens_count = np.concatenate([ + expert_tokens_count, + np.zeros((expert_end - expert_start) - + len(expert_tokens_count)).astype(np.int64) + ]) + elif expert_tokens_num_type == 0: + expert_tokens_count = np.bincount( + sorted_expert_idx[:actual_expert_total_num] - expert_start) + expert_tokens_count = np.concatenate([ + expert_tokens_count, + np.zeros((expert_end - expert_start) - + len(expert_tokens_count)).astype(np.int64) + ]) + expert_tokens_count = np.cumsum(expert_tokens_count) + elif expert_tokens_num_type == 2: + expert_id, counts = np.unique( + sorted_expert_idx[:actual_expert_total_num], + return_counts=True) + expert_tokens_count = np.column_stack((expert_id, counts)) + if expert_tokens_count.shape[0] < expert_num: + expert_tokens_count = np.concatenate( + (expert_tokens_count, [ + [0, 0], + ]), axis=0) + else: + expert_tokens_count = np.bincount( + sorted_expert_idx[:actual_expert_total_num] - expert_start) + zeros_array = np.zeros( + (expert_end - expert_start) - len(expert_tokens_count), + dtype=np.int64) + expert_tokens_count = np.concatenate( + [expert_tokens_count, zeros_array]) + expert_tokens_count = expert_tokens_count.astype(np.int64) + + if drop_pad_mode == 0: + if active_num == 0: + active_num = actual_expert_total_num + else: + active_num = min(active_num, actual_expert_total_num) + expanded_scale = None + expanded_x = x[sorted_expert_indices[:active_num] // k, :] + if scale is not None and quant_mode == -1: + expanded_scale = scale[sorted_expert_indices[:active_num] // k] + else: + adapter_capacity(sorted_expert_indices, sorted_expert_idx, + expert_capacity) + + sort_row_tmp = np.full((expert_num * expert_capacity), -1, dtype=int) + offset_tmp = 0 + lastExpertId = 0 + for i, val in enumerate(sorted_expert_indices): + if val != -1: + if lastExpertId != sorted_expert_idx[i]: + offset_tmp = 0 + lastExpertId = sorted_expert_idx[i] + sort_row_tmp[sorted_expert_idx[i] * expert_capacity + + offset_tmp] = sorted_expert_indices[i] + offset_tmp = offset_tmp + 1 + + expanded_row_idx = np.full(sorted_expert_indices.shape, -1) + for i, val in enumerate(sort_row_tmp): + if val != -1: + expanded_row_idx[val] = i + + expanded_x_mask = np.full((expert_num * expert_capacity, h), + 1, + dtype=int) + expanded_x = np.full((expert_num * expert_capacity, h), + 0, + dtype=x.dtype) + for i, val in enumerate(sort_row_tmp): + if val != -1: + expanded_x[i] = x[val // k] + expanded_x_mask[i] = np.full((h, ), 0, dtype=int) + + if quant_mode == -1: + expanded_x = expanded_x + expanded_row_idx = expanded_row_idx + if scale is not None and drop_pad_mode == 1: + expanded_scale = np.full((expert_num * expert_capacity, ), + 0, + dtype=scale.dtype) + for i, val in enumerate(sort_row_tmp): + if val != -1: + expanded_scale[i] = scale[val // k] + if scale is None: + expanded_scale = None + + if quant_mode == 0: + expanded_scale = None + expanded_x_fp16 = expanded_x.astype(np.float16) + if scale is not None: + scale_val = scale.astype(np.float16) + else: + raise ValueError("scale cannot be None when quant_mode is 0") + if offset is not None: + offset_val = offset.astype(np.float16) + else: + raise ValueError("offset cannot be None when quant_mode is 0") + scale_rst = expanded_x_fp16 * scale_val[0] + add_offset = scale_rst + offset_val[0] + round_data = np.rint(add_offset) + round_data = np.clip(round_data, -128, 127) + expanded_x = round_data.astype(np.int8) + + if quant_mode == 1: + x_final = expanded_x.astype(np.float32) + if scale is None: + x_abs = np.abs(x_final) + x_max = np.max(x_abs, axis=-1, keepdims=True) + expanded_scale = x_max / 127 + expanded_x = x_final / expanded_scale + expanded_x = np.round(expanded_x).astype(np.int8) + else: + if scale.shape[0] == 1: + x_final = x_final * scale + else: + if drop_pad_mode == 0: + x_final = x_final * scale[sorted_expert_idx[:active_num] - + expert_start] + + else: + for i, val in enumerate(sort_row_tmp): + if val != -1: + x_final[i] = x_final[i] * scale[i // + expert_capacity] + x_abs = np.abs(x_final) + x_max = np.max(x_abs, axis=-1, keepdims=True) + expanded_scale = x_max / 127 + expanded_x = x_final / expanded_scale + expanded_x = np.round(expanded_x).astype(np.int8) + if x.dtype == np.int8: + expanded_scale = None + if drop_pad_mode == 1: + expanded_x = np.ma.array(expanded_x, mask=expanded_x_mask).filled(0) + expanded_x = expanded_x.reshape(expert_num, expert_capacity, h) + + return expanded_x, expanded_row_idx, expert_tokens_count, expanded_scale + + +def npu_pta(x, expert_idx, scale, offset, active_num, expert_capacity, + expert_num, drop_pad_mode, expert_tokens_num_type, + expert_tokens_num_flag, quant_mode, active_expert_range, + row_idx_type): + expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch.ops._C_ascend.npu_moe_init_routing_custom( + x, + expert_idx, + scale=scale, + offset=offset, + active_num=active_num, + expert_capacity=expert_capacity, + expert_num=expert_num, + drop_pad_mode=drop_pad_mode, + expert_tokens_num_type=expert_tokens_num_type, + expert_tokens_num_flag=expert_tokens_num_flag, + quant_mode=quant_mode, + active_expert_range=active_expert_range, + row_idx_type=row_idx_type) + + return expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale + + +def cmp_out_golden(x_golden, x_out, dtype): + if dtype == 'int8': + cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)], x_golden, atol=1) + else: + cmp = np.isclose(x_out.cpu().numpy()[:len(x_golden)], + x_golden, + rtol=1e-05, + atol=1e-05) + return np.all(cmp) + + +def test_moe_npu(x, expert_idx, scale, offset, active_num, expert_capacity, + expert_num, drop_pad_mode, expert_tokens_num_type, + expert_tokens_num_flag, quant_mode, active_expert_range, + row_idx_type): + x_npu = x.npu() + expert_idx_npu = expert_idx.npu() + scale_npu = scale.npu() if scale is not None else None + offset_npu = offset.npu() if offset is not None else None + + x_numpy = x.numpy() + expert_idx_numpy = expert_idx.numpy() + scale_numpy = scale.numpy() if scale is not None else None + offset_numpy = offset.numpy() if offset is not None else None + + expanded_x_golden, expanded_row_idx_golden, expert_token_cumsum_or_count_golden, expanded_scale_golden = moe_init_routing_golden( + x_numpy, expert_idx_numpy, scale_numpy, offset_numpy, active_num, + expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type, + expert_tokens_num_flag, active_expert_range, quant_mode, row_idx_type) + + expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = npu_pta( + x_npu, expert_idx_npu, scale_npu, offset_npu, active_num, + expert_capacity, expert_num, drop_pad_mode, expert_tokens_num_type, + expert_tokens_num_flag, quant_mode, active_expert_range, row_idx_type) + if quant_mode == -1: + expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x, + "float32") + else: + expanded_x_result = cmp_out_golden(expanded_x_golden, expanded_x, + "int8") + + expanded_row_idx_result = cmp_out_golden(expanded_row_idx_golden, + expanded_row_idx, "int32") + + if expert_tokens_num_flag: + expert_tokens_result = cmp_out_golden( + expert_token_cumsum_or_count_golden, expert_token_cumsum_or_count, + "int64") + else: + expert_tokens_result = True + + if quant_mode == 1 or (quant_mode == -1 and scale is not None): + expand_scale_result = cmp_out_golden(expanded_scale_golden.flatten(), + expanded_scale, "float32") + else: + expand_scale_result = True + + compare_result = expanded_x_result and expanded_row_idx_result and expert_tokens_result and expand_scale_result + # print('=======case result=======: ', compare_result) + return compare_result + + +def test_moe_init_routing_custom(): + failed_test_cnt = 0 + drop_pad_mode = [0, 1] + expert_tokens_num_type = [0, 1, 2] + expert_tokens_num_flag = [True, False] + quant_mode = [0, 1, -1] + row_idx_type = [0, 1] + scale_type = [0, 1, 2] + product_result = itertools.product(drop_pad_mode, expert_tokens_num_type, + expert_tokens_num_flag, quant_mode, + row_idx_type, scale_type) + + for idx, (drop_pad_mode_, expert_tokens_num_type_, expert_tokens_num_flag_, + quant_mode_, row_idx_type_, + scale_type_) in enumerate(product_result, 5): + expert_num_ = random.randint(2, 500) + expert_start = random.randint(0, expert_num_ - 1) + expert_end = random.randint(expert_start + 1, expert_num_) + active_expert_range_ = [expert_start, expert_end] + + N = random.randint(1, 100) + H = random.randint(12, 100) + K = random.randint(1, 12) + x_ = torch.randn(N, H, dtype=torch.float16) * 5 + expert_capacity_ = random.randint(1, N - 1) if N > 1 else 1 + expert_idx_ = torch.randint(0, + expert_num_ - 1, (N, K), + dtype=torch.int32) + active_num_ = N * K + + if drop_pad_mode_ == 1: + active_expert_range_ = [0, expert_num_] + expert_tokens_num_type_ = 1 + row_idx_type_ = 0 + + if quant_mode_ == 0: + scale_ = torch.randn(1, dtype=torch.float) + offset_ = torch.randn(1, dtype=torch.float) + elif quant_mode_ == -1: + scale_ = None + offset_ = None + else: + if scale_type_ == 0: + scale_ = None + offset_ = None + elif scale_type_ == 1: + scale_ = torch.randn(1, H, dtype=torch.float) + offset_ = None + else: + scale_ = torch.randn(active_expert_range_[1] - + active_expert_range_[0], + H, + dtype=torch.float) + offset_ = None + + result_pta = test_moe_npu(x_, expert_idx_, scale_, offset_, + active_num_, expert_capacity_, expert_num_, + drop_pad_mode_, expert_tokens_num_type_, + expert_tokens_num_flag_, quant_mode_, + active_expert_range_, row_idx_type_) + if not result_pta: + failed_test_cnt += 1 + + assert (failed_test_cnt == 0)