Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 146 additions & 71 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,103 @@ inline int32_t nextPowerOfTwo(float value) {
std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums,
int64_t const num_tokens, int64_t const top_k,
int64_t const num_local_experts) {
TVM_FFI_ICHECK(!supported_tile_nums.empty()) << "supported_tile_nums must not be empty.";
TVM_FFI_ICHECK(std::is_sorted(supported_tile_nums.begin(), supported_tile_nums.end()))
<< "supported_tile_nums must be sorted in ascending order.";
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
// NOTE: This differs from Python AutoTuner bucketing:
// - AutoTuner maps raw num_tokens with last_positive_power_of_2 (round-down).
// - Here we map derived avg_tokens_per_expert and use nextPowerOfTwo (round-up).
// Because they round different quantities in different directions, cache bucket and runtime
// tile candidates can diverge; launcher-side tactic resolution handles that mismatch.
// assume supported_tile_nums is sorted
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());
auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);

FLASHINFER_CHECK(
it != supported_tile_nums.end(), "computeSelectedTileN expected exact tile ", tile_tokens_dim,
" in supported_tile_nums (size=", supported_tile_nums.size(),
"). Please keep supported_tile_nums as a dense power-of-2 ladder for this launcher.");

// Candidate tile set centered on the heuristic tile.
// This function returns nearby candidates (not a single final tile):
// center, +1, +2, and -1 neighbors when available.
// Final tile choice is made later (autotuner-provided tile if valid, otherwise fallback policy).
std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
// Prefer exploring larger tiles too because they can win for dense routing/shape regimes.
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
// Keep one smaller neighbor to avoid over-committing to an oversized center tile.
selected_tile_nums.insert(*std::prev(it));
}

return selected_tile_nums;
}

int64_t selectDefaultTileN(std::vector<int32_t> const& supported_tile_nums,
int64_t const num_tokens, int64_t const top_k,
int64_t const num_local_experts) {
auto selected = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);
TVM_FFI_ICHECK(!selected.empty()) << "No selected tile_N candidates for current MoE input.";
return *selected.begin();
}

// Resolve the (tile_N, config) pair passed from Python into a launcher-safe pair.
//
// Interaction with Python AutoTuner:
// - Python AutoTuner stores and returns "tactics" as a 2-int payload [tile_N, config].
// - Tuning/profiling uses bucketed num_tokens (power-of-2 buckets), while runtime uses actual
// num_tokens. Therefore, a cached tile_N chosen for the bucket may be suboptimal or unsupported
// for the runtime shape.
// - Tactics may also come from persisted config files and can be stale/malformed across versions.
//
// This resolver makes C++ robust to those mismatches:
// 1) malformed/missing tactic payload -> fallback to runtime-selected default tile/config
// 2) fallback marker from Python (-1, -1) -> runtime-selected default tile/config
// 3) stale tile_N not in current supported set -> runtime-selected default tile/config
//
// We intentionally do not fully validate `config` here because config validity depends on the
// concrete Runner instance and current shape. That validation/fallback happens later in
// FusedMoeLauncher::prepare_moe_common().
std::pair<int64_t, int64_t> resolveMoeTileAndConfig(Array<int64_t> const& config_index,
std::vector<int32_t> const& supported_tile_nums,
int64_t const num_tokens, int64_t const top_k,
int64_t const num_local_experts) {
int64_t tile_N = -1;
int64_t config = -1;

// Python side convention: tactic is [tile_N, config]
if (config_index.size() >= 2) {
tile_N = config_index[0];
config = config_index[1];
}

// Runtime default is selected from actual shape stats (num_tokens/top_k/local_num_experts),
// which may differ from the bucketed shape that produced the cached tactic.
auto const default_tile_N =
selectDefaultTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);
if (tile_N == -1 || config == -1) {
// Use fallback tactic
return {default_tile_N, -1};
}

bool const tile_supported = std::find(supported_tile_nums.begin(), supported_tile_nums.end(),
static_cast<int32_t>(tile_N)) != supported_tile_nums.end();
if (!tile_supported) {
// Tactic can be stale due to cache/file mismatch across versions or shape bucketing mismatch.
// Fall back to a runtime-selected tile and default config instead of crashing.
return {default_tile_N, -1};
}

return {tile_N, config};
}

class FusedMoeLauncher {
protected:
Optional<TensorView> routing_logits;
Expand Down Expand Up @@ -334,19 +410,38 @@ class FusedMoeLauncher {
this->activation_type, this->use_shuffled_weight, this->weight_layout);
}

auto valid_cfgs =
moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size,
args->local_num_experts, args->num_tokens);
FLASHINFER_CHECK(!valid_cfgs.empty(), "No valid MoE tactics for tile_N=", tile_tokens_dim,
" with shape (num_tokens=", args->num_tokens,
", hidden_size=", args->hidden_size,
", intermediate_size=", args->intermediate_size, ", top_k=", args->top_k,
", local_num_experts=", args->local_num_experts, ").");
int64_t default_tactic = -1;
if (moe_tactic == -1) {
moe_tactic = moe_runner->getDefaultValidConfigIndex(
default_tactic = moe_runner->getDefaultValidConfigIndex(
args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts,
args->num_tokens);
moe_tactic = default_tactic;
}
auto valid_cfgs =
moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size,
args->local_num_experts, args->num_tokens);
auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic);
FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic,
" for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ",
valid_cfgs.size(),
". This often indicates a stale or mismatched autotuner cache entry.");
if (valid_it == valid_cfgs.end()) {
if (default_tactic == -1) {
default_tactic = moe_runner->getDefaultValidConfigIndex(
args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts,
args->num_tokens);
}
auto default_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), default_tactic);
FLASHINFER_CHECK(
default_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic,
" for tile_N=", tile_tokens_dim,
". No valid fallback default tactic found for shape (num_tokens=", args->num_tokens,
", hidden_size=", args->hidden_size, ", intermediate_size=", args->intermediate_size,
", top_k=", args->top_k, ", local_num_experts=", args->local_num_experts,
"). This indicates a deeper kernel/config mismatch.");
moe_tactic = default_tactic;
}
this->moe_tactic = moe_tactic;

auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic);
Expand Down Expand Up @@ -1632,13 +1727,14 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
// Calculate supported tile sizes
std::vector<int32_t> mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(),
Bf16MoeLauncher::mSupportedTileNums.end());
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
// Build launchers for ALL supported tiles (not just the computeSelectedTileN subset)
// so that autotuner-cached tactics always find their tile_N in the map.
// Launcher creation is cheap (no GPU allocation until run()), so this is safe.

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<Bf16MoeLauncher>> launchers_map;

for (int32_t curr_tile_N : selected_tile_nums) {
for (int32_t curr_tile_N : mSupportedTileN) {
// Create MoE arguments for this launcher
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
Expand Down Expand Up @@ -1666,17 +1762,14 @@ Array<Tensor> trtllm_bf16_moe(Optional<TensorView> const& routing_logits,
launchers_map[curr_tile_N] = std::move(launcher);
}

// Extract tile_N and config from moe_tactic
int64_t tile_N = moe_tactic[0];
int64_t config = moe_tactic[1];

// Handle default case
if (tile_N == -1 || config == -1) {
tile_N = *selected_tile_nums.begin();
}
auto const [tile_N, config] =
resolveMoeTileAndConfig(moe_tactic, mSupportedTileN, num_tokens, top_k, local_num_experts);

// Get the launcher for the selected tile_N
auto& selected_launcher = launchers_map.at(tile_N);
auto launcher_it = launchers_map.find(static_cast<int32_t>(tile_N));
FLASHINFER_CHECK(launcher_it != launchers_map.end(),
"Internal error: missing BF16 MoE launcher for tile_N=", tile_N);
auto& selected_launcher = launcher_it->second;

// Run the launcher - it will create its own runner internally
return selected_launcher->run(config, enable_pdl);
Expand Down Expand Up @@ -1724,13 +1817,12 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
// Calculate supported tile sizes
std::vector<int32_t> mSupportedTileN(Fp8PerTensorLauncher::mSupportedTileNums.begin(),
Fp8PerTensorLauncher::mSupportedTileNums.end());
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
// Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N.

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<Fp8PerTensorLauncher>> launchers_map;

for (int32_t curr_tile_N : selected_tile_nums) {
for (int32_t curr_tile_N : mSupportedTileN) {
// Create MoE arguments for this launcher
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
Expand Down Expand Up @@ -1758,17 +1850,14 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
launchers_map[curr_tile_N] = std::move(launcher);
}

// Extract tile_N and config from config_index
int64_t tile_N = config_index[0];
int64_t config = config_index[1];

// Handle default case
if (tile_N == -1 || config == -1) {
tile_N = *selected_tile_nums.begin();
}
auto const [tile_N, config] =
resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts);

// Get the launcher for the selected tile_N
auto& selected_launcher = launchers_map.at(tile_N);
auto launcher_it = launchers_map.find(static_cast<int32_t>(tile_N));
FLASHINFER_CHECK(launcher_it != launchers_map.end(),
"Internal error: missing FP8 per-tensor MoE launcher for tile_N=", tile_N);
auto& selected_launcher = launcher_it->second;

// Run the launcher - it will create its own runner internally
return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input);
Expand Down Expand Up @@ -1844,13 +1933,12 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
auto const hidden_size = hidden_states.size(1);

auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type);
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts);
// Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N.

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<Fp8BlockScaleLauncher>> launchers_map;

