Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/build_aclnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,28 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
TARGET_FILE="$TARGET_DIR/$(basename "$HCCL_STRUCT_FILE_PATH")"
# for dispatch_ffn_combine_bf16
SCRIPT_DIR_BF16=$(cd "$(dirname "$0")" && pwd)
TARGET_DIR_BF16="$SCRIPT_DIR_BF16/dispatch_ffn_combine_bf16/op_kernel/utils/"
TARGET_FILE_BF16="$TARGET_DIR_BF16/$(basename "$HCCL_STRUCT_FILE_PATH")"

echo "*************************************"
echo $HCCL_STRUCT_FILE_PATH
echo "$TARGET_DIR"
cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR"
cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR_BF16"

sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE_BF16"
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE_BF16"

CUSTOM_OPS_ARRAY=(
"grouped_matmul_swiglu_quant_weight_nz_tensor_list"
"lightning_indexer"
"sparse_flash_attention"
"dispatch_ffn_combine"
"dispatch_ffn_combine_bf16"
"dispatch_gmm_combine_decode"
"moe_combine_normal"
"moe_dispatch_normal"
Expand Down
66 changes: 66 additions & 0 deletions csrc/dispatch_ffn_combine_bf16/op_host/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================

set(_DISPATCH_FFN_INC_OPTS)
if (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/aarch64-linux/ascendc/include)
elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/arm64-linux/ascendc/include)
elseif (EXISTS ${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${ASCEND_CANN_PACKAGE_PATH}/${CMAKE_SYSTEM_PROCESSOR}-linux/ascendc/include)
endif()
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
list(APPEND _DISPATCH_FFN_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
endif()

add_ops_compile_options(
OP_NAME DispatchFFNCombineBF16
OPTIONS --cce-auto-sync=on
-Wno-deprecated-declarations
-Werror
-DHCCL_COMM
${_DISPATCH_FFN_INC_OPTS}
)

target_sources(op_host_aclnnInner PRIVATE
dispatch_ffn_combine_bf16_def.cpp
)

target_sources(opapi PRIVATE
aclnn_dispatch_ffn_combine_bf16.cpp
)

if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_dispatch_ffn_combine_bf16.cpp
)

target_sources(aclnn_ops_infer PRIVATE
aclnn_dispatch_ffn_combine_bf16.cpp
)
endif ()

target_sources(optiling PRIVATE
dispatch_ffn_combine_bf16_tiling.cpp
)

target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/../op_kernel
)

target_sources(opsproto PRIVATE
dispatch_ffn_combine_bf16_proto.cpp
)

file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_ffn_combine_bf16.h")

install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "aclnn_dispatch_ffn_combine_bf16.h"
#include <algorithm>
// #include "aclnn_kernels/common/op_error_check.h"
// #include "opdev/op_log.h"
// #include "opdev/common_types.h"
// #include "opdev/platform.h"
// #include "ophost/matmul_util.h"
#include <unistd.h>
#include <vector>
#include <string>
#include <iostream>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/file.h>
#include <climits>
#include "../op_host/error_log.h"
// using namespace op;

// using namespace op;

#ifdef __cplusplus
extern "C" {
#endif

static constexpr size_t TWO_DIMS = 2;
static constexpr int64_t KVALUE_MIN = 256;
static constexpr int64_t KVALUE_MAX = 65535;
static constexpr size_t HCCL_GROUP_NAME_MAX = 128U;
enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};

extern aclnnStatus aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
bool transB, bool weightNz,
const aclTensor* out,
uint64_t* workspaceSize, aclOpExecutor** executor);
extern aclnnStatus aclnnInnerDispatchFFNCombineBF16(void *workspace, uint64_t workspaceSize,
aclOpExecutor *executor, aclrtStream stream);
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);



aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,
uint64_t* workspaceSize, aclOpExecutor** executor)
{
bool transB = false;
bool weightNz = true;
Comment thread
guanguan0308 marked this conversation as resolved.

aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16GetWorkspaceSize(x, weight1, weight2, expertId, scale1, scale2, probs, group,
maxOutputSize, transB, weightNz,
out, workspaceSize, executor);
return ret;
}

aclnnStatus aclnnDispatchFFNCombineBF16(void* workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
aclnnStatus ret = aclnnInnerDispatchFFNCombineBF16(workspace, workspaceSize, executor, stream);
return ret;
}
#ifdef __cplusplus
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

#ifndef OP_API_INC_DISPATCH_FFN_COMBINE_BF16_
#define OP_API_INC_DISPATCH_FFN_COMBINE_BF16_

#include <string>

#include "aclnn/aclnn_base.h"
#include "hccl/hccl.h"
#include "hccl/hccl_types.h"

#ifdef __cplusplus
extern "C" {
#endif

__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16GetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2,
const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2,
const aclTensor* probs,
const char* group, int64_t maxOutputSize,
const aclTensor* out,
uint64_t* workspaceSize, aclOpExecutor** executor);


__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineBF16(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
aclrtStream stream);

#ifdef __cplusplus
}
#endif

#endif // OP_API_INC_DISPATCH_FFN_COMBINE_BF16_
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
* \file dispatch_ffn_combine_bf16_def.cpp
* \brief
*/
#include "register/op_def_registry.h"

namespace ops {
class DispatchFFNCombineBF16 : public OpDef {
public:
explicit DispatchFFNCombineBF16(const char *name) : OpDef(name) {
this->Input("a")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("w1")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.IgnoreContiguous();
this->Input("w2")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.IgnoreContiguous();
this->Input("expertIdx")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale1")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("scale2")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("probs")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});

// 输出
this->Output("out")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16})
.Format({ ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});

this->Attr("group").AttrType(REQUIRED).String();
this->Attr("M").AttrType(OPTIONAL).Int();
this->Attr("transB").AttrType(OPTIONAL).Bool(false);
this->Attr("weightNz").AttrType(OPTIONAL).Bool(false);

OpAICoreConfig aicore_config;
aicore_config.DynamicCompileStaticFlag(true)
.DynamicFormatFlag(true)
.DynamicRankSupportFlag(true)
.DynamicShapeSupportFlag(true)
.NeedCheckSupportFlag(false)
.PrecisionReduceFlag(true)
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
.ExtendCfgInfo("jitCompile.flag", "static_false")
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
this->AICore().AddConfig("ascend910_93", aicore_config);
// this->AICore().AddConfig("ascend910b", aicore_config);
this->MC2().HcclGroup("group");
}
};

OP_ADD(DispatchFFNCombineBF16);
} // namespace ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
* \file dispatch_ffn_proto.cpp
* \brief
*/
#include <graph/utils/type_utils.h>
#include <register/op_impl_registry.h>
// #include "../../common/ophost/op_util.h"
// #include "../../common/ophost/hcom_topo_info.h"
// #include "log/ops_log.h"

using namespace ge;
namespace ops {
const size_t ATTR_GROUP = 0;
const size_t ATTR_RANK_SIZE = 1;
const size_t SUPPORT_DIM_SIZE = 2;

static ge::graphStatus InferShapeDispatchFFNCombineBF16(gert::InferShapeContext* context) {
return ge::GRAPH_SUCCESS;
}

static ge::graphStatus InferDataTypeDispatchFFNCombineBF16(gert::InferDataTypeContext* context) {
// auto d_type = context->GetInputDataType(0);
// context->SetOutputDataType(0, d_type);
return ge::GRAPH_SUCCESS;
}
Comment thread
guanguan0308 marked this conversation as resolved.

IMPL_OP_INFERSHAPE(DispatchFFNCombineBF16)
.InferShape(InferShapeDispatchFFNCombineBF16)
.InferDataType(InferDataTypeDispatchFFNCombineBF16);
} // namespace ops
Loading