diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 64fece5021..d7fe85417b 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -85,27 +85,103 @@ inline int32_t nextPowerOfTwo(float value) { std::set computeSelectedTileN(std::vector const& supported_tile_nums, int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { + TVM_FFI_ICHECK(!supported_tile_nums.empty()) << "supported_tile_nums must not be empty."; + TVM_FFI_ICHECK(std::is_sorted(supported_tile_nums.begin(), supported_tile_nums.end())) + << "supported_tile_nums must be sorted in ascending order."; float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; + // NOTE: This differs from Python AutoTuner bucketing: + // - AutoTuner maps raw num_tokens with last_positive_power_of_2 (round-down). + // - Here we map derived avg_tokens_per_expert and use nextPowerOfTwo (round-up). + // Because they round different quantities in different directions, cache bucket and runtime + // tile candidates can diverge; launcher-side tactic resolution handles that mismatch. // assume supported_tile_nums is sorted int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), supported_tile_nums.front(), supported_tile_nums.back()); auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); - + FLASHINFER_CHECK( + it != supported_tile_nums.end(), "computeSelectedTileN expected exact tile ", tile_tokens_dim, + " in supported_tile_nums (size=", supported_tile_nums.size(), + "). Please keep supported_tile_nums as a dense power-of-2 ladder for this launcher."); + + // Candidate tile set centered on the heuristic tile. + // This function returns nearby candidates (not a single final tile): + // center, +1, +2, and -1 neighbors when available. + // Final tile choice is made later (autotuner-provided tile if valid, otherwise fallback policy). std::set selected_tile_nums; selected_tile_nums.insert(tile_tokens_dim); if (std::next(it) != supported_tile_nums.end()) { + // Prefer exploring larger tiles too because they can win for dense routing/shape regimes. selected_tile_nums.insert(*std::next(it)); if (std::next(std::next(it)) != supported_tile_nums.end()) { selected_tile_nums.insert(*std::next(std::next(it))); } } if (it != supported_tile_nums.begin()) { + // Keep one smaller neighbor to avoid over-committing to an oversized center tile. selected_tile_nums.insert(*std::prev(it)); } return selected_tile_nums; } +int64_t selectDefaultTileN(std::vector const& supported_tile_nums, + int64_t const num_tokens, int64_t const top_k, + int64_t const num_local_experts) { + auto selected = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + TVM_FFI_ICHECK(!selected.empty()) << "No selected tile_N candidates for current MoE input."; + return *selected.begin(); +} + +// Resolve the (tile_N, config) pair passed from Python into a launcher-safe pair. +// +// Interaction with Python AutoTuner: +// - Python AutoTuner stores and returns "tactics" as a 2-int payload [tile_N, config]. +// - Tuning/profiling uses bucketed num_tokens (power-of-2 buckets), while runtime uses actual +// num_tokens. Therefore, a cached tile_N chosen for the bucket may be suboptimal or unsupported +// for the runtime shape. +// - Tactics may also come from persisted config files and can be stale/malformed across versions. +// +// This resolver makes C++ robust to those mismatches: +// 1) malformed/missing tactic payload -> fallback to runtime-selected default tile/config +// 2) fallback marker from Python (-1, -1) -> runtime-selected default tile/config +// 3) stale tile_N not in current supported set -> runtime-selected default tile/config +// +// We intentionally do not fully validate `config` here because config validity depends on the +// concrete Runner instance and current shape. That validation/fallback happens later in +// FusedMoeLauncher::prepare_moe_common(). +std::pair resolveMoeTileAndConfig(Array const& config_index, + std::vector const& supported_tile_nums, + int64_t const num_tokens, int64_t const top_k, + int64_t const num_local_experts) { + int64_t tile_N = -1; + int64_t config = -1; + + // Python side convention: tactic is [tile_N, config] + if (config_index.size() >= 2) { + tile_N = config_index[0]; + config = config_index[1]; + } + + // Runtime default is selected from actual shape stats (num_tokens/top_k/local_num_experts), + // which may differ from the bucketed shape that produced the cached tactic. + auto const default_tile_N = + selectDefaultTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + if (tile_N == -1 || config == -1) { + // Use fallback tactic + return {default_tile_N, -1}; + } + + bool const tile_supported = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), + static_cast(tile_N)) != supported_tile_nums.end(); + if (!tile_supported) { + // Tactic can be stale due to cache/file mismatch across versions or shape bucketing mismatch. + // Fall back to a runtime-selected tile and default config instead of crashing. + return {default_tile_N, -1}; + } + + return {tile_N, config}; +} + class FusedMoeLauncher { protected: Optional routing_logits; @@ -334,19 +410,38 @@ class FusedMoeLauncher { this->activation_type, this->use_shuffled_weight, this->weight_layout); } + auto valid_cfgs = + moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, + args->local_num_experts, args->num_tokens); + FLASHINFER_CHECK(!valid_cfgs.empty(), "No valid MoE tactics for tile_N=", tile_tokens_dim, + " with shape (num_tokens=", args->num_tokens, + ", hidden_size=", args->hidden_size, + ", intermediate_size=", args->intermediate_size, ", top_k=", args->top_k, + ", local_num_experts=", args->local_num_experts, ")."); + int64_t default_tactic = -1; if (moe_tactic == -1) { - moe_tactic = moe_runner->getDefaultValidConfigIndex( + default_tactic = moe_runner->getDefaultValidConfigIndex( args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, args->num_tokens); + moe_tactic = default_tactic; } - auto valid_cfgs = - moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, - args->local_num_experts, args->num_tokens); auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic); - FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, - " for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ", - valid_cfgs.size(), - ". This often indicates a stale or mismatched autotuner cache entry."); + if (valid_it == valid_cfgs.end()) { + if (default_tactic == -1) { + default_tactic = moe_runner->getDefaultValidConfigIndex( + args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, + args->num_tokens); + } + auto default_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), default_tactic); + FLASHINFER_CHECK( + default_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, + " for tile_N=", tile_tokens_dim, + ". No valid fallback default tactic found for shape (num_tokens=", args->num_tokens, + ", hidden_size=", args->hidden_size, ", intermediate_size=", args->intermediate_size, + ", top_k=", args->top_k, ", local_num_experts=", args->local_num_experts, + "). This indicates a deeper kernel/config mismatch."); + moe_tactic = default_tactic; + } this->moe_tactic = moe_tactic; auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); @@ -1632,13 +1727,14 @@ Array trtllm_bf16_moe(Optional const& routing_logits, // Calculate supported tile sizes std::vector mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(), Bf16MoeLauncher::mSupportedTileNums.end()); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles (not just the computeSelectedTileN subset) + // so that autotuner-cached tactics always find their tile_N in the map. + // Launcher creation is cheap (no GPU allocation until run()), so this is safe. // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -1666,17 +1762,14 @@ Array trtllm_bf16_moe(Optional const& routing_logits, launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from moe_tactic - int64_t tile_N = moe_tactic[0]; - int64_t config = moe_tactic[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(moe_tactic, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing BF16 MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl); @@ -1724,13 +1817,12 @@ Array trtllm_fp8_per_tensor_scale_moe( // Calculate supported tile sizes std::vector mSupportedTileN(Fp8PerTensorLauncher::mSupportedTileNums.begin(), Fp8PerTensorLauncher::mSupportedTileNums.end()); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -1758,17 +1850,14 @@ Array trtllm_fp8_per_tensor_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing FP8 per-tensor MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input); @@ -1844,13 +1933,12 @@ Array trtllm_fp8_block_scale_moe( auto const hidden_size = hidden_states.size(1); auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type); - std::set selected_tile_nums = - computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : supported_tile_nums) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -1879,17 +1967,14 @@ Array trtllm_fp8_block_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - } + auto const [tile_N, config] = resolveMoeTileAndConfig(config_index, supported_tile_nums, + num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing FP8 block-scale MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally return selected_launcher->run( @@ -1985,13 +2070,12 @@ Array trtllm_fp4_block_scale_moe( // Determine supported tile sizes std::vector mSupportedTileN = FP4BlockScaleLauncher::getSupportedTileNums(mDtypeAct); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -2023,18 +2107,14 @@ Array trtllm_fp4_block_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing FP4 block-scale MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl); @@ -2078,13 +2158,12 @@ Array trtllm_mxint4_block_scale_moe( // Determine supported tile sizes std::vector mSupportedTileN(MxInt4BlockScaleLauncher::mSupportedTileNums.begin(), MxInt4BlockScaleLauncher::mSupportedTileNums.end()); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -2112,18 +2191,14 @@ Array trtllm_mxint4_block_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *selected_tile_nums.begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing MXINT4 block-scale MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl); diff --git a/tests/autotuner/test_autotuner_tile_mismatch.py b/tests/autotuner/test_autotuner_tile_mismatch.py new file mode 100644 index 0000000000..d88e373ead --- /dev/null +++ b/tests/autotuner/test_autotuner_tile_mismatch.py @@ -0,0 +1,260 @@ +"""Regression tests for MoE autotuner tile/config mismatch handling.""" + +import math + +import pytest +import torch + +from flashinfer import autotune +from flashinfer.autotuner import ( + AutoTuner, + DynamicTensorSpec, + TuningConfig, + TunableRunner, +) +from flashinfer.fused_moe.utils import ( + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, +) + + +def _next_power_of_two(n: float) -> int: + if n <= 1: + return 1 + return 1 << math.ceil(math.log2(n)) + + +def compute_selected_tile_n( + supported_tiles: list[int], + num_tokens: int, + top_k: int, + num_local_experts: int, +) -> set[int]: + """Python equivalent of C++ computeSelectedTileN.""" + avg = num_tokens * top_k / num_local_experts + tile = max(supported_tiles[0], min(supported_tiles[-1], _next_power_of_two(avg))) + try: + idx = supported_tiles.index(tile) + except ValueError: + # tile not in supported_tiles, find closest + idx = min( + range(len(supported_tiles)), + key=lambda i: abs(supported_tiles[i] - tile), + ) + selected = {supported_tiles[idx]} + if idx + 1 < len(supported_tiles): + selected.add(supported_tiles[idx + 1]) + if idx + 2 < len(supported_tiles): + selected.add(supported_tiles[idx + 2]) + if idx > 0: + selected.add(supported_tiles[idx - 1]) + return selected + + +class TileAwareDummyRunner(TunableRunner): + """Dummy runner whose valid tactics depend on num_tokens. + returns different valid tactics depending on shapes.""" + + SUPPORTED_TILES = [8, 16, 32, 64] # FP4 base tiles (no 128/256 for bf16 act) + + def __init__(self, top_k: int = 8, num_local_experts: int = 64): + self.top_k = top_k + self.num_local_experts = num_local_experts + + def get_valid_tactics(self, inputs, profile): + num_tokens = inputs[0].shape[0] + selected = compute_selected_tile_n( + self.SUPPORTED_TILES, num_tokens, self.top_k, self.num_local_experts + ) + tactics = [] + for tile in sorted(selected): + for cfg in range(2): # 2 configs per tile + tactics.append([tile, cfg]) + return tactics + + def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): + return inputs[0] + + +TUNE_MAX = 8192 + +MOE_TUNING_CONFIG = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + input_idx=(0,), + dim_idx=(0,), + gen_tuning_buckets=get_last_power_of_2_num_tokens_buckets(TUNE_MAX, 1), + map_to_tuning_buckets=lambda x: min(last_positive_power_of_2(x), TUNE_MAX), + ), + ), +) + + +def _reset_autotuner() -> AutoTuner: + tuner = AutoTuner.get() + tuner.clear_cache() + tuner.reset_statistics() + tuner.is_tuning_mode = False + return tuner + + +def _find_bucket_actual_tile_mismatch_case( + *, + supported_tiles: list[int], + top_k: int, + num_local_experts: int, + tune_max: int = TUNE_MAX, +): + """Find (bucket, actual, tuned_tile) where tuned_tile is valid only for bucket.""" + for actual in range(2, tune_max + 1): + bucket = min(last_positive_power_of_2(actual), tune_max) + if bucket == actual: + continue + tiles_bucket = compute_selected_tile_n( + supported_tiles, bucket, top_k, num_local_experts + ) + tiles_actual = compute_selected_tile_n( + supported_tiles, actual, top_k, num_local_experts + ) + only_in_bucket = sorted(tiles_bucket - tiles_actual) + if only_in_bucket: + return bucket, actual, only_in_bucket[0] + return None + + +@pytest.mark.parametrize( + "top_k,num_experts", + [ + (2, 8), + (4, 16), + (8, 64), + (8, 128), # Qwen3-VL-MoE-like (text) routing fanout/expert count + (10, 512), # Qwen3.5-MoE-like (text) routing fanout/expert count + ], +) +def test_tile_set_differs_between_bucketed_and_actual_num_tokens(top_k, num_experts): + """Bucketed and actual shapes in the same bucket can still pick different tiles.""" + tiles = [8, 16, 32, 64, 128] + case = _find_bucket_actual_tile_mismatch_case( + supported_tiles=tiles, top_k=top_k, num_local_experts=num_experts + ) + assert case is not None, ( + f"No mismatch case found for top_k={top_k}, experts={num_experts}" + ) + bucket, actual, _ = case + tiles_bucket = compute_selected_tile_n(tiles, bucket, top_k, num_experts) + tiles_actual = compute_selected_tile_n(tiles, actual, top_k, num_experts) + + assert tiles_bucket != tiles_actual, ( + f"Expected tile sets to differ for bucket={bucket} vs actual={actual}, " + f"got {tiles_bucket} vs {tiles_actual}" + ) + + only_in_bucket = tiles_bucket - tiles_actual + assert len(only_in_bucket) > 0, ( + "Expected at least one tile valid for bucketed but not for actual shapes" + ) + + +@pytest.mark.parametrize( + "top_k,num_experts,hidden_size", + [ + (2, 8, 128), + (4, 16, 128), + (8, 64, 256), + (12, 128, 1024), + ( + 8, + 256, + 7168, + ), # DeepSeek-v3-like MoE dimensions (lightweight autotuner-only path) + (8, 128, 4096), # Qwen3-VL-MoE-like text dimensions (autotuner-only path) + (10, 512, 4096), # Qwen3.5-MoE-like text dimensions (autotuner-only path) + ], +) +def test_autotuner_returns_cached_tactic_for_different_actual_shape( + monkeypatch, top_k, num_experts, hidden_size +): + """Cache lookup uses bucketed profile, not raw runtime num_tokens.""" + tuner = _reset_autotuner() + runner = TileAwareDummyRunner(top_k=top_k, num_local_experts=num_experts) + case = _find_bucket_actual_tile_mismatch_case( + supported_tiles=runner.SUPPORTED_TILES, + top_k=top_k, + num_local_experts=num_experts, + ) + assert case is not None, ( + f"No mismatch case found for top_k={top_k}, experts={num_experts}" + ) + tune_bucket, infer_actual, tuned_tile = case + + def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + # Make the chosen mismatch tile "best" so autotuner caches it. + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == tuned_tile else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + tune_inputs = [torch.empty((tune_bucket, hidden_size), dtype=torch.float32)] + with autotune(tune_mode=True): + _, tuned_tactic = tuner.choose_one( + "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, tune_inputs + ) + + assert isinstance(tuned_tactic, list) + assert tuned_tactic[0] == tuned_tile, ( + f"Expected tile_N={tuned_tile} for top_k={top_k}, experts={num_experts}, " + f"got {tuned_tactic}" + ) + + assert last_positive_power_of_2(infer_actual) == tune_bucket + infer_inputs = [torch.empty((infer_actual, hidden_size), dtype=torch.float32)] + _, infer_tactic = tuner.choose_one( + "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, infer_inputs + ) + + assert infer_tactic == tuned_tactic, "Expected cache hit returning the tuned tactic" + + actual_tiles = compute_selected_tile_n( + TileAwareDummyRunner.SUPPORTED_TILES, + infer_actual, + runner.top_k, + runner.num_local_experts, + ) + assert tuned_tactic[0] not in actual_tiles, ( + f"tile_N={tuned_tactic[0]} should NOT be in actual tiles {actual_tiles} — " + "this is the mismatch that caused the C++ unordered_map::at crash" + ) + + +def test_different_actual_tokens_same_bucket_get_same_cached_tactic(monkeypatch): + """Multiple actual num_tokens that map to the same bucket should all + receive the same cached tactic — confirming the autotuner uses the + bucketed profile, not the actual shapes, for cache lookup.""" + tuner = _reset_autotuner() + runner = TileAwareDummyRunner(top_k=8, num_local_experts=64) + hidden_size = 128 + + def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + tune_inputs = [torch.empty((512, hidden_size), dtype=torch.float32)] + with autotune(tune_mode=True): + _, tuned_tactic = tuner.choose_one( + "test_same_bucket", [runner], MOE_TUNING_CONFIG, tune_inputs + ) + + assert tuned_tactic[0] == 32 + + for actual in [513, 600, 700, 800, 900, 1000, 1023]: + assert last_positive_power_of_2(actual) == 512 + infer_inputs = [torch.empty((actual, hidden_size), dtype=torch.float32)] + _, tactic = tuner.choose_one( + "test_same_bucket", [runner], MOE_TUNING_CONFIG, infer_inputs + ) + assert tactic == tuned_tactic, ( + f"Expected cached tactic {tuned_tactic} for num_tokens={actual}, got {tactic}" + ) diff --git a/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py b/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py new file mode 100644 index 0000000000..227fc282e4 --- /dev/null +++ b/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py @@ -0,0 +1,513 @@ +"""Integration tests for TRTLLM MoE launcher fallback and wrapper contracts.""" + +import pytest +import torch + +from flashinfer import autotune +from flashinfer.autotuner import AutoTuner +from flashinfer.utils import get_compute_capability + +TUNE_MAX = 8192 + + +def _reset_autotuner() -> AutoTuner: + tuner = AutoTuner.get() + tuner.clear_cache() + tuner.reset_statistics() + tuner.is_tuning_mode = False + return tuner + + +def _prepare_bf16_moe_weights(num_experts, intermediate_size, hidden_size, device): + """Prepare shuffled BF16 weights in BlockMajorK layout.""" + from flashinfer import shuffle_matrix_a + from flashinfer.fused_moe import convert_to_block_layout + + gemm1 = torch.randn( + num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ) + gemm2 = torch.randn( + num_experts, hidden_size, intermediate_size, device=device, dtype=torch.bfloat16 + ) + g1_shuffled, g2_shuffled = [], [] + for i in range(num_experts): + g1_shuffled.append( + convert_to_block_layout( + shuffle_matrix_a(gemm1[i].view(torch.uint8), 64), 128 + ) + ) + g2_shuffled.append( + convert_to_block_layout( + shuffle_matrix_a(gemm2[i].view(torch.uint8), 64), 128 + ) + ) + return ( + torch.stack(g1_shuffled).view(torch.bfloat16), + torch.stack(g2_shuffled).view(torch.bfloat16), + ) + + +def _overwrite_cached_tactic_for_op(custom_op: str, new_tactic): + """Overwrite cached tactics for one op key.""" + tuner = AutoTuner.get() + updated = 0 + for key, (runner_id, _tactic, profile) in list(tuner.profiling_cache.items()): + if key[0] == custom_op: + tuner.profiling_cache[key] = (runner_id, new_tactic, profile) + updated += 1 + assert updated > 0, f"No autotuner cache entries found for {custom_op}" + + +def _tune_bf16_moe_once( + *, + device, + tune_num_tokens: int, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + gemm1_weights, + gemm2_weights, + tune_max: int, +): + from flashinfer.fused_moe import trtllm_bf16_moe, WeightLayout + + routing_tune = torch.rand( + tune_num_tokens, num_experts, device=device, dtype=torch.bfloat16 + ) + hidden_tune = torch.randn( + tune_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + with autotune(tune_mode=True): + trtllm_bf16_moe( + routing_logits=routing_tune, + routing_bias=None, + hidden_states=hidden_tune, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize + use_shuffled_weight=True, + weight_layout=WeightLayout.BlockMajorK, + tune_max_num_tokens=tune_max, + ) + + +def _run_bf16_moe_infer( + *, + device, + infer_num_tokens: int, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + gemm1_weights, + gemm2_weights, + tune_max: int, +): + from flashinfer.fused_moe import trtllm_bf16_moe, WeightLayout + + routing_infer = torch.rand( + infer_num_tokens, num_experts, device=device, dtype=torch.bfloat16 + ) + hidden_infer = torch.randn( + infer_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + output = trtllm_bf16_moe( + routing_logits=routing_infer, + routing_bias=None, + hidden_states=hidden_infer, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize + use_shuffled_weight=True, + weight_layout=WeightLayout.BlockMajorK, + tune_max_num_tokens=tune_max, + ) + assert output.shape[0] == infer_num_tokens + assert output.isfinite().all(), "Output should be finite" + + +def _require_sm100(): + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] not in [10]: + pytest.skip("This test requires an SM100/SM103 (Blackwell) GPU.") + + +@pytest.mark.parametrize("fp8_quantization_type", ["DeepSeekFp8", "MxFp8"]) +def test_fp8_block_scale_moe_deepseek_contract_args(monkeypatch, fp8_quantization_type): + """Contract test: wrapper forwards DeepSeek-style routing/group arguments unchanged.""" + from flashinfer.fused_moe import core as moe_core + + captured = {} + + def fake_trtllm_fp8_block_scale_moe(self, *args): + captured["args"] = args + hidden_states = args[4] + return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + + fake_module = type( + "FakeMoeModule", + (), + {"trtllm_fp8_block_scale_moe": fake_trtllm_fp8_block_scale_moe}, + )() + monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) + + seq_len = 8 + hidden_size = 128 + intermediate_size = 256 + num_experts = 256 + top_k = 8 + n_group = 8 + topk_group = 4 + routed_scaling_factor = 2.5 + tune_max_num_tokens = TUNE_MAX + + routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) + routing_bias = torch.randn(num_experts, dtype=torch.float32) + hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) + hidden_states_scale = torch.ones(hidden_size // 128, seq_len, dtype=torch.float32) + gemm1_weights = torch.empty(1, dtype=torch.float32) + gemm1_weights_scale = torch.empty(1, dtype=torch.float32) + gemm2_weights = torch.empty(1, dtype=torch.float32) + gemm2_weights_scale = torch.empty(1, dtype=torch.float32) + + quant_type = getattr(moe_core.Fp8QuantizationType, fp8_quantization_type) + output = moe_core.trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + num_experts=num_experts, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=2, # DeepSeekV3 routing mode + use_shuffled_weight=True, + weight_layout=moe_core.WeightLayout.BlockMajorK, + do_finalize=True, + enable_pdl=True, + tune_max_num_tokens=tune_max_num_tokens, + fp8_quantization_type=quant_type, + ) + + assert output.shape == hidden_states.shape + assert "args" in captured, ( + "Expected wrapper to call trtllm_fp8_block_scale_moe backend" + ) + + args = captured["args"] + assert args[12] == top_k + assert args[13] == n_group + assert args[14] == topk_group + assert args[18] == routed_scaling_factor + assert args[19] == 2 + assert args[20] is True + assert args[21] == int(moe_core.WeightLayout.BlockMajorK) + assert args[24] == tune_max_num_tokens + assert args[25] == quant_type + + +def test_fp4_block_scale_moe_contract_args(monkeypatch): + """Contract test: FP4 wrapper forwards core MoE routing/kernel args unchanged.""" + from flashinfer.fused_moe import core as moe_core + + captured = {} + + def fake_trtllm_fp4_block_scale_moe(self, *args): + captured["args"] = args + hidden_states = args[4] + return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + + fake_module = type( + "FakeMoeModule", + (), + {"trtllm_fp4_block_scale_moe": fake_trtllm_fp4_block_scale_moe}, + )() + monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) + + seq_len = 8 + hidden_size = 128 + intermediate_size = 256 + num_experts = 128 + top_k = 8 + tune_max_num_tokens = TUNE_MAX + + routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) + hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) + hidden_states_scale = torch.ones(seq_len, hidden_size // 16, dtype=torch.float32) + gemm1_weights = torch.empty(1, dtype=torch.uint8) + gemm1_weights_scale = torch.empty(1, dtype=torch.float32) + gemm2_weights = torch.empty(1, dtype=torch.uint8) + gemm2_weights_scale = torch.empty(1, dtype=torch.float32) + + output = moe_core.trtllm_fp4_block_scale_moe( + routing_logits=routing_logits, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + gemm2_bias=None, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, + do_finalize=True, + enable_pdl=True, + activation_type=moe_core.ActivationType.Swiglu.value, + output=None, + tune_max_num_tokens=tune_max_num_tokens, + ) + + assert output[0].shape == hidden_states.shape + assert "args" in captured, ( + "Expected wrapper to call trtllm_fp4_block_scale_moe backend" + ) + + args = captured["args"] + assert args[18] == num_experts + assert args[19] == top_k + assert args[20] is None + assert args[21] is None + assert args[22] == intermediate_size + assert args[24] == num_experts + assert args[26] == 1 + assert args[27] is True + assert args[31] == tune_max_num_tokens + + +def test_bf16_moe_qwen35_contract_args(monkeypatch): + """Contract test: BF16 wrapper forwards Qwen3.5-style ungrouped routing args unchanged.""" + from flashinfer.fused_moe import core as moe_core + + captured = {} + + def fake_trtllm_bf16_moe(self, *args): + captured["args"] = args + hidden_states = args[4] + return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + + fake_module = type("FakeMoeModule", (), {"trtllm_bf16_moe": fake_trtllm_bf16_moe})() + monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) + + seq_len = 8 + hidden_size = 128 + intermediate_size = 1024 + num_experts = 512 + top_k = 10 + tune_max_num_tokens = TUNE_MAX + + routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) + hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) + gemm1_weights = torch.empty(1, dtype=torch.bfloat16) + gemm2_weights = torch.empty(1, dtype=torch.bfloat16) + + output = moe_core.trtllm_bf16_moe( + routing_logits=routing_logits, + routing_bias=None, + hidden_states=hidden_states, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=0, + use_shuffled_weight=True, + weight_layout=moe_core.WeightLayout.BlockMajorK, + do_finalize=True, + enable_pdl=True, + tune_max_num_tokens=tune_max_num_tokens, + ) + + assert output.shape == hidden_states.shape + assert "args" in captured, "Expected wrapper to call trtllm_bf16_moe backend" + + args = captured["args"] + assert args[7] == num_experts + assert args[8] == top_k + assert args[9] is None + assert args[10] is None + assert args[11] == intermediate_size + assert args[13] == num_experts + assert args[16] is True + assert args[17] == int(moe_core.WeightLayout.BlockMajorK) + assert args[20] == tune_max_num_tokens + + +@pytest.mark.parametrize( + "top_k,num_experts", + [ + (2, 8), + (4, 16), + (8, 128), # Qwen3-VL-MoE-like (text) + (10, 128), # Routing kernel currently supports top_k <= 10 + ], +) +def test_bf16_moe_tile_mismatch_no_crash_after_fix(monkeypatch, top_k, num_experts): + """SM100 BF16 integration: cached mismatched tile falls back without crash.""" + from flashinfer.fused_moe.utils import last_positive_power_of_2 + + _require_sm100() + _reset_autotuner() + torch.manual_seed(42) + device = torch.device("cuda:0") + + hidden_size = 1024 + intermediate_size = 1024 + + gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( + num_experts, intermediate_size, hidden_size, device + ) + + tune_num_tokens = 256 + infer_num_tokens = 500 + assert last_positive_power_of_2(infer_num_tokens) == tune_num_tokens + + tune_max = TUNE_MAX + + def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + """Force tile_N=32 to be selected as the 'fastest' tactic.""" + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + _tune_bf16_moe_once( + device=device, + tune_num_tokens=tune_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) + + tuner = AutoTuner.get() + assert len(tuner.profiling_cache) > 0, ( + "Autotuner cache should be populated after tuning" + ) + + _run_bf16_moe_infer( + device=device, + infer_num_tokens=infer_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) + + +@pytest.mark.parametrize( + "top_k,num_experts", + [ + (2, 8), + (4, 16), + (8, 128), # Qwen3-VL-MoE-like (text) + (10, 128), # Routing kernel currently supports top_k <= 10 + ], +) +@pytest.mark.parametrize( + "seed,stale_tactic", + [ + (123, [4096, 0]), # unsupported tile_N + (124, [32, 10_000_000]), # invalid config index + (125, [32]), # malformed tactic payload + ], +) +def test_bf16_moe_stale_cached_tactic_falls_back_no_crash( + monkeypatch, top_k, num_experts, seed, stale_tactic +): + """SM100 integration: stale cache payloads should fall back without crashing.""" + _require_sm100() + _reset_autotuner() + torch.manual_seed(seed) + device = torch.device("cuda:0") + + hidden_size, intermediate_size = 1024, 1024 + tune_num_tokens, infer_num_tokens = 256, 500 + tune_max = TUNE_MAX + + gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( + num_experts, intermediate_size, hidden_size, device + ) + + def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + _tune_bf16_moe_once( + device=device, + tune_num_tokens=tune_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) + + _overwrite_cached_tactic_for_op("flashinfer::trtllm_bf16_moe", stale_tactic) + _run_bf16_moe_infer( + device=device, + infer_num_tokens=infer_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + )