Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#include "error_log.h"
#include "hcom_topo_info.h"
#include "register/op_def_registry.h"
#include "dispatch_ffn_combine_tiling.h"
#include "../op_kernel/dispatch_ffn_combine_tiling.h"
#include <vector>
#include <map>
#include <algorithm>
#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"
#include "../op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h"

using namespace AscendC;
using namespace ge;
Expand Down Expand Up @@ -278,8 +278,12 @@ static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *con
uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) +
info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 +
info.maxOutputSize * sizeof(float) * 2 +
std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
info.maxOutputSize * info.N * sizeof(int16_t) +
info.maxOutputSize * n2 * sizeof(int16_t) +
info.maxOutputSize * info.K * sizeof(int8_t) +
info.maxOutputSize * k2 * sizeof(int8_t) +
info.worldSize * sizeof(int32_t) * 16 +
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;

workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace);

Expand Down
20 changes: 1 addition & 19 deletions csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,11 @@ extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1
GM_ADDR c, GM_ADDR expertTokenNums, GM_ADDR workspaceGM, GM_ADDR tilingGM)
{
REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData);
if (TILING_KEY_IS(1000000)) {
KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000001)) {
KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, false> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000010)) {
if (TILING_KEY_IS(1000010)) {
KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, false, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
} else if (TILING_KEY_IS(1000011)) {
KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2);
GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM);
DispatchFFNCombine<int8_t, DTYPE_W1, DTYPE_OUT, true, true> op;
op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, expertTokenNums, workspaceGM, tilingGM);
op.Process();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ __aicore__ inline void DispatchFFNCombine<TemplateMMA2ACFunc>::Process()
using BlockEpilogue1 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy1, CType, PerTokenScaleType,
D1Type, TileElemWiseMuls, TileCopy1>;

using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant<ubStages>;
using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2<ubStages>;
using TileCopy2 = Epilogue::Tile::TileCopy<ArchTag, CType, ScaleType, PerTokenScaleType, D2Type>;
using BlockEpilogue2 = Epilogue::Block::BlockEpilogue<EpilogueDispatchPolicy2, CType,PerTokenScaleType,
D2Type, TileCopy2>;
Expand Down
Loading