Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
Expand All @@ -57,8 +58,9 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
Expand All @@ -73,7 +75,7 @@ aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
gmm2Weight, gmm2WeightScale, expertScales, expertSmoothScalesOptional, xActiveMaskOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertScales,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
const aclTensor *xActiveMaskOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,19 @@ class DispatchGmmCombineDecode : public OpDef
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
this->Input("x_active_mask")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.DataType({ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2;
constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3;
constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4;
constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7;
constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 6;
constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 7;
constexpr uint32_t INPUT_SHARE_X_ACTIVE_MASK_INDEX = 8;

constexpr uint32_t ATTR_GROUP_EP_INDEX = 0;
constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1;
Expand All @@ -51,6 +52,7 @@ constexpr uint32_t MIN_BATCH_SIZE = 1;
constexpr uint32_t MAX_BATCH_SIZE = 256;
constexpr uint32_t MAX_MOE_EXERT_NUM = 512;
constexpr uint32_t SUPPORT_TOP_K = 12;
constexpr uint32_t ONE_DIMS = 1;
constexpr uint32_t TWO_DIMS = 2;
constexpr uint32_t MIN_TOKEN_LENGTH = 512;
constexpr uint32_t MAX_TOKEN_LENGTH = 7168;
Expand All @@ -71,6 +73,7 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum;
uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum;
uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank;
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;

Expand Down Expand Up @@ -122,6 +125,18 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
return ge::GRAPH_FAILED);

const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
if (xActiveMaskStorageShape != nullptr) {
OPS_ERR_IF(xActiveMaskStorageShape->GetStorageShape().GetDimNum() != ONE_DIMS,
OPS_LOG_E(nodeName, " xActiveMask scale shape dims must be 1, but current dim num is %lu.",
xActiveMaskStorageShape->GetStorageShape().GetDimNum()),
return ge::GRAPH_FAILED);
const int64_t xActiveMaskDim0 = xActiveMaskStorageShape->GetStorageShape().GetDim(0);
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
"xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
return ge::GRAPH_FAILED);
Comment on lines +136 to +138
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message in this log appears to be a copy-paste error from another check. It refers to gmm2WeightScale Dim0 when it should be referring to xActiveMask Dim0. This could be misleading during debugging.

        OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
                    "xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
                    return ge::GRAPH_FAILED);

}
return ge::GRAPH_SUCCESS;
}

Expand Down Expand Up @@ -308,14 +323,22 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = aivNum;
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."),
return ge::GRAPH_FAILED);
OPS_ERR_IF(CheckTensorShape(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(
nodeName, "CheckTensorShape failed."), return ge::GRAPH_FAILED);
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
SetHcommCfg(context, tilingData, groupEp);
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {
context->SetTilingKey(0);
} else {
context->SetTilingKey(EXEC_FLAG_DEEP_FUSE);
const gert::StorageShape* xActiveMaskStorageShape = context->GetOptionalInputShape(
INPUT_SHARE_X_ACTIVE_MASK_INDEX);
bool xActiveMaskEnable = (xActiveMaskStorageShape != nullptr);
uint64_t tilingKey = 0;
if (xActiveMaskEnable) {
tilingKey |= EXEC_FLAG_X_ACTIVE_MASK;
}
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank != 1) {
tilingKey |= EXEC_FLAG_DEEP_FUSE;
}
context->SetTilingKey(tilingKey);
context->SetBlockDim(aicNum);
return ge::GRAPH_SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
Expand All @@ -24,10 +25,10 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V
GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) {
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(4) || TILING_KEY_IS(5)) {
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data);
expert_scales, expert_smooth_scales, x_active_mask, output, outputRecvCount, workspace, nullptr, &tiling_data);
op.Process();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD,
GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace,
GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx,
GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount,
uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum,
uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum,
uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen)
Expand Down Expand Up @@ -110,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;

Expand All @@ -136,6 +136,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
gmexpertIds,
gmExpandIdx,
gmEpSendCount,
xActiveMask,
gmResvered,
gmOutputRecvCount,
epRankSize,
Expand Down Expand Up @@ -241,7 +242,7 @@ class DispatchGmmCombineDecode
__aicore__ inline void Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
Expand All @@ -260,6 +261,7 @@ class DispatchGmmCombineDecode
GM_ADDR workspaceGM_;
GM_ADDR gmSmoothScales_;
GM_ADDR gmexpertScales_;
GM_ADDR xActiveMask_;

uint32_t maxTokenNum_{0};
uint32_t gmm1OutputDim_{0};
Expand Down Expand Up @@ -291,7 +293,8 @@ template <TemplateMC2TypeClass>
__aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
// input
GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales,
GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales,
GM_ADDR x_active_mask,
// output
GM_ADDR output, GM_ADDR outputRecvCount,
// system
Expand All @@ -312,6 +315,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Init(
gmOutputRecvCount_ = outputRecvCount;
workspaceGM_ = workspaceGM;
gmexpertScales_ = expert_scales;
xActiveMask_ = x_active_mask;
tilingData_ = tilingData;
epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
Expand Down Expand Up @@ -386,12 +390,12 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
GM_ADDR gmResvered = workspaceGM_ + workspaceOffset;
workspaceOffset += RoundUp<GM_ALIGN_BYTE>(resveredWorkSpaceSize);

if constexpr (EXEC_FLAG == 0) {
if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) {
if constexpr (g_coreType == AscendC::AIV) {
AscendC::TPipe tpipe;
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false>
MoeDistributeDispatchImpl::CamMoeDistributeDispatch<ExpandXType, int8_t, false, true, false, false, EXEC_FLAG>
dispatcher;
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList,
gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_);
dispatcher.Process();
tpipe.Destroy();
Expand All @@ -411,7 +415,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered,
layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered,
gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_,
sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_);
AscendC::PipeBarrier<PIPE_ALL>();
Expand All @@ -425,7 +429,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()

MoeDistributeCombineImpl::CamMoeDistributeCombine<TemplateMC2TypeFunc> combiner;
if (g_coreType == AscendC::AIV) {
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_,
combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_,
workspaceGM_, nullptr, tilingData_);
}
GmmDeq<TemplateMC2TypeFunc, Gmm2L1TileShape, Gmm2L0TileShape, Gmm2EpilogueTileShape, Gmm2BlockScheduler,
Expand Down
Loading
Loading