Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void TrtllmGenBatchedGemmRunner::run(

auto const err =
bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
enable_pdl, /*pinnedHostBuffer=*/nullptr, globalTrtllmGenBatchedGemmModuleCache);
enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache);

FLASHINFER_CHECK(err == 0,
"Error occurred when running GEMM!"
Expand Down
143 changes: 125 additions & 18 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1246,19 +1246,22 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
public:
static constexpr std::array<int32_t, 5> mSupportedTileNums = {8, 16, 32, 64, 128};

MxInt4BlockScaleLauncher(TensorView const& routing_logits,
MxInt4BlockScaleLauncher(Optional<TensorView> const& routing_logits,
Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm1_weights_scale,
Optional<TensorView> const& gemm1_alpha,
Optional<TensorView> const& gemm1_beta,
Optional<TensorView> const& gemm1_clamp_limit,
TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale)
: FusedMoeLauncher(Optional<TensorView>(routing_logits), routing_bias, hidden_states,
gemm1_weights, Optional<TensorView>(), Optional<TensorView>(),
gemm2_weights, Optional<TensorView>()),
TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale,
TensorView const& expert_indices, TensorView const& expert_weights_in)
: FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights,
Optional<TensorView>(), Optional<TensorView>(), gemm2_weights,
Optional<TensorView>()),
gemm1_weights_scale(gemm1_weights_scale),
gemm2_weights_scale(gemm2_weights_scale) {}
gemm2_weights_scale(gemm2_weights_scale),
expert_indices(expert_indices),
expert_weights_in(expert_weights_in) {}

void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
int64_t tile_tokens_dim, int64_t routing_method_type) {
Expand All @@ -1280,7 +1283,29 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int64_t>(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu);
}

void check_routing() const override { FusedMoeLauncher::check_routing_common(); }
void check_routing() const override {
if (has_precomputed_indices()) {
// Pre-computed routing: expert_indices is a packed tensor
// Format: (expert_id << 16) | (weight_bf16.view(int16))
TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
<< "expert_indices and hidden_states must have same number of tokens.";
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
<< "expert_indices dim1 must match top_k.";
TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
}

if (has_precomputed_weights()) {
// Pre-computed expert weights: validate shape and dtype
TVM_FFI_ICHECK_EQ(expert_weights_in.size(0), hidden_states.size(0))
<< "expert_weights_in and hidden_states must have same number of tokens.";
TVM_FFI_ICHECK_EQ(expert_weights_in.size(1), args->top_k)
<< "expert_weights_in dim1 must match top_k.";
TVM_FFI_ICHECK_EQ(expert_weights_in.dtype(), dl_bfloat16)
<< "expert_weights_in must be bfloat16.";
}

FusedMoeLauncher::check_routing_common();
}

void prepare_routing() override {
FusedMoeLauncher::prepare_routing_common();
Expand All @@ -1292,11 +1317,24 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;

auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype, hidden_states.device());
if (has_precomputed_indices()) {
// Use expert_indices directly
workspace.routing_expert_indexes =
static_cast<int*>(const_cast<void*>(expert_indices.data_ptr()));
} else {
// Use routing_logits directly
args->routing_logits = static_cast<float*>(routing_logits.value().data_ptr());
}

workspace.expert_weights = expert_weights.data_ptr();
if (has_precomputed_weights()) {
workspace.expert_weights = const_cast<void*>(expert_weights_in.data_ptr());
} else {
// Allocate expert_weights buffer for routing output
auto expert_weights_dtype = mDtypeScore == btg::Dtype::Fp32 ? dl_float32 : dl_bfloat16;
expert_weights = alloc_tensor({args->num_tokens, args->top_k}, expert_weights_dtype,
hidden_states.device());
workspace.expert_weights = expert_weights.data_ptr();
}
}

