From 698cf18311d0949cb4505119086837c9977f8754 Mon Sep 17 00:00:00 2001 From: xulei_ict Date: Sun, 1 Feb 2026 18:05:48 +0800 Subject: [PATCH] perf: Optimize DispatchFFNCombine performance Signed-off-by: xulei_ict --- .../op_host/dispatch_ffn_combine_tiling.cpp | 12 +- .../op_kernel/dispatch_ffn_combine.cpp | 20 +- .../op_kernel/dispatch_ffn_combine.h | 2 +- .../op_kernel/dispatch_ffn_combine_kernel.hpp | 648 ++++++++++++------ .../moe_init_routing_quant_v2/moe_v2_common.h | 1 + .../moe_v2_fullload_dynamic_quant.h | 84 +-- .../moe_v2_gather_dynamic_quant.h | 39 +- .../op_kernel/unpermute/moe_token_unpermute.h | 4 +- .../utils/block_epilogue_pertoken_v2.hpp | 243 +++++++ ...block_mmad_preload_async_fixpipe_quant.hpp | 53 +- .../op_kernel/utils/const_args.hpp | 4 +- .../utils/dispatch_policy_custom.hpp | 4 +- .../op_kernel/utils/hccl_shmem.hpp | 239 +++++-- 13 files changed, 939 insertions(+), 414 deletions(-) create mode 100644 csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp index 8b16f0b919d..d41d4d93ff3 100644 --- a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp @@ -17,11 +17,11 @@ #include "error_log.h" #include "hcom_topo_info.h" #include "register/op_def_registry.h" -#include "dispatch_ffn_combine_tiling.h" +#include "../op_kernel/dispatch_ffn_combine_tiling.h" #include #include #include -#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" +#include "../op_kernel/moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" using namespace AscendC; using namespace ge; @@ -278,8 +278,12 @@ static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *con uint64_t cocWorkspace = (info.M + 256 - 1) / 256 * 256 * info.topK *sizeof(int32_t) + info.worldSize * info.worldSize * info.expertPerRank * sizeof(int32_t) * 3 + info.maxOutputSize * sizeof(float) * 2 + - std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) + - std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t)); + info.maxOutputSize * info.N * sizeof(int16_t) + + info.maxOutputSize * n2 * sizeof(int16_t) + + info.maxOutputSize * info.K * sizeof(int8_t) + + info.maxOutputSize * k2 * sizeof(int8_t) + + info.worldSize * sizeof(int32_t) * 16 + + (info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16; workSpaces[0] = SYSTEM_NEED_WORKSPACE + std::max(cocWorkspace, initRoutingWorkspace); diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp index db3cf771fd0..1ccdfd6fc59 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.cpp @@ -23,29 +23,11 @@ extern "C" __global__ __aicore__ void dispatch_ffn_combine(GM_ADDR x, GM_ADDR w1 GM_ADDR c, GM_ADDR workspaceGM, GM_ADDR tilingGM) { REGISTER_TILING_DEFAULT(DispatchFFNCombineTilingData); - if (TILING_KEY_IS(1000000)) { - KERNEL_TASK_TYPE(1000000, KERNEL_TYPE_MIX_AIC_1_2); - GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); - DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); - op.Process(); - } else if (TILING_KEY_IS(1000001)) { - KERNEL_TASK_TYPE(1000001, KERNEL_TYPE_MIX_AIC_1_2); - GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); - DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); - op.Process(); - } else if (TILING_KEY_IS(1000010)) { + if (TILING_KEY_IS(1000010)) { KERNEL_TASK_TYPE(1000010, KERNEL_TYPE_MIX_AIC_1_2); GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); DispatchFFNCombine op; op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); op.Process(); - } else if (TILING_KEY_IS(1000011)) { - KERNEL_TASK_TYPE(1000011, KERNEL_TYPE_MIX_AIC_1_2); - GET_TILING_DATA_WITH_STRUCT(DispatchFFNCombineTilingData, tilingData, tilingGM); - DispatchFFNCombine op; - op.Init(x, w1, w2, expertId, scale1, scale2, probs, c, workspaceGM, tilingGM); - op.Process(); } } \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index 704809dcc4a..0bb329ae70b 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -232,7 +232,7 @@ __aicore__ inline void DispatchFFNCombine::Process() using BlockEpilogue1 = Epilogue::Block::BlockEpilogue; - using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequant; + using EpilogueDispatchPolicy2 = Epilogue::EpilogueAtlasA2PerTokenDequantV2; using TileCopy2 = Epilogue::Tile::TileCopy; using BlockEpilogue2 = Epilogue::Block::BlockEpilogue; diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 422595aae95..f2bef06db93 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -22,21 +22,38 @@ #include "catlass/matrix_coord.hpp" #include "catlass/epilogue/tile/tile_copy.hpp" -#include "utils/block_mmad_preload_async_fixpipe_quant.hpp" -#include "utils/copy_gm_to_l1_custom.hpp" -#include "utils/copy_l0c_to_gm_custom.hpp" -#include "utils/block_epilogue_pertoken_row.hpp" -#include "utils/block_epilogue_pertoken_swiglu.hpp" -#include "utils/hccl_shmem.hpp" -#include "utils/const_args.hpp" -#include "utils/layout3d.hpp" -#include "utils/get_tensor_addr.hpp" - -#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" -#include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" -#include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" -#include "unpermute/moe_token_unpermute.h" - +#ifndef HCCL_COMM + #include "block_mmad_preload_async_fixpipe_quant.hpp" + #include "copy_gm_to_l1_custom.hpp" + #include "copy_l0c_to_gm_custom.hpp" + #include "block_epilogue_pertoken_row.hpp" + #include "block_epilogue_pertoken_v2.hpp" + #include "block_epilogue_pertoken_swiglu.hpp" + #include "hccl_shmem.hpp" + #include "const_args.hpp" + #include "layout3d.hpp" + #include "tiling/moe_init_routing_quant_v2_tiling.h" + #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" + #include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" + #include "moe_token_unpermute.h" + #include "get_tensor_addr.hpp" + inline __gm__ struct OpSystemRunCfg g_opSystemRunCfg{Catlass::L2_OFFSET}; +#else + #include "utils/block_mmad_preload_async_fixpipe_quant.hpp" + #include "utils/copy_gm_to_l1_custom.hpp" + #include "utils/copy_l0c_to_gm_custom.hpp" + #include "utils/block_epilogue_pertoken_row.hpp" + #include "utils/block_epilogue_pertoken_v2.hpp" + #include "utils/block_epilogue_pertoken_swiglu.hpp" + #include "utils/hccl_shmem.hpp" + #include "utils/const_args.hpp" + #include "utils/layout3d.hpp" + #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" + #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" + #include "moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h" + #include "unpermute/moe_token_unpermute.h" + #include "utils/get_tensor_addr.hpp" +#endif using namespace AscendC; @@ -44,7 +61,6 @@ namespace Catlass::Gemm::Kernel { constexpr uint16_t SYNCFLAGC2V = 9; constexpr uint16_t SYNCFLAGV2C = 10; -constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; template < class BlockMmad_, @@ -103,6 +119,7 @@ class DispatchFFNCombineKernel { uint32_t rank; uint32_t rankSize; int32_t ubMoveNum; + GM_ADDR symmetricPtr; //-------------- GM_ADDR expertIdx; GM_ADDR moeInitRoutingQuantV2Scale; @@ -140,7 +157,8 @@ class DispatchFFNCombineKernel { GM_ADDR moeInitRoutingQuantV2Offset_, GM_ADDR expertTokensBeforeCapacity_, GM_ADDR probs_, GM_ADDR ptrWorkspace_, int32_t ubMoveNum_, - optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_ + optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_, + GM_ADDR symmetricPtr_ = nullptr ) : problemShape(problemShape_), EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_), rank(rank_), rankSize(rankSize_), topK(topK_), @@ -155,7 +173,7 @@ class DispatchFFNCombineKernel { expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_), moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_), expertTokensBeforeCapacity(expertTokensBeforeCapacity_), probs(probs_), - ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_), + ptrWorkspace(ptrWorkspace_), ubMoveNum(ubMoveNum_),symmetricPtr(symmetricPtr_), moeInitRoutingQuantV2TilingData(moeInitRoutingQuantV2TilingData_) { } @@ -192,9 +210,7 @@ class DispatchFFNCombineKernel { void operator()(Params const ¶ms) { GMM1(params); - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); - GMM2(params); } @@ -203,32 +219,26 @@ class DispatchFFNCombineKernel { CATLASS_DEVICE void operator()(Params const ¶ms) { - Dispatch(params); - AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); - - Combine(params); + DispatchAndCombine(params); } private: CATLASS_DEVICE void initBuffer(Params const ¶ms) { + #ifndef HCCL_COMM + shmem.initShmem(params.symmetricPtr, params.rank, params.rankSize); + #endif workspaceInfo = WorkspaceInfo(params); peermemInfo = PeermemInfo(params, shmem); - cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM)); - gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA)); gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC)); - gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken)); gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2)); - gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale)); gmPerTokenScale2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale2)); - tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); - - tokenPerExpertLayout = Layout3D( AlignUp(params.EP * params.expertPerRank, ALIGN_128), params.expertPerRank); + tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, 128), params.expertPerRank); + preSumBeforeRank.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrSumBeforeRank)); } template @@ -284,6 +294,51 @@ class DispatchFFNCombineKernel { AscendC::WaitFlag(EVENT_ID1); } + // Move tokens and scales together, then write them to different positions respectively + template + CATLASS_DEVICE void CopyGMToGMPerToken( + AscendC::GlobalTensor dst, + AscendC::GlobalTensor dstScale, + AscendC::GlobalTensor src, + int32_t rows, + int32_t hiddenSize + ) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + + constexpr int32_t BufferNum = 2; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + constexpr int tmpBufferOffset = 96 * 1024; // half of UB + AscendC::LocalTensor tmpBuffer2 = resource.ubBuf.template GetBufferByByte(tmpBufferOffset); + uint32_t copyInNum = hiddenSize + ALIGN_512; + // [ReduceScatter] 2. Pre Interface Sync + int pingpongId = 0; + for (uint32_t processIndex = 0; processIndex < rows; ++processIndex) { + AscendC::TEventID EVENT_ID = pingpongId == 0 ? EVENT_ID0 : EVENT_ID1; + AscendC::LocalTensor buf = pingpongId == 0 ? tmpBuffer1 : tmpBuffer2; + AscendC::LocalTensor bufScale = buf[hiddenSize].template ReinterpretCast(); + auto inputOffset = processIndex * copyInNum; + auto outputOffset = processIndex * hiddenSize; + // [ReduceScatter] 2. Pre Interface Sync + AscendC::WaitFlag(EVENT_ID); + // [ReduceScatter] 3. Start shmem_mte_get_mem_nbi + AscendC::DataCopy(buf, src[inputOffset], copyInNum); + AscendC::SetFlag(EVENT_ID); + AscendC::WaitFlag(EVENT_ID); + AscendC::DataCopy(dst[outputOffset], buf, hiddenSize); + AscendC::DataCopyPad(dstScale[processIndex], bufScale, {1, 4, 0, 0, 0}); + + // [ReduceScatter] 4. Post Interface Sync + AscendC::SetFlag(EVENT_ID); + pingpongId = (pingpongId + 1) % BufferNum; + } + // [ReduceScatter] 4. Post Interface Sync + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } + CATLASS_DEVICE void GetCumsumForMMAIV(AscendC::GlobalTensor & tokenPerExpert, AscendC::GlobalTensor & result, uint32_t expertPerRank, uint32_t rankId, uint32_t EP) { @@ -295,7 +350,7 @@ class DispatchFFNCombineKernel { AscendC::DataCopyPad( tmpBuffer1, tokenPerExpert[rankId * expertPerRank], - {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, ALIGN_128) - expertPerRank) * sizeof(int32_t)), 0}, + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16((AlignUp(EP * expertPerRank, 128) - expertPerRank) * sizeof(int32_t)), 0}, {} ); @@ -322,6 +377,8 @@ class DispatchFFNCombineKernel { icache_preload(8); BlockScheduler blockScheduler; BlockMmad blockMmad(resource); + float aivFinishGroups = 0.0f; + __gm__ float* aivFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE; int64_t gmGroupOffsetA = 0; int64_t gmGroupOffsetB = 0; @@ -330,20 +387,10 @@ class DispatchFFNCombineKernel { uint32_t syncGroupIdx = 0; int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; - + uint16_t syncgmmIdx = 0; AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); // Wait for AIV to finish cumsum for matmul - syncgmmIdx++; - - constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; - __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; - __gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK]; - - int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; - for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { - weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr(loopIdx, params.ptrB1)); - scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale1)); - } + syncgmmIdx ++; AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -356,11 +403,9 @@ class DispatchFFNCombineKernel { AscendC::GlobalTensor gmB1; AscendC::GlobalTensor gmS; int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; - gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx])); - gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx])); - + gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB1))); + gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale1))); AscendC::PipeBarrier(); - if (currentM <= L1TileShape::M) { gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -380,6 +425,7 @@ class DispatchFFNCombineKernel { AscendC::CrossCoreWaitFlag<0x2>(syncgmmIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmmIdx ++; } + // Compute block location GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); @@ -407,6 +453,7 @@ class DispatchFFNCombineKernel { if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } + // Synchronization signal: GMM1 notifies SwiGLU [1] blockMmad.Finalize(syncLoopIdx, SYNCFLAGC2V); } @@ -427,6 +474,7 @@ class DispatchFFNCombineKernel { if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } + // Synchronization signal: GMM1 notifies SwiGLU [2] blockMmad.Finalize(syncLoopIdx + 1, SYNCFLAGC2V); } @@ -435,7 +483,7 @@ class DispatchFFNCombineKernel { icache_preload(8); BlockScheduler blockScheduler; BlockMmad blockMmad(resource); - + uint32_t n2 = params.problemShape.k(); uint32_t k2 = params.problemShape.n() / 2; @@ -455,14 +503,6 @@ class DispatchFFNCombineKernel { lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } - constexpr uint32_t MAX_EXPERTS_PER_RANK = 8; - __gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK]; - __gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK]; - int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; - for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { - weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(loopIdx, params.ptrB2)); - scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale2)); - } AscendC::PipeBarrier(); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { @@ -474,11 +514,10 @@ class DispatchFFNCombineKernel { } AscendC::GlobalTensor gmB2; AscendC::GlobalTensor gmS2; - AscendC::PipeBarrier(); int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; - gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx])); - gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx])); - + gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(arrayGroupIdx, params.ptrB2))); + gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(arrayGroupIdx, params.ptrScale2))); + AscendC::PipeBarrier(); if (currentM <= L1TileShape::M) { gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -498,6 +537,7 @@ class DispatchFFNCombineKernel { if (params.expertPerRank > lastDequantExpertNum && groupIdx + 1 == params.expertPerRank - lastDequantExpertNum) { AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGV2C); } + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { if (loopIdx + coreNum >= coreLoops) { syncLoopIdx = groupIdx; @@ -518,12 +558,12 @@ class DispatchFFNCombineKernel { int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * n2 : 0); // One scale group per expert if (currentM > 0) { blockMmad( - gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, - gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, - gmC2[gmGroupOffsetC + gmOffsetC], layoutC, - gmS2[gmOffsetS], layoutScale, - actualBlockShape, syncLoopIdx, 0 - ); + gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, + gmB2[gmGroupOffsetB + gmOffsetB], layoutB2, + gmC2[gmGroupOffsetC + gmOffsetC], layoutC, + gmS2[gmOffsetS], layoutScale, + actualBlockShape, syncLoopIdx, 0 + ); } } preCurrentmSum += currentM; @@ -534,34 +574,34 @@ class DispatchFFNCombineKernel { gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } - if constexpr (BlockMmad::DispatchPolicy::ASYNC) { blockMmad.SynchronizeBlock(); } - blockMmad.Finalize(params.expertPerRank - 1, 0); } - CATLASS_DEVICE - void ResetTokenPerExpert(AscendC::GlobalTensor & tokenPerExpert, int32_t num) - { - if (coreIdx != coreNum - 1) { - return; - } + + CATLASS_DEVICE + void InitArithProgress(Params const ¶ms) { + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); - AscendC::Duplicate(tmp, 0, num); + AscendC::Duplicate(tmpBuffer1, 0.0f, (params.EP + 1) * FLAGSTRIDE); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - AscendC::DataCopy(tokenPerExpert, tmp, num); + + AscendC::GlobalTensor flagGlobalBase; + flagGlobalBase.SetGlobalBuffer(workspaceInfo.ptrSoftFlagBase); + AscendC::DataCopy(flagGlobalBase, tmpBuffer1, (params.EP + 1) * FLAGSTRIDE); } + CATLASS_DEVICE - void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){ + void CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(Params const ¶ms, int64_t localTokenPerExpertOffset){ + uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, 128); AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); - uint32_t numPerCore = AlignUp(params.EP * params.expertPerRank, ALIGN_128); + AscendC::LocalTensor prevSumBuf = tmpBuffer[numPerCore]; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { if (dstEpIdx == params.rank) { continue; @@ -578,11 +618,11 @@ class DispatchFFNCombineKernel { using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; CopyGmToUb copyGmToUb; CopyUbToGm copyUbToGm; - + AscendC::WaitFlag(EVENT_ID0); - - copyGmToUb(tmpBuffer, srcAddress[0], - layout::RowMajor{ 1, numPerCore}, + + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, numPerCore}, layout::RowMajor{1, numPerCore}); AscendC::SetFlag(EVENT_ID0); @@ -590,36 +630,126 @@ class DispatchFFNCombineKernel { AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - copyUbToGm(dstAddress[0], tmpBuffer, - layout::RowMajor{ 1, numPerCore}, + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, numPerCore}, layout::RowMajor{1, numPerCore}); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); } for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - if (dstEpIdx == params.rank) { - continue; + if (dstEpIdx != params.rank) { + int32_t intPer512 = CACHE_LINE / sizeof(int); + for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, 128); checkIdx += intPer512) { + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); + gm_signal_wait_until_ne(sync_check, 0); + } + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + } else { + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::PipeBarrier(); + int32_t prevSum = 0; + int32_t j = 0; + for (int32_t i = 0; i < (params.rank + 1) * params.expertPerRank; i++) { + if (i >= params.rank * params.expertPerRank) { + prevSumBuf(j) = prevSum; + j++; + } + prevSum += tmpBuffer(i); } - int32_t intPer512 = CACHE_LINE / sizeof(int); - for(int32_t checkIdx = 0; checkIdx < AlignUp(params.EP * params.expertPerRank, ALIGN_128); checkIdx += intPer512) { - __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); - gm_signal_wait_until_ne(sync_check, 0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopyPad(preSumBeforeRank[dstEpIdx * params.expertPerRank], prevSumBuf, + AscendC::DataCopyParams{1, static_cast(params.expertPerRank * sizeof(int32_t)), 0, 0}); + } + + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void ResetTokenPerExpert(int32_t num) + { + if (coreIdx != coreNum - 1) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmp, 0, num); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert, tmp, num); + } + + CATLASS_DEVICE + void UpdateAicFlags(const Params ¶ms) + { + float flagBase = 1.0f * params.expertPerRank; + __gm__ float* aicFinishPtr = workspaceInfo.ptrSoftFlagBase + params.EP * FLAGSTRIDE; + float flag = 0.0f; + float lastflag = -1.0f; + AscendC::LocalTensor tmpBuffer1 = resource.ubBuf.template GetBufferByByte(0); + __gm__ float* flagPtr = workspaceInfo.ptrSoftFlagBase; + AscendC::GlobalTensor flagGM; + flagGM.SetGlobalBuffer(flagPtr); + int32_t flagBufferSize = max(4, params.EP) * FLAGSTRIDE; + AscendC::LocalTensor dstValueBuffer = resource.ubBuf.template GetBufferByByte(flagBufferSize); + AscendC::LocalTensor sharedTmpBuffer = resource.ubBuf.template GetBufferByByte((flagBufferSize + 64)); + uint64_t mask[1] = {0}; + uint32_t repeatNum = (flagBufferSize / (4 * FLAGSTRIDE)); + for (int32_t i = 0; i < 4; i ++) { + if (i < params.EP) { + mask[0] |= 1ull * (1ull << (i * 16)); } - AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + while (flag < flagBase) { + flag = flagBase; + AscendC::DataCopy(tmpBuffer1, flagGM, params.EP * FLAGSTRIDE); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); + + AscendC::ReduceMin(dstValueBuffer, tmpBuffer1, sharedTmpBuffer, mask, repeatNum, 8, false); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + flag = min(flag, dstValueBuffer.GetValue(0)); + + if (flag > lastflag) { + *aicFinishPtr = flag; + gm_dcci(aicFinishPtr); + lastflag = flag; + } } - AscendC::SyncAll(); } CATLASS_DEVICE - void Dispatch(Params const ¶ms) { + void CombineSetFlag() { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + } + + + CATLASS_DEVICE + void DispatchAndCombine(Params const ¶ms) { icache_preload(8); int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t); GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // Place the entire communication matrix in peermem @@ -634,33 +764,32 @@ class DispatchFFNCombineKernel { ¶ms.moeInitRoutingQuantV2TilingData, params.initRoutingQuantTilingKey); AscendC::SyncAll(); - CrossRankSyncAndlocalTokenPerExpertAllGather(params, localTokenPerExpertOffset); + + CrossRankSyncAndlocalTokenPerExpertAllGatherAndGetSumPreRankV2(params, localTokenPerExpertOffset); + if (coreIdx == 0) { GetCumsumForMMAIV(tokenPerExpert, cumsumMM, params.expertPerRank, params.rank, params.EP); } - AscendC::SyncAll(); - uint16_t syncgmm1Idx = 0; - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); - syncgmm1Idx++; - + uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; int32_t groupIdxDeq = 0; + int32_t prevSum = 0; if (coreIdx < params.EP) { - for (int32_t i = 0; i < params.rank * params.expertPerRank; i++) { - prevSumBeforeRank += tokenPerExpert(tokenPerExpertLayout(coreIdx, 0, i)); - } - m_prevSumBeforeRank = prevSumBeforeRank; + prevSum = preSumBeforeRank(coreIdx * params.expertPerRank); } - int prevSum = prevSumBeforeRank; - uint32_t prevGroupSum1 = 0; + AscendC::SyncAll(); + uint16_t syncgmm1Idx = 0; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx++; + + uint32_t prevGroupSum1 = 0, dequantSum1 = 0, dequantSum2 = 0; uint32_t dequantSum = 0; - int32_t syncLoopIdx = -1; - uint32_t n = params.problemShape.n(); - BlockEpilogue1 blockEpilogue(resource, n); + + icache_preload(8); for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { // The ith core reads data from the ith rank's peermem - groupIdxDeq = groupIdx - 2; + uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1; if (rowStart < params.maxOutputSize) { @@ -673,137 +802,199 @@ class DispatchFFNCombineKernel { GM_ADDR otherRankPtr = shmem(0, dstEpIdx); AscendC::GlobalTensor gmRemoteA; gmRemoteA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA*>(otherRankPtr + peermemInfo.offsetA)); - AscendC::GlobalTensor gmRemotePerTokenScale; - gmRemotePerTokenScale.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale*>(otherRankPtr + peermemInfo.offsetPeerPerTokenScale)); + MatrixCoord offsetA{rowStart, 0}; - MatrixCoord shapeA{rows, params.problemShape.k()}; MatrixCoord offsetPeer{rowSrc, 0}; int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); - int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer); + int64_t gmOffsetPeer = rowSrc * (params.problemShape.k() + ALIGN_512); // Communication data - CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum); - // Communication scale - CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows); + CopyGMToGMPerToken(gmA[gmOffsetA], gmPerTokenScale1[rowStart], gmRemoteA[gmOffsetPeer], rows, params.problemShape.k()); } - } - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1) { - syncLoopIdx++; - AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); } AscendC::SyncAll(); - AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); // V notifies C that the current communication round is complete - syncgmm1Idx++; - - if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) { - uint32_t rowStartThisCore = 0; - MatrixCoord offsetC{0U, 0}; - uint32_t dequantLen = prevGroupSum1 - dequantSum; - if (dequantLen >= params.maxOutputSize) { - dequantLen = dequantLen - params.maxOutputSize; + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); + syncgmm1Idx ++; + + prevGroupSum1 += currentM; + + // Token count and truncation logic for the first SwiGLU operation + if (groupIdx + 1 <= params.epilogueGranularity) { + if (dequantSum1 + currentM <= params.maxOutputSize) { + dequantSum1 += currentM; + } else if (dequantSum1 < params.maxOutputSize) { + dequantSum1 = params.maxOutputSize; } - - MatrixCoord shapeC{dequantLen, params.problemShape.n()}; - LayoutC layoutC{dequantLen, params.problemShape.n()}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); } - prevGroupSum1 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - dequantSum += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); - if (groupIdx + 1 == params.epilogueGranularity && groupIdx < params.expertPerRank - 1) { - dequantSum = 0; + + // Token count and truncation logic for the second SwiGLU operation + if (groupIdx + 1 > params.epilogueGranularity && dequantSum1 < params.maxOutputSize) { + if (dequantSum1 + dequantSum2 + currentM <= params.maxOutputSize) { + dequantSum2 += currentM; + } else if (dequantSum1 + dequantSum2 < params.maxOutputSize) { + dequantSum2 += params.maxOutputSize - dequantSum1 - dequantSum2; + } } } - syncLoopIdx ++; + + uint32_t n2 = params.problemShape.k(); + + typename BlockEpilogue2::Params epilogueParams{ + static_cast(params.EP), + static_cast(params.expertPerRank), + static_cast(params.rank), + reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert), + params.layoutD2, + static_cast(n2), + static_cast(L1TileShape::N), + shmem, + static_cast(peermemInfo.offsetD) + }; + + uint32_t n = params.problemShape.n(); + BlockEpilogue2 blockEpilogue2(resource, epilogueParams); + BlockEpilogue1 blockEpilogue1(resource, n); + + // Synchronous wait: SwiGLU waits for GMM1 [1] AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); AscendC::SyncAll(); - - uint32_t lastDequantExpertNum = params.expertPerRank; - if (params.epilogueGranularity < params.expertPerRank) { - lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; + if (dequantSum1 > 0) { + uint32_t rowStartThisCore = 0; + MatrixCoord offsetC{0U, 0}; + MatrixCoord shapeC{dequantSum1, params.problemShape.n()}; + LayoutC layoutC{dequantSum1, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], params.epilogueCoreNum); } - if (lastDequantExpertNum < params.expertPerRank) { + AscendC::SyncAll(); + // Synchronization signal: SwiGLU notifies GMM2 [1] + AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); + + if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0)) { + // Synchronous wait: SwiGLU waits for GMM1 [2] + AscendC::CrossCoreWaitFlag<0x2>(SYNCFLAGC2V); + AscendC::SyncAll(); + if (dequantSum2 > 0) { + uint32_t rowStartThisCore = dequantSum1; + MatrixCoord offsetC{rowStartThisCore, 0}; + uint32_t dequantLen = dequantSum2; + MatrixCoord shapeC{dequantLen, params.problemShape.n()}; + LayoutC layoutC{dequantLen, params.problemShape.n()}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); + blockEpilogue1(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + } + AscendC::SyncAll(); + // Synchronization signal: SwiGLU notifies GMM2 [2] AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(SYNCFLAGV2C); } - if (prevGroupSum1 - dequantSum < params.maxOutputSize) { - uint32_t rowStartThisCore = prevGroupSum1 - dequantSum;; - MatrixCoord offsetC{rowStartThisCore, 0}; - uint32_t dequantLen = dequantSum; - if (prevGroupSum1 >= params.maxOutputSize) { - dequantLen = dequantSum - (prevGroupSum1 - params.maxOutputSize); + + blockEpilogue1.Finalize(); + + + CombineSetFlag(); + + CombineV2(params, blockEpilogue2); + + AscendC::SyncAll(); + #ifndef __CROSSRANKSYNCANDALLGATHERV1__ + ResetTokenPerExpert(params.EP * AlignUp(params.EP * params.expertPerRank, 128)); + #endif + shmem.InitStatusTargetSum(); + if (get_subblockid() == 0) { + AscendC::LocalTensor ctrBuffer = resource.ubBuf.template GetBufferByByte(0); + shmem.CrossRankSyncV2Set(ctrBuffer); + } else { + uint32_t uboffset = 0; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t sendRankNum_ = params.EP / aicCoreNum; + uint32_t remainderRankNum = params.EP % aicCoreNum; + if (aicCoreIdx < remainderRankNum) { + sendRankNum_++; } - MatrixCoord shapeC{dequantLen, params.problemShape.n()}; - LayoutC layoutC{dequantLen, params.problemShape.n()}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetD = params.layoutD1.GetOffset(offsetC); - blockEpilogue(gmC[gmOffsetC], shapeC, gmPerTokenScale1[rowStartThisCore], gmPermutedToken[gmOffsetD], gmPerTokenScale2[rowStartThisCore], coreNum); + AscendC::LocalTensor statusTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sendRankNum_ * UB_ALIGN; + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += params.EP * sizeof(float); + AscendC::LocalTensor gatherTmpTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sizeof(uint32_t); + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); + uboffset += sizeof(float); + shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor); + MoeTokenUnpermuteTilingData tilingData; + MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2); + KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; + kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); + kernelMoeTokenUnpermuteOp.Process(); } - blockEpilogue.Finalize(); + } CATLASS_DEVICE - void Combine(Params const ¶ms) { - int32_t prevSumBeforeRank = 0; - if (coreIdx < params.EP) { - prevSumBeforeRank = m_prevSumBeforeRank; - } - - int prevSum = prevSumBeforeRank; + void CombineV2(Params const ¶ms, BlockEpilogue2 & blockEpilogue) { + BlockScheduler blockScheduler; + int32_t syncLoopIdx = 0; + uint32_t startCoreIdx = 0; + uint32_t aicCoreNum = coreNum / 2; + uint32_t aicCoreIdx = get_block_idx(); + uint32_t aivSubCoreIdx = get_subblockid(); + uint32_t preSrcExpertSum = 0; uint32_t n2 = params.problemShape.k(); uint32_t k2 = params.problemShape.n() / 2; - - // TODO compute the cumsum of tokenPerExpert - typename BlockEpilogue2::Params epilogueParams{ - static_cast(params.EP), - static_cast(params.expertPerRank), - reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace), - static_cast(n2) - }; - BlockEpilogue2 blockEpilogue(resource, epilogueParams); - int32_t prevGroupSum2 = 0; + icache_preload(8); for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { - AscendC::CrossCoreWaitFlag<0x2>(groupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT); - AscendC::SyncAll(); + uint32_t currentExpertM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + if (preSrcExpertSum >= params.maxOutputSize) { + currentExpertM = 0; + } else if (preSrcExpertSum + currentExpertM > params.maxOutputSize) { + currentExpertM = params.maxOutputSize - preSrcExpertSum; + } + GemmCoord inGroupProblemShape{currentExpertM, n2, k2}; // M N K + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + uint32_t startLoopIdx = ((aicCoreIdx < startCoreIdx) ? (aicCoreIdx + aicCoreNum) : aicCoreIdx) - startCoreIdx; - for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { - __gm__ void* dstPeermemPtr = shmem(peermemInfo.offsetD, dstEpIdx); - AscendC::GlobalTensor gmRemotePeer; - gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD2*>(dstPeermemPtr)); - uint32_t srcRowOffset = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum2; - if (srcRowOffset < params.maxOutputSize) { - uint32_t dataRows = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); - if (srcRowOffset + dataRows > params.maxOutputSize) { - dataRows = params.maxOutputSize - srcRowOffset; - } - uint32_t dstRowOffset = prevSum; - prevSum += dataRows; - MatrixCoord offsetC{srcRowOffset, 0}; - MatrixCoord offsetPeer{dstRowOffset, 0}; - MatrixCoord shapeC{dataRows, n2}; - int64_t gmOffsetC = params.layoutD2.GetOffset(offsetC); - int64_t gmOffsetPeer = params.layoutD2.GetOffset(offsetPeer); - if constexpr (std::is_same_v) { - blockEpilogue(gmC2[gmOffsetC], shapeC, gmPerTokenScale2[srcRowOffset], gmRemotePeer[gmOffsetPeer]); - } else { - blockEpilogue(gmC2[gmOffsetC], shapeC, gmRemotePeer[gmOffsetPeer]); + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicCoreNum) { + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + int32_t m0 = 16; + // Block count, the shape of each block is (m0, actualBlockShape.n()) + int32_t m_rows = (actualBlockShape.m() + m0 - 1) / m0; + int32_t aiv_m_rows = m_rows / 2; + if (aivSubCoreIdx == 1 && aiv_m_rows * 2 < m_rows) { + aiv_m_rows += 1; + } + uint32_t m_offset = blockCoord.m() * L1TileShape::M;//blockOffset + if(aivSubCoreIdx == 1) { + m_offset += (m_rows / 2) * m0; + } + + + for (;syncLoopIdx <= groupIdx; syncLoopIdx ++) { + int32_t flag_id = syncLoopIdx / CROSS_CORE_FLAG_MAX_SET_COUNT; + AscendC::CrossCoreWaitFlag<0x2>(flag_id); + } + + for (int32_t cur_row = 0; cur_row < aiv_m_rows; cur_row ++) { + GemmCoord realTileCoord{m_offset, blockCoord.n() * L1TileShape::N, 1}; + uint32_t actualm = m0; + if(aivSubCoreIdx == 1 && cur_row == aiv_m_rows - 1){ + actualm = actualBlockShape.m() - (m_rows / 2) * m0 - cur_row * m0; } + GemmCoord realTileShape{actualm, actualBlockShape.n(), 1}; + blockEpilogue(gmC2, gmPerTokenScale2, realTileCoord, realTileShape, groupIdx, preSrcExpertSum, preSumBeforeRank); + m_offset += m0; } } - prevGroupSum2 += cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); + preSrcExpertSum += currentExpertM; + startCoreIdx = (startCoreIdx + coreLoops) % aicCoreNum; } blockEpilogue.Finalize(); - AscendC::SyncAll(); - ResetTokenPerExpert(tokenPerExpert, params.EP * AlignUp(params.EP * params.expertPerRank, ALIGN_128)); - shmem.CrossRankSync(); - MoeTokenUnpermuteTilingData tilingData; - MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); - KernelMoeTokenUnpermute kernelMoeTokenUnpermuteOp; - - kernelMoeTokenUnpermuteOp.Init(shmem() + peermemInfo.offsetD, workspaceInfo.expandedRowIdx, params.probs, reinterpret_cast(params.ptrOutput), &tilingData); - kernelMoeTokenUnpermuteOp.Process(); } + private: struct WorkspaceInfo { GM_ADDR ptrA; @@ -815,6 +1006,9 @@ class DispatchFFNCombineKernel { GM_ADDR ptrPerTokenScale2; GM_ADDR expandedRowIdx; GM_ADDR ptrTokenPerExpert; + GM_ADDR ptrSumBeforeRank; + __gm__ float* ptrSoftFlagBase; + CATLASS_DEVICE WorkspaceInfo(){} @@ -842,15 +1036,21 @@ class DispatchFFNCombineKernel { workspaceOffset += (params.EP * params.EP * params.expertPerRank) * sizeof(int32_t); ptrC = params.ptrWorkspace + workspaceOffset; - ptrC2 = ptrC; - workspaceOffset += max(params.maxOutputSize * params.problemShape.n() * sizeof(ElementC), - params.maxOutputSize * n2 * sizeof(ElementC)); + workspaceOffset += params.maxOutputSize * params.problemShape.n() * sizeof(ElementC); + ptrC2 = params.ptrWorkspace + workspaceOffset; + workspaceOffset += params.maxOutputSize * n2 * sizeof(ElementC); ptrA = params.ptrWorkspace + workspaceOffset; - ptrPermutedToken = ptrA; - workspaceOffset += max(params.maxOutputSize * params.problemShape.k() * sizeof(ElementA), - params.maxOutputSize * k2 * sizeof(ElementA)); + + workspaceOffset += params.maxOutputSize * params.problemShape.k() * sizeof(ElementA); + ptrPermutedToken = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.maxOutputSize * k2 * sizeof(ElementA); + ptrSumBeforeRank = params.ptrWorkspace + workspaceOffset; + + workspaceOffset += params.EP * sizeof(int32_t) * FLAGSTRIDE; + ptrSoftFlagBase = reinterpret_cast<__gm__ float*>(params.ptrWorkspace + workspaceOffset); } }; @@ -877,12 +1077,9 @@ class DispatchFFNCombineKernel { uint32_t coreIdx; uint32_t coreNum; - Params params; WorkspaceInfo workspaceInfo; PeermemInfo peermemInfo; - int64_t m_prevSumBeforeRank; - AscendC::GlobalTensor gmA; AscendC::GlobalTensor gmC; @@ -894,10 +1091,11 @@ class DispatchFFNCombineKernel { AscendC::GlobalTensor tokenPerExpert; AscendC::GlobalTensor cumsumMM; + AscendC::GlobalTensor preSumBeforeRank; Layout3D tokenPerExpertLayout; HcclShmem shmem; }; } // namespace Catlass::Gemm::Kernel -#endif // DISPATH_FFN_COMBINE_KERNEL_HPP +#endif // DISPATCH_FFN_COMBINE_KERNEL_HPP diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h index c190033ade8..d29286b1bb1 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_common.h @@ -44,6 +44,7 @@ constexpr int64_t EXERPT_TOKENS_COUNT = 2; constexpr int64_t EXERPT_TOKENS_CUMSUM = 1; constexpr int64_t EXERPT_TOKENS_NONE = 0; constexpr int64_t EXERPT_TOKENS_BEFORE_CAPACITY = 1; +constexpr int64_t ALIGN_512 = 512; const __gm__ int32_t assist[256] = { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h index 824e9af303a..9d77c5e2e44 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_fullload_dynamic_quant.h @@ -35,7 +35,6 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { __aicore__ inline void CopyOutIdx(); __aicore__ inline void CopyOutEmpty(); __aicore__ inline void CopyOutXQuant1H(); - __aicore__ inline void CopyOutXQuantEH(); __aicore__ inline void ComputeExpertTokenCountOrCumsum(); __aicore__ inline void Compute(LocalTensor& smoothLocal); @@ -49,6 +48,7 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { int64_t k_; int64_t n_; int64_t cols_; + int64_t cols_scale_; int64_t activateRows_; int64_t expertNum; int64_t expertCapacity; @@ -63,12 +63,10 @@ class MoeV2FullLoadDynamicQuant : public MoeV2SortBase { TQue smoothInQueue; TQue calcQueue; TQue inputXOutQueue; - TQue scaleOutQueue; GlobalTensor xGm_; GlobalTensor expertIdxGm_; GlobalTensor quantSmoothGm; - GlobalTensor dynamicQuantScaleGm; GlobalTensor expandedXGm_; GlobalTensor expandedRowIdxGm_; @@ -225,7 +223,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& LocalTensor tempLocal = calcQueue.AllocTensor(); LocalTensor outLocal = inputXOutQueue.AllocTensor(); - LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = outLocal[this->cols_].template ReinterpretCast(); if constexpr (!IsSameType::value) { Cast(inLocal, inLocal.ReinterpretCast()[colsAlign], RoundMode::CAST_NONE, this->cols_); @@ -259,7 +257,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Compute(LocalTensor& calcQueue.FreeTensor(tempLocal); inputXOutQueue.EnQue(outLocal); - scaleOutQueue.EnQue(dynamicQuantLocal); } template @@ -275,7 +272,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; - DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams intriParams{1, static_cast((this->cols_ + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0}; LocalTensor smoothLocal; if (smoothType == 1) { @@ -295,7 +292,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { xCopyInQueue_.EnQue(xLocal); Compute(smoothLocal); - LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); LocalTensor outLocal = inputXOutQueue.DeQue(); while (curRowsStart <= curRowsEnd && curRowsStart / this->k_ == row) { int32_t outIndex = expandedRowIdx.GetValue(curRowsStart); @@ -303,74 +299,13 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuant1H() { if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows_)) { continue; } - DataCopyPad(expandedXGm_[outIndex * cols_], outLocal, intriParams); - DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + DataCopyPad(expandedXGm_[outIndex * this->cols_scale_], outLocal, intriParams); } xCopyInQueue_.FreeTensor(xLocal); inputXOutQueue.FreeTensor(outLocal); - scaleOutQueue.FreeTensor(quantScaleLocal); } - - if (smoothType == 1) { - smoothInQueue.FreeTensor(smoothLocal); - } - expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); -} - -template -__aicore__ inline void MoeV2FullLoadDynamicQuant::CopyOutXQuantEH() { - LocalTensor expandedRowIdx = expandedRowIdxCopyOutQueue_.DeQue(); expandedRowIdxCopyOutQueue_.FreeTensor(expandedRowIdx); - - Muls(expandDstToSrcRowLocal.ReinterpretCast(), expandDstToSrcRowLocal.ReinterpretCast(), (float)-1, - this->totalLength); - pipe_barrier(PIPE_V); - LocalTensor sortedRowIdx = expandDstToSrcRowLocal.ReinterpretCast(); - Cast(sortedRowIdx, expandDstToSrcRowLocal.ReinterpretCast(), RoundMode::CAST_ROUND, this->totalLength); - - int64_t curRowsStart = this->blockIdx_ * this->perCoreRows_; - int64_t curRowsEnd = curRowsStart + this->coreRows_ - 1; - - DataCopyExtParams dataXCopyParams{1, static_cast(this->cols_ * sizeof(T)), 0, 0, 0}; - DataCopyExtParams smoothCopyParams{1, static_cast(this->cols_ * sizeof(float)), 0, 0, 0}; - DataCopyExtParams intriParams{1, static_cast(this->cols_ * sizeof(int8_t)), 0, 0, 0}; - - for (int64_t row = curRowsStart; row <= curRowsEnd; row++) { - if (this->dropPadMode == DROPLESS_MODE && row >= this->activateRows_) { - break; - } - int32_t srcIdx = sortedRowIdx.GetValue(row); - int32_t expertIdx = expandedExpertIdxLocal.GetValue(row); - - LocalTensor inLocal = xCopyInQueue_.AllocTensor(); - LocalTensor smoothLocal = smoothInQueue.AllocTensor(); - if constexpr (IsSameType::value) { - DataCopyPad(inLocal, xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); - } else { - DataCopyPad(inLocal[colsAlign], xGm_[srcIdx / this->k_ * this->cols_], dataXCopyParams, {false, 0, 0, 0}); - } - DataCopyPad(smoothLocal, quantSmoothGm[expertIdx * this->cols_], smoothCopyParams, {false, 0, 0, 0}); - xCopyInQueue_.EnQue(inLocal); - smoothInQueue.EnQue(smoothLocal); - smoothLocal = smoothInQueue.DeQue(); - - Compute(smoothLocal); - - LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); - DataCopyPad(dynamicQuantScaleGm[row], quantScaleLocal, {1, 4, 0, 0, 0}); - - LocalTensor outLocal = inputXOutQueue.DeQue(); - DataCopyPad(expandedXGm_[row * this->cols_], outLocal, intriParams); - - xCopyInQueue_.FreeTensor(inLocal); - smoothInQueue.FreeTensor(smoothLocal); - inputXOutQueue.FreeTensor(outLocal); - scaleOutQueue.FreeTensor(quantScaleLocal); - } - - expandDstToSrcRowQueue_.FreeTensor(expandDstToSrcRowLocal); - expandedExpertIdxCopyOutQueue_.FreeTensor(expandedExpertIdxLocal); } template @@ -386,6 +321,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp this->k_ = tilingData->k; this->n_ = tilingData->n; this->cols_ = tilingData->cols; + this->cols_scale_ = this->cols_ + ALIGN_512; this->needCoreNum_ = this->gatherOutTilingData_->needCoreNum; this->perCoreRows_ = this->gatherOutTilingData_->perCoreRows; this->activateRows_ = this->gatherOutTilingData_->activateRows; @@ -416,7 +352,6 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp Align(this->expertNum, sizeof(int32_t))); } quantSmoothGm.SetGlobalBuffer((__gm__ float*)quantSmooth); - dynamicQuantScaleGm.SetGlobalBuffer((__gm__ float*)dynamicQuantScale); int64_t kvFactor = 2; int64_t buffSize = this->sortNum_ * sizeof(int32_t); @@ -440,8 +375,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Init(GM_ADDR x, GM_ADDR exp } pipe->InitBuffer(smoothInQueue, 1, AlignBytes(this->cols_, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->cols_, sizeof(float))); - pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_, sizeof(int8_t))); - pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); + pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->cols_scale_, sizeof(int8_t))); } template @@ -457,11 +391,7 @@ __aicore__ inline void MoeV2FullLoadDynamicQuant::Process() { } else { CopyOutEmpty(); } - if (smoothType == 2) { - CopyOutXQuantEH(); - } else { - CopyOutXQuant1H(); - } + CopyOutXQuant1H(); } } } // namespace MoeInitRoutingQuantV2 diff --git a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h index 924e854891c..64852f31315 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h +++ b/csrc/dispatch_ffn_combine/op_kernel/moe_init_routing_quant_v2/moe_v2_gather_dynamic_quant.h @@ -66,6 +66,7 @@ class MoeV2GatherDynamicQuant { int64_t needCoreNum; int64_t blockIdx; int64_t cols; + int64_t cols_scale_; int64_t n; int64_t k; int64_t totalLength; @@ -117,7 +118,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& s LocalTensor tempLocal = calcQueue.AllocTensor(); LocalTensor outLocal = inputXOutQueue.AllocTensor(); - LocalTensor dynamicQuantLocal = scaleOutQueue.AllocTensor(); + LocalTensor dynamicQuantLocal = outLocal[this->cols].template ReinterpretCast(); if constexpr (!IsSameType::value) { Cast(inLocal, inLocal.ReinterpretCast()[perLoopColsAlign], RoundMode::CAST_NONE, this->cols); @@ -151,7 +152,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Compute(LocalTensor& s calcQueue.FreeTensor(tempLocal); inputXOutQueue.EnQue(outLocal); - scaleOutQueue.EnQue(dynamicQuantLocal); } template @@ -163,7 +163,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr int64_t currentLoopStartRow = initialRow / this->k; int64_t currentLoopLastRow = (initialRow + this->currentLoopRows - 1) / this->k; DataCopyExtParams copyInParams{1, static_cast(this->cols * sizeof(T)), 0, 0, 0}; - DataCopyExtParams copyOutParams{1, static_cast(this->cols * sizeof(int8_t)), 0, 0, 0}; + DataCopyExtParams copyOutParams{1, static_cast((this->cols + BLOCK_BYTES) * sizeof(int8_t)), 0, 0, 0}; DataCopyExtParams smoothParams{1, static_cast(this->cols * sizeof(float)), 0, 0, 0}; LocalTensor smoothLocal; @@ -187,7 +187,6 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr // Compute quantization Compute(smoothLocal); - LocalTensor quantScaleLocal = scaleOutQueue.DeQue(); LocalTensor outLocal = inputXOutQueue.DeQue(); while (curLoopRow < this->currentLoopRows && initialRow / this->k == row) { @@ -197,15 +196,11 @@ __aicore__ inline void MoeV2GatherDynamicQuant::CopyOutXQuant1H(int64_t progr if (outIndex == -1 || (this->dropPadMode == DROPLESS_MODE && outIndex >= this->activateRows)) { continue; } - DataCopyPad(expandedXGm[outIndex * cols], outLocal, copyOutParams); - DataCopyPad(dynamicQuantScaleGm[outIndex], quantScaleLocal, {1, 4, 0, 0, 0}); + // Scale is placed after the data position + DataCopyPad(expandedXGm[outIndex * cols_scale_], outLocal, copyOutParams); } inputXInQueue.FreeTensor(inLocal); inputXOutQueue.FreeTensor(outLocal); - scaleOutQueue.FreeTensor(quantScaleLocal); - } - if (smoothType == 1) { - smoothInQueue.FreeTensor(smoothLocal); } expandRowIdxInQueue.FreeTensor(indicesLocal); } @@ -463,6 +458,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR this->needCoreNum = this->gatherOutTilingData->needCoreNum; this->activateRows = this->gatherOutTilingData->activateRows; this->cols = tilingData->cols; + this->cols_scale_ = this->cols + ALIGN_512; this->n = tilingData->n; this->k = tilingData->k; this->totalLength = tilingData->n * tilingData->k; @@ -518,32 +514,15 @@ __aicore__ inline void MoeV2GatherDynamicQuant::Init(GM_ADDR inputX, GM_ADDR pipe->InitBuffer(smoothInQueue, BUFFER_NUM, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(calcQueue, 1, AlignBytes(this->perLoopCols, sizeof(float))); pipe->InitBuffer(inputXOutQueue, 1, AlignBytes(this->perLoopCols, sizeof(int8_t))); - pipe->InitBuffer(scaleOutQueue, 1, BLOCK_BYTES + BLOCK_BYTES); } template __aicore__ inline void MoeV2GatherDynamicQuant::Process() { if (this->blockIdx < this->needCoreNum) { currentLoopRows = perLoopRows; - if (colLoops > 1) { // A single row cannot be fully loaded; workspace is required - if (smoothType == 2) { - for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { - CopyInExpandedExpertIdx(loop); - CopyOutPartialXQuantEH(loop); - } - currentLoopRows = lastLoopRows; - CopyInExpandedExpertIdx(this->rowLoops - 1); - CopyOutPartialXQuantEH(this->rowLoops - 1); - } else { - for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { - CopyInExpandedRowIdx(loop); - CopyOutPartialXQuant1H(loop); - } - currentLoopRows = lastLoopRows; - CopyInExpandedRowIdx(this->rowLoops - 1); - CopyOutPartialXQuant1H(this->rowLoops - 1); - } - } else { // A single row can be fully loaded + if (colLoops > 1) { // Cannot fit all data in one row, workspace is required + trap(); // Not supported + } else { // All data can fit in one row if (smoothType == 2) { for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) { CopyInExpandedExpertIdx(loop); diff --git a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h index 1255b5cfef4..adb805b8fc0 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h +++ b/csrc/dispatch_ffn_combine/op_kernel/unpermute/moe_token_unpermute.h @@ -85,8 +85,8 @@ KernelMoeTokenUnpermute::Init(GM_ADDR permuted_tokens, GM_ADD GM_ADDR unpermuted_tokens, const MoeTokenUnpermuteTilingData *__restrict tiling_data) { - this->blockIdx = get_block_idx() + get_subblockid() * get_block_num(); - this->blockNum = get_block_num() * get_subblockdim(); + this->blockIdx = get_block_idx(); + this->blockNum = get_block_num(); if (blockIdx >= blockNum) { return; diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp new file mode 100644 index 00000000000..926622dc5da --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_v2.hpp @@ -0,0 +1,243 @@ +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_V2_ONLY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +#include "hccl_shmem.hpp" +#include "layout3d.hpp" + +namespace Catlass::Epilogue::Block { +template < + uint32_t UB_STAGES_, + class CType_, + class LayoutPerTokenScale_, + class DType_, + class TileCopy_ +> +class BlockEpilogue < + EpilogueAtlasA2PerTokenDequantV2, + CType_, + Gemm::GemmType, + DType_, + TileCopy_ +> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantV2; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + using CopyScaleGmToUb = Epilogue::Tile::CopyGm2Ub>; + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + struct Params { + __gm__ int32_t *ptrTokenPerExpert{nullptr}; + int32_t EP; + int32_t expertPerRank; + int32_t n2; + LayoutC layoutC; + int32_t n0; + int32_t rank; + HcclShmem shmem; + int32_t offsetD; + + CATLASS_DEVICE + Params() {}; + CATLASS_DEVICE + Params(int32_t EP_, int32_t expertPerRank_, int32_t rank_, __gm__ int32_t *ptrTokenPerExpert_, + LayoutC layoutC_, int32_t n2_, int32_t n0_, HcclShmem& shmem_, int32_t offsetD_) : + ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), + expertPerRank(expertPerRank_),rank(rank_), layoutC(layoutC_), n2(n2_), n0(n0_), + shmem(shmem_), offsetD(offsetD_) + {} + }; + + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + //ub:192KB + n0 = params.n0; + size_t ubOffset = 0; + for(int32_t i = 0; i < 2; i++) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(ElementD); + ubFp32List[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += max_len * sizeof(float); + scaleUbList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += (max_len / n0) * sizeof(float); + source_scale_offset[i] = -1; + } + tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.ptrTokenPerExpert)); + tokenPerExpertLayout = Layout3D(AlignUp(params.EP * params.expertPerRank, 128), params.expertPerRank); + is_ping = true; + } + + CATLASS_DEVICE + void Finalize() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + } + CATLASS_DEVICE + ~BlockEpilogue() + { + + } + CATLASS_DEVICE + void operator() ( + AscendC::GlobalTensor const &gmC, + AscendC::GlobalTensor const &gmPerTokenScale, + GemmCoord& blockCoord, + GemmCoord& actualBlockShape, + int32_t groupIdx, + int32_t preSrcExpertSum, + AscendC::GlobalTensor preSumBeforeRank + ){ + is_ping = !is_ping; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + auto event_id_2 = is_ping ? EVENT_ID2 : EVENT_ID3; + + auto &ubC = ubCList[is_ping]; + auto &ubD = ubDList[is_ping]; + int32_t gmCOffset = preSrcExpertSum * params.n2 + blockCoord.m() * params.n2 + blockCoord.n(); + auto gmTileC = gmC[gmCOffset]; + auto &ubCFp32 = ubFp32List[is_ping]; + auto &scaleUb = scaleUbList[is_ping]; + + LayoutC layoutGM{actualBlockShape.m(), actualBlockShape.n(), params.n2}; + LayoutC layoutUB{actualBlockShape.m(), actualBlockShape.n(), n0}; + + + AscendC::WaitFlag(event_id); + copyGmToUbC(ubC, gmTileC, layoutUB, layoutGM); + AscendC::SetFlag(event_id); + + AscendC::WaitFlag(event_id); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, -1, repeat, {1, 1, 8, 4}); + AscendC::SetFlag(event_id); + + + AscendC::WaitFlag(event_id_2); + AscendC::WaitFlag(event_id_2); + + int32_t gmScaleOffset = preSrcExpertSum + blockCoord.m(); + layout::VectorLayout scaleLauout{actualBlockShape.m()}; + if (source_scale_offset[event_id] != gmScaleOffset) { + source_scale_offset[event_id] = gmScaleOffset; + copyScaleGmToUb(scaleUb, gmPerTokenScale[gmScaleOffset], scaleLauout, scaleLauout); + } + + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id_2); + + + + + AscendC::WaitFlag(event_id_2); + AscendC::WaitFlag(event_id_2); // Note that the value must be MTE2_S instead of MTE2_V. + // Otherwise, 0 will be read, causing garbled characters. + AscendC::PipeBarrier(); + for (int32_t row = 0; row < actualBlockShape.m(); ++row) { + float scale = scaleUb(row); + Muls(ubCFp32[n0* row], ubCFp32[n0 * row] , scale, -1, (actualBlockShape.n() + 127) / 128 * 2, {1, 1, 8, 8}); + } + AscendC::PipeBarrier(); + AscendC::WaitFlag(event_id); + AscendC::Cast(ubD, ubCFp32, AscendC::RoundMode::CAST_RINT, -1, repeat, {1, 1, 4, 8}); + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id_2); + AscendC::SetFlag(event_id); + + int32_t lenTile = actualBlockShape.m(); + int32_t stTile = blockCoord.m(); + int32_t edTile = stTile + lenTile; + int32_t preSumRankInExpert = 0; + int32_t tileOffset = 0; + + AscendC::WaitFlag(event_id); + for (int32_t dstEpIdx = 0; dstEpIdx < params.EP; dstEpIdx ++) { + int32_t lenRankInExpert = tokenPerExpert(tokenPerExpertLayout(dstEpIdx, params.rank, groupIdx)); + int32_t dstExpertOffset = preSumBeforeRank(dstEpIdx * params.expertPerRank + groupIdx); + int32_t stRankInExpert = preSumRankInExpert; + int32_t edRankInExpert = stRankInExpert + lenRankInExpert; + preSumRankInExpert += lenRankInExpert; + if (stRankInExpert >= edTile) { + break; + } + else if (edRankInExpert <= stTile) { + continue; + } + int32_t stData = max(stRankInExpert, stTile); + int32_t edData = min(edRankInExpert, edTile); + uint32_t lenData = edData - stData; + if (lenData <= 0){ + continue; + } + + uint32_t dstOffsetInExpert = 0; + if (stTile > stRankInExpert) { + dstOffsetInExpert = stTile - stRankInExpert; + } + AscendC::GlobalTensor gmRemotePeer; + __gm__ void* dstPeermemPtr = params.shmem(params.offsetD, dstEpIdx); + gmRemotePeer.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD*>(dstPeermemPtr)); + MatrixCoord dstOffset{dstOffsetInExpert + dstExpertOffset, blockCoord.n()}; + int64_t gmDstOffset = params.layoutC.GetOffset(dstOffset); + auto gmTileD = gmRemotePeer[gmDstOffset]; + LayoutC layoutGM2{lenData, actualBlockShape.n(), params.n2}; + LayoutC layoutUB2{lenData, actualBlockShape.n(), n0}; + copyUbToGmD(gmTileD, ubD[tileOffset * n0], layoutGM2, layoutUB2); + tileOffset += lenData; + } + AscendC::SetFlag(event_id); + + } +private: + + Params params; + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + AscendC::LocalTensor ubFp32List[UB_STAGES]; + AscendC::LocalTensor scaleUbList[UB_STAGES]; + int32_t source_scale_offset[UB_STAGES]; + + int32_t max_len = 8 * 32 / 4 * 128; + int32_t n0; + bool is_ping = false; + + + int32_t repeat = 128; + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; + + CopyScaleGmToUb copyScaleGmToUb; + AscendC::GlobalTensor tokenPerExpert; + Layout3D tokenPerExpertLayout; +}; +} +#endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp index 3b435f267f3..4f9351808e7 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_mmad_preload_async_fixpipe_quant.hpp @@ -22,8 +22,6 @@ namespace Catlass::Gemm::Block { -constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; - template __aicore__ inline void SyncFlagFunc(int32_t eventID) { @@ -153,9 +151,11 @@ struct BlockMmad < L1TileShape::K, L1TileShape::N); CATLASS_DEVICE - BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + BlockMmad(Arch::Resource &resource, __gm__ int32_t* flagPtr = nullptr, int32_t expertPerRank = 0, uint32_t l1BufAddrStart = 0) { syncGroupIdx = 0; + ptrSoftFlagBase_ = flagPtr; + expertPerRank_ = expertPerRank; InitL1(resource, l1BufAddrStart); InitL0A(resource); InitL0B(resource); @@ -272,9 +272,21 @@ struct BlockMmad < CATLASS_DEVICE void Finalize(int32_t target, int32_t flag = 0) { - for(;syncGroupIdx <= target; syncGroupIdx++) { - int32_t flagId = syncGroupIdx / CROSS_CORE_FLAG_MAX_SET_COUNT + flag; - AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + if (ptrSoftFlagBase_ != nullptr) { + if (target < 0) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::GlobalTensor flagGlobal; + flagGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_) + (expertPerRank_ + AscendC::GetBlockIdx()) * FLAGSTRIDE); + AscendC::DataCopy(flagGlobal, l1FTensor[target * 16], FLAGSTRIDE); + } + else { + for(;syncGroupIdx <= target; syncGroupIdx++) { + int32_t flagId = syncGroupIdx / 15 + flag; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } } } private: @@ -291,7 +303,6 @@ struct BlockMmad < layout::VectorLayout layoutScale; int32_t syncLoopIdx; int32_t flag; - CATLASS_DEVICE L1TileMmadParams() = default; }; @@ -310,11 +321,24 @@ struct BlockMmad < AscendC::SetFlag(l1AEventList[i]); AscendC::SetFlag(l1BEventList[i]); } + uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; if constexpr (std::is_same_v) { - uint32_t l1SOffset = l1BOffset + L1B_TILE_SIZE * L1_STAGES; l1STensor = resource.l1Buf.template GetBufferByByte(l1SOffset); AscendC::SetFlag(0); } + if (ptrSoftFlagBase_ != nullptr) { + // Initialize the flag matrix (structure as below): + // 1 0 0 0 0 0 0 0 + // 2 0 0 0 0 0 0 0 + // ... + // 16 0 0 0 0 0 0 0 + // Then move it to L1 + uint32_t l1FOffset = l1SOffset + L1S_TILE_SIZE; + l1FTensor = resource.l1Buf.template GetBufferByByte(l1FOffset); + AscendC::GlobalTensor flagBase; + flagBase.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ptrSoftFlagBase_)); + AscendC::DataCopy(l1FTensor, flagBase, expertPerRank_ * FLAGSTRIDE); + } } CATLASS_DEVICE @@ -463,12 +487,20 @@ struct BlockMmad < if constexpr (std::is_same_v) { AscendC::SetFlag(0); } + #ifdef __TILE_SYNC__ + if (params.flag > 0) { + int32_t flagId = params.flag + params.syncLoopIdx / 8; + AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(flagId); + } + #else Finalize(params.syncLoopIdx, params.flag); + #endif } } AscendC::LocalTensor l1ATensorList[L1_STAGES]; AscendC::LocalTensor l1BTensorList[L1_STAGES]; AscendC::LocalTensor l1STensor; + AscendC::LocalTensor l1FTensor; int32_t syncGroupIdx; int32_t l1AEventList[L1_STAGES]; int32_t l1BEventList[L1_STAGES]; @@ -497,8 +529,11 @@ struct BlockMmad < CopyL1ToL0A copyL1ToL0A; CopyL1ToL0B copyL1ToL0B; CopyL0CToGm copyL0CToGm; + + __gm__ int32_t* ptrSoftFlagBase_ = nullptr; + int32_t expertPerRank_; }; } // namespace Catlass::Gemm::Block -#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_FIXPIPE_QUANT_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp index 84cb6c4ec91..3249138e178 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp @@ -5,5 +5,7 @@ constexpr static uint64_t MB_SIZE = 1024 * 1024UL; constexpr static int32_t NUMS_PER_FLAG = 16; constexpr static int32_t CACHE_LINE = 512; constexpr static int32_t RESET_VAL = 0xffff; -constexpr static int32_t ALIGN_128 = 128; +constexpr static int32_t FLAGSTRIDE = 16; +constexpr static int32_t UB_ALIGN = 32; +constexpr uint16_t CROSS_CORE_FLAG_MAX_SET_COUNT = 15; #endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp index 31fdbad1c27..7e30114e404 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/dispatch_policy_custom.hpp @@ -33,13 +33,13 @@ namespace Catlass::Epilogue { }; template - struct EpilogueAtlasA2PerTokenDequantQuant { + struct EpilogueAtlasA2PerTokenDequantSwigluQuant { using ArchTag = Arch::AtlasA2; static constexpr uint32_t UB_STAGES = UB_STAGES_; }; template - struct EpilogueAtlasA2PerTokenDequantSwigluQuant { + struct EpilogueAtlasA2PerTokenDequantV2 { using ArchTag = Arch::AtlasA2; static constexpr uint32_t UB_STAGES = UB_STAGES_; }; diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp index cfbb4daf8b1..93d4c9e7d4e 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp @@ -5,13 +5,28 @@ #include "kernel_operator.h" #include "const_args.hpp" +#ifdef HCCL_COMM #include "moe_distribute_base.h" +using namespace AscendC::HcclContextDef; -#ifndef HCCL_COMM +#else #include "shmem_api.h" #endif #define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ +constexpr int32_t MAX_RANK_SIZE = 32; +constexpr int32_t SHMEM_MEM = 700 * MB_SIZE; + +constexpr uint16_t SEND_SYNC_EVENT_ID = 9; +constexpr uint16_t RECV_SYNC_EVENT_ID = 10; + +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t STATE_OFFSET = 512; + +FORCE_INLINE_AICORE void AicSyncAll() { + AscendC::CrossCoreSetFlag<0x0, PIPE_FIX>(8); + AscendC::CrossCoreWaitFlag<0x0>(8); +} template FORCE_INLINE_AICORE void gm_store(__gm__ T *addr, T val) { @@ -23,10 +38,11 @@ FORCE_INLINE_AICORE T gm_load(__gm__ T *cache) { return *((__gm__ T *)cache); } -FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { +template +FORCE_INLINE_AICORE void gm_dcci(__gm__ T * addr) { using namespace AscendC; GlobalTensor global; - global.SetGlobalBuffer(addr); + global.SetGlobalBuffer(reinterpret_cast(addr)); // Important: add hint to avoid dcci being optimized by compiler __asm__ __volatile__(""); @@ -37,26 +53,20 @@ FORCE_INLINE_AICORE void gm_dcci(__gm__ uint8_t * addr) { FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *sig_addr, int32_t cmp_val) { do { gm_dcci((__gm__ uint8_t *)sig_addr); - if (*sig_addr == cmp_val) { return *sig_addr; } - - // in case when peer pe enters next barrier if (*sig_addr == cmp_val + 1) { return *sig_addr; } } while (true); - - // never reach return -1; } - FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) { do { AscendC::LocalTensor ub; - ub.address_.logicPos = static_cast(TPosition::VECIN); + ub.address_.logicPos = static_cast(AscendC::TPosition::VECIN); ub.address_.bufferAddr = 0; AscendC::GlobalTensor sig; sig.SetGlobalBuffer(sig_addr); @@ -71,59 +81,53 @@ FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32 } -constexpr int32_t MAX_RANK_SIZE = 32; class HcclShmem { public: #ifdef HCCL_COMM // HCCL needs to initialize the HCCL context - __gm__ HcclOpResParamCustom *WinContext_{nullptr}; - Hccl hccl_; - size_t m_segmentSize; - int32_t m_rank; - int32_t m_rankSize; - - FORCE_INLINE_AICORE - HcclShmem(){ - auto contextGM0 = AscendC::GetHcclContext(); - WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; - - m_rank = WinContext_->localUsrRankId; - m_rankSize = WinContext_->rankSize; - m_segmentSize = WinContext_->winSize; - - } - - FORCE_INLINE_AICORE - size_t SegmentSize() const { - return m_segmentSize; - } - - FORCE_INLINE_AICORE - int32_t RankSize() const { - return m_rankSize; - } + __gm__ HcclOpResParamCustom *WinContext_{nullptr}; + Hccl hccl_; + AscendC::LocalTensor ub; + FORCE_INLINE_AICORE + HcclShmem(){ + auto contextGM0 = AscendC::GetHcclContext(); + WinContext_ = (__gm__ HcclOpResParamCustom *)contextGM0; + + m_rank = WinContext_->localUsrRankId; + m_rankSize = WinContext_->rankSize; + m_segmentSize = WinContext_->winSize; + } + #else + FORCE_INLINE_AICORE + HcclShmem(){ + m_segmentSize = SHMEM_MEM; + } + FORCE_INLINE_AICORE + void initShmem(GM_ADDR symmetricPtr_, size_t rank, size_t rankSize) { + symmetricPtr = symmetricPtr_; + m_rank = rank; + m_rankSize = rankSize; + } #endif FORCE_INLINE_AICORE - GM_ADDR operator() () const { // No argument: return local peermem + GM_ADDR operator() () const { // No parameters: return pointer to local peermem #ifdef HCCL_COMM return (GM_ADDR)(WinContext_->localWindowsIn); #else - return reinterpret_cast(shmemi_get_state()->heap_base); + return reinterpret_cast(shmem_ptr(symmetricPtr, m_rank)); #endif } FORCE_INLINE_AICORE - GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address + GM_ADDR operator() (int32_t index) const { // With index parameter: return pointer to the base address of remote peermem #ifdef HCCL_COMM return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn : ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn); #else - return reinterpret_cast(shmem_ptr(shmemi_get_state()->heap_base, index)); + return reinterpret_cast(shmem_ptr(symmetricPtr, index)); #endif } - - FORCE_INLINE_AICORE GM_ADDR operator () (int64_t offset, int32_t rankId) const { #ifdef HCCL_COMM @@ -136,15 +140,28 @@ class HcclShmem { return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn : ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset; #else - return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId); + return reinterpret_cast(shmem_ptr((symmetricPtr + offset), rankId)); #endif } + + FORCE_INLINE_AICORE + size_t SegmentSize() const { + return m_segmentSize; + } + + FORCE_INLINE_AICORE + int32_t RankSize() const { + return m_rankSize; + } + + FORCE_INLINE_AICORE ~HcclShmem() { } + FORCE_INLINE_AICORE void CrossRankSync() { uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); @@ -165,12 +182,146 @@ class HcclShmem { gm_store(sync_base, count); } + + FORCE_INLINE_AICORE + void InitStatusTargetSum() + { + using namespace AscendC; + uint64_t flag_offset = (m_segmentSize - MB_SIZE) + SELF_STATE_OFFSET; + //uint64_t self_state_offset = (m_segmentSize - 2 * MB_SIZE); + // ep state + //uint32_t coreIdx = get_block_idx();; + uint32_t coreIdx = GetBlockIdx(); + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)((*this)() + flag_offset)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(selfStatusTensor[coreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + int32_t state = selfStatusTensor(coreIdx * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1.0); + selfStatusTensor(coreIdx * UB_ALIGN) = 0x3F800000; // 1.0f + epStateValue_ = 0x3F800000; // 1.0f + } else { + sumTarget_ = static_cast(0.0); + selfStatusTensor(coreIdx * UB_ALIGN) = 0; + epStateValue_ = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(selfStatusTensor[coreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + } + + FORCE_INLINE_AICORE + void CrossRankSyncV2Set(AscendC::LocalTensor ctrBuffer) { + //subblockid = 0 + uint32_t stateOffset_ = STATE_OFFSET; + // uint32_t epStateOffsetOnWin_ = m_rank * stateOffset_; + + uint64_t flag_offset = (m_segmentSize - MB_SIZE) + m_rank * stateOffset_; + //uint64_t flag_offset = (m_segmentSize - MB_SIZE); + int vec_size = get_block_num(); + int vec_id = get_block_idx(); + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + pipe_barrier(PIPE_ALL); + + ctrBuffer.SetValue(0, epStateValue_); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + for (uint32_t dstEpIdx = vec_id; dstEpIdx < m_rankSize; dstEpIdx += vec_size) { + AscendC::GlobalTensor gmDstStates; + gmDstStates.SetGlobalBuffer((__gm__ int32_t*)((*this)(flag_offset, dstEpIdx))); + DataCopy(gmDstStates, ctrBuffer, 8); + } + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } + + FORCE_INLINE_AICORE + void CrossRankSyncV2Wait(AscendC::LocalTensor statusTensor, AscendC::LocalTensor gatherMaskOutTensor, + AscendC::LocalTensor gatherTmpTensor, AscendC::LocalTensor statusSumOutTensor) { + + uint64_t flag_offset = (m_segmentSize - MB_SIZE); + int vec_size = get_block_num(); + int vec_id = get_block_idx(); + uint32_t stateOffset_ = STATE_OFFSET; + + uint32_t sendRankNum_ = m_rankSize / vec_size; + uint32_t remainderRankNum = m_rankSize % vec_size; + uint32_t startRankId_ = sendRankNum_ * vec_id; + if (vec_id < remainderRankNum) { + sendRankNum_++; + startRankId_ += vec_id; + } else { + startRankId_ += remainderRankNum; + } + uint32_t endRankId_ = startRankId_ + sendRankNum_; + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + + AscendC::GlobalTensor epStatusSpaceGlobalTensor_; + epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)((*this)() + flag_offset)); + + if (startRankId_ < m_rankSize) { + AscendC::PipeBarrier(); + gatherTmpTensor.SetValue(0, 1); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + // DataCopyParams intriParams{static_cast(sendRankNum_), 1, + // static_cast((moeSendNum_ > 512) ? 7 : 15), 0}; + AscendC::DataCopyParams intriParams{static_cast(sendRankNum_), 1, + static_cast(15), 0}; + + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5; + float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5; + AscendC::SumParams sumParams{1, sendRankNum_, sendRankNum_}; + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)], + intriParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask, + {1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt); + + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + + //unpermute + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } + + FORCE_INLINE_AICORE __gm__ int32_t* SyncBaseAddr() { uint64_t flag_offset = (m_segmentSize - MB_SIZE) / sizeof(int32_t); return (__gm__ int32_t*)(*this)() + flag_offset + 2048; } + +private: + GM_ADDR symmetricPtr; + int32_t m_rank; + int32_t m_rankSize; + size_t m_segmentSize; + float sumTarget_{0.0}; + int32_t epStateValue_; }; + + #endif