diff --git a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp index 0206fa5cfdb..ff037903966 100644 --- a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp +++ b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.cpp @@ -42,8 +42,8 @@ enum NnopbaseHcclServerType { NNOPBASE_HCCL_SERVER_TYPE_END }; -extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2, - const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2, +extern aclnnStatus aclnnInnerDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2, + const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, bool transB, bool weightNz, @@ -55,8 +55,8 @@ extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, -aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2, - const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2, +aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2, + const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, const aclTensor* out, diff --git a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h index 153612637bc..871a30b0e62 100644 --- a/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_host/aclnn_dispatch_ffn_combine.h @@ -39,8 +39,8 @@ extern "C" { * @param [out] executor: op executor containing the operator compute flow. * @return aclnnStatus: status code. */ -__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2, - const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2, +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensorList* weight1, const aclTensorList* weight2, + const aclTensor* expertId, const aclTensorList* scale1, const aclTensorList* scale2, const aclTensor* probs, const char* group, int64_t maxOutputSize, const aclTensor* out, diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp index 649edf1fa40..6b8a7a33f3f 100644 --- a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_def.cpp @@ -24,13 +24,13 @@ class DispatchFFNCombine : public OpDef { .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("w1") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) .IgnoreContiguous(); this->Input("w2") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ}) @@ -41,12 +41,12 @@ class DispatchFFNCombine : public OpDef { .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("scale1") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("scale2") - .ParamType(REQUIRED) + .ParamType(DYNAMIC) .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp index 6342f1a1d5b..90b3c45ebdd 100644 --- a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp @@ -91,27 +91,42 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info) { const char *nodeName = context->GetNodeName(); - // OPS_LOG_I(nodeName, "DispatchFFnCombine DispatchFFNCombineCheckShapeAndSetTiling."); const gert::StorageShape *aStorageShape = context->GetInputShape(X_INDEX); - const gert::StorageShape *bStorageShape = context->GetInputShape(WEIGHT_INDEX); - const gert::StorageShape *expertIdxShape = context->GetInputShape(EXPERTID_INDEX); + auto expertIdxTensor = context->GetDynamicInputTensor(EXPERTID_INDEX, 0); uint32_t M = aStorageShape->GetStorageShape().GetDim(0); uint32_t K = aStorageShape->GetStorageShape().GetDim(1); - uint32_t expertPerRank = bStorageShape->GetStorageShape().GetDim(0); - uint32_t N = bStorageShape->GetStorageShape().GetDim(2); - uint32_t topK = expertIdxShape->GetStorageShape().GetDim(1); + + auto wTensor = context->GetDynamicInputTensor(WEIGHT_INDEX, 0); + uint32_t wTensorDims = wTensor->GetOriginShape().GetDimNum(); + uint32_t N = wTensor->GetStorageShape().GetDim(wTensorDims - 1); + + uint32_t topK = expertIdxTensor->GetStorageShape().GetDim(1); + uint32_t listLen = 0; + while (true) { + auto wTensorT = context->GetDynamicInputTensor(WEIGHT_INDEX, ++listLen); + if (wTensorT == nullptr) {break;} + } + + uint32_t expertPerRank; + if (listLen == 1) { + expertPerRank = wTensor->GetStorageShape().GetDim(0); + } else { + expertPerRank = listLen; + } info.M = M; info.N = N; info.K = K; info.expertPerRank = expertPerRank; info.topK = topK; + info.listLen = listLen; OP_LOGD(K_INNER_DEBUG, "M=%d ", info.M); OP_LOGD(K_INNER_DEBUG, "K=%d ", info.K); OP_LOGD(K_INNER_DEBUG, "N=%d ", info.N); OP_LOGD(K_INNER_DEBUG, "expertPerRank=%d ", info.expertPerRank); OP_LOGD(K_INNER_DEBUG, "topK=%d ", info.topK); + OP_LOGD(K_INNER_DEBUG, "listLen=%d ", info.listLen); return ge::GRAPH_SUCCESS; } diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h index eb19ede9fca..704809dcc4a 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine.h @@ -100,6 +100,7 @@ class DispatchFFNCombine { int32_t expertPerRank; int32_t maxOutputSize; int32_t EP; + int32_t listLen; optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData; uint64_t initRoutingQuantTilingKey; @@ -138,6 +139,7 @@ __aicore__ inline void DispatchFFNCombine::Init(GM_ADDR xGM, topK = tilingData.dispatchFFNCombineInfo.topK; expertPerRank = tilingData.dispatchFFNCombineInfo.expertPerRank; maxOutputSize = tilingData.dispatchFFNCombineInfo.maxOutputSize; + listLen = tilingData.dispatchFFNCombineInfo.listLen; m0 = tilingData.cocTiling.m0; k0 = tilingData.cocTiling.k0; @@ -254,7 +256,7 @@ __aicore__ inline void DispatchFFNCombine::Process() uint32_t epilogueGranularity = expertPerRank - 1; typename MatmulKernel::Params params{ - problemShape, static_cast(EP), static_cast(expertPerRank), static_cast(maxOutputSize), + problemShape, static_cast(EP), static_cast(listLen), static_cast(expertPerRank), static_cast(maxOutputSize), static_cast(rank), static_cast(rankSize), static_cast(topK), initRoutingQuantTilingKey, epilogueCoreNum, epilogueGranularity, diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index 179cdc8eca9..469a89e2e52 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -30,6 +30,7 @@ #include "utils/hccl_shmem.hpp" #include "utils/const_args.hpp" #include "utils/layout3d.hpp" +#include "utils/get_tensor_addr.hpp" #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2_tiling.h" #include "moe_init_routing_quant_v2/moe_init_routing_quant_v2.cpp" @@ -79,19 +80,20 @@ class DispatchFFNCombineKernel { __gm__ ElementA *ptrA; LayoutA layoutA; LayoutA layoutA2; - __gm__ ElementB *ptrB1; + GM_ADDR ptrB1; LayoutB layoutB1; - __gm__ ElementB *ptrB2; + GM_ADDR ptrB2; LayoutB layoutB2; - __gm__ ElementScale *ptrScale1; + GM_ADDR ptrScale1; LayoutScale layoutScale1; - __gm__ ElementScale *ptrScale2; + GM_ADDR ptrScale2; LayoutScale layoutScale2; __gm__ ElementD2 *ptrOutput; LayoutD1 layoutD1; LayoutD2 layoutD2; GM_ADDR ptrWorkspace; int32_t EP; + int32_t listLen; int32_t expertPerRank; uint32_t maxOutputSize; uint32_t rank; @@ -121,7 +123,7 @@ class DispatchFFNCombineKernel { CATLASS_HOST_DEVICE Params( GemmCoord problemShape_, - uint32_t EP_, uint32_t expertPerRank_, uint32_t maxOutputSize_, + uint32_t EP_, uint32_t listLen_, uint32_t expertPerRank_, uint32_t maxOutputSize_, uint32_t rank_, uint32_t rankSize_, int64_t topK_, uint64_t initRoutingQuantTilingKey_, uint32_t epilogueCoreNum_, uint32_t epilogueGranularity_, GM_ADDR ptrA_, LayoutA layoutA_, LayoutA layoutA2_, @@ -136,15 +138,15 @@ class DispatchFFNCombineKernel { GM_ADDR ptrWorkspace_, int32_t ubMoveNum_, optiling::MoeInitRoutingQuantV2TilingData moeInitRoutingQuantV2TilingData_ ) : problemShape(problemShape_), - EP(EP_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_), + EP(EP_), listLen(listLen_), expertPerRank(expertPerRank_), maxOutputSize(maxOutputSize_), rank(rank_), rankSize(rankSize_), topK(topK_), initRoutingQuantTilingKey(initRoutingQuantTilingKey_), epilogueCoreNum(epilogueCoreNum_), epilogueGranularity(epilogueGranularity_), ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), layoutA2(layoutA2_), - ptrB1(reinterpret_cast<__gm__ ElementB *>(ptrB1_)), layoutB1(layoutB1_), - ptrB2(reinterpret_cast<__gm__ ElementB *>(ptrB2_)), layoutB2(layoutB2_), - ptrScale1(reinterpret_cast<__gm__ ElementScale *>(ptrScale1_)), layoutScale1(layoutScale1_), - ptrScale2(reinterpret_cast<__gm__ ElementScale *>(ptrScale2_)), layoutScale2(layoutScale2_), + ptrB1(ptrB1_), layoutB1(layoutB1_), + ptrB2(ptrB2_), layoutB2(layoutB2_), + ptrScale1(ptrScale1_), layoutScale1(layoutScale1_), + ptrScale2(ptrScale2_), layoutScale2(layoutScale2_), ptrOutput(reinterpret_cast<__gm__ ElementD2 *>(ptrOutput_)), layoutD1(layoutD1_), layoutD2(layoutD2_), expertIdx(expertIdx_), moeInitRoutingQuantV2Scale(moeInitRoutingQuantV2Scale_), moeInitRoutingQuantV2Offset(moeInitRoutingQuantV2Offset_), @@ -212,11 +214,9 @@ class DispatchFFNCombineKernel { cumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM)); gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(workspaceInfo.ptrA)); - gmS.SetGlobalBuffer(params.ptrScale1); gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC)); gmPermutedToken.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD1 *>(workspaceInfo.ptrPermutedToken)); - gmS2.SetGlobalBuffer(params.ptrScale2); gmC2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(workspaceInfo.ptrC2)); gmPerTokenScale1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementPerTokenScale *>(workspaceInfo.ptrPerTokenScale)); @@ -224,7 +224,7 @@ class DispatchFFNCombineKernel { tokenPerExpert.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(shmem() + peermemInfo.offsetPeerTokenPerExpert)); - tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank + 8, params.expertPerRank); + tokenPerExpertLayout = Layout3D(params.EP * params.expertPerRank, params.expertPerRank); } template @@ -291,7 +291,7 @@ class DispatchFFNCombineKernel { AscendC::DataCopyPad( tmpBuffer1, tokenPerExpert[rankId * expertPerRank], - {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank + 8) * sizeof(int32_t)), 0}, + {U16(EP), U16(expertPerRank * sizeof(int32_t)), U16(((EP - 1) * expertPerRank) * sizeof(int32_t)), 0}, {} ); @@ -327,6 +327,18 @@ class DispatchFFNCombineKernel { AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul int64_t preCurrentmSum = 0; int32_t syncLoopIdx = -1; + + constexpr uint32_t MAX_EXPERTS_PER_RANK = 32; + __gm__ ElementB* weight1Array[MAX_EXPERTS_PER_RANK]; + __gm__ ElementScale * scale1Array[MAX_EXPERTS_PER_RANK]; + + int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; + for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { + weight1Array[loopIdx] = reinterpret_cast<__gm__ ElementB*>(GetTensorAddr(loopIdx, params.ptrB1)); + scale1Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale1)); + } + AscendC::PipeBarrier(); + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); if (preCurrentmSum >= params.maxOutputSize) { @@ -335,7 +347,13 @@ class DispatchFFNCombineKernel { currentM = params.maxOutputSize - preCurrentmSum; } AscendC::GlobalTensor gmB1; - gmB1.SetGlobalBuffer(params.ptrB1); + AscendC::GlobalTensor gmS; + int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; + gmB1.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight1Array[arrayGroupIdx])); + gmS.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale1Array[arrayGroupIdx])); + + AscendC::PipeBarrier(); + if (currentM <= L1TileShape::M) { gmB1.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -364,7 +382,7 @@ class DispatchFFNCombineKernel { int64_t gmOffsetA = layoutA.GetOffset(offsetA); int64_t gmOffsetB = layoutB1.GetOffset(offsetB); int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // One scale group per expert + int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * params.problemShape.n() : 0); if (currentM > 0) { blockMmad( gmA[gmGroupOffsetA + gmOffsetA], layoutA, @@ -386,7 +404,9 @@ class DispatchFFNCombineKernel { preCurrentmSum += currentM; gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + if (params.listLen == 1) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } @@ -420,6 +440,17 @@ class DispatchFFNCombineKernel { if (params.epilogueGranularity < params.expertPerRank) { lastDequantExpertNum = params.expertPerRank - params.epilogueGranularity; } + + constexpr uint32_t MAX_EXPERTS_PER_RANK = 8; + __gm__ ElementB* weight2Array[MAX_EXPERTS_PER_RANK]; + __gm__ ElementScale * scale2Array[MAX_EXPERTS_PER_RANK]; + int32_t loopCount = params.listLen == 1 ? 1 : params.expertPerRank; + for (uint32_t loopIdx = 0; loopIdx < loopCount; ++loopIdx) { + weight2Array[loopIdx] = reinterpret_cast<__gm__ ElementB *>(GetTensorAddr(loopIdx, params.ptrB2)); + scale2Array[loopIdx] = reinterpret_cast<__gm__ ElementScale *>(GetTensorAddr(loopIdx, params.ptrScale2)); + } + AscendC::PipeBarrier(); + for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { uint32_t currentM = cumsumMM((params.EP - 1) * params.expertPerRank + groupIdx); if (preCurrentmSum >= params.maxOutputSize) { @@ -428,7 +459,12 @@ class DispatchFFNCombineKernel { currentM = params.maxOutputSize - preCurrentmSum; } AscendC::GlobalTensor gmB2; - gmB2.SetGlobalBuffer(params.ptrB2); + AscendC::GlobalTensor gmS2; + AscendC::PipeBarrier(); + int32_t arrayGroupIdx = params.listLen == 1 ? 0 : groupIdx; + gmB2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(weight2Array[arrayGroupIdx])); + gmS2.SetGlobalBuffer(reinterpret_cast<__gm__ ElementScale *>(scale2Array[arrayGroupIdx])); + if (currentM <= L1TileShape::M) { gmB2.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); } @@ -465,7 +501,7 @@ class DispatchFFNCombineKernel { int64_t gmOffsetA = layoutA.GetOffset(offsetA); int64_t gmOffsetB = layoutB2.GetOffset(offsetB); int64_t gmOffsetC = layoutC.GetOffset(offsetC); - int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // One scale group per expert + int64_t gmOffsetS = blockCoord.n() * L1TileShape::N + (params.listLen == 1 ? groupIdx * n2 : 0); // One scale group per expert if (currentM > 0) { blockMmad( gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA, @@ -478,7 +514,9 @@ class DispatchFFNCombineKernel { } preCurrentmSum += currentM; gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + if (params.listLen == 1) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); startCoreIdx = (startCoreIdx + coreLoops) % coreNum; @@ -491,12 +529,29 @@ class DispatchFFNCombineKernel { blockMmad.Finalize(params.expertPerRank - 1, 3); } + CATLASS_DEVICE + void ResetTokenPerExpert(AscendC::GlobalTensor & tokenPerExpert, int32_t num) + { + if (coreIdx != coreNum - 1) { + return; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::LocalTensor tmp = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmp, 0, num); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert, tmp, num); + } + CATLASS_DEVICE void CrossRankSyncAndlocalTokenPerExpertAllGather(Params const ¶ms, int64_t localTokenPerExpertOffset){ - uint64_t flag_offset = (shmem.SegmentSize() - MB_SIZE) / sizeof(int32_t); - __gm__ int32_t* sync_base = shmem.SyncBaseAddr(); - int count = gm_load(sync_base) + 1; - if (coreIdx < params.EP && coreIdx != params.rank) { + AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); + uint32_t numPerCore = params.EP * params.expertPerRank; + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx == params.rank) { + continue; + } AscendC::GlobalTensor srcAddress; srcAddress.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(shmem() + localTokenPerExpertOffset)); AscendC::GlobalTensor dstAddress; @@ -509,27 +564,42 @@ class DispatchFFNCombineKernel { using CopyUbToGm = Epilogue::Tile::CopyUb2Gm; CopyGmToUb copyGmToUb; CopyUbToGm copyUbToGm; - AscendC::LocalTensor tmpBuffer = resource.ubBuf.template GetBufferByByte(0); + AscendC::WaitFlag(EVENT_ID0); - uint32_t tmp = params.EP * params.expertPerRank; - copyGmToUb(tmpBuffer, srcAddress[0], - layout::RowMajor{ 1, tmp}, - layout::RowMajor{1, tmp}); - - tmpBuffer.SetValue(params.EP * params.expertPerRank, count); - AscendC::SetFlag(EVENT_ID0); - AscendC::WaitFlag(EVENT_ID0); - copyUbToGm(dstAddress[0], tmpBuffer, - layout::RowMajor{ 1, tmp + 1}, - layout::RowMajor{1, tmp + 1}); + + copyGmToUb(tmpBuffer, srcAddress[0], + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, 0x800000, numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyUbToGm(dstAddress[0], tmpBuffer, + layout::RowMajor{ 1, numPerCore}, + layout::RowMajor{1, numPerCore}); AscendC::SetFlag(EVENT_ID0); AscendC::WaitFlag(EVENT_ID0); - - __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(coreIdx, params.EP, 0); - gm_signal_wait_until_eq_for_barrier(sync_check, count); + } + for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) { + if (dstEpIdx == params.rank) { + continue; + } + int32_t intPer512 = CACHE_LINE / sizeof(int); + for(int32_t checkIdx = 0; checkIdx < params.EP * params.expertPerRank; checkIdx += intPer512) { + __gm__ int32_t* sync_check = reinterpret_cast<__gm__ int32_t*>(shmem() + peermemInfo.offsetPeerTokenPerExpert) + tokenPerExpertLayout(dstEpIdx, 0, checkIdx); + gm_signal_wait_until_ne(sync_check, 0); + } + AscendC::DataCopy(tmpBuffer, tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::Adds(tmpBuffer, tmpBuffer, -0x800000, numPerCore); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopy(tokenPerExpert[tokenPerExpertLayout(dstEpIdx, 0, 0)], tmpBuffer, numPerCore); } AscendC::SyncAll(); - gm_store(sync_base, count); } @@ -569,7 +639,8 @@ class DispatchFFNCombineKernel { uint32_t prevGroupSum1 = 0; uint32_t dequantSum = 0; int32_t syncLoopIdx = -1; - BlockEpilogue1 blockEpilogue(resource); + uint32_t n = params.problemShape.n(); + BlockEpilogue1 blockEpilogue(resource, n); for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) { // The ith core reads data from the ith rank's peermem groupIdxDeq = groupIdx - 2; @@ -668,7 +739,8 @@ class DispatchFFNCombineKernel { typename BlockEpilogue2::Params epilogueParams{ static_cast(params.EP), static_cast(params.expertPerRank), - reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace) + reinterpret_cast<__gm__ int32_t *>(params.ptrWorkspace), + static_cast(n2) }; BlockEpilogue2 blockEpilogue(resource, epilogueParams); int32_t prevGroupSum2 = 0; @@ -704,6 +776,7 @@ class DispatchFFNCombineKernel { } blockEpilogue.Finalize(); AscendC::SyncAll(); + ResetTokenPerExpert(tokenPerExpert, params.EP * params.EP * params.expertPerRank); shmem.CrossRankSync(); MoeTokenUnpermuteTilingData tilingData; MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum); @@ -794,10 +867,8 @@ class DispatchFFNCombineKernel { AscendC::GlobalTensor gmA; AscendC::GlobalTensor gmC; - AscendC::GlobalTensor gmS; AscendC::GlobalTensor gmPermutedToken; - AscendC::GlobalTensor gmS2; AscendC::GlobalTensor gmC2; AscendC::GlobalTensor gmPerTokenScale1; diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h index de891e9f026..04ca56a4b97 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_tiling.h @@ -30,6 +30,7 @@ struct DispatchFFNCombineInfo { uint32_t totalUbSize; uint32_t topK; uint32_t worldSize; + uint32_t listLen; }; struct CoCTiling { diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp index 4b627a67d80..5616b1f87c9 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_row.hpp @@ -70,23 +70,24 @@ class BlockEpilogue < __gm__ int32_t *ptrTokenPerExpert{nullptr}; int32_t EP; int32_t expertPerRank; + int32_t n2; CATLASS_DEVICE Params() {}; CATLASS_DEVICE - Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_) {} + Params(int32_t EP_, int32_t expertPerRank_, __gm__ int32_t *ptrTokenPerExpert_, int32_t n2_) : ptrTokenPerExpert(ptrTokenPerExpert_), EP(EP_), expertPerRank(expertPerRank_), n2(n2_) {} }; CATLASS_DEVICE BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) { - size_t ubOffset = 4096; + size_t ubOffset = 0; int32_t eventVMTE2 = 0; int32_t eventMTE2V = 0; int32_t eventMTE3V = 0; int32_t eventVMTE3 = 0; - constexpr int32_t blockN = 12000; + int32_t blockN = params.n2; for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += blockN * sizeof(ElementC); diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp index 26630122567..5588e9cd302 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/block_epilogue_pertoken_swiglu.hpp @@ -84,16 +84,16 @@ class BlockEpilogue < }; CATLASS_DEVICE - BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + BlockEpilogue(Arch::Resource const &resource, int32_t n, Params const ¶ms = Params{}) : params(params) { size_t ubOffset = 0; int32_t eventVMTE2 = 0; int32_t eventMTE2V = 0; int32_t eventMTE3V = 0; int32_t eventVMTE3 = 0; - constexpr uint32_t blockN = 4096; - constexpr uint32_t ChunkTileLen = blockN / 2; - constexpr uint32_t HalfChunkTileLen = ChunkTileLen / 2; + uint32_t blockN = n; + uint32_t ChunkTileLen = blockN / 2; + uint32_t HalfChunkTileLen = ChunkTileLen / 2; for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp index 71b422c924d..61a3d866dcc 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/const_args.hpp @@ -3,4 +3,6 @@ #define CONST_ARGS_HPP constexpr static uint64_t MB_SIZE = 1024 * 1024UL; constexpr static int32_t NUMS_PER_FLAG = 16; +constexpr static int32_t CACHE_LINE = 512; +constexpr static int32_t RESET_VAL = 0xffff; #endif \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/get_tensor_addr.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/get_tensor_addr.hpp new file mode 100644 index 00000000000..67b32c25ba4 --- /dev/null +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/get_tensor_addr.hpp @@ -0,0 +1,16 @@ +#ifndef GET_TENSOR_ADDR_HPP +#define GET_TENSOR_ADDR_HPP +#include "kernel_operator.h" + +#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ + +template +FORCE_INLINE_AICORE __gm__ T* GetTensorAddr(uint32_t index, GM_ADDR tensorPtr) { + __gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr); + uint64_t tensorPtrOffset = *dataAddr; // The offset of the data address from the first address. + // Moving 3 bits to the right means dividing by sizeof(uint64 t). + __gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3); + return reinterpret_cast<__gm__ T*>(*(retPtr + index)); +} + +#endif // GET_TENSOR_ADDR_HPP \ No newline at end of file diff --git a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp index fd2b995cef2..cfbb4daf8b1 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/utils/hccl_shmem.hpp @@ -53,17 +53,34 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t * } +FORCE_INLINE_AICORE void gm_signal_wait_until_ne(__gm__ int32_t *sig_addr, int32_t cmp_val) { + do { + AscendC::LocalTensor ub; + ub.address_.logicPos = static_cast(TPosition::VECIN); + ub.address_.bufferAddr = 0; + AscendC::GlobalTensor sig; + sig.SetGlobalBuffer(sig_addr); + AscendC::DataCopy(ub, sig, 8); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + if (ub(0) != cmp_val) { + return; + } + } while (true); + return; +} + + constexpr int32_t MAX_RANK_SIZE = 32; class HcclShmem { public: #ifdef HCCL_COMM // HCCL needs to initialize the HCCL context __gm__ HcclOpResParamCustom *WinContext_{nullptr}; Hccl hccl_; - GM_ADDR m_ptrArray[MAX_RANK_SIZE]; size_t m_segmentSize; int32_t m_rank; int32_t m_rankSize; - + FORCE_INLINE_AICORE HcclShmem(){ auto contextGM0 = AscendC::GetHcclContext(); @@ -73,18 +90,13 @@ class HcclShmem { m_rankSize = WinContext_->rankSize; m_segmentSize = WinContext_->winSize; - for (int i = 0; i < m_rankSize; i++) { - m_ptrArray[i] = (GM_ADDR)((i == m_rank) ? WinContext_->localWindowsIn : - ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[i].nextDevicePtr))->windowsIn); - } - } FORCE_INLINE_AICORE size_t SegmentSize() const { return m_segmentSize; } - + FORCE_INLINE_AICORE int32_t RankSize() const { return m_rankSize; @@ -94,7 +106,7 @@ class HcclShmem { FORCE_INLINE_AICORE GM_ADDR operator() () const { // No argument: return local peermem #ifdef HCCL_COMM - return m_ptrArray[m_rank]; + return (GM_ADDR)(WinContext_->localWindowsIn); #else return reinterpret_cast(shmemi_get_state()->heap_base); #endif @@ -103,7 +115,8 @@ class HcclShmem { FORCE_INLINE_AICORE GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address #ifdef HCCL_COMM - return m_ptrArray[index]; + return (GM_ADDR)((index == m_rank) ? WinContext_->localWindowsIn : + ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[index].nextDevicePtr))->windowsIn); #else return reinterpret_cast(shmem_ptr(shmemi_get_state()->heap_base, index)); #endif @@ -120,7 +133,8 @@ class HcclShmem { if (rankId < 0 || rankId >= m_rankSize) { return nullptr; } - return m_ptrArray[rankId] + offset; + return (GM_ADDR)((rankId == m_rank) ? WinContext_->localWindowsIn : + ((HcclRankRelationResV2Custom *)(WinContext_->remoteRes[rankId].nextDevicePtr))->windowsIn) + offset; #else return shmem_ptr(shmemi_get_state()->heap_base + offset, rankId); #endif diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index ca701235e91..fe06fbe5f05 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -727,11 +727,11 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor at::Tensor& dispatch_ffn_combine( const at::Tensor& x, - const at::Tensor& weight1, - const at::Tensor& weight2, + const at::TensorList& weight1, + const at::TensorList& weight2, const at::Tensor& expert_idx, - const at::Tensor& scale1, - const at::Tensor& scale2, + const at::TensorList& scale1, + const at::TensorList& scale2, const at::Tensor& probs, c10::string_view group, int64_t max_output_size, @@ -1383,8 +1383,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) ops.impl("npu_sparse_flash_attention", torch::kPrivateUse1, &vllm_ascend::npu_sparse_flash_attention); ops.def( - "dispatch_ffn_combine(Tensor x, Tensor weight1, Tensor weight2, Tensor expert_idx," - " Tensor scale1, Tensor scale2, Tensor probs, str group," + "dispatch_ffn_combine(Tensor x, Tensor[] weight1, Tensor[] weight2, Tensor expert_idx," + " Tensor[] scale1, Tensor[] scale2, Tensor probs, str group," " int max_output_size, Tensor! out) -> Tensor" ); ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index a166ba166a0..c9949be6ae6 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -196,11 +196,11 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor at::Tensor& dispatch_ffn_combine_meta( const at::Tensor& x, - const at::Tensor& weight1, - const at::Tensor& weight2, + const at::TensorList& weight1, + const at::TensorList& weight2, const at::Tensor& expert_idx, - const at::Tensor& scale1, - const at::Tensor& scale2, + const at::TensorList& scale1, + const at::TensorList& scale2, const at::Tensor& probs, c10::string_view group, int64_t max_output_size, diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py index 90ce1f07fcf..9f86cc0f1c4 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_ffn_combine.py @@ -87,13 +87,13 @@ def generate_hcom(self): hcomm_info = hcomm_info_dist["default_pg_info"] self.hcomm_info = hcomm_info - def run_npu_out(self) -> bool: + def run_tensor_list(self) -> bool: torch_npu.npu.set_device(self.rank) - m = 2 # token-num 32 - k = 4 # hidden_size 7168 - n = 4 # mid-hidden-size 4096 - topk = 2 - e = 2 # expert-num-per-rank 16 + m = 64 + k = 1024 + n = 1024 + topk = 8 + e = 8 k2 = n // 2 n2 = k @@ -112,15 +112,79 @@ def run_npu_out(self) -> bool: scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu() scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu() probs = torch.randn(size=(m, topk), dtype=torch.float32).npu() + + weight1_nz_npu = [] + weight2_nz_npu = [] + scale1_npu = [] + scale2_npu = [] + for i in range(e): + weight1_nz_npu.append( + torch_npu.npu_format_cast(weight1[i].npu(), 29)) + scale1_npu.append(scale1[i].npu()) + weight2_nz_npu.append( + torch_npu.npu_format_cast(weight2[i].npu(), 29)) + scale2_npu.append(scale2[i].npu()) + + out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + + torch.ops._C_ascend.dispatch_ffn_combine( + x=x, + weight1=weight1_nz_npu, + weight2=weight2_nz_npu, + expert_idx=expert_idx, + scale1=scale1_npu, + scale2=scale2_npu, + probs=probs, + group=self.hcomm_info, + max_output_size=512, + out=out, + ) + return True + + def run_normal(self) -> bool: + torch_npu.npu.set_device(self.rank) + m = 64 + k = 1024 + n = 1024 + topk = 8 + e = 8 + k2 = n // 2 + n2 = k + + torch_npu.npu.config.allow_internal_format = True + x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() + weight1 = self.generate_random_tensor((e, k, n), + dtype=torch.int8).npu() + weight1 = torch_npu.npu_format_cast(weight1, 29) + weight2 = self.generate_random_tensor((e, k2, n2), + dtype=torch.int8).npu() + weight2 = torch_npu.npu_format_cast(weight2, 29) + + expert_idx = torch.randint(0, + self.world_size * e, (m, topk), + dtype=torch.int32).npu() + scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu() + scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu() + probs = torch.randn(size=(m, topk), dtype=torch.float32).npu() + + weight1_nz_npu = [] + weight2_nz_npu = [] + scale1_npu = [] + scale2_npu = [] + weight1_nz_npu.append(torch_npu.npu_format_cast(weight1.npu(), 29)) + scale1_npu.append(scale1.npu()) + weight2_nz_npu.append(torch_npu.npu_format_cast(weight2.npu(), 29)) + scale2_npu.append(scale2.npu()) + out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu() torch.ops._C_ascend.dispatch_ffn_combine( x=x, - weight1=weight1, - weight2=weight2, + weight1=weight1_nz_npu, + weight2=weight2_nz_npu, expert_idx=expert_idx, - scale1=scale1, - scale2=scale2, + scale1=scale1_npu, + scale2=scale2_npu, probs=probs, group=self.hcomm_info, max_output_size=512, @@ -142,8 +206,10 @@ def generate_random_tensor(self, size, dtype): def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue): op = TestDisptachFFNCombine(rank, world_size, port) op.generate_hcom() - out = op.run_npu_out() - q.put(out) + out1 = op.run_tensor_list() + q.put(out1) + out2 = op.run_normal() + q.put(out2) @torch.inference_mode() diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 8ad25e2fdf7..631830bb8b0 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -306,11 +306,11 @@ def fused_experts( out = torch.empty_like(hidden_states) torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore x=hidden_states, - weight1=w1[0], - weight2=w2[0], + weight1=w1, + weight2=w2, expert_idx=topk_ids, - scale1=w1_scale[0], - scale2=w2_scale[0], + scale1=w1_scale, + scale2=w2_scale, probs=topk_weights.to(torch.float32), group=self.token_dispatcher.moe_all_to_all_group_name, max_output_size=65536,