void check_moe() const override {
Expand Down Expand Up @@ -1364,10 +1402,64 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher {
Optional<TensorView> gemm1_beta;
Optional<TensorView> gemm1_clamp_limit;
TensorView gemm2_weights_scale;
TensorView expert_indices;
TensorView expert_weights_in;
int32_t max_num_padded_tokens_gemm1{};
int32_t max_num_padded_tokens_gemm2{};

// Helper to check if pre-computed routing indices are provided
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
bool has_precomputed_indices() const {
return expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
}

// Helper to check if pre-computed routing weights are provided
bool has_precomputed_weights() const {
return expert_weights_in.ndim() == 2 && expert_weights_in.size(0) > 0;
}

public:
// Override to handle pre-computed routing
Array<Tensor> run(int64_t moe_tactic, bool enable_pdl = true,
bool use_routing_scales_on_input = false,
bool use_deep_seek_fp8 = false) override {
check_routing();
prepare_routing();

cudaStream_t routing_stream = get_stream(hidden_states.device());
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);

// When using pre-computed routing, pass nullptr as routing_logits to tell the
// routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
routing_runner.run(
has_precomputed_indices() ? nullptr : args->routing_logits, args->routing_bias,
args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group,
args->local_expert_offset, args->local_num_experts, args->routed_scaling_factor,
workspace.routing_expert_indexes, static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), workspace.expert_weights,
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), mDtypeScore, args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
prepare_moe(moe_tactic);

cudaStream_t moe_stream = get_stream(hidden_states.device());
moe_runner->run(*args, workspace, hidden_states.device().device_id, moe_stream, moe_tactic,
enable_pdl);

if (args->do_finalize) {
return {output};
}
return {gemm2_output, FusedMoeLauncher::expert_weights, expanded_idx_to_permuted_idx};
}

static Array<Array<int64_t>> getValidConfigs(int64_t top_k, int64_t hidden_size,
int64_t intermediate_size, int64_t num_local_experts,
int64_t num_tokens) {
Expand Down Expand Up @@ -2114,8 +2206,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
}

Array<Tensor> trtllm_mxint4_block_scale_moe(
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
TensorView gemm1_weights, TensorView gemm1_weights_scale, Optional<TensorView> gemm1_alpha,
Optional<TensorView> routing_logits, TensorView expert_indices, TensorView expert_weights,
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView gemm1_weights,
TensorView gemm1_weights_scale, Optional<TensorView> gemm1_alpha,
Optional<TensorView> gemm1_beta, Optional<TensorView> gemm1_clamp_limit,
TensorView gemm2_weights, TensorView gemm2_weights_scale, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
Expand All @@ -2132,10 +2225,23 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(

TVM_FFI_ICHECK(weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size.";

TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "routing_logits must be float or bfloat16.";
TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D.";
TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape.";
// Either routing_logits or expert_indices must be provided
// expert_indices is a packed tensor: (expert_id << 16) | (weight_bf16.view(int16))
bool use_routing_logits = routing_logits.has_value();
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
bool use_precomputed_routing = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;

TVM_FFI_ICHECK(use_routing_logits || use_precomputed_routing)
<< "Either routing_logits or expert_indices must be provided.";

if (use_routing_logits) {
TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
routing_logits.value().dtype() == dl_bfloat16)
<< "routing_logits must be float or bfloat16.";
TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D.";
TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts)
<< "routing_logits has incorrect shape.";
}
if (routing_bias.has_value()) {
TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16) << "routing_bias must be bfloat16.";
TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D.";
Expand Down Expand Up @@ -2178,7 +2284,8 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(
// Create and initialize launcher for this tile size
auto launcher = std::make_unique<MxInt4BlockScaleLauncher>(
routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_weights_scale,
gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale);
gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale,
expert_indices, expert_weights);
launcher->init(std::move(args), curr_tile_N, routing_method_type);

launchers_map[curr_tile_N] = std::move(launcher);
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
trtllm_bf16_moe,
trtllm_bf16_routed_moe,
trtllm_mxint4_block_scale_moe,
trtllm_mxint4_block_scale_routed_moe,
)

from .fused_routing_dsv3 import ( # noqa: F401
Expand Down Expand Up @@ -73,6 +74,7 @@
"trtllm_fp8_block_scale_routed_moe",
"trtllm_fp8_per_tensor_scale_moe",
"trtllm_mxint4_block_scale_moe",
"trtllm_mxint4_block_scale_routed_moe",
"fused_topk_deepseek",
]

Expand Down
Loading
Loading