fix: Autotuner _find_nearest_profile non-power-of-2 num_tokens, create launchers for all supported tileN in trtllm fused MoE#2821
Conversation
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
…lashinfer-ai#2617 Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…est of supporting tileN that was filtered out by computeSelectedTileN Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
📝 WalkthroughWalkthroughValidated MoE tile selection, added helpers to resolve default tile/config, and expanded launcher construction to include all supported tiles. Adjusted autotuner bucket mapping to propagate a single mapped bucket across linked dimensions. Added/reset autotuner test utilities and new SM100 MoE integration tests. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the robustness and correctness of the autotuner and fused Mixture-of-Experts (MoE) kernel launchers. It rectifies a long-standing problem where the autotuner's cache lookup for non-power-of-2 Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request effectively addresses two bugs related to the autotuner. The first fix correctly handles non-power-of-2 num_tokens for cache lookups by ensuring all linked dimensions are mapped to the same bucket. The second fix prevents crashes in the fused MoE kernel launcher by creating launchers for all supported tileN values, rather than a filtered subset. The C++ changes also improve robustness by using find instead of at for map lookups and encapsulating fallback logic. The accompanying tests, including new integration tests, are comprehensive. I've found one issue in a new test case and provided a suggestion to fix it.
| tune_inputs = [torch.empty((bucket_start, hidden_size), dtype=torch.float32)] | ||
| tuning_config = TuningConfig( | ||
| dynamic_tensor_specs=( | ||
| DynamicTensorSpec( | ||
| input_idx=(0,), | ||
| dim_idx=(0,), | ||
| gen_tuning_buckets=tuning_buckets, | ||
| map_to_tuning_buckets=lambda x: min( | ||
| last_positive_power_of_2(x), tune_max | ||
| ), | ||
| ), | ||
| ), | ||
| ) | ||
| with autotune(tune_mode=True): | ||
| tuner.choose_one("test_same_bucket", [runner], tuning_config, tune_inputs) |
There was a problem hiding this comment.
The test logic for tuning is flawed. It only populates the autotuner cache for a single bucket (for num_tokens in [512, 1024)), but the subsequent inference checks assert expected tactics for three different buckets. The checks for buckets that were not tuned will fail because they will result in a cache miss and receive the fallback tactic, not the expected one.
To fix this, the tuning step should be performed for a representative num_tokens from each of the three bucket ranges being tested to ensure the cache is populated correctly before the inference-time checks.
tuning_config = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
input_idx=(0,),
dim_idx=(0,),
gen_tuning_buckets=tuning_buckets,
map_to_tuning_buckets=lambda x: min(
last_positive_power_of_2(x), tune_max
),
),
),
)
with autotune(tune_mode=True):
# Tune for a representative num_tokens from each bucket range defined in fake_profile
# to populate the cache correctly for the inference checks below.
for tune_tokens in [bucket_start // 2, bucket_start, bucket_end]:
tune_inputs = [torch.empty((tune_tokens, hidden_size), dtype=torch.float32)]
tuner.choose_one("test_same_bucket", [runner], tuning_config, tune_inputs)…ame_cached_tactic Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py`:
- Around line 221-253: The loop over supported_tile_n_values is not validating
each tile because trtllm_get_valid_moe_configs/computeSelectedTileN may not
include some tile_n and AutoTuner.choose_one consults AutoTuner.profiling_cache,
so later iterations reuse a cached tactic; fix by forcing a fresh profiling per
tile: in the loop, monkeypatch AutoTuner._profile_single_kernel as you already
do, then clear or overwrite AutoTuner.get().profiling_cache (or call
AutoTuner.get().clear_cache()) at the start of each iteration to avoid cache
hits, and after tuning assert that the cached tactic chosen has the expected
tile_n (inspect the cached entry’s tile_n) or alternatively bypass choose_one by
injecting the target [tile_n, config] directly into AutoTuner.profiling_cache
before calling _run_bf16_moe_infer so each iteration exercises and verifies the
requested tile_n.
In `@tests/autotuner/utils.py`:
- Around line 4-9: reset_autotuner currently clears cache, statistics and sets
is_tuning_mode but does not reset the internal counter that autotune() uses;
update the helper (reset_autotuner / AutoTuner.get()) to also reset the
AutoTuner._active_tuning_contexts counter back to zero (or its default empty
state) so that autotune() will not incorrectly derive tuning mode from leftover
contexts; locate the AutoTuner instance via AutoTuner.get(), set its
_active_tuning_contexts to the appropriate empty/zero value, then return the
tuner.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 960704f7-b832-4d0c-a83f-dadd2535665e
📒 Files selected for processing (5)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/autotuner.pytests/autotuner/test_autotuner_core.pytests/autotuner/test_trtllm_fused_moe_autotuner_integration.pytests/autotuner/utils.py
| def reset_autotuner() -> AutoTuner: | ||
| tuner = AutoTuner.get() | ||
| tuner.clear_cache() | ||
| tuner.reset_statistics() | ||
| tuner.is_tuning_mode = False | ||
| return tuner |
There was a problem hiding this comment.
Reset _active_tuning_contexts in the shared helper.
autotune() derives is_tuning_mode from _active_tuning_contexts, so leaving that counter untouched can leak tuning mode across tests even after this helper forces is_tuning_mode = False.
🛠 Suggested fix
def reset_autotuner() -> AutoTuner:
tuner = AutoTuner.get()
- tuner.clear_cache()
- tuner.reset_statistics()
- tuner.is_tuning_mode = False
+ with tuner._lock:
+ tuner.clear_cache()
+ tuner.reset_statistics()
+ tuner._active_tuning_contexts = 0
+ tuner.is_tuning_mode = False
return tuner📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def reset_autotuner() -> AutoTuner: | |
| tuner = AutoTuner.get() | |
| tuner.clear_cache() | |
| tuner.reset_statistics() | |
| tuner.is_tuning_mode = False | |
| return tuner | |
| def reset_autotuner() -> AutoTuner: | |
| tuner = AutoTuner.get() | |
| with tuner._lock: | |
| tuner.clear_cache() | |
| tuner.reset_statistics() | |
| tuner._active_tuning_contexts = 0 | |
| tuner.is_tuning_mode = False | |
| return tuner |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/autotuner/utils.py` around lines 4 - 9, reset_autotuner currently
clears cache, statistics and sets is_tuning_mode but does not reset the internal
counter that autotune() uses; update the helper (reset_autotuner /
AutoTuner.get()) to also reset the AutoTuner._active_tuning_contexts counter
back to zero (or its default empty state) so that autotune() will not
incorrectly derive tuning mode from leftover contexts; locate the AutoTuner
instance via AutoTuner.get(), set its _active_tuning_contexts to the appropriate
empty/zero value, then return the tuner.
IwakuraRein
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the fix.
…otuner before every tileN Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
|
/bot run |
|
[SUCCESS] Pipeline #46565266: 14/20 passed |
📌 Description
It fixes two autotuner related bugs:
computeSelectedTileN, by creating kernel launchers for all supported tileN values.This PR continues the work in #2695 by @danisereb to revert bugfix 1 and to fix bug 2.
More technical details:
Bug 1:
When given num_tokens that isn't a power-of-2, the autotuner (python side) fails to find its appropriate entry in the autotuner cache, so it falls back to passing default, which means passing
[-1, -1]as the(tileN, tactic)to the CPP.It was fixed in this PR but soon after merge, it was reverted here, as it exposed the next bug.
Bug 2 (exposed after fixing bug 1):
Crash in fused MoE kernel launcher on forward pass on some values of num_tokens. The crash is at
launchers_map.at(tile_N)intrtllm_fused_moe_kernel_launcher.cu. It happens because:The python side of the autotuner profiles num_tokens that are power of 2, and each such value represents the range until the next power of 2.
e.g.: The profile for the range
[2048, 4095]is done on num_tokens=2048.computeSelectedTileNfunction intrtllm_fused_moe_kernel_launcher.cureduces the set of supported tileN values (to reduce the autotuner's search space), by choosing specific values from the supported tileN sorted list, the values are:roundUpToPowerOfTwo(num_tokens * topK / numExperts), its previous one, and its next 2 values (max value is 256). So values in the same range can get different sets of tileN values.For example, on Nemotron 3 Super NVFP4:
num_tokens=2048->2048*22/512 = 88, which rounds up to 128, so the tileN set is(64, 128, 256)num_tokens=3003->3003*22/512 = 129.03, which rounds up to 256, so the tileN set is(128, 256)In case
tileN=64was found to be the fastest onnum_tokens=2048for range[2048, 4095], when givennum_tokens=3003, the python side would pass[64, someTactic]to the CPP, but fornum_tokens=3003, there's no launcher fortileN=64ascomputeSelectedTileNfiltered it out.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests