Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 33 additions & 36 deletions custom_ops/metax_ops/mc_fused_moe_helper.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
#include "fused_moe_helper.h"
Expand Down Expand Up @@ -45,63 +30,75 @@ void mc_grouped_gemm_basic_kernel(
mctlassExMatrixLayout_t matLayoutC;

// mat A: (m, k)
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// mat B: (num_experts, n, k)
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k);
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// mat C: (m, n)
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n);
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC, sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts, sizeof(int));
// bias: (num_experts, n)
// scale: (num, n)

mctlassExDesc_t mctlass_desc;
mctlassExCreateDesc(&mctlass_desc);
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32;
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT;
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
if (ptrBias) {
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP;
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
}
// set scale
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
&ptrScale, sizeof(ptrScale));
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE,
&scale_type, sizeof(mctlassExDataType));
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
&input_type, sizeof(mctlassExDataType));
// set bias
if (ptrBias) {
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
&ptrBias, sizeof(ptrBias));
}
// set coumpute type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
&compute_type, sizeof(mctlassExDataType));
// set epilogue type
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE,
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
&epilogue_type, sizeof(mctlassExEpilogueType));

const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR;
int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM);
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
mctlassExContiguousGroupedDescCreate(&contiguous_group_desc,
ptrSegInd,
nullptr,
ptrMNumTilesInd,
1);
int blocksizeM;
mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, &blocksizeM);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, contiguous_group_desc, numExperts, blocksizeM, stream);

mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
ptrA, matLayoutA,
ptrB, matLayoutB,
ptrC, matLayoutC,
ptrSegInd, nullptr, ptrMNumTilesInd,
contiguous_group_desc,
&algo, nullptr, 0, stream);

mctlassExHandleDestroy(handle);
mctlassExMatrixLayoutDestroy(matLayoutA);
mctlassExMatrixLayoutDestroy(matLayoutB);
mctlassExMatrixLayoutDestroy(matLayoutC);
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
mctlassExDestroyDesc(mctlass_desc);
mcFreeAsync(ptrMNumTilesInd, stream);
}
Expand Down Expand Up @@ -334,8 +331,8 @@ class McMoeHelper {
total_rows_before_expert_,
stream);

mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;

mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA *>(permuted_data_),
Expand Down
14 changes: 7 additions & 7 deletions custom_ops/metax_ops/moe_ffn.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//

// http://www.apache.org/licenses/LICENSE-2.0
//

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


// BUILD_MARK
#pragma once
#include "mc_fused_moe_helper.h"
#include "helper.h"
Expand Down Expand Up @@ -47,8 +47,8 @@ void McMoeFFNKernel(const paddle::Tensor& permute_input,
{expanded_active_expert_rows, inter_size}, input_type, place);
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();

mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;

// ffn1
auto fc1_expert_biases =
Expand Down Expand Up @@ -131,7 +131,7 @@ std::vector<paddle::Tensor> MoeExpertFFN(
// ffn_out);
// break;
default:
PD_THROW("Only support bf16 for MoeExpertFFN");
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return {ffn_out};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import paddle
from paddle import nn
Expand Down Expand Up @@ -110,7 +112,7 @@ def apply_tp(
False,
)
if layer.reduce_results and layer.tp_size > 1:
fused_moe_out = tensor_model_parallel_all_reduce(fused_moe_out)
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)

return fused_moe_out

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import paddle
from paddle import nn
Expand Down Expand Up @@ -44,6 +46,7 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange:
"""process_prequanted_weights"""
pass

@paddle.no_grad()
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
"""
Triton MoE create weight process.
Expand Down Expand Up @@ -100,6 +103,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
),
)

@paddle.no_grad()
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE load weight process.
Expand All @@ -110,8 +114,6 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):

algo = layer.quant_method.quant_config.name()

assert algo == "wint8"

assert up_gate_proj_weights[0].shape == [
layer.hidden_size,
layer.moe_intermediate_size * 2,
Expand Down Expand Up @@ -151,31 +153,42 @@ def apply(
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
hidden_size = layer.hidden_size

gate_out = gate(x.cast("float32"))
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
top_k,
True, # apply_norm_weight,
layer.top_k,
True, # apply_norm_weight
False,
)

up_gate_proj_out = paddle.empty(
[token_num * top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)

config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
if self.quant_config is not None:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}
else:
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
}

sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
topk_ids, num_local_experts, config["BLOCK_SIZE_M"]
)
Expand Down Expand Up @@ -282,5 +295,5 @@ def apply(
down_proj_out.reshape_([token_num, top_k, hidden_size])
out = down_proj_out.sum(axis=1)
if layer.tp_size > 1:
out = tensor_model_parallel_all_reduce(out)
tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group)
return out
7 changes: 7 additions & 0 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def get_moe_method():

return HpuMoEMethod(None)
# return HpuTensorWiseFP8MoEMethod(None)

elif current_platform.is_maca():
from fastdeploy.model_executor.layers.backends import (
MetaxCutlassWeightOnlyMoEMethod,
)

return MetaxCutlassWeightOnlyMoEMethod(None)
raise NotImplementedError


Expand Down
Loading