for (int32_t curr_tile_N : selected_tile_nums) {
for (int32_t curr_tile_N : supported_tile_nums) {
// Create MoE arguments for this launcher
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
Expand Down Expand Up @@ -1879,17 +1967,14 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
launchers_map[curr_tile_N] = std::move(launcher);
}

// Extract tile_N and config from config_index
int64_t tile_N = config_index[0];
int64_t config = config_index[1];

// Handle default case
if (tile_N == -1 || config == -1) {
tile_N = *selected_tile_nums.begin();
}
auto const [tile_N, config] = resolveMoeTileAndConfig(config_index, supported_tile_nums,
num_tokens, top_k, local_num_experts);

// Get the launcher for the selected tile_N
auto& selected_launcher = launchers_map.at(tile_N);
auto launcher_it = launchers_map.find(static_cast<int32_t>(tile_N));
FLASHINFER_CHECK(launcher_it != launchers_map.end(),
"Internal error: missing FP8 block-scale MoE launcher for tile_N=", tile_N);
auto& selected_launcher = launcher_it->second;

// Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally
return selected_launcher->run(
Expand Down Expand Up @@ -1985,13 +2070,12 @@ Array<Tensor> trtllm_fp4_block_scale_moe(

// Determine supported tile sizes
std::vector<int32_t> mSupportedTileN = FP4BlockScaleLauncher::getSupportedTileNums(mDtypeAct);
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
// Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N.

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<FP4BlockScaleLauncher>> launchers_map;

for (int32_t curr_tile_N : selected_tile_nums) {
for (int32_t curr_tile_N : mSupportedTileN) {
// Create MoE arguments for this launcher
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
Expand Down Expand Up @@ -2023,18 +2107,14 @@ Array<Tensor> trtllm_fp4_block_scale_moe(
launchers_map[curr_tile_N] = std::move(launcher);
}

// Extract tile_N and config from config_index
int64_t tile_N = config_index[0];
int64_t config = config_index[1];

// Handle default case
if (tile_N == -1 || config == -1) {
tile_N = *selected_tile_nums.begin();
config = -1; // Let the runner choose default
}
auto const [tile_N, config] =
resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts);

// Get the launcher for the selected tile_N
auto& selected_launcher = launchers_map.at(tile_N);
auto launcher_it = launchers_map.find(static_cast<int32_t>(tile_N));
FLASHINFER_CHECK(launcher_it != launchers_map.end(),
"Internal error: missing FP4 block-scale MoE launcher for tile_N=", tile_N);
auto& selected_launcher = launcher_it->second;

// Run the launcher - it will create its own runner internally
return selected_launcher->run(config, enable_pdl);
Expand Down Expand Up @@ -2078,13 +2158,12 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(
// Determine supported tile sizes
std::vector<int32_t> mSupportedTileN(MxInt4BlockScaleLauncher::mSupportedTileNums.begin(),
MxInt4BlockScaleLauncher::mSupportedTileNums.end());
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);
// Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N.

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<MxInt4BlockScaleLauncher>> launchers_map;

for (int32_t curr_tile_N : selected_tile_nums) {
for (int32_t curr_tile_N : mSupportedTileN) {
// Create MoE arguments for this launcher
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
Expand Down Expand Up @@ -2112,18 +2191,14 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(
launchers_map[curr_tile_N] = std::move(launcher);
}

// Extract tile_N and config from config_index
int64_t tile_N = config_index[0];
int64_t config = config_index[1];

// Handle default case
if (tile_N == -1 || config == -1) {
tile_N = *selected_tile_nums.begin();
config = -1; // Let the runner choose default
}
auto const [tile_N, config] =
resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts);

// Get the launcher for the selected tile_N
auto& selected_launcher = launchers_map.at(tile_N);
auto launcher_it = launchers_map.find(static_cast<int32_t>(tile_N));
FLASHINFER_CHECK(launcher_it != launchers_map.end(),
"Internal error: missing MXINT4 block-scale MoE launcher for tile_N=", tile_N);
auto& selected_launcher = launcher_it->second;

// Run the launcher - it will create its own runner internally
return selected_launcher->run(config, enable_pdl);
Expand Down
Loading
Loading