From 261efcd7b9a8bae42d11ee78282bdc30d15e0976 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 17 Mar 2026 17:22:32 +0200 Subject: [PATCH 01/12] Add test for "unordered_map::at" failure with TRTLLM MoE Signed-off-by: Daniel Serebrenik --- .../autotuner/test_autotuner_tile_mismatch.py | 436 ++++++++++++++++++ 1 file changed, 436 insertions(+) create mode 100644 tests/autotuner/test_autotuner_tile_mismatch.py diff --git a/tests/autotuner/test_autotuner_tile_mismatch.py b/tests/autotuner/test_autotuner_tile_mismatch.py new file mode 100644 index 0000000000..a2ec41c23d --- /dev/null +++ b/tests/autotuner/test_autotuner_tile_mismatch.py @@ -0,0 +1,436 @@ +""" +Regression test for the tile_N mismatch bug (RuntimeError: unordered_map::at). + +The bug is in the C++ MoE kernel launcher (trtllm_fused_moe_kernel_launcher.cu). +During autotuning, the Python autotuner profiles kernels using bucketed +num_tokens (e.g. 256) and caches the best tactic [tile_N, config]. +During inference, the C++ launcher calls computeSelectedTileN with the +*actual* num_tokens (e.g. 500) to build launchers_map — a subset of +supported tile sizes. When the actual num_tokens differs from the +bucketed value, computeSelectedTileN can produce a different tile subset, +and the cached tile_N may not exist in launchers_map, causing +launchers_map.at(tile_N) to throw. + +The C++ fix adds a fallback when the cached tile_N is missing: + if (launchers_map.find(tile_N) == launchers_map.end()) { fallback } + +The fix to _find_nearest_profile (propagating the bucketed value to all +linked dimensions) is what exposed this latent C++ bug. Before that fix, +only the first linked dimension was bucketed during inference, so the +profile key never matched the tuning-time key — the autotuner always +returned the fallback tactic (-1) and the C++ fallback path was taken. +After the fix, cache hits occur and the cached tile_N reaches the C++ +launcher, where the missing guard causes the crash. Both fixes are needed: +the Python fix makes the autotuner cache work correctly for MoE, and the +C++ fix makes the launcher handle tile_N mismatches gracefully. +""" + +import math + +import pytest +import torch + +from flashinfer import autotune +from flashinfer.autotuner import ( + AutoTuner, + DynamicTensorSpec, + TuningConfig, + TunableRunner, +) +from flashinfer.fused_moe.utils import ( + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, +) +from flashinfer.utils import get_compute_capability + + +# --------------------------------------------------------------------------- +# Helper: Python mirror of C++ computeSelectedTileN +# (csrc/trtllm_fused_moe_kernel_launcher.cu, lines 85-107) +# --------------------------------------------------------------------------- +def _next_power_of_two(n: float) -> int: + if n <= 1: + return 1 + return 1 << math.ceil(math.log2(n)) + + +def compute_selected_tile_n( + supported_tiles: list[int], + num_tokens: int, + top_k: int, + num_local_experts: int, +) -> set[int]: + """Python equivalent of C++ computeSelectedTileN.""" + avg = num_tokens * top_k / num_local_experts + tile = max(supported_tiles[0], min(supported_tiles[-1], _next_power_of_two(avg))) + try: + idx = supported_tiles.index(tile) + except ValueError: + # tile not in supported_tiles, find closest + idx = min( + range(len(supported_tiles)), + key=lambda i: abs(supported_tiles[i] - tile), + ) + selected = {supported_tiles[idx]} + if idx + 1 < len(supported_tiles): + selected.add(supported_tiles[idx + 1]) + if idx + 2 < len(supported_tiles): + selected.add(supported_tiles[idx + 2]) + if idx > 0: + selected.add(supported_tiles[idx - 1]) + return selected + + +# --------------------------------------------------------------------------- +# TileAwareDummyRunner: returns different valid tactics depending on shapes +# --------------------------------------------------------------------------- +class TileAwareDummyRunner(TunableRunner): + """Runner whose valid tactics depend on input num_tokens, mimicking + the real trtllm-gen MoE runner where computeSelectedTileN filters tiles. + + Each tactic is a list [tile_N, config_idx] matching the real MoE convention. + """ + + SUPPORTED_TILES = [8, 16, 32, 64] # FP4 base tiles (no 128/256 for bf16 act) + + def __init__(self, top_k: int = 8, num_local_experts: int = 64): + self.top_k = top_k + self.num_local_experts = num_local_experts + + def get_valid_tactics(self, inputs, profile): + num_tokens = inputs[0].shape[0] + selected = compute_selected_tile_n( + self.SUPPORTED_TILES, num_tokens, self.top_k, self.num_local_experts + ) + tactics = [] + for tile in sorted(selected): + for cfg in range(2): # 2 configs per tile + tactics.append([tile, cfg]) + return tactics + + def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): + return inputs[0] + + +# --------------------------------------------------------------------------- +# Shared MoE-like tuning config (mirrors the real one) +# --------------------------------------------------------------------------- +TUNE_MAX = 4096 + +MOE_TUNING_CONFIG = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + input_idx=(0,), + dim_idx=(0,), + gen_tuning_buckets=get_last_power_of_2_num_tokens_buckets(TUNE_MAX, 1), + map_to_tuning_buckets=lambda x: min(last_positive_power_of_2(x), TUNE_MAX), + ), + ), +) + + +def _reset_autotuner() -> AutoTuner: + tuner = AutoTuner.get() + tuner.clear_cache() + tuner.reset_statistics() + tuner.is_tuning_mode = False + return tuner + + +# =================================================================== +# Tests +# =================================================================== + + +def test_tile_set_differs_between_bucketed_and_actual_num_tokens(): + """Prove that computeSelectedTileN can return different tile sets + for the bucketed num_tokens vs. the actual num_tokens that maps to + the same autotuner bucket. + + This is the precondition for the unordered_map::at crash. + """ + top_k = 8 + num_experts = 64 + + # With the base SUPPORTED_TILES=[8,16,32,64] and top_k=8, num_experts=64, + # both bucketed=1024 and actual=1624 clamp to max tile 64, so the sets match. + # Use a wider tile range that DOES produce a mismatch: + tiles_small = [8, 16, 32, 64, 128] + actual = 500 + bucket = last_positive_power_of_2(actual) # 256 + assert bucket == 256 + + tiles_bucket = compute_selected_tile_n(tiles_small, bucket, top_k, num_experts) + tiles_actual = compute_selected_tile_n(tiles_small, actual, top_k, num_experts) + + # bucket=256: avg=256*8/64=32, nextPow2=32 → idx=2, selected={16,32,64,128} + # actual=500: avg=500*8/64=62.5, nextPow2=64 → idx=3, selected={32,64,128} + assert tiles_bucket != tiles_actual, ( + f"Expected tile sets to differ for bucket={bucket} vs actual={actual}, " + f"got {tiles_bucket} vs {tiles_actual}" + ) + + # A tile_N valid for the bucket may not be valid for the actual + only_in_bucket = tiles_bucket - tiles_actual + assert len(only_in_bucket) > 0, ( + "Expected at least one tile valid for bucketed but not for actual shapes" + ) + + +def test_autotuner_returns_cached_tactic_for_different_actual_shape(monkeypatch): + """The autotuner returns a cached tactic when num_tokens differs from + the tuning shape but maps to the same bucket. + + Before the C++ fix, if the cached tile_N was not in + computeSelectedTileN(actual_num_tokens), this caused + RuntimeError: unordered_map::at. + """ + tuner = _reset_autotuner() + runner = TileAwareDummyRunner(top_k=8, num_local_experts=64) + hidden_size = 256 + + def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + # Make the tactic with tile_N=16 the "best" (lowest time) + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 16 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + # --- Phase 1: Tune with bucketed num_tokens --- + # The autotuner generates profiles for all buckets. + # For bucket=256, avg=256*8/64=32, tiles include {16,32,64,...}, + # so tile_N=16 is valid and will be selected as best. + tune_inputs = [torch.empty((256, hidden_size), dtype=torch.float32)] + with autotune(tune_mode=True): + _, tuned_tactic = tuner.choose_one( + "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, tune_inputs + ) + + assert isinstance(tuned_tactic, list) + assert tuned_tactic[0] == 16, f"Expected tile_N=16, got {tuned_tactic}" + + # --- Phase 2: Inference with different actual num_tokens --- + # actual=500 maps to bucket=256 (same bucket) so we get a cache hit. + infer_inputs = [torch.empty((500, hidden_size), dtype=torch.float32)] + _, infer_tactic = tuner.choose_one( + "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, infer_inputs + ) + + # The autotuner returns the CACHED tactic from tuning (tile_N=16). + assert infer_tactic == tuned_tactic, "Expected cache hit returning the tuned tactic" + + # --- Verify the mismatch --- + # For actual=500: avg=500*8/64=62.5, nextPow2=64, tiles={32,64,128} + # tile_N=16 is NOT in this set → would crash without the C++ fix. + actual_tiles = compute_selected_tile_n( + TileAwareDummyRunner.SUPPORTED_TILES + [128], + 500, + runner.top_k, + runner.num_local_experts, + ) + assert tuned_tactic[0] not in actual_tiles, ( + f"tile_N={tuned_tactic[0]} should NOT be in actual tiles {actual_tiles} — " + "this is the mismatch that caused the C++ unordered_map::at crash" + ) + + +def test_autotuner_cache_miss_returns_fallback_for_unseen_bucket(): + """When num_tokens maps to a bucket that was never tuned, + the autotuner correctly returns the fallback tactic=-1.""" + tuner = _reset_autotuner() + runner = TileAwareDummyRunner() + hidden_size = 256 + + # No tuning done — inference should always get fallback + inputs = [torch.empty((1624, hidden_size), dtype=torch.float32)] + _, tactic = tuner.choose_one("test_no_tune", [runner], MOE_TUNING_CONFIG, inputs) + assert tactic == -1, "Expected fallback tactic when no tuning was done" + + +def test_different_actual_tokens_same_bucket_get_same_cached_tactic(monkeypatch): + """Multiple actual num_tokens that map to the same bucket should all + receive the same cached tactic — confirming the autotuner uses the + bucketed profile, not the actual shapes, for cache lookup.""" + tuner = _reset_autotuner() + runner = TileAwareDummyRunner(top_k=8, num_local_experts=64) + hidden_size = 128 + + def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + tune_inputs = [torch.empty((512, hidden_size), dtype=torch.float32)] + with autotune(tune_mode=True): + _, tuned_tactic = tuner.choose_one( + "test_same_bucket", [runner], MOE_TUNING_CONFIG, tune_inputs + ) + + assert tuned_tactic[0] == 32 + + # All these map to bucket 512 via last_positive_power_of_2 + for actual in [513, 600, 700, 800, 900, 1000, 1023]: + assert last_positive_power_of_2(actual) == 512 + infer_inputs = [torch.empty((actual, hidden_size), dtype=torch.float32)] + _, tactic = tuner.choose_one( + "test_same_bucket", [runner], MOE_TUNING_CONFIG, infer_inputs + ) + assert tactic == tuned_tactic, ( + f"Expected cached tactic {tuned_tactic} for num_tokens={actual}, got {tactic}" + ) + + +# =================================================================== +# SM100 integration test — exercises the real C++ launcher +# =================================================================== + + +def _prepare_bf16_moe_weights(num_experts, intermediate_size, hidden_size, device): + """Prepare shuffled BF16 weights in BlockMajorK layout.""" + from flashinfer import shuffle_matrix_a + from flashinfer.fused_moe import convert_to_block_layout + + gemm1 = torch.randn( + num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ) + gemm2 = torch.randn( + num_experts, hidden_size, intermediate_size, device=device, dtype=torch.bfloat16 + ) + g1_shuffled, g2_shuffled = [], [] + for i in range(num_experts): + g1_shuffled.append( + convert_to_block_layout( + shuffle_matrix_a(gemm1[i].view(torch.uint8), 64), 128 + ) + ) + g2_shuffled.append( + convert_to_block_layout( + shuffle_matrix_a(gemm2[i].view(torch.uint8), 64), 128 + ) + ) + return ( + torch.stack(g1_shuffled).view(torch.bfloat16), + torch.stack(g2_shuffled).view(torch.bfloat16), + ) + + +def test_bf16_moe_tile_mismatch_no_crash_after_fix(monkeypatch): + """SM100 integration test: tune BF16 MoE with one num_tokens, then + infer with a different num_tokens that maps to the same autotuner + bucket but selects different C++ tiles. + + Before the C++ fix this raised ``RuntimeError: unordered_map::at``. + After the fix the launcher falls back to a default tile gracefully. + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] not in [10]: + pytest.skip("This test requires an SM100/SM103 (Blackwell) GPU.") + + from flashinfer.fused_moe import trtllm_bf16_moe, WeightLayout + + _reset_autotuner() + torch.manual_seed(42) + device = torch.device("cuda:0") + + # Model dimensions — keep small for speed + num_experts = 8 + top_k = 2 + hidden_size = 1024 + intermediate_size = 1024 + + # Prepare weights once (they don't depend on num_tokens) + gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( + num_experts, intermediate_size, hidden_size, device + ) + + # --- Choose num_tokens that trigger the tile mismatch --- + # BF16 MoE supported tiles: {8, 16, 32, 64, 128} + # tune_num_tokens=256 (bucket): avg=256*2/8=64 → tiles={32,64,128} + # infer_num_tokens=500 (bucket=256): avg=500*2/8=125 → tiles={64,128} + # A tactic with tile_N=32 is valid for bucket but NOT for actual. + tune_num_tokens = 256 + infer_num_tokens = 500 + assert last_positive_power_of_2(infer_num_tokens) == tune_num_tokens + + tune_max = 4096 + + def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + """Force tile_N=32 to be selected as the 'fastest' tactic.""" + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + + # --- Phase 1: Tune with tune_num_tokens --- + routing_tune = torch.rand( + tune_num_tokens, num_experts, device=device, dtype=torch.bfloat16 + ) + hidden_tune = torch.randn( + tune_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + + with autotune(tune_mode=True): + trtllm_bf16_moe( + routing_logits=routing_tune, + routing_bias=None, + hidden_states=hidden_tune, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize (TopK/Default not supported for BF16) + use_shuffled_weight=True, + weight_layout=WeightLayout.BlockMajorK, + tune_max_num_tokens=tune_max, + ) + + # Verify the autotuner populated its cache during tuning + tuner = AutoTuner.get() + assert len(tuner.profiling_cache) > 0, ( + "Autotuner cache should be populated after tuning" + ) + + # --- Phase 2: Inference with infer_num_tokens (same bucket, different tiles) --- + # Without the C++ fix, this would crash with RuntimeError: unordered_map::at + # because tile_N=32 is not in computeSelectedTileN(500, 2, 8) = {64, 128}. + routing_infer = torch.rand( + infer_num_tokens, num_experts, device=device, dtype=torch.bfloat16 + ) + hidden_infer = torch.randn( + infer_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + + # This call should NOT crash (with the C++ fix it falls back to a valid tile) + output = trtllm_bf16_moe( + routing_logits=routing_infer, + routing_bias=None, + hidden_states=hidden_infer, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize + use_shuffled_weight=True, + weight_layout=WeightLayout.BlockMajorK, + tune_max_num_tokens=tune_max, + ) + + assert output.shape[0] == infer_num_tokens + assert output.isfinite().all(), "Output should be finite" From 55c351b38a74e08fbaede442ab1052d5e0a28e80 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 17 Mar 2026 17:22:32 +0200 Subject: [PATCH 02/12] Fix crash when key not found in launchers_map Signed-off-by: Daniel Serebrenik --- csrc/trtllm_fused_moe_kernel_launcher.cu | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 64fece5021..ab43305a7e 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1670,9 +1670,10 @@ Array trtllm_bf16_moe(Optional const& routing_logits, int64_t tile_N = moe_tactic[0]; int64_t config = moe_tactic[1]; - // Handle default case - if (tile_N == -1 || config == -1) { + // Handle default case or stale autotuner cache (tile_N not in current selected set) + if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { tile_N = *selected_tile_nums.begin(); + config = -1; // Let the runner choose default } // Get the launcher for the selected tile_N @@ -1762,9 +1763,10 @@ Array trtllm_fp8_per_tensor_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { + // Handle default case or stale autotuner cache (tile_N not in current selected set) + if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { tile_N = *selected_tile_nums.begin(); + config = -1; // Let the runner choose default } // Get the launcher for the selected tile_N @@ -1883,9 +1885,10 @@ Array trtllm_fp8_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { + // Handle default case or stale autotuner cache (tile_N not in current selected set) + if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { tile_N = *selected_tile_nums.begin(); + config = -1; // Let the runner choose default } // Get the launcher for the selected tile_N @@ -2027,8 +2030,8 @@ Array trtllm_fp4_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { + // Handle default case or stale autotuner cache (tile_N not in current selected set) + if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { tile_N = *selected_tile_nums.begin(); config = -1; // Let the runner choose default } @@ -2116,8 +2119,8 @@ Array trtllm_mxint4_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case - if (tile_N == -1 || config == -1) { + // Handle default case or stale autotuner cache (tile_N not in current selected set) + if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { tile_N = *selected_tile_nums.begin(); config = -1; // Let the runner choose default } From cd7ea99d1aca51bd9ae1245609deacf8b77adc91 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 17 Mar 2026 17:22:32 +0200 Subject: [PATCH 03/12] Add proper fix for tile selection Signed-off-by: Daniel Serebrenik --- csrc/trtllm_fused_moe_kernel_launcher.cu | 53 +++++++++++++----------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index ab43305a7e..d92732155f 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1632,8 +1632,10 @@ Array trtllm_bf16_moe(Optional const& routing_logits, // Calculate supported tile sizes std::vector mSupportedTileN(Bf16MoeLauncher::mSupportedTileNums.begin(), Bf16MoeLauncher::mSupportedTileNums.end()); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles (not just the computeSelectedTileN subset) + // so that autotuner-cached tactics always find their tile_N in the map. + // Launcher creation is cheap (no GPU allocation until run()), so this is safe. + std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -1670,9 +1672,9 @@ Array trtllm_bf16_moe(Optional const& routing_logits, int64_t tile_N = moe_tactic[0]; int64_t config = moe_tactic[1]; - // Handle default case or stale autotuner cache (tile_N not in current selected set) - if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { - tile_N = *selected_tile_nums.begin(); + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); config = -1; // Let the runner choose default } @@ -1725,8 +1727,8 @@ Array trtllm_fp8_per_tensor_scale_moe( // Calculate supported tile sizes std::vector mSupportedTileN(Fp8PerTensorLauncher::mSupportedTileNums.begin(), Fp8PerTensorLauncher::mSupportedTileNums.end()); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. + std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -1763,9 +1765,9 @@ Array trtllm_fp8_per_tensor_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case or stale autotuner cache (tile_N not in current selected set) - if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { - tile_N = *selected_tile_nums.begin(); + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); config = -1; // Let the runner choose default } @@ -1846,8 +1848,8 @@ Array trtllm_fp8_block_scale_moe( auto const hidden_size = hidden_states.size(1); auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type); - std::set selected_tile_nums = - computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. + std::set selected_tile_nums(supported_tile_nums.begin(), supported_tile_nums.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -1885,9 +1887,10 @@ Array trtllm_fp8_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case or stale autotuner cache (tile_N not in current selected set) - if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { - tile_N = *selected_tile_nums.begin(); + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = + *computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts).begin(); config = -1; // Let the runner choose default } @@ -1988,8 +1991,8 @@ Array trtllm_fp4_block_scale_moe( // Determine supported tile sizes std::vector mSupportedTileN = FP4BlockScaleLauncher::getSupportedTileNums(mDtypeAct); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. + std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -2030,9 +2033,9 @@ Array trtllm_fp4_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case or stale autotuner cache (tile_N not in current selected set) - if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { - tile_N = *selected_tile_nums.begin(); + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); config = -1; // Let the runner choose default } @@ -2081,8 +2084,8 @@ Array trtllm_mxint4_block_scale_moe( // Determine supported tile sizes std::vector mSupportedTileN(MxInt4BlockScaleLauncher::mSupportedTileNums.begin(), MxInt4BlockScaleLauncher::mSupportedTileNums.end()); - std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. + std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -2119,9 +2122,9 @@ Array trtllm_mxint4_block_scale_moe( int64_t tile_N = config_index[0]; int64_t config = config_index[1]; - // Handle default case or stale autotuner cache (tile_N not in current selected set) - if (tile_N == -1 || config == -1 || launchers_map.find(tile_N) == launchers_map.end()) { - tile_N = *selected_tile_nums.begin(); + // Handle default case + if (tile_N == -1 || config == -1) { + tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); config = -1; // Let the runner choose default } From 0feb41cc0f38fe4c7e6ce999f8f3d861ed48701d Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 17 Mar 2026 17:22:32 +0200 Subject: [PATCH 04/12] Improve bugfix code Signed-off-by: Daniel Serebrenik --- csrc/trtllm_fused_moe_kernel_launcher.cu | 209 +++++++++++++++-------- 1 file changed, 139 insertions(+), 70 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index d92732155f..d7fe85417b 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -85,27 +85,103 @@ inline int32_t nextPowerOfTwo(float value) { std::set computeSelectedTileN(std::vector const& supported_tile_nums, int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { + TVM_FFI_ICHECK(!supported_tile_nums.empty()) << "supported_tile_nums must not be empty."; + TVM_FFI_ICHECK(std::is_sorted(supported_tile_nums.begin(), supported_tile_nums.end())) + << "supported_tile_nums must be sorted in ascending order."; float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; + // NOTE: This differs from Python AutoTuner bucketing: + // - AutoTuner maps raw num_tokens with last_positive_power_of_2 (round-down). + // - Here we map derived avg_tokens_per_expert and use nextPowerOfTwo (round-up). + // Because they round different quantities in different directions, cache bucket and runtime + // tile candidates can diverge; launcher-side tactic resolution handles that mismatch. // assume supported_tile_nums is sorted int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), supported_tile_nums.front(), supported_tile_nums.back()); auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim); - + FLASHINFER_CHECK( + it != supported_tile_nums.end(), "computeSelectedTileN expected exact tile ", tile_tokens_dim, + " in supported_tile_nums (size=", supported_tile_nums.size(), + "). Please keep supported_tile_nums as a dense power-of-2 ladder for this launcher."); + + // Candidate tile set centered on the heuristic tile. + // This function returns nearby candidates (not a single final tile): + // center, +1, +2, and -1 neighbors when available. + // Final tile choice is made later (autotuner-provided tile if valid, otherwise fallback policy). std::set selected_tile_nums; selected_tile_nums.insert(tile_tokens_dim); if (std::next(it) != supported_tile_nums.end()) { + // Prefer exploring larger tiles too because they can win for dense routing/shape regimes. selected_tile_nums.insert(*std::next(it)); if (std::next(std::next(it)) != supported_tile_nums.end()) { selected_tile_nums.insert(*std::next(std::next(it))); } } if (it != supported_tile_nums.begin()) { + // Keep one smaller neighbor to avoid over-committing to an oversized center tile. selected_tile_nums.insert(*std::prev(it)); } return selected_tile_nums; } +int64_t selectDefaultTileN(std::vector const& supported_tile_nums, + int64_t const num_tokens, int64_t const top_k, + int64_t const num_local_experts) { + auto selected = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + TVM_FFI_ICHECK(!selected.empty()) << "No selected tile_N candidates for current MoE input."; + return *selected.begin(); +} + +// Resolve the (tile_N, config) pair passed from Python into a launcher-safe pair. +// +// Interaction with Python AutoTuner: +// - Python AutoTuner stores and returns "tactics" as a 2-int payload [tile_N, config]. +// - Tuning/profiling uses bucketed num_tokens (power-of-2 buckets), while runtime uses actual +// num_tokens. Therefore, a cached tile_N chosen for the bucket may be suboptimal or unsupported +// for the runtime shape. +// - Tactics may also come from persisted config files and can be stale/malformed across versions. +// +// This resolver makes C++ robust to those mismatches: +// 1) malformed/missing tactic payload -> fallback to runtime-selected default tile/config +// 2) fallback marker from Python (-1, -1) -> runtime-selected default tile/config +// 3) stale tile_N not in current supported set -> runtime-selected default tile/config +// +// We intentionally do not fully validate `config` here because config validity depends on the +// concrete Runner instance and current shape. That validation/fallback happens later in +// FusedMoeLauncher::prepare_moe_common(). +std::pair resolveMoeTileAndConfig(Array const& config_index, + std::vector const& supported_tile_nums, + int64_t const num_tokens, int64_t const top_k, + int64_t const num_local_experts) { + int64_t tile_N = -1; + int64_t config = -1; + + // Python side convention: tactic is [tile_N, config] + if (config_index.size() >= 2) { + tile_N = config_index[0]; + config = config_index[1]; + } + + // Runtime default is selected from actual shape stats (num_tokens/top_k/local_num_experts), + // which may differ from the bucketed shape that produced the cached tactic. + auto const default_tile_N = + selectDefaultTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + if (tile_N == -1 || config == -1) { + // Use fallback tactic + return {default_tile_N, -1}; + } + + bool const tile_supported = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), + static_cast(tile_N)) != supported_tile_nums.end(); + if (!tile_supported) { + // Tactic can be stale due to cache/file mismatch across versions or shape bucketing mismatch. + // Fall back to a runtime-selected tile and default config instead of crashing. + return {default_tile_N, -1}; + } + + return {tile_N, config}; +} + class FusedMoeLauncher { protected: Optional routing_logits; @@ -334,19 +410,38 @@ class FusedMoeLauncher { this->activation_type, this->use_shuffled_weight, this->weight_layout); } + auto valid_cfgs = + moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, + args->local_num_experts, args->num_tokens); + FLASHINFER_CHECK(!valid_cfgs.empty(), "No valid MoE tactics for tile_N=", tile_tokens_dim, + " with shape (num_tokens=", args->num_tokens, + ", hidden_size=", args->hidden_size, + ", intermediate_size=", args->intermediate_size, ", top_k=", args->top_k, + ", local_num_experts=", args->local_num_experts, ")."); + int64_t default_tactic = -1; if (moe_tactic == -1) { - moe_tactic = moe_runner->getDefaultValidConfigIndex( + default_tactic = moe_runner->getDefaultValidConfigIndex( args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, args->num_tokens); + moe_tactic = default_tactic; } - auto valid_cfgs = - moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, - args->local_num_experts, args->num_tokens); auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic); - FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, - " for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ", - valid_cfgs.size(), - ". This often indicates a stale or mismatched autotuner cache entry."); + if (valid_it == valid_cfgs.end()) { + if (default_tactic == -1) { + default_tactic = moe_runner->getDefaultValidConfigIndex( + args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, + args->num_tokens); + } + auto default_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), default_tactic); + FLASHINFER_CHECK( + default_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, + " for tile_N=", tile_tokens_dim, + ". No valid fallback default tactic found for shape (num_tokens=", args->num_tokens, + ", hidden_size=", args->hidden_size, ", intermediate_size=", args->intermediate_size, + ", top_k=", args->top_k, ", local_num_experts=", args->local_num_experts, + "). This indicates a deeper kernel/config mismatch."); + moe_tactic = default_tactic; + } this->moe_tactic = moe_tactic; auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); @@ -1635,12 +1730,11 @@ Array trtllm_bf16_moe(Optional const& routing_logits, // Build launchers for ALL supported tiles (not just the computeSelectedTileN subset) // so that autotuner-cached tactics always find their tile_N in the map. // Launcher creation is cheap (no GPU allocation until run()), so this is safe. - std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -1668,18 +1762,14 @@ Array trtllm_bf16_moe(Optional const& routing_logits, launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from moe_tactic - int64_t tile_N = moe_tactic[0]; - int64_t config = moe_tactic[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(moe_tactic, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing BF16 MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl); @@ -1728,12 +1818,11 @@ Array trtllm_fp8_per_tensor_scale_moe( std::vector mSupportedTileN(Fp8PerTensorLauncher::mSupportedTileNums.begin(), Fp8PerTensorLauncher::mSupportedTileNums.end()); // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. - std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -1761,18 +1850,14 @@ Array trtllm_fp8_per_tensor_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing FP8 per-tensor MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input); @@ -1849,12 +1934,11 @@ Array trtllm_fp8_block_scale_moe( auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type); // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. - std::set selected_tile_nums(supported_tile_nums.begin(), supported_tile_nums.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : supported_tile_nums) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -1883,19 +1967,14 @@ Array trtllm_fp8_block_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = - *computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts).begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = resolveMoeTileAndConfig(config_index, supported_tile_nums, + num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing FP8 block-scale MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally return selected_launcher->run( @@ -1992,12 +2071,11 @@ Array trtllm_fp4_block_scale_moe( // Determine supported tile sizes std::vector mSupportedTileN = FP4BlockScaleLauncher::getSupportedTileNums(mDtypeAct); // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. - std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -2029,18 +2107,14 @@ Array trtllm_fp4_block_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing FP4 block-scale MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl); @@ -2085,12 +2159,11 @@ Array trtllm_mxint4_block_scale_moe( std::vector mSupportedTileN(MxInt4BlockScaleLauncher::mSupportedTileNums.begin(), MxInt4BlockScaleLauncher::mSupportedTileNums.end()); // Build launchers for ALL supported tiles so autotuner-cached tactics always find their tile_N. - std::set selected_tile_nums(mSupportedTileN.begin(), mSupportedTileN.end()); // Create a map of launchers for each tile size std::unordered_map> launchers_map; - for (int32_t curr_tile_N : selected_tile_nums) { + for (int32_t curr_tile_N : mSupportedTileN) { // Create MoE arguments for this launcher auto args = std::make_unique(); args->num_tokens = num_tokens; @@ -2118,18 +2191,14 @@ Array trtllm_mxint4_block_scale_moe( launchers_map[curr_tile_N] = std::move(launcher); } - // Extract tile_N and config from config_index - int64_t tile_N = config_index[0]; - int64_t config = config_index[1]; - - // Handle default case - if (tile_N == -1 || config == -1) { - tile_N = *computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts).begin(); - config = -1; // Let the runner choose default - } + auto const [tile_N, config] = + resolveMoeTileAndConfig(config_index, mSupportedTileN, num_tokens, top_k, local_num_experts); // Get the launcher for the selected tile_N - auto& selected_launcher = launchers_map.at(tile_N); + auto launcher_it = launchers_map.find(static_cast(tile_N)); + FLASHINFER_CHECK(launcher_it != launchers_map.end(), + "Internal error: missing MXINT4 block-scale MoE launcher for tile_N=", tile_N); + auto& selected_launcher = launcher_it->second; // Run the launcher - it will create its own runner internally return selected_launcher->run(config, enable_pdl); From 5611ea6b85c83a07d19714d33eb6721755ce695d Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik Date: Tue, 17 Mar 2026 17:22:32 +0200 Subject: [PATCH 05/12] Improve tests and increase coverage Signed-off-by: Daniel Serebrenik --- .../autotuner/test_autotuner_tile_mismatch.py | 358 ++++-------- ...st_trtllm_moe_tile_mismatch_integration.py | 513 ++++++++++++++++++ 2 files changed, 604 insertions(+), 267 deletions(-) create mode 100644 tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py diff --git a/tests/autotuner/test_autotuner_tile_mismatch.py b/tests/autotuner/test_autotuner_tile_mismatch.py index a2ec41c23d..d88e373ead 100644 --- a/tests/autotuner/test_autotuner_tile_mismatch.py +++ b/tests/autotuner/test_autotuner_tile_mismatch.py @@ -1,29 +1,4 @@ -""" -Regression test for the tile_N mismatch bug (RuntimeError: unordered_map::at). - -The bug is in the C++ MoE kernel launcher (trtllm_fused_moe_kernel_launcher.cu). -During autotuning, the Python autotuner profiles kernels using bucketed -num_tokens (e.g. 256) and caches the best tactic [tile_N, config]. -During inference, the C++ launcher calls computeSelectedTileN with the -*actual* num_tokens (e.g. 500) to build launchers_map — a subset of -supported tile sizes. When the actual num_tokens differs from the -bucketed value, computeSelectedTileN can produce a different tile subset, -and the cached tile_N may not exist in launchers_map, causing -launchers_map.at(tile_N) to throw. - -The C++ fix adds a fallback when the cached tile_N is missing: - if (launchers_map.find(tile_N) == launchers_map.end()) { fallback } - -The fix to _find_nearest_profile (propagating the bucketed value to all -linked dimensions) is what exposed this latent C++ bug. Before that fix, -only the first linked dimension was bucketed during inference, so the -profile key never matched the tuning-time key — the autotuner always -returned the fallback tactic (-1) and the C++ fallback path was taken. -After the fix, cache hits occur and the cached tile_N reaches the C++ -launcher, where the missing guard causes the crash. Both fixes are needed: -the Python fix makes the autotuner cache work correctly for MoE, and the -C++ fix makes the launcher handle tile_N mismatches gracefully. -""" +"""Regression tests for MoE autotuner tile/config mismatch handling.""" import math @@ -41,13 +16,8 @@ get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, ) -from flashinfer.utils import get_compute_capability -# --------------------------------------------------------------------------- -# Helper: Python mirror of C++ computeSelectedTileN -# (csrc/trtllm_fused_moe_kernel_launcher.cu, lines 85-107) -# --------------------------------------------------------------------------- def _next_power_of_two(n: float) -> int: if n <= 1: return 1 @@ -81,15 +51,9 @@ def compute_selected_tile_n( return selected -# --------------------------------------------------------------------------- -# TileAwareDummyRunner: returns different valid tactics depending on shapes -# --------------------------------------------------------------------------- class TileAwareDummyRunner(TunableRunner): - """Runner whose valid tactics depend on input num_tokens, mimicking - the real trtllm-gen MoE runner where computeSelectedTileN filters tiles. - - Each tactic is a list [tile_N, config_idx] matching the real MoE convention. - """ + """Dummy runner whose valid tactics depend on num_tokens. + returns different valid tactics depending on shapes.""" SUPPORTED_TILES = [8, 16, 32, 64] # FP4 base tiles (no 128/256 for bf16 act) @@ -112,10 +76,7 @@ def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): return inputs[0] -# --------------------------------------------------------------------------- -# Shared MoE-like tuning config (mirrors the real one) -# --------------------------------------------------------------------------- -TUNE_MAX = 4096 +TUNE_MAX = 8192 MOE_TUNING_CONFIG = TuningConfig( dynamic_tensor_specs=( @@ -137,94 +98,126 @@ def _reset_autotuner() -> AutoTuner: return tuner -# =================================================================== -# Tests -# =================================================================== - - -def test_tile_set_differs_between_bucketed_and_actual_num_tokens(): - """Prove that computeSelectedTileN can return different tile sets - for the bucketed num_tokens vs. the actual num_tokens that maps to - the same autotuner bucket. - - This is the precondition for the unordered_map::at crash. - """ - top_k = 8 - num_experts = 64 - - # With the base SUPPORTED_TILES=[8,16,32,64] and top_k=8, num_experts=64, - # both bucketed=1024 and actual=1624 clamp to max tile 64, so the sets match. - # Use a wider tile range that DOES produce a mismatch: - tiles_small = [8, 16, 32, 64, 128] - actual = 500 - bucket = last_positive_power_of_2(actual) # 256 - assert bucket == 256 - - tiles_bucket = compute_selected_tile_n(tiles_small, bucket, top_k, num_experts) - tiles_actual = compute_selected_tile_n(tiles_small, actual, top_k, num_experts) +def _find_bucket_actual_tile_mismatch_case( + *, + supported_tiles: list[int], + top_k: int, + num_local_experts: int, + tune_max: int = TUNE_MAX, +): + """Find (bucket, actual, tuned_tile) where tuned_tile is valid only for bucket.""" + for actual in range(2, tune_max + 1): + bucket = min(last_positive_power_of_2(actual), tune_max) + if bucket == actual: + continue + tiles_bucket = compute_selected_tile_n( + supported_tiles, bucket, top_k, num_local_experts + ) + tiles_actual = compute_selected_tile_n( + supported_tiles, actual, top_k, num_local_experts + ) + only_in_bucket = sorted(tiles_bucket - tiles_actual) + if only_in_bucket: + return bucket, actual, only_in_bucket[0] + return None + + +@pytest.mark.parametrize( + "top_k,num_experts", + [ + (2, 8), + (4, 16), + (8, 64), + (8, 128), # Qwen3-VL-MoE-like (text) routing fanout/expert count + (10, 512), # Qwen3.5-MoE-like (text) routing fanout/expert count + ], +) +def test_tile_set_differs_between_bucketed_and_actual_num_tokens(top_k, num_experts): + """Bucketed and actual shapes in the same bucket can still pick different tiles.""" + tiles = [8, 16, 32, 64, 128] + case = _find_bucket_actual_tile_mismatch_case( + supported_tiles=tiles, top_k=top_k, num_local_experts=num_experts + ) + assert case is not None, ( + f"No mismatch case found for top_k={top_k}, experts={num_experts}" + ) + bucket, actual, _ = case + tiles_bucket = compute_selected_tile_n(tiles, bucket, top_k, num_experts) + tiles_actual = compute_selected_tile_n(tiles, actual, top_k, num_experts) - # bucket=256: avg=256*8/64=32, nextPow2=32 → idx=2, selected={16,32,64,128} - # actual=500: avg=500*8/64=62.5, nextPow2=64 → idx=3, selected={32,64,128} assert tiles_bucket != tiles_actual, ( f"Expected tile sets to differ for bucket={bucket} vs actual={actual}, " f"got {tiles_bucket} vs {tiles_actual}" ) - # A tile_N valid for the bucket may not be valid for the actual only_in_bucket = tiles_bucket - tiles_actual assert len(only_in_bucket) > 0, ( "Expected at least one tile valid for bucketed but not for actual shapes" ) -def test_autotuner_returns_cached_tactic_for_different_actual_shape(monkeypatch): - """The autotuner returns a cached tactic when num_tokens differs from - the tuning shape but maps to the same bucket. - - Before the C++ fix, if the cached tile_N was not in - computeSelectedTileN(actual_num_tokens), this caused - RuntimeError: unordered_map::at. - """ +@pytest.mark.parametrize( + "top_k,num_experts,hidden_size", + [ + (2, 8, 128), + (4, 16, 128), + (8, 64, 256), + (12, 128, 1024), + ( + 8, + 256, + 7168, + ), # DeepSeek-v3-like MoE dimensions (lightweight autotuner-only path) + (8, 128, 4096), # Qwen3-VL-MoE-like text dimensions (autotuner-only path) + (10, 512, 4096), # Qwen3.5-MoE-like text dimensions (autotuner-only path) + ], +) +def test_autotuner_returns_cached_tactic_for_different_actual_shape( + monkeypatch, top_k, num_experts, hidden_size +): + """Cache lookup uses bucketed profile, not raw runtime num_tokens.""" tuner = _reset_autotuner() - runner = TileAwareDummyRunner(top_k=8, num_local_experts=64) - hidden_size = 256 + runner = TileAwareDummyRunner(top_k=top_k, num_local_experts=num_experts) + case = _find_bucket_actual_tile_mismatch_case( + supported_tiles=runner.SUPPORTED_TILES, + top_k=top_k, + num_local_experts=num_experts, + ) + assert case is not None, ( + f"No mismatch case found for top_k={top_k}, experts={num_experts}" + ) + tune_bucket, infer_actual, tuned_tile = case def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): - # Make the tactic with tile_N=16 the "best" (lowest time) + # Make the chosen mismatch tile "best" so autotuner caches it. tile_n = tactic[0] if isinstance(tactic, list) else -1 - return 1.0 if tile_n == 16 else 5.0 + return 1.0 if tile_n == tuned_tile else 5.0 monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) - # --- Phase 1: Tune with bucketed num_tokens --- - # The autotuner generates profiles for all buckets. - # For bucket=256, avg=256*8/64=32, tiles include {16,32,64,...}, - # so tile_N=16 is valid and will be selected as best. - tune_inputs = [torch.empty((256, hidden_size), dtype=torch.float32)] + tune_inputs = [torch.empty((tune_bucket, hidden_size), dtype=torch.float32)] with autotune(tune_mode=True): _, tuned_tactic = tuner.choose_one( "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, tune_inputs ) assert isinstance(tuned_tactic, list) - assert tuned_tactic[0] == 16, f"Expected tile_N=16, got {tuned_tactic}" + assert tuned_tactic[0] == tuned_tile, ( + f"Expected tile_N={tuned_tile} for top_k={top_k}, experts={num_experts}, " + f"got {tuned_tactic}" + ) - # --- Phase 2: Inference with different actual num_tokens --- - # actual=500 maps to bucket=256 (same bucket) so we get a cache hit. - infer_inputs = [torch.empty((500, hidden_size), dtype=torch.float32)] + assert last_positive_power_of_2(infer_actual) == tune_bucket + infer_inputs = [torch.empty((infer_actual, hidden_size), dtype=torch.float32)] _, infer_tactic = tuner.choose_one( "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, infer_inputs ) - # The autotuner returns the CACHED tactic from tuning (tile_N=16). assert infer_tactic == tuned_tactic, "Expected cache hit returning the tuned tactic" - # --- Verify the mismatch --- - # For actual=500: avg=500*8/64=62.5, nextPow2=64, tiles={32,64,128} - # tile_N=16 is NOT in this set → would crash without the C++ fix. actual_tiles = compute_selected_tile_n( - TileAwareDummyRunner.SUPPORTED_TILES + [128], - 500, + TileAwareDummyRunner.SUPPORTED_TILES, + infer_actual, runner.top_k, runner.num_local_experts, ) @@ -234,19 +227,6 @@ def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw ) -def test_autotuner_cache_miss_returns_fallback_for_unseen_bucket(): - """When num_tokens maps to a bucket that was never tuned, - the autotuner correctly returns the fallback tactic=-1.""" - tuner = _reset_autotuner() - runner = TileAwareDummyRunner() - hidden_size = 256 - - # No tuning done — inference should always get fallback - inputs = [torch.empty((1624, hidden_size), dtype=torch.float32)] - _, tactic = tuner.choose_one("test_no_tune", [runner], MOE_TUNING_CONFIG, inputs) - assert tactic == -1, "Expected fallback tactic when no tuning was done" - - def test_different_actual_tokens_same_bucket_get_same_cached_tactic(monkeypatch): """Multiple actual num_tokens that map to the same bucket should all receive the same cached tactic — confirming the autotuner uses the @@ -269,7 +249,6 @@ def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw assert tuned_tactic[0] == 32 - # All these map to bucket 512 via last_positive_power_of_2 for actual in [513, 600, 700, 800, 900, 1000, 1023]: assert last_positive_power_of_2(actual) == 512 infer_inputs = [torch.empty((actual, hidden_size), dtype=torch.float32)] @@ -279,158 +258,3 @@ def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw assert tactic == tuned_tactic, ( f"Expected cached tactic {tuned_tactic} for num_tokens={actual}, got {tactic}" ) - - -# =================================================================== -# SM100 integration test — exercises the real C++ launcher -# =================================================================== - - -def _prepare_bf16_moe_weights(num_experts, intermediate_size, hidden_size, device): - """Prepare shuffled BF16 weights in BlockMajorK layout.""" - from flashinfer import shuffle_matrix_a - from flashinfer.fused_moe import convert_to_block_layout - - gemm1 = torch.randn( - num_experts, - 2 * intermediate_size, - hidden_size, - device=device, - dtype=torch.bfloat16, - ) - gemm2 = torch.randn( - num_experts, hidden_size, intermediate_size, device=device, dtype=torch.bfloat16 - ) - g1_shuffled, g2_shuffled = [], [] - for i in range(num_experts): - g1_shuffled.append( - convert_to_block_layout( - shuffle_matrix_a(gemm1[i].view(torch.uint8), 64), 128 - ) - ) - g2_shuffled.append( - convert_to_block_layout( - shuffle_matrix_a(gemm2[i].view(torch.uint8), 64), 128 - ) - ) - return ( - torch.stack(g1_shuffled).view(torch.bfloat16), - torch.stack(g2_shuffled).view(torch.bfloat16), - ) - - -def test_bf16_moe_tile_mismatch_no_crash_after_fix(monkeypatch): - """SM100 integration test: tune BF16 MoE with one num_tokens, then - infer with a different num_tokens that maps to the same autotuner - bucket but selects different C++ tiles. - - Before the C++ fix this raised ``RuntimeError: unordered_map::at``. - After the fix the launcher falls back to a default tile gracefully. - """ - compute_capability = get_compute_capability(torch.device("cuda")) - if compute_capability[0] not in [10]: - pytest.skip("This test requires an SM100/SM103 (Blackwell) GPU.") - - from flashinfer.fused_moe import trtllm_bf16_moe, WeightLayout - - _reset_autotuner() - torch.manual_seed(42) - device = torch.device("cuda:0") - - # Model dimensions — keep small for speed - num_experts = 8 - top_k = 2 - hidden_size = 1024 - intermediate_size = 1024 - - # Prepare weights once (they don't depend on num_tokens) - gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( - num_experts, intermediate_size, hidden_size, device - ) - - # --- Choose num_tokens that trigger the tile mismatch --- - # BF16 MoE supported tiles: {8, 16, 32, 64, 128} - # tune_num_tokens=256 (bucket): avg=256*2/8=64 → tiles={32,64,128} - # infer_num_tokens=500 (bucket=256): avg=500*2/8=125 → tiles={64,128} - # A tactic with tile_N=32 is valid for bucket but NOT for actual. - tune_num_tokens = 256 - infer_num_tokens = 500 - assert last_positive_power_of_2(infer_num_tokens) == tune_num_tokens - - tune_max = 4096 - - def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): - """Force tile_N=32 to be selected as the 'fastest' tactic.""" - tile_n = tactic[0] if isinstance(tactic, list) else -1 - return 1.0 if tile_n == 32 else 5.0 - - monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) - - # --- Phase 1: Tune with tune_num_tokens --- - routing_tune = torch.rand( - tune_num_tokens, num_experts, device=device, dtype=torch.bfloat16 - ) - hidden_tune = torch.randn( - tune_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 - ) - - with autotune(tune_mode=True): - trtllm_bf16_moe( - routing_logits=routing_tune, - routing_bias=None, - hidden_states=hidden_tune, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - num_experts=num_experts, - top_k=top_k, - n_group=None, - topk_group=None, - intermediate_size=intermediate_size, - local_expert_offset=0, - local_num_experts=num_experts, - routed_scaling_factor=None, - routing_method_type=1, # Renormalize (TopK/Default not supported for BF16) - use_shuffled_weight=True, - weight_layout=WeightLayout.BlockMajorK, - tune_max_num_tokens=tune_max, - ) - - # Verify the autotuner populated its cache during tuning - tuner = AutoTuner.get() - assert len(tuner.profiling_cache) > 0, ( - "Autotuner cache should be populated after tuning" - ) - - # --- Phase 2: Inference with infer_num_tokens (same bucket, different tiles) --- - # Without the C++ fix, this would crash with RuntimeError: unordered_map::at - # because tile_N=32 is not in computeSelectedTileN(500, 2, 8) = {64, 128}. - routing_infer = torch.rand( - infer_num_tokens, num_experts, device=device, dtype=torch.bfloat16 - ) - hidden_infer = torch.randn( - infer_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 - ) - - # This call should NOT crash (with the C++ fix it falls back to a valid tile) - output = trtllm_bf16_moe( - routing_logits=routing_infer, - routing_bias=None, - hidden_states=hidden_infer, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - num_experts=num_experts, - top_k=top_k, - n_group=None, - topk_group=None, - intermediate_size=intermediate_size, - local_expert_offset=0, - local_num_experts=num_experts, - routed_scaling_factor=None, - routing_method_type=1, # Renormalize - use_shuffled_weight=True, - weight_layout=WeightLayout.BlockMajorK, - tune_max_num_tokens=tune_max, - ) - - assert output.shape[0] == infer_num_tokens - assert output.isfinite().all(), "Output should be finite" diff --git a/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py b/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py new file mode 100644 index 0000000000..227fc282e4 --- /dev/null +++ b/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py @@ -0,0 +1,513 @@ +"""Integration tests for TRTLLM MoE launcher fallback and wrapper contracts.""" + +import pytest +import torch + +from flashinfer import autotune +from flashinfer.autotuner import AutoTuner +from flashinfer.utils import get_compute_capability + +TUNE_MAX = 8192 + + +def _reset_autotuner() -> AutoTuner: + tuner = AutoTuner.get() + tuner.clear_cache() + tuner.reset_statistics() + tuner.is_tuning_mode = False + return tuner + + +def _prepare_bf16_moe_weights(num_experts, intermediate_size, hidden_size, device): + """Prepare shuffled BF16 weights in BlockMajorK layout.""" + from flashinfer import shuffle_matrix_a + from flashinfer.fused_moe import convert_to_block_layout + + gemm1 = torch.randn( + num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ) + gemm2 = torch.randn( + num_experts, hidden_size, intermediate_size, device=device, dtype=torch.bfloat16 + ) + g1_shuffled, g2_shuffled = [], [] + for i in range(num_experts): + g1_shuffled.append( + convert_to_block_layout( + shuffle_matrix_a(gemm1[i].view(torch.uint8), 64), 128 + ) + ) + g2_shuffled.append( + convert_to_block_layout( + shuffle_matrix_a(gemm2[i].view(torch.uint8), 64), 128 + ) + ) + return ( + torch.stack(g1_shuffled).view(torch.bfloat16), + torch.stack(g2_shuffled).view(torch.bfloat16), + ) + + +def _overwrite_cached_tactic_for_op(custom_op: str, new_tactic): + """Overwrite cached tactics for one op key.""" + tuner = AutoTuner.get() + updated = 0 + for key, (runner_id, _tactic, profile) in list(tuner.profiling_cache.items()): + if key[0] == custom_op: + tuner.profiling_cache[key] = (runner_id, new_tactic, profile) + updated += 1 + assert updated > 0, f"No autotuner cache entries found for {custom_op}" + + +def _tune_bf16_moe_once( + *, + device, + tune_num_tokens: int, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + gemm1_weights, + gemm2_weights, + tune_max: int, +): + from flashinfer.fused_moe import trtllm_bf16_moe, WeightLayout + + routing_tune = torch.rand( + tune_num_tokens, num_experts, device=device, dtype=torch.bfloat16 + ) + hidden_tune = torch.randn( + tune_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + with autotune(tune_mode=True): + trtllm_bf16_moe( + routing_logits=routing_tune, + routing_bias=None, + hidden_states=hidden_tune, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize + use_shuffled_weight=True, + weight_layout=WeightLayout.BlockMajorK, + tune_max_num_tokens=tune_max, + ) + + +def _run_bf16_moe_infer( + *, + device, + infer_num_tokens: int, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + gemm1_weights, + gemm2_weights, + tune_max: int, +): + from flashinfer.fused_moe import trtllm_bf16_moe, WeightLayout + + routing_infer = torch.rand( + infer_num_tokens, num_experts, device=device, dtype=torch.bfloat16 + ) + hidden_infer = torch.randn( + infer_num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) + output = trtllm_bf16_moe( + routing_logits=routing_infer, + routing_bias=None, + hidden_states=hidden_infer, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize + use_shuffled_weight=True, + weight_layout=WeightLayout.BlockMajorK, + tune_max_num_tokens=tune_max, + ) + assert output.shape[0] == infer_num_tokens + assert output.isfinite().all(), "Output should be finite" + + +def _require_sm100(): + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] not in [10]: + pytest.skip("This test requires an SM100/SM103 (Blackwell) GPU.") + + +@pytest.mark.parametrize("fp8_quantization_type", ["DeepSeekFp8", "MxFp8"]) +def test_fp8_block_scale_moe_deepseek_contract_args(monkeypatch, fp8_quantization_type): + """Contract test: wrapper forwards DeepSeek-style routing/group arguments unchanged.""" + from flashinfer.fused_moe import core as moe_core + + captured = {} + + def fake_trtllm_fp8_block_scale_moe(self, *args): + captured["args"] = args + hidden_states = args[4] + return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + + fake_module = type( + "FakeMoeModule", + (), + {"trtllm_fp8_block_scale_moe": fake_trtllm_fp8_block_scale_moe}, + )() + monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) + + seq_len = 8 + hidden_size = 128 + intermediate_size = 256 + num_experts = 256 + top_k = 8 + n_group = 8 + topk_group = 4 + routed_scaling_factor = 2.5 + tune_max_num_tokens = TUNE_MAX + + routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) + routing_bias = torch.randn(num_experts, dtype=torch.float32) + hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) + hidden_states_scale = torch.ones(hidden_size // 128, seq_len, dtype=torch.float32) + gemm1_weights = torch.empty(1, dtype=torch.float32) + gemm1_weights_scale = torch.empty(1, dtype=torch.float32) + gemm2_weights = torch.empty(1, dtype=torch.float32) + gemm2_weights_scale = torch.empty(1, dtype=torch.float32) + + quant_type = getattr(moe_core.Fp8QuantizationType, fp8_quantization_type) + output = moe_core.trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + num_experts=num_experts, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=2, # DeepSeekV3 routing mode + use_shuffled_weight=True, + weight_layout=moe_core.WeightLayout.BlockMajorK, + do_finalize=True, + enable_pdl=True, + tune_max_num_tokens=tune_max_num_tokens, + fp8_quantization_type=quant_type, + ) + + assert output.shape == hidden_states.shape + assert "args" in captured, ( + "Expected wrapper to call trtllm_fp8_block_scale_moe backend" + ) + + args = captured["args"] + assert args[12] == top_k + assert args[13] == n_group + assert args[14] == topk_group + assert args[18] == routed_scaling_factor + assert args[19] == 2 + assert args[20] is True + assert args[21] == int(moe_core.WeightLayout.BlockMajorK) + assert args[24] == tune_max_num_tokens + assert args[25] == quant_type + + +def test_fp4_block_scale_moe_contract_args(monkeypatch): + """Contract test: FP4 wrapper forwards core MoE routing/kernel args unchanged.""" + from flashinfer.fused_moe import core as moe_core + + captured = {} + + def fake_trtllm_fp4_block_scale_moe(self, *args): + captured["args"] = args + hidden_states = args[4] + return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + + fake_module = type( + "FakeMoeModule", + (), + {"trtllm_fp4_block_scale_moe": fake_trtllm_fp4_block_scale_moe}, + )() + monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) + + seq_len = 8 + hidden_size = 128 + intermediate_size = 256 + num_experts = 128 + top_k = 8 + tune_max_num_tokens = TUNE_MAX + + routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) + hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) + hidden_states_scale = torch.ones(seq_len, hidden_size // 16, dtype=torch.float32) + gemm1_weights = torch.empty(1, dtype=torch.uint8) + gemm1_weights_scale = torch.empty(1, dtype=torch.float32) + gemm2_weights = torch.empty(1, dtype=torch.uint8) + gemm2_weights_scale = torch.empty(1, dtype=torch.float32) + + output = moe_core.trtllm_fp4_block_scale_moe( + routing_logits=routing_logits, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + gemm2_bias=None, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, + do_finalize=True, + enable_pdl=True, + activation_type=moe_core.ActivationType.Swiglu.value, + output=None, + tune_max_num_tokens=tune_max_num_tokens, + ) + + assert output[0].shape == hidden_states.shape + assert "args" in captured, ( + "Expected wrapper to call trtllm_fp4_block_scale_moe backend" + ) + + args = captured["args"] + assert args[18] == num_experts + assert args[19] == top_k + assert args[20] is None + assert args[21] is None + assert args[22] == intermediate_size + assert args[24] == num_experts + assert args[26] == 1 + assert args[27] is True + assert args[31] == tune_max_num_tokens + + +def test_bf16_moe_qwen35_contract_args(monkeypatch): + """Contract test: BF16 wrapper forwards Qwen3.5-style ungrouped routing args unchanged.""" + from flashinfer.fused_moe import core as moe_core + + captured = {} + + def fake_trtllm_bf16_moe(self, *args): + captured["args"] = args + hidden_states = args[4] + return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + + fake_module = type("FakeMoeModule", (), {"trtllm_bf16_moe": fake_trtllm_bf16_moe})() + monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) + + seq_len = 8 + hidden_size = 128 + intermediate_size = 1024 + num_experts = 512 + top_k = 10 + tune_max_num_tokens = TUNE_MAX + + routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) + hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) + gemm1_weights = torch.empty(1, dtype=torch.bfloat16) + gemm2_weights = torch.empty(1, dtype=torch.bfloat16) + + output = moe_core.trtllm_bf16_moe( + routing_logits=routing_logits, + routing_bias=None, + hidden_states=hidden_states, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=0, + use_shuffled_weight=True, + weight_layout=moe_core.WeightLayout.BlockMajorK, + do_finalize=True, + enable_pdl=True, + tune_max_num_tokens=tune_max_num_tokens, + ) + + assert output.shape == hidden_states.shape + assert "args" in captured, "Expected wrapper to call trtllm_bf16_moe backend" + + args = captured["args"] + assert args[7] == num_experts + assert args[8] == top_k + assert args[9] is None + assert args[10] is None + assert args[11] == intermediate_size + assert args[13] == num_experts + assert args[16] is True + assert args[17] == int(moe_core.WeightLayout.BlockMajorK) + assert args[20] == tune_max_num_tokens + + +@pytest.mark.parametrize( + "top_k,num_experts", + [ + (2, 8), + (4, 16), + (8, 128), # Qwen3-VL-MoE-like (text) + (10, 128), # Routing kernel currently supports top_k <= 10 + ], +) +def test_bf16_moe_tile_mismatch_no_crash_after_fix(monkeypatch, top_k, num_experts): + """SM100 BF16 integration: cached mismatched tile falls back without crash.""" + from flashinfer.fused_moe.utils import last_positive_power_of_2 + + _require_sm100() + _reset_autotuner() + torch.manual_seed(42) + device = torch.device("cuda:0") + + hidden_size = 1024 + intermediate_size = 1024 + + gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( + num_experts, intermediate_size, hidden_size, device + ) + + tune_num_tokens = 256 + infer_num_tokens = 500 + assert last_positive_power_of_2(infer_num_tokens) == tune_num_tokens + + tune_max = TUNE_MAX + + def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + """Force tile_N=32 to be selected as the 'fastest' tactic.""" + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + _tune_bf16_moe_once( + device=device, + tune_num_tokens=tune_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) + + tuner = AutoTuner.get() + assert len(tuner.profiling_cache) > 0, ( + "Autotuner cache should be populated after tuning" + ) + + _run_bf16_moe_infer( + device=device, + infer_num_tokens=infer_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) + + +@pytest.mark.parametrize( + "top_k,num_experts", + [ + (2, 8), + (4, 16), + (8, 128), # Qwen3-VL-MoE-like (text) + (10, 128), # Routing kernel currently supports top_k <= 10 + ], +) +@pytest.mark.parametrize( + "seed,stale_tactic", + [ + (123, [4096, 0]), # unsupported tile_N + (124, [32, 10_000_000]), # invalid config index + (125, [32]), # malformed tactic payload + ], +) +def test_bf16_moe_stale_cached_tactic_falls_back_no_crash( + monkeypatch, top_k, num_experts, seed, stale_tactic +): + """SM100 integration: stale cache payloads should fall back without crashing.""" + _require_sm100() + _reset_autotuner() + torch.manual_seed(seed) + device = torch.device("cuda:0") + + hidden_size, intermediate_size = 1024, 1024 + tune_num_tokens, infer_num_tokens = 256, 500 + tune_max = TUNE_MAX + + gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( + num_experts, intermediate_size, hidden_size, device + ) + + def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + tile_n = tactic[0] if isinstance(tactic, list) else -1 + return 1.0 if tile_n == 32 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + _tune_bf16_moe_once( + device=device, + tune_num_tokens=tune_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) + + _overwrite_cached_tactic_for_op("flashinfer::trtllm_bf16_moe", stale_tactic) + _run_bf16_moe_infer( + device=device, + infer_num_tokens=infer_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=tune_max, + ) From 4a42e7a461356b8e9e68320c4346559b645771b3 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:12:37 +0000 Subject: [PATCH 06/12] Restore autotuner fix of support non-power-of-two cache lookup from https://github.com/flashinfer-ai/flashinfer/pull/2617 Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- flashinfer/autotuner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 641b1c79e6..1ec1b58fda 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -1037,11 +1037,12 @@ def _find_nearest_profile( base_profile = list(list(shape) for shape in shapes) for spec in tuning_config.dynamic_tensor_specs: - base_profile[spec.input_idx[0]][spec.dim_idx[0]] = ( - spec.map_to_tuning_buckets( - base_profile[spec.input_idx[0]][spec.dim_idx[0]] - ) + mapped_val = spec.map_to_tuning_buckets( + base_profile[spec.input_idx[0]][spec.dim_idx[0]] ) + # Apply the same mapped bucket to all linked dimensions in this spec. + for input_i, dim_i in zip(spec.input_idx, spec.dim_idx, strict=True): + base_profile[input_i][dim_i] = mapped_val # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile for constraint_spec in tuning_config.constraint_specs: From 0c843c0326be95e1252444cebd0e8b1e90e25b8b Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:27:21 +0000 Subject: [PATCH 07/12] Don't allow invalid tactic tuple, update tests accordingly, improve test of supporting tileN that was filtered out by computeSelectedTileN Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 80 +--- ...st_trtllm_moe_tile_mismatch_integration.py | 360 +++++------------- 2 files changed, 102 insertions(+), 338 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index d7fe85417b..dd206d71b0 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -132,50 +132,23 @@ int64_t selectDefaultTileN(std::vector const& supported_tile_nums, return *selected.begin(); } -// Resolve the (tile_N, config) pair passed from Python into a launcher-safe pair. -// -// Interaction with Python AutoTuner: -// - Python AutoTuner stores and returns "tactics" as a 2-int payload [tile_N, config]. -// - Tuning/profiling uses bucketed num_tokens (power-of-2 buckets), while runtime uses actual -// num_tokens. Therefore, a cached tile_N chosen for the bucket may be suboptimal or unsupported -// for the runtime shape. -// - Tactics may also come from persisted config files and can be stale/malformed across versions. -// -// This resolver makes C++ robust to those mismatches: -// 1) malformed/missing tactic payload -> fallback to runtime-selected default tile/config -// 2) fallback marker from Python (-1, -1) -> runtime-selected default tile/config -// 3) stale tile_N not in current supported set -> runtime-selected default tile/config -// -// We intentionally do not fully validate `config` here because config validity depends on the -// concrete Runner instance and current shape. That validation/fallback happens later in -// FusedMoeLauncher::prepare_moe_common(). +// Resolve the (tile_N, config) pair passed from Python side, applying fallback logic +// when tile_N is -1. std::pair resolveMoeTileAndConfig(Array const& config_index, std::vector const& supported_tile_nums, int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { - int64_t tile_N = -1; - int64_t config = -1; - // Python side convention: tactic is [tile_N, config] - if (config_index.size() >= 2) { - tile_N = config_index[0]; - config = config_index[1]; - } + TVM_FFI_ICHECK(config_index.size() == 2) + << "Invalid tactic, expected to be [tile_N, config], but got array of size " + << config_index.size(); + const int64_t tile_N = config_index[0]; + const int64_t config = config_index[1]; - // Runtime default is selected from actual shape stats (num_tokens/top_k/local_num_experts), - // which may differ from the bucketed shape that produced the cached tactic. - auto const default_tile_N = - selectDefaultTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); if (tile_N == -1 || config == -1) { // Use fallback tactic - return {default_tile_N, -1}; - } - - bool const tile_supported = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), - static_cast(tile_N)) != supported_tile_nums.end(); - if (!tile_supported) { - // Tactic can be stale due to cache/file mismatch across versions or shape bucketing mismatch. - // Fall back to a runtime-selected tile and default config instead of crashing. + auto const default_tile_N = + selectDefaultTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); return {default_tile_N, -1}; } @@ -410,38 +383,19 @@ class FusedMoeLauncher { this->activation_type, this->use_shuffled_weight, this->weight_layout); } - auto valid_cfgs = - moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, - args->local_num_experts, args->num_tokens); - FLASHINFER_CHECK(!valid_cfgs.empty(), "No valid MoE tactics for tile_N=", tile_tokens_dim, - " with shape (num_tokens=", args->num_tokens, - ", hidden_size=", args->hidden_size, - ", intermediate_size=", args->intermediate_size, ", top_k=", args->top_k, - ", local_num_experts=", args->local_num_experts, ")."); - int64_t default_tactic = -1; if (moe_tactic == -1) { - default_tactic = moe_runner->getDefaultValidConfigIndex( + moe_tactic = moe_runner->getDefaultValidConfigIndex( args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, args->num_tokens); - moe_tactic = default_tactic; } + auto valid_cfgs = + moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, + args->local_num_experts, args->num_tokens); auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic); - if (valid_it == valid_cfgs.end()) { - if (default_tactic == -1) { - default_tactic = moe_runner->getDefaultValidConfigIndex( - args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, - args->num_tokens); - } - auto default_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), default_tactic); - FLASHINFER_CHECK( - default_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, - " for tile_N=", tile_tokens_dim, - ". No valid fallback default tactic found for shape (num_tokens=", args->num_tokens, - ", hidden_size=", args->hidden_size, ", intermediate_size=", args->intermediate_size, - ", top_k=", args->top_k, ", local_num_experts=", args->local_num_experts, - "). This indicates a deeper kernel/config mismatch."); - moe_tactic = default_tactic; - } + FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, + " for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ", + valid_cfgs.size(), + ". This often indicates a stale or mismatched autotuner cache entry."); this->moe_tactic = moe_tactic; auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); diff --git a/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py b/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py index 227fc282e4..129e8d6f47 100644 --- a/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py +++ b/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py @@ -3,9 +3,10 @@ import pytest import torch -from flashinfer import autotune +from flashinfer import autotune, RoutingMethodType from flashinfer.autotuner import AutoTuner from flashinfer.utils import get_compute_capability +import contextlib TUNE_MAX = 8192 @@ -97,7 +98,7 @@ def _tune_bf16_moe_once( local_expert_offset=0, local_num_experts=num_experts, routed_scaling_factor=None, - routing_method_type=1, # Renormalize + routing_method_type=RoutingMethodType.Renormalize.value, use_shuffled_weight=True, weight_layout=WeightLayout.BlockMajorK, tune_max_num_tokens=tune_max, @@ -138,7 +139,7 @@ def _run_bf16_moe_infer( local_expert_offset=0, local_num_experts=num_experts, routed_scaling_factor=None, - routing_method_type=1, # Renormalize + routing_method_type=RoutingMethodType.Renormalize.value, use_shuffled_weight=True, weight_layout=WeightLayout.BlockMajorK, tune_max_num_tokens=tune_max, @@ -147,238 +148,45 @@ def _run_bf16_moe_infer( assert output.isfinite().all(), "Output should be finite" -def _require_sm100(): - compute_capability = get_compute_capability(torch.device("cuda")) - if compute_capability[0] not in [10]: - pytest.skip("This test requires an SM100/SM103 (Blackwell) GPU.") - - -@pytest.mark.parametrize("fp8_quantization_type", ["DeepSeekFp8", "MxFp8"]) -def test_fp8_block_scale_moe_deepseek_contract_args(monkeypatch, fp8_quantization_type): - """Contract test: wrapper forwards DeepSeek-style routing/group arguments unchanged.""" - from flashinfer.fused_moe import core as moe_core - - captured = {} - - def fake_trtllm_fp8_block_scale_moe(self, *args): - captured["args"] = args - hidden_states = args[4] - return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] - - fake_module = type( - "FakeMoeModule", - (), - {"trtllm_fp8_block_scale_moe": fake_trtllm_fp8_block_scale_moe}, - )() - monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) - - seq_len = 8 - hidden_size = 128 - intermediate_size = 256 - num_experts = 256 - top_k = 8 - n_group = 8 - topk_group = 4 - routed_scaling_factor = 2.5 - tune_max_num_tokens = TUNE_MAX - - routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) - routing_bias = torch.randn(num_experts, dtype=torch.float32) - hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) - hidden_states_scale = torch.ones(hidden_size // 128, seq_len, dtype=torch.float32) - gemm1_weights = torch.empty(1, dtype=torch.float32) - gemm1_weights_scale = torch.empty(1, dtype=torch.float32) - gemm2_weights = torch.empty(1, dtype=torch.float32) - gemm2_weights_scale = torch.empty(1, dtype=torch.float32) - - quant_type = getattr(moe_core.Fp8QuantizationType, fp8_quantization_type) - output = moe_core.trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=hidden_states, - hidden_states_scale=hidden_states_scale, - gemm1_weights=gemm1_weights, - gemm1_weights_scale=gemm1_weights_scale, - gemm2_weights=gemm2_weights, - gemm2_weights_scale=gemm2_weights_scale, - num_experts=num_experts, - top_k=top_k, - n_group=n_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=0, - local_num_experts=num_experts, - routed_scaling_factor=routed_scaling_factor, - routing_method_type=2, # DeepSeekV3 routing mode - use_shuffled_weight=True, - weight_layout=moe_core.WeightLayout.BlockMajorK, - do_finalize=True, - enable_pdl=True, - tune_max_num_tokens=tune_max_num_tokens, - fp8_quantization_type=quant_type, - ) - - assert output.shape == hidden_states.shape - assert "args" in captured, ( - "Expected wrapper to call trtllm_fp8_block_scale_moe backend" - ) - - args = captured["args"] - assert args[12] == top_k - assert args[13] == n_group - assert args[14] == topk_group - assert args[18] == routed_scaling_factor - assert args[19] == 2 - assert args[20] is True - assert args[21] == int(moe_core.WeightLayout.BlockMajorK) - assert args[24] == tune_max_num_tokens - assert args[25] == quant_type - - -def test_fp4_block_scale_moe_contract_args(monkeypatch): - """Contract test: FP4 wrapper forwards core MoE routing/kernel args unchanged.""" - from flashinfer.fused_moe import core as moe_core - - captured = {} - - def fake_trtllm_fp4_block_scale_moe(self, *args): - captured["args"] = args - hidden_states = args[4] - return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] - - fake_module = type( - "FakeMoeModule", - (), - {"trtllm_fp4_block_scale_moe": fake_trtllm_fp4_block_scale_moe}, - )() - monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) - - seq_len = 8 - hidden_size = 128 - intermediate_size = 256 - num_experts = 128 - top_k = 8 - tune_max_num_tokens = TUNE_MAX - - routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) - hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) - hidden_states_scale = torch.ones(seq_len, hidden_size // 16, dtype=torch.float32) - gemm1_weights = torch.empty(1, dtype=torch.uint8) - gemm1_weights_scale = torch.empty(1, dtype=torch.float32) - gemm2_weights = torch.empty(1, dtype=torch.uint8) - gemm2_weights_scale = torch.empty(1, dtype=torch.float32) - - output = moe_core.trtllm_fp4_block_scale_moe( - routing_logits=routing_logits, - routing_bias=None, - hidden_states=hidden_states, - hidden_states_scale=hidden_states_scale, - gemm1_weights=gemm1_weights, - gemm1_weights_scale=gemm1_weights_scale, - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=gemm2_weights, - gemm2_weights_scale=gemm2_weights_scale, - gemm2_bias=None, - output1_scale_scalar=None, - output1_scale_gate_scalar=None, - output2_scale_scalar=None, - num_experts=num_experts, - top_k=top_k, - n_group=None, - topk_group=None, - intermediate_size=intermediate_size, - local_expert_offset=0, - local_num_experts=num_experts, - routed_scaling_factor=None, - routing_method_type=1, - do_finalize=True, - enable_pdl=True, - activation_type=moe_core.ActivationType.Swiglu.value, - output=None, - tune_max_num_tokens=tune_max_num_tokens, - ) - - assert output[0].shape == hidden_states.shape - assert "args" in captured, ( - "Expected wrapper to call trtllm_fp4_block_scale_moe backend" - ) +def _compute_selected_tile_n_base_element( + num_tokens: int, top_k: int, num_experts: int +) -> int: + """Compute the base element used by computeSelectedTileN(num_tokens) to filter tile_N candidates.""" + from flashinfer.fused_moe.utils import next_positive_power_of_2 - args = captured["args"] - assert args[18] == num_experts - assert args[19] == top_k - assert args[20] is None - assert args[21] is None - assert args[22] == intermediate_size - assert args[24] == num_experts - assert args[26] == 1 - assert args[27] is True - assert args[31] == tune_max_num_tokens + return next_positive_power_of_2(int(num_tokens * top_k / num_experts)) -def test_bf16_moe_qwen35_contract_args(monkeypatch): - """Contract test: BF16 wrapper forwards Qwen3.5-style ungrouped routing args unchanged.""" - from flashinfer.fused_moe import core as moe_core +def _make_tile_bias(target_tile_n: int): + """Return a profiling stub that favours *target_tile_n* over every other tile.""" - captured = {} + def _bias(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + tile_n = -1 + with contextlib.suppress(BaseException): + tile_n = tactic[0] + return 1.0 if tile_n == target_tile_n else 5.0 - def fake_trtllm_bf16_moe(self, *args): - captured["args"] = args - hidden_states = args[4] - return [hidden_states.new_empty(hidden_states.shape, dtype=torch.bfloat16)] + return _bias - fake_module = type("FakeMoeModule", (), {"trtllm_bf16_moe": fake_trtllm_bf16_moe})() - monkeypatch.setattr(moe_core, "get_trtllm_moe_sm100_module", lambda: fake_module) - seq_len = 8 - hidden_size = 128 - intermediate_size = 1024 - num_experts = 512 - top_k = 10 - tune_max_num_tokens = TUNE_MAX - - routing_logits = torch.randn(seq_len, num_experts, dtype=torch.float32) - hidden_states = torch.randn(seq_len, hidden_size, dtype=torch.bfloat16) - gemm1_weights = torch.empty(1, dtype=torch.bfloat16) - gemm2_weights = torch.empty(1, dtype=torch.bfloat16) +def _make_tile_bias_min_tile_n(num_tokens: int, top_k: int, num_experts: int): + """Return a profiling stub that favours the minimum tile_N in computeSelectedTileN(num_tokens).""" - output = moe_core.trtllm_bf16_moe( - routing_logits=routing_logits, - routing_bias=None, - hidden_states=hidden_states, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - num_experts=num_experts, - top_k=top_k, - n_group=None, - topk_group=None, - intermediate_size=intermediate_size, - local_expert_offset=0, - local_num_experts=num_experts, - routed_scaling_factor=None, - routing_method_type=0, - use_shuffled_weight=True, - weight_layout=moe_core.WeightLayout.BlockMajorK, - do_finalize=True, - enable_pdl=True, - tune_max_num_tokens=tune_max_num_tokens, + # The values computeSelectedTileN keeps from a sorted supported tileN list: + # - "base element": roundUpToPowerOf2(num_tokens*top_k/num_experts) + # - its previous element + # - its next two elements + # So the min tileN is the "base element" divided by 2, with a minimum of 1. + target_tile_n = max( + 1, _compute_selected_tile_n_base_element(num_tokens, top_k, num_experts) // 2 ) + return _make_tile_bias(target_tile_n) - assert output.shape == hidden_states.shape - assert "args" in captured, "Expected wrapper to call trtllm_bf16_moe backend" - args = captured["args"] - assert args[7] == num_experts - assert args[8] == top_k - assert args[9] is None - assert args[10] is None - assert args[11] == intermediate_size - assert args[13] == num_experts - assert args[16] is True - assert args[17] == int(moe_core.WeightLayout.BlockMajorK) - assert args[20] == tune_max_num_tokens +def _require_sm100(): + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] not in [10]: + pytest.skip("This test requires an SM100/SM103 (Blackwell) GPU.") @pytest.mark.parametrize( @@ -390,8 +198,22 @@ def fake_trtllm_bf16_moe(self, *args): (10, 128), # Routing kernel currently supports top_k <= 10 ], ) -def test_bf16_moe_tile_mismatch_no_crash_after_fix(monkeypatch, top_k, num_experts): - """SM100 BF16 integration: cached mismatched tile falls back without crash.""" +@pytest.mark.parametrize( + "tune_num_tokens,infer_num_tokens", + [ + (256, 500), + ], +) +def test_bf16_moe_cached_tile_excluded_at_inference_succeeds( + monkeypatch, + top_k: int, + num_experts: int, + tune_num_tokens: int, + infer_num_tokens: int, +): + """SM100 BF16 integration: Test that MoE works when given by the autotuner a cached + tileN that's filtered out by computeSelectedTileN for the given inference num tokens. + """ from flashinfer.fused_moe.utils import last_positive_power_of_2 _require_sm100() @@ -406,18 +228,21 @@ def test_bf16_moe_tile_mismatch_no_crash_after_fix(monkeypatch, top_k, num_exper num_experts, intermediate_size, hidden_size, device ) - tune_num_tokens = 256 - infer_num_tokens = 500 + tile_n_base_autotune_cache = _compute_selected_tile_n_base_element( + tune_num_tokens, top_k, num_experts + ) + tile_n_base_inference = _compute_selected_tile_n_base_element( + infer_num_tokens, top_k, num_experts + ) + assert tile_n_base_autotune_cache < tile_n_base_inference, ( + "Test setup error: autotuning tile_N base element should be smaller than inference tile_N base element to trigger the intended scenario" + ) assert last_positive_power_of_2(infer_num_tokens) == tune_num_tokens - tune_max = TUNE_MAX - - def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): - """Force tile_N=32 to be selected as the 'fastest' tactic.""" - tile_n = tactic[0] if isinstance(tactic, list) else -1 - return 1.0 if tile_n == 32 else 5.0 - - monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + min_tile_n_bias_function = _make_tile_bias_min_tile_n( + tune_num_tokens, top_k, num_experts + ) + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", min_tile_n_bias_function) _tune_bf16_moe_once( device=device, tune_num_tokens=tune_num_tokens, @@ -427,7 +252,7 @@ def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw intermediate_size=intermediate_size, gemm1_weights=gemm1_weights, gemm2_weights=gemm2_weights, - tune_max=tune_max, + tune_max=TUNE_MAX, ) tuner = AutoTuner.get() @@ -444,49 +269,33 @@ def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw intermediate_size=intermediate_size, gemm1_weights=gemm1_weights, gemm2_weights=gemm2_weights, - tune_max=tune_max, + tune_max=TUNE_MAX, ) @pytest.mark.parametrize( - "top_k,num_experts", - [ - (2, 8), - (4, 16), - (8, 128), # Qwen3-VL-MoE-like (text) - (10, 128), # Routing kernel currently supports top_k <= 10 - ], -) -@pytest.mark.parametrize( - "seed,stale_tactic", + "invalid_tactic", [ - (123, [4096, 0]), # unsupported tile_N - (124, [32, 10_000_000]), # invalid config index - (125, [32]), # malformed tactic payload + [4096, 0], # unsupported tile_N + [32, 10_000_000], # invalid config index + [32], # malformed tactic payload ], ) -def test_bf16_moe_stale_cached_tactic_falls_back_no_crash( - monkeypatch, top_k, num_experts, seed, stale_tactic -): - """SM100 integration: stale cache payloads should fall back without crashing.""" +def test_bf16_moe_invalid_tactic_raises_runtime_error(monkeypatch, invalid_tactic): + """SM100 integration: invalid tactics should fail.""" _require_sm100() _reset_autotuner() - torch.manual_seed(seed) device = torch.device("cuda:0") hidden_size, intermediate_size = 1024, 1024 + top_k, num_experts = 4, 16 tune_num_tokens, infer_num_tokens = 256, 500 - tune_max = TUNE_MAX gemm1_weights, gemm2_weights = _prepare_bf16_moe_weights( num_experts, intermediate_size, hidden_size, device ) - def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): - tile_n = tactic[0] if isinstance(tactic, list) else -1 - return 1.0 if tile_n == 32 else 5.0 - - monkeypatch.setattr(AutoTuner, "_profile_single_kernel", bias_tile_32) + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", _make_tile_bias(32)) _tune_bf16_moe_once( device=device, tune_num_tokens=tune_num_tokens, @@ -496,18 +305,19 @@ def bias_tile_32(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw intermediate_size=intermediate_size, gemm1_weights=gemm1_weights, gemm2_weights=gemm2_weights, - tune_max=tune_max, + tune_max=TUNE_MAX, ) - _overwrite_cached_tactic_for_op("flashinfer::trtllm_bf16_moe", stale_tactic) - _run_bf16_moe_infer( - device=device, - infer_num_tokens=infer_num_tokens, - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - tune_max=tune_max, - ) + _overwrite_cached_tactic_for_op("flashinfer::trtllm_bf16_moe", invalid_tactic) + with pytest.raises(RuntimeError): + _run_bf16_moe_infer( + device=device, + infer_num_tokens=infer_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=TUNE_MAX, + ) From e90b66bf2342cfd8379ed2bbbab47358ff212c03 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:51:08 +0000 Subject: [PATCH 08/12] Unskip tests related to bug in _find_nearest_profile Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/autotuner/test_autotuner_core.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/tests/autotuner/test_autotuner_core.py b/tests/autotuner/test_autotuner_core.py index 1a80a5f924..289919daf0 100644 --- a/tests/autotuner/test_autotuner_core.py +++ b/tests/autotuner/test_autotuner_core.py @@ -126,17 +126,14 @@ def test_find_nearest_profile_single_tensor_bucketization_exact_powers( "num_tokens,expected_bucket", [ (1000, 512), + (1024, 1024), (4000, 2048), + (4096, 4096), (8000, 4096), - (12000, 8192), + (8192, 8192), + (10000, 8192), ], ) -@pytest.mark.skip( - reason=( - "_find_nearest_profile linked-dimension mapping was reverted; " - "re-enable when linked-dim bucket propagation is restored." - ) -) def test_find_nearest_profile_moe_shared_num_tokens_axis(num_tokens, expected_bucket): """MoE linked tensors should all map num_tokens together to one bucket.""" gen_tuning_buckets = (512, 1024, 2048, 4096, 8192, 16384) @@ -161,12 +158,6 @@ def test_find_nearest_profile_moe_shared_num_tokens_axis(num_tokens, expected_bu assert nearest_shape[1:] == original_shape[1:] -@pytest.mark.skip( - reason=( - "_find_nearest_profile linked-dimension mapping was reverted; " - "re-enable when linked-dim bucket propagation is restored." - ) -) def test_find_nearest_profile_moe_same_bucket_same_profile(): """MoE inputs mapping to the same bucket should share an identical profile.""" config = TuningConfig( @@ -185,12 +176,6 @@ def test_find_nearest_profile_moe_same_bucket_same_profile(): assert p1 == p2 -@pytest.mark.skip( - reason=( - "_find_nearest_profile linked-dimension mapping was reverted; " - "re-enable when linked-dim bucket propagation is restored." - ) -) def test_find_nearest_profile_maps_all_linked_dims(): """One logical dynamic axis should update every linked tensor/dimension.""" tuning_config = TuningConfig( From 887d7a8c575cf90d7c55bb9bf448b899d9abb3b1 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 19 Mar 2026 16:52:27 +0000 Subject: [PATCH 09/12] Remove is_sorted check in computeSelectedTileN, remove inline comments Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index dd206d71b0..eaece6c20c 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -86,8 +86,6 @@ std::set computeSelectedTileN(std::vector const& supported_til int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { TVM_FFI_ICHECK(!supported_tile_nums.empty()) << "supported_tile_nums must not be empty."; - TVM_FFI_ICHECK(std::is_sorted(supported_tile_nums.begin(), supported_tile_nums.end())) - << "supported_tile_nums must be sorted in ascending order."; float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; // NOTE: This differs from Python AutoTuner bucketing: // - AutoTuner maps raw num_tokens with last_positive_power_of_2 (round-down). @@ -110,14 +108,12 @@ std::set computeSelectedTileN(std::vector const& supported_til std::set selected_tile_nums; selected_tile_nums.insert(tile_tokens_dim); if (std::next(it) != supported_tile_nums.end()) { - // Prefer exploring larger tiles too because they can win for dense routing/shape regimes. selected_tile_nums.insert(*std::next(it)); if (std::next(std::next(it)) != supported_tile_nums.end()) { selected_tile_nums.insert(*std::next(std::next(it))); } } if (it != supported_tile_nums.begin()) { - // Keep one smaller neighbor to avoid over-committing to an oversized center tile. selected_tile_nums.insert(*std::prev(it)); } From 73428ee873f56c443d7bf7d00d10868ba30e2a30 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:19:43 +0000 Subject: [PATCH 10/12] Autotuner new tests improvements Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/autotuner/test_autotuner_core.py | 124 ++++++++- .../autotuner/test_autotuner_tile_mismatch.py | 260 ------------------ ...trtllm_fused_moe_autotuner_integration.py} | 100 +++---- tests/autotuner/utils.py | 9 + 4 files changed, 160 insertions(+), 333 deletions(-) delete mode 100644 tests/autotuner/test_autotuner_tile_mismatch.py rename tests/autotuner/{test_trtllm_moe_tile_mismatch_integration.py => test_trtllm_fused_moe_autotuner_integration.py} (78%) create mode 100644 tests/autotuner/utils.py diff --git a/tests/autotuner/test_autotuner_core.py b/tests/autotuner/test_autotuner_core.py index 289919daf0..e2bc18da05 100644 --- a/tests/autotuner/test_autotuner_core.py +++ b/tests/autotuner/test_autotuner_core.py @@ -1,3 +1,5 @@ +import random + import pytest import torch @@ -10,6 +12,7 @@ TunableRunner, ) from flashinfer.fused_moe.utils import last_positive_power_of_2 +from .utils import reset_autotuner def _moe_input_shapes( @@ -41,14 +44,6 @@ def forward(self, inputs, tactic: int = -1, do_preparation: bool = False, **kwar return inputs[0] -def _reset_autotuner() -> AutoTuner: - tuner = AutoTuner.get() - tuner.clear_cache() - tuner.reset_statistics() - tuner.is_tuning_mode = False - return tuner - - def test_find_nearest_profile_passthrough_without_specs(): """No dynamic/constraint specs should keep shape values unchanged.""" shapes = (torch.Size([3, 5]), torch.Size([7, 11, 13])) @@ -223,7 +218,7 @@ def test_get_cache_key_bucketization(shape_a, shape_b, expected_equal): def test_search_cache_hit_and_miss(): """search_cache should report miss before seeding and hit after seeding.""" - tuner = _reset_autotuner() + tuner = reset_autotuner() config = TuningConfig() runner = DummyRunner() shapes = (torch.Size([8, 16]),) @@ -239,7 +234,7 @@ def test_search_cache_hit_and_miss(): def test_search_cache_preserving_leading_dims_hits_while_flattened_misses(monkeypatch): """Shape-preserving reshape keeps cache-hit behavior; full flatten can change bucket/key.""" - tuner = _reset_autotuner() + tuner = reset_autotuner() runner = DummyRunner() config = TuningConfig( dynamic_tensor_specs=( @@ -295,7 +290,7 @@ def fake_profile( def test_choose_one_inference_uses_cache_or_fallback(): """Inference path should use cached tactic when present, else fallback -1.""" - tuner = _reset_autotuner() + tuner = reset_autotuner() runner = DummyRunner() inputs = [torch.empty((4, 8), dtype=torch.float32)] config = TuningConfig() @@ -315,7 +310,7 @@ def test_choose_one_inference_uses_cache_or_fallback(): def test_choose_one_tuning_selects_best_tactic_and_populates_cache(monkeypatch): """Tuning path should select lowest-profile-time tactic and cache result.""" - tuner = _reset_autotuner() + tuner = reset_autotuner() runner = DummyRunner(valid_tactics=(0, 1, 2)) inputs = [torch.empty((16, 32), dtype=torch.float32)] config = TuningConfig() @@ -338,7 +333,7 @@ def fake_profile( def test_prepare_input_tensors_reuses_static_and_recreates_dynamic(): """Profiles apply constraints, dynamic inputs are recreated, static inputs are reused.""" - tuner = _reset_autotuner() + tuner = reset_autotuner() config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -370,3 +365,106 @@ def test_prepare_input_tensors_reuses_static_and_recreates_dynamic(): assert tuple(prepared[0].shape) == (8, 4) assert prepared[0] is not inputs[0] assert prepared[1] is inputs[1] + + +class TileTacticDummyRunner(TunableRunner): + def __init__(self, supported_tiles: tuple[int, ...], num_tactics_per_tile: int = 2): + self.supported_tiles = supported_tiles + self.num_tactics_per_tile = num_tactics_per_tile + + def get_valid_tactics(self, inputs, profile): + tactics = [] + for tile in sorted(self.supported_tiles): + for cfg in range(self.num_tactics_per_tile): + tactics.append([tile, cfg]) + return tactics + + def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): + return inputs[0] + + +def test_choose_one_different_infer_tokens_same_bucket_get_same_cached_tactic( + monkeypatch, +): + """Multiple actual num_tokens that map to the same bucket should all + receive the same cached tactic - confirming the autotuner uses the + bucketed profile, not the actual shapes, for cache lookup.""" + tuner = reset_autotuner() + runner = TileTacticDummyRunner(supported_tiles=(8, 16, 32, 64)) + hidden_size = 128 + bucket_start = 512 + bucket_end = 1024 + tuning_buckets = ( + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + ) + tune_max = max(tuning_buckets) + + def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): + """When num_tokens is in the bucket [512, 1024): + return a low score (indicating good performance) if tile_n=32 and cfg=1 + When num_tokens < 512: + return a low score if tile_n=16 and cfg=1 + When num_tokens >= 1024: + return a low score if tile_n=64 and cfg=1 + For all other tile_n and cfg combinations, return a high score (indicating bad performance). + """ + if isinstance(tactic, list): + tile_n = tactic[0] + tactic_cfg = tactic[1] + else: + tile_n = -1 + tactic_cfg = -1 + num_tokens = prof_inputs[0].shape[0] + if num_tokens < bucket_start: + target_tile_n = 16 + elif bucket_start <= num_tokens < bucket_end: + target_tile_n = 32 + else: + target_tile_n = 64 + return 1.0 if tile_n == target_tile_n and tactic_cfg == 1 else 5.0 + + monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) + + tune_inputs = [torch.empty((bucket_start, hidden_size), dtype=torch.float32)] + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + input_idx=(0,), + dim_idx=(0,), + gen_tuning_buckets=tuning_buckets, + map_to_tuning_buckets=lambda x: min( + last_positive_power_of_2(x), tune_max + ), + ), + ), + ) + with autotune(tune_mode=True): + tuner.choose_one("test_same_bucket", [runner], tuning_config, tune_inputs) + + num_tokens_with_expected_tactic_list = [ + (random.randrange(bucket_start, bucket_end), [32, 1]) for _ in range(3) + ] + num_tokens_with_expected_tactic_list += [ + (random.randrange(0, bucket_start), [16, 1]) for _ in range(3) + ] + num_tokens_with_expected_tactic_list += [ + (random.randrange(bucket_end, bucket_end * 2), [64, 1]) for _ in range(3) + ] + for actual, expected_tactic in num_tokens_with_expected_tactic_list: + infer_inputs = [torch.empty((actual, hidden_size), dtype=torch.float32)] + _, tactic = tuner.choose_one( + "test_same_bucket", [runner], tuning_config, infer_inputs + ) + assert tactic == expected_tactic, ( + f"Expected cached tactic {expected_tactic} for num_tokens={actual}, got {tactic}" + ) diff --git a/tests/autotuner/test_autotuner_tile_mismatch.py b/tests/autotuner/test_autotuner_tile_mismatch.py deleted file mode 100644 index d88e373ead..0000000000 --- a/tests/autotuner/test_autotuner_tile_mismatch.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Regression tests for MoE autotuner tile/config mismatch handling.""" - -import math - -import pytest -import torch - -from flashinfer import autotune -from flashinfer.autotuner import ( - AutoTuner, - DynamicTensorSpec, - TuningConfig, - TunableRunner, -) -from flashinfer.fused_moe.utils import ( - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, -) - - -def _next_power_of_two(n: float) -> int: - if n <= 1: - return 1 - return 1 << math.ceil(math.log2(n)) - - -def compute_selected_tile_n( - supported_tiles: list[int], - num_tokens: int, - top_k: int, - num_local_experts: int, -) -> set[int]: - """Python equivalent of C++ computeSelectedTileN.""" - avg = num_tokens * top_k / num_local_experts - tile = max(supported_tiles[0], min(supported_tiles[-1], _next_power_of_two(avg))) - try: - idx = supported_tiles.index(tile) - except ValueError: - # tile not in supported_tiles, find closest - idx = min( - range(len(supported_tiles)), - key=lambda i: abs(supported_tiles[i] - tile), - ) - selected = {supported_tiles[idx]} - if idx + 1 < len(supported_tiles): - selected.add(supported_tiles[idx + 1]) - if idx + 2 < len(supported_tiles): - selected.add(supported_tiles[idx + 2]) - if idx > 0: - selected.add(supported_tiles[idx - 1]) - return selected - - -class TileAwareDummyRunner(TunableRunner): - """Dummy runner whose valid tactics depend on num_tokens. - returns different valid tactics depending on shapes.""" - - SUPPORTED_TILES = [8, 16, 32, 64] # FP4 base tiles (no 128/256 for bf16 act) - - def __init__(self, top_k: int = 8, num_local_experts: int = 64): - self.top_k = top_k - self.num_local_experts = num_local_experts - - def get_valid_tactics(self, inputs, profile): - num_tokens = inputs[0].shape[0] - selected = compute_selected_tile_n( - self.SUPPORTED_TILES, num_tokens, self.top_k, self.num_local_experts - ) - tactics = [] - for tile in sorted(selected): - for cfg in range(2): # 2 configs per tile - tactics.append([tile, cfg]) - return tactics - - def forward(self, inputs, tactic=-1, do_preparation=False, **kwargs): - return inputs[0] - - -TUNE_MAX = 8192 - -MOE_TUNING_CONFIG = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - input_idx=(0,), - dim_idx=(0,), - gen_tuning_buckets=get_last_power_of_2_num_tokens_buckets(TUNE_MAX, 1), - map_to_tuning_buckets=lambda x: min(last_positive_power_of_2(x), TUNE_MAX), - ), - ), -) - - -def _reset_autotuner() -> AutoTuner: - tuner = AutoTuner.get() - tuner.clear_cache() - tuner.reset_statistics() - tuner.is_tuning_mode = False - return tuner - - -def _find_bucket_actual_tile_mismatch_case( - *, - supported_tiles: list[int], - top_k: int, - num_local_experts: int, - tune_max: int = TUNE_MAX, -): - """Find (bucket, actual, tuned_tile) where tuned_tile is valid only for bucket.""" - for actual in range(2, tune_max + 1): - bucket = min(last_positive_power_of_2(actual), tune_max) - if bucket == actual: - continue - tiles_bucket = compute_selected_tile_n( - supported_tiles, bucket, top_k, num_local_experts - ) - tiles_actual = compute_selected_tile_n( - supported_tiles, actual, top_k, num_local_experts - ) - only_in_bucket = sorted(tiles_bucket - tiles_actual) - if only_in_bucket: - return bucket, actual, only_in_bucket[0] - return None - - -@pytest.mark.parametrize( - "top_k,num_experts", - [ - (2, 8), - (4, 16), - (8, 64), - (8, 128), # Qwen3-VL-MoE-like (text) routing fanout/expert count - (10, 512), # Qwen3.5-MoE-like (text) routing fanout/expert count - ], -) -def test_tile_set_differs_between_bucketed_and_actual_num_tokens(top_k, num_experts): - """Bucketed and actual shapes in the same bucket can still pick different tiles.""" - tiles = [8, 16, 32, 64, 128] - case = _find_bucket_actual_tile_mismatch_case( - supported_tiles=tiles, top_k=top_k, num_local_experts=num_experts - ) - assert case is not None, ( - f"No mismatch case found for top_k={top_k}, experts={num_experts}" - ) - bucket, actual, _ = case - tiles_bucket = compute_selected_tile_n(tiles, bucket, top_k, num_experts) - tiles_actual = compute_selected_tile_n(tiles, actual, top_k, num_experts) - - assert tiles_bucket != tiles_actual, ( - f"Expected tile sets to differ for bucket={bucket} vs actual={actual}, " - f"got {tiles_bucket} vs {tiles_actual}" - ) - - only_in_bucket = tiles_bucket - tiles_actual - assert len(only_in_bucket) > 0, ( - "Expected at least one tile valid for bucketed but not for actual shapes" - ) - - -@pytest.mark.parametrize( - "top_k,num_experts,hidden_size", - [ - (2, 8, 128), - (4, 16, 128), - (8, 64, 256), - (12, 128, 1024), - ( - 8, - 256, - 7168, - ), # DeepSeek-v3-like MoE dimensions (lightweight autotuner-only path) - (8, 128, 4096), # Qwen3-VL-MoE-like text dimensions (autotuner-only path) - (10, 512, 4096), # Qwen3.5-MoE-like text dimensions (autotuner-only path) - ], -) -def test_autotuner_returns_cached_tactic_for_different_actual_shape( - monkeypatch, top_k, num_experts, hidden_size -): - """Cache lookup uses bucketed profile, not raw runtime num_tokens.""" - tuner = _reset_autotuner() - runner = TileAwareDummyRunner(top_k=top_k, num_local_experts=num_experts) - case = _find_bucket_actual_tile_mismatch_case( - supported_tiles=runner.SUPPORTED_TILES, - top_k=top_k, - num_local_experts=num_experts, - ) - assert case is not None, ( - f"No mismatch case found for top_k={top_k}, experts={num_experts}" - ) - tune_bucket, infer_actual, tuned_tile = case - - def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): - # Make the chosen mismatch tile "best" so autotuner caches it. - tile_n = tactic[0] if isinstance(tactic, list) else -1 - return 1.0 if tile_n == tuned_tile else 5.0 - - monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) - - tune_inputs = [torch.empty((tune_bucket, hidden_size), dtype=torch.float32)] - with autotune(tune_mode=True): - _, tuned_tactic = tuner.choose_one( - "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, tune_inputs - ) - - assert isinstance(tuned_tactic, list) - assert tuned_tactic[0] == tuned_tile, ( - f"Expected tile_N={tuned_tile} for top_k={top_k}, experts={num_experts}, " - f"got {tuned_tactic}" - ) - - assert last_positive_power_of_2(infer_actual) == tune_bucket - infer_inputs = [torch.empty((infer_actual, hidden_size), dtype=torch.float32)] - _, infer_tactic = tuner.choose_one( - "test_tile_mismatch", [runner], MOE_TUNING_CONFIG, infer_inputs - ) - - assert infer_tactic == tuned_tactic, "Expected cache hit returning the tuned tactic" - - actual_tiles = compute_selected_tile_n( - TileAwareDummyRunner.SUPPORTED_TILES, - infer_actual, - runner.top_k, - runner.num_local_experts, - ) - assert tuned_tactic[0] not in actual_tiles, ( - f"tile_N={tuned_tactic[0]} should NOT be in actual tiles {actual_tiles} — " - "this is the mismatch that caused the C++ unordered_map::at crash" - ) - - -def test_different_actual_tokens_same_bucket_get_same_cached_tactic(monkeypatch): - """Multiple actual num_tokens that map to the same bucket should all - receive the same cached tactic — confirming the autotuner uses the - bucketed profile, not the actual shapes, for cache lookup.""" - tuner = _reset_autotuner() - runner = TileAwareDummyRunner(top_k=8, num_local_experts=64) - hidden_size = 128 - - def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): - tile_n = tactic[0] if isinstance(tactic, list) else -1 - return 1.0 if tile_n == 32 else 5.0 - - monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) - - tune_inputs = [torch.empty((512, hidden_size), dtype=torch.float32)] - with autotune(tune_mode=True): - _, tuned_tactic = tuner.choose_one( - "test_same_bucket", [runner], MOE_TUNING_CONFIG, tune_inputs - ) - - assert tuned_tactic[0] == 32 - - for actual in [513, 600, 700, 800, 900, 1000, 1023]: - assert last_positive_power_of_2(actual) == 512 - infer_inputs = [torch.empty((actual, hidden_size), dtype=torch.float32)] - _, tactic = tuner.choose_one( - "test_same_bucket", [runner], MOE_TUNING_CONFIG, infer_inputs - ) - assert tactic == tuned_tactic, ( - f"Expected cached tactic {tuned_tactic} for num_tokens={actual}, got {tactic}" - ) diff --git a/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py b/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py similarity index 78% rename from tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py rename to tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py index 129e8d6f47..8df1d1ea37 100644 --- a/tests/autotuner/test_trtllm_moe_tile_mismatch_integration.py +++ b/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py @@ -1,24 +1,17 @@ -"""Integration tests for TRTLLM MoE launcher fallback and wrapper contracts.""" +"""Integration tests for TRTLLM fused MoE launcher with autotuner.""" +import contextlib import pytest import torch from flashinfer import autotune, RoutingMethodType from flashinfer.autotuner import AutoTuner from flashinfer.utils import get_compute_capability -import contextlib +from .utils import reset_autotuner TUNE_MAX = 8192 -def _reset_autotuner() -> AutoTuner: - tuner = AutoTuner.get() - tuner.clear_cache() - tuner.reset_statistics() - tuner.is_tuning_mode = False - return tuner - - def _prepare_bf16_moe_weights(num_experts, intermediate_size, hidden_size, device): """Prepare shuffled BF16 weights in BlockMajorK layout.""" from flashinfer import shuffle_matrix_a @@ -154,7 +147,7 @@ def _compute_selected_tile_n_base_element( """Compute the base element used by computeSelectedTileN(num_tokens) to filter tile_N candidates.""" from flashinfer.fused_moe.utils import next_positive_power_of_2 - return next_positive_power_of_2(int(num_tokens * top_k / num_experts)) + return min(next_positive_power_of_2(int(num_tokens * top_k / num_experts)), 256) def _make_tile_bias(target_tile_n: int): @@ -169,20 +162,6 @@ def _bias(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw): return _bias -def _make_tile_bias_min_tile_n(num_tokens: int, top_k: int, num_experts: int): - """Return a profiling stub that favours the minimum tile_N in computeSelectedTileN(num_tokens).""" - - # The values computeSelectedTileN keeps from a sorted supported tileN list: - # - "base element": roundUpToPowerOf2(num_tokens*top_k/num_experts) - # - its previous element - # - its next two elements - # So the min tileN is the "base element" divided by 2, with a minimum of 1. - target_tile_n = max( - 1, _compute_selected_tile_n_base_element(num_tokens, top_k, num_experts) // 2 - ) - return _make_tile_bias(target_tile_n) - - def _require_sm100(): compute_capability = get_compute_capability(torch.device("cuda")) if compute_capability[0] not in [10]: @@ -204,20 +183,20 @@ def _require_sm100(): (256, 500), ], ) -def test_bf16_moe_cached_tile_excluded_at_inference_succeeds( +def test_bf16_moe_all_supported_tile_n_inference_succeed( monkeypatch, top_k: int, num_experts: int, tune_num_tokens: int, infer_num_tokens: int, ): - """SM100 BF16 integration: Test that MoE works when given by the autotuner a cached - tileN that's filtered out by computeSelectedTileN for the given inference num tokens. + """SM100 BF16 integration: Test that MoE works when given any supported tileN value, + including values filtered out by computeSelectedTileN for the given inference num tokens. """ from flashinfer.fused_moe.utils import last_positive_power_of_2 _require_sm100() - _reset_autotuner() + reset_autotuner() torch.manual_seed(42) device = torch.device("cuda:0") @@ -239,38 +218,39 @@ def test_bf16_moe_cached_tile_excluded_at_inference_succeeds( ) assert last_positive_power_of_2(infer_num_tokens) == tune_num_tokens - min_tile_n_bias_function = _make_tile_bias_min_tile_n( - tune_num_tokens, top_k, num_experts - ) - monkeypatch.setattr(AutoTuner, "_profile_single_kernel", min_tile_n_bias_function) - _tune_bf16_moe_once( - device=device, - tune_num_tokens=tune_num_tokens, - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - tune_max=TUNE_MAX, - ) + supported_tile_n_values = [8, 16, 32, 64, 128] + for tile_n in supported_tile_n_values: + monkeypatch.setattr( + AutoTuner, "_profile_single_kernel", _make_tile_bias(tile_n) + ) + _tune_bf16_moe_once( + device=device, + tune_num_tokens=tune_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=TUNE_MAX, + ) - tuner = AutoTuner.get() - assert len(tuner.profiling_cache) > 0, ( - "Autotuner cache should be populated after tuning" - ) + tuner = AutoTuner.get() + assert len(tuner.profiling_cache) > 0, ( + "Autotuner cache should be populated after tuning" + ) - _run_bf16_moe_infer( - device=device, - infer_num_tokens=infer_num_tokens, - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - gemm1_weights=gemm1_weights, - gemm2_weights=gemm2_weights, - tune_max=TUNE_MAX, - ) + _run_bf16_moe_infer( + device=device, + infer_num_tokens=infer_num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + tune_max=TUNE_MAX, + ) @pytest.mark.parametrize( @@ -284,7 +264,7 @@ def test_bf16_moe_cached_tile_excluded_at_inference_succeeds( def test_bf16_moe_invalid_tactic_raises_runtime_error(monkeypatch, invalid_tactic): """SM100 integration: invalid tactics should fail.""" _require_sm100() - _reset_autotuner() + reset_autotuner() device = torch.device("cuda:0") hidden_size, intermediate_size = 1024, 1024 diff --git a/tests/autotuner/utils.py b/tests/autotuner/utils.py new file mode 100644 index 0000000000..86370c00d3 --- /dev/null +++ b/tests/autotuner/utils.py @@ -0,0 +1,9 @@ +from flashinfer.autotuner import AutoTuner + + +def reset_autotuner() -> AutoTuner: + tuner = AutoTuner.get() + tuner.clear_cache() + tuner.reset_statistics() + tuner.is_tuning_mode = False + return tuner From c67f600279cd1ea26ce0e244636612de65a0b491 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 19 Mar 2026 18:08:03 +0000 Subject: [PATCH 11/12] Fix setup in test_choose_one_different_infer_tokens_same_bucket_get_same_cached_tactic Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/autotuner/test_autotuner_core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/autotuner/test_autotuner_core.py b/tests/autotuner/test_autotuner_core.py index e2bc18da05..77f35d1e25 100644 --- a/tests/autotuner/test_autotuner_core.py +++ b/tests/autotuner/test_autotuner_core.py @@ -435,7 +435,11 @@ def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw monkeypatch.setattr(AutoTuner, "_profile_single_kernel", fake_profile) - tune_inputs = [torch.empty((bucket_start, hidden_size), dtype=torch.float32)] + tune_inputs = [ + [torch.empty((bucket_start // 2, hidden_size), dtype=torch.float32)], + [torch.empty((bucket_start, hidden_size), dtype=torch.float32)], + [torch.empty((bucket_end, hidden_size), dtype=torch.float32)], + ] tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -449,7 +453,8 @@ def fake_profile(self, runner_obj, prof_inputs, tactic, tuning_config=None, **kw ), ) with autotune(tune_mode=True): - tuner.choose_one("test_same_bucket", [runner], tuning_config, tune_inputs) + for inputs in tune_inputs: + tuner.choose_one("test_same_bucket", [runner], tuning_config, inputs) num_tokens_with_expected_tactic_list = [ (random.randrange(bucket_start, bucket_end), [32, 1]) for _ in range(3) From 7799f1914ddc7f73bd5212c6411281c3fe01c5f3 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 19 Mar 2026 19:52:38 +0000 Subject: [PATCH 12/12] Fix test_bf16_moe_all_supported_tile_n_inference_succeed by reset autotuner before every tileN Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py b/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py index 8df1d1ea37..07df7e2e45 100644 --- a/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py +++ b/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py @@ -196,7 +196,6 @@ def test_bf16_moe_all_supported_tile_n_inference_succeed( from flashinfer.fused_moe.utils import last_positive_power_of_2 _require_sm100() - reset_autotuner() torch.manual_seed(42) device = torch.device("cuda:0") @@ -220,6 +219,7 @@ def test_bf16_moe_all_supported_tile_n_inference_succeed( supported_tile_n_values = [8, 16, 32, 64, 128] for tile_n in supported_tile_n_values: + reset_autotuner() monkeypatch.setattr( AutoTuner, "_profile_single_kernel", _make_tile_bias(tile_n) )