-
Notifications
You must be signed in to change notification settings - Fork 573
feat: autotune tile_tokens_dim in trtllm-gen MOE #1980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe PR removes tile_tokens_dim propagation and replaces it with autotuning-driven MoE flows: per‑tile runner/config selection, tactic/config indices, new public MoERunner options (gated activation, shuffled‑weight flag, weight‑layout); tests and benchmarks pass Changes
Sequence Diagram(s)sequenceDiagram
actor Caller
participant Py as Python: MoE core (AutoTuner)
participant Auto as AutoTuner
participant CPP as C++ Launcher
participant Runners as Per‑tile MoE Runners
Note over Caller,Py: Old flow (static tile_tokens_dim)
Caller->>Py: call_moe(..., tile_tokens_dim=N)
Py->>CPP: launcher(..., tile_tokens_dim=N)
CPP->>Runners: select single runner/config
Runners-->>CPP: run fixed config
Note over Caller,Py: New flow (autotune + per‑tile)
Caller->>Py: call_moe(..., tune_max_num_tokens)
Py->>Auto: select_tactic(instance_key, max_tokens)
Auto-->>Py: tactic / Array<int64_t> config_index
Py->>CPP: launcher(..., Array<int64_t> config_index, moe_runner)
CPP->>CPP: computeSelectedTileN(...) -> tile groups
CPP->>Runners: build/select per‑tile runner(s) using config_index
Runners-->>CPP: chosen runner per tile
CPP->>CPP: allocate per‑tile workspace & execute per‑tile runner/config
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
7e0ed1b to
0cd4848
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
99-101: Fix FC2 bias shape.gemm2 (FC2) bias must be sized by hidden_size, not 2*intermediate_size. Current shape will fail C++ checks and/or broadcast incorrectly.
Suggested diff:
- bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 - bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 + bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 + bias2 = torch.randn(num_experts, hidden_size, device=device) * 10Also applies to: 120-126
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
195-206: args.mDtypeOut may be uninitialized. Initialize explicitly to Bfloat16.maybeGetMinTokenCount reads args.mDtypeOut bits; in FP8/FP4 block-scale launchers it’s not set. Set to BF16 to match output dtype.
// In trtllm_fp8_block_scale_moe_launcher auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { args.mDtypeElt = btg::Dtype::Fp16; } else if (dtype == dl_bfloat16) { args.mDtypeElt = btg::Dtype::Bfloat16; } else if (dtype == dl_float8_e4m3fn) { args.mDtypeElt = btg::Dtype::E4m3; } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } + args.mDtypeOut = btg::Dtype::Bfloat16; // In trtllm_fp4_block_scale_moe_launcher (after args setup) + // Output dtype is BF16 for block-scale path + args.mDtypeOut = btg::Dtype::Bfloat16;Also applies to: 522-531, 919-925
548-556: gemm1_output_scale uses wrong padded token count.Should use max_num_padded_tokens_gemm1 (like activation_output_scale) to match gemm1_output. Using max_num_padded_tokens can under-allocate.
- Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, + Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device());
🧹 Nitpick comments (5)
tests/moe/test_trtllm_gen_fused_moe.py (1)
205-209: Deprecation path exercised as intended.Explicitly passing tile_tokens_dim=None to trtllm_fp4_block_scale_moe keeps tests aligned with autotune. Consider adding a brief comment noting autotuner selection is used when None is given.
flashinfer/fused_moe/core.py (4)
21-21: typing_extensions.deprecated is type-checker only; add runtime warning.@deprecated won’t emit runtime warnings. Emit warnings.warn on first call to improve UX.
Apply in each deprecated wrapper, e.g.:
+import warnings ... @deprecated("tile_tokens_dim is deprecated and will be removed in trtllm_fp8_per_tensor_scale_moe after v0.5.0") def trtllm_fp8_per_tensor_scale_moe(...): + warnings.warn( + "tile_tokens_dim is deprecated and ignored; autotuner selects tactics. " + "Support will be removed after v0.5.0.", + DeprecationWarning, + stacklevel=2, + ) return get_trtllm_moe_sm100_module()...
126-134: Duplicate entry in trtllm_gen_dtype_has_scale.MxE4m3 appears twice. Remove the duplicate for clarity.
- if dtype in [ - DtypeTrtllmGen.MxE4m3, - DtypeTrtllmGen.E2m1, - DtypeTrtllmGen.MxE2m1, - DtypeTrtllmGen.MxE4m3, - ]: + if dtype in {DtypeTrtllmGen.MxE4m3, DtypeTrtllmGen.E2m1, DtypeTrtllmGen.MxE2m1}: return True
1643-1716: Silence ARG001 for unused tile_tokens_dim in deprecated wrapper.Ruff flags tile_tokens_dim as unused. Delete the name or add a no-op to satisfy the linter without changing the API.
def trtllm_fp8_per_tensor_scale_moe(..., tile_tokens_dim: int = 8, ...): - """FP8 per tensor scale MoE operation. + """FP8 per tensor scale MoE operation. ... """ + # tile_tokens_dim kept for backward-compat; ignored by autotuner. + del tile_tokens_dim # ruff: ARG001 return get_trtllm_moe_sm100_module()...
1797-1930: Silence ARG001 and add consistent deprecation note.Do the same for FP8 block scale and FP4 block scale wrappers. Also emit a warning (as above) for consistency; routed_moe already logs.
def trtllm_fp8_block_scale_moe(..., tile_tokens_dim: Optional[int] = None, ...): + del tile_tokens_dim # ruff: ARG001 + warnings.warn("tile_tokens_dim is deprecated and ignored...", DeprecationWarning, stacklevel=2) ... def trtllm_fp4_block_scale_moe(..., tile_tokens_dim: Optional[int], ...): + del tile_tokens_dim # ruff: ARG001 + warnings.warn("tile_tokens_dim is deprecated and ignored...", DeprecationWarning, stacklevel=2) ...
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(2 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(12 hunks)flashinfer/fused_moe/core.py(22 hunks)tests/conftest.py(1 hunks)tests/moe/test_trtllm_gen_fused_moe.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
flashinfer/utils.py (1)
device_support_pdl(568-572)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/utils.py (1)
get_compute_capability(251-254)include/flashinfer/trtllm/fused_moe/runner.h (1)
intermediate_size(275-275)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (11)
top_k(270-270)maybeGetMinTokenCount(55-60)intermediate_size(275-275)hidden_size(265-265)mUseDeepSeekFp8(285-342)local_num_experts(277-277)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)do_finalize(295-296)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(88-91)
flashinfer/fused_moe/core.py (4)
include/flashinfer/trtllm/fused_moe/runner.h (1)
GatedActType(141-158)csrc/trtllm_fused_moe_kernel_launcher.cu (6)
trtllm_fp8_block_scale_moe(699-750)trtllm_fp8_block_scale_moe(699-706)trtllm_fp8_per_tensor_scale_moe(342-402)trtllm_fp8_per_tensor_scale_moe(342-350)trtllm_fp4_block_scale_moe(1175-1266)trtllm_fp4_block_scale_moe(1175-1188)flashinfer/utils.py (1)
device_support_pdl(568-572)flashinfer/autotuner.py (3)
AutoTuner(335-780)get(362-365)choose_one(400-525)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py
1738-1738: Unused function argument: tile_tokens_dim
(ARG001)
1825-1825: Unused function argument: tile_tokens_dim
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/conftest.py (2)
140-144: Verify PR scope: Changes don't match PR objectives.The PR title indicates this should add autotuning for
tile_tokens_dimin TRT-LLM MoE kernels, but the only change is a pytest hook refactoring in test infrastructure. Please confirm:
- Are there additional files/changes not included in this review?
- Is this an infrastructure prerequisite for the main feature?
- Should this be split into a separate PR?
140-144: Change is correct and follows pytest best practices.The migration from
tryfirst=Truetowrapper=Truewithyieldis the recommended approach for wrapping test execution. Verification confirms:
- Only one
pytest_runtest_callhook exists (no conflicts)- Exception handling for
OutOfMemoryErrorandMissingJITCacheErrorremains intact- The wrapper pattern properly delegates to pytest's execution chain while preserving error handling behavior
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
136-143: tile_tokens_dim=None usage is correct.Passing None here aligns with the deprecated API and delegates selection to the autotuner; no action needed.
tests/moe/test_trtllm_gen_fused_moe.py (1)
1866-1884: Good guards for compatibility and speed.The added compatible_intermediate_size filters and targeted skips keep CI time reasonable while respecting kernel constraints. LGTM.
Also applies to: 1905-1917, 1930-1933, 1962-1981
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
371-399: Per‑tile autotune path looks correct.
- Runners constructed for selected tiles.
- Fallback when config_index contains -1 is well‑handled.
- computeSelectedTileN centers tile_N on avg tokens per expert with neighbors.
LGTM.Also applies to: 720-747, 1234-1266, 1268-1331
flashinfer/fused_moe/core.py (1)
905-1072: Verified: Autotuner plumbing and tactics caching are correctly implemented.Static analysis confirms all review observations:
- Caching consistency:
valid_tactics_dictproperly caches perinstance_keytuple (lines 945-967); cache key is comprehensive and includes all relevant parametersrefine_tuning_configusage: Correctly invoked before eachtuner.choose_one()call to customize token buckets (lines 458, 1164, 1301, 1487)choose_onepattern: Consistently used across all 5 autotuner invocations (lines 481, 495, 1195, 1339, 1522)[-1, -1]fallback: Correctly applied throughout for default tactic selection (lines 1046, 1070, 1105, 1237, 1382, 1584)The smoke test cannot run in the sandbox environment (missing torch), but code structure verification confirms the implementation is sound.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
flashinfer/fused_moe/core.py (2)
1026-1033: Consider avoiding repeated tensor allocation during profiling.The
current_hidden_states_scaletensor is allocated fresh on every forward call. During autotuning, this function may be invoked many times. Consider whether the scale tensor could be allocated once and reused, or if this allocation is intentional for isolation between runs.
1659-1731: Consider consistent deprecation warnings across all deprecated functions.While all four deprecated functions correctly accept but ignore
tile_tokens_dim, onlytrtllm_fp4_block_scale_routed_moe(lines 2048-2053) logs an informational message when the parameter is provided. For consistency and better user guidance, consider adding similar logging to the other three deprecated functions:trtllm_fp8_per_tensor_scale_moe,trtllm_fp8_block_scale_moe, andtrtllm_fp4_block_scale_moe.Note: The static analysis warnings about unused
tile_tokens_dimarguments are expected and correct for deprecated parameters.Also applies to: 1734-1810, 1813-1945, 1948-2087
tests/moe/test_trtllm_gen_fused_moe.py (1)
1953-1956: Clarify skip reason for RenormNaive test.The skip reason states "similar to RenormalizeNaive" but this test IS for
RoutingMethodType.RenormalizeNaive. Did you mean "similar to Renormalize" (referring toRoutingMethodType.Renormalize)? If so, please update the message for clarity.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/fused_moe/core.py(23 hunks)tests/moe/test_trtllm_gen_fused_moe.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/fused_moe/core.py (4)
include/flashinfer/trtllm/fused_moe/runner.h (9)
GatedActType(141-158)top_k(270-270)intermediate_size(275-275)hidden_size(265-265)local_num_experts(277-277)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)csrc/trtllm_fused_moe_kernel_launcher.cu (6)
trtllm_fp8_block_scale_moe(699-750)trtllm_fp8_block_scale_moe(699-706)trtllm_fp8_per_tensor_scale_moe(342-402)trtllm_fp8_per_tensor_scale_moe(342-350)trtllm_fp4_block_scale_moe(1175-1266)trtllm_fp4_block_scale_moe(1175-1188)flashinfer/utils.py (1)
device_support_pdl(568-572)flashinfer/autotuner.py (3)
AutoTuner(335-780)get(362-365)choose_one(400-525)
tests/moe/test_trtllm_gen_fused_moe.py (4)
flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/fused_moe/core.py (2)
trtllm_fp8_block_scale_moe(1737-1810)trtllm_fp8_per_tensor_scale_moe(1662-1731)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp8_block_scale_moe(699-750)trtllm_fp8_block_scale_moe(699-706)trtllm_fp8_per_tensor_scale_moe(342-402)trtllm_fp8_per_tensor_scale_moe(342-350)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py
1754-1754: Unused function argument: tile_tokens_dim
(ARG001)
1841-1841: Unused function argument: tile_tokens_dim
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (10)
flashinfer/fused_moe/core.py (5)
914-937: LGTM: Clean parameter addition with proper validation.The new parameters (
gated_act_type,use_shuffled_weight,weight_layout) are well-defined with sensible defaults, and the validation logic correctly restricts non-shuffled weights or non-MajorK layouts to FP8 block-scale mode only.
954-976: LGTM: Proper cache key extension for autotuning.The
instance_keycorrectly includesuse_shuffled_weightandweight_layoutto distinguish different MoE configurations, and the error handling gracefully returns an empty list when tactics retrieval fails.
1194-1204: Verify hardcodedweight_layoutassumption.Line 1202 hardcodes
weight_layout=WeightLayout.MajorKfor FP8 per-tensor MoE. According to the validation logic inMoERunner.__init__(lines 933-937), non-MajorK layouts are only supported for FP8 block-scale. Confirm this is the intended behavior and document if this restriction is permanent or planned for future extension.
1313-1375: LGTM: Autotuning integration for FP8 block-scale.The autotuning logic correctly instantiates the MoERunner with
use_shuffled_weightandweight_layoutparameters, selects the appropriate tuning config variant, and passes the selected tactic to the underlying C++ operation.
1518-1519: LGTM: Consistent with FP4 MoE layout requirements.Hardcoding
weight_layout=WeightLayout.MajorKanduse_shuffled_weight=Truealigns with the validation logic that restricts alternative layouts to FP8 block-scale only.tests/moe/test_trtllm_gen_fused_moe.py (5)
48-48: LGTM: Test updated for autotuning workflow.The removal of
calculate_tile_tokens_dimimport and passingtile_tokens_dim=Nonecorrectly reflects the transition to autotuning-based tile dimension selection.Also applies to: 205-205
772-796: LGTM: FP8 block-scale test properly uses autotuning.The
autotune(True)context manager correctly enables autotuning for the FP8 block-scale MoE call, andtile_tokens_dim=Noneappropriately delegates tile dimension selection to the autotuner.
945-972: LGTM: FP8 per-tensor test properly uses autotuning.Consistent with the FP8 block-scale changes, the per-tensor test correctly enables autotuning and passes
Nonefortile_tokens_dim.
1869-1869: Verify thatcompatible_intermediate_sizerestrictions reflect actual kernel limitations.The new
compatible_intermediate_sizeconstraints reduce test coverage for specific routing configurations. Confirm whether these restrictions are due to:
- Known kernel limitations or performance characteristics that make certain combinations unsupported
- Test optimization to reduce CI time
If (2), consider documenting this as a test-only constraint to avoid confusion with actual runtime restrictions.
Also applies to: 1887-1887, 1905-1905, 1920-1920, 1935-1935, 1950-1950, 1968-1968, 1983-1983
2080-2083: LGTM: Clear skip logic for intermediate size compatibility.The skip logic properly enforces the
compatible_intermediate_sizeconstraint with a descriptive message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
350-403: Add bounds checking and validation for config_index array.Multiple safety concerns with the
config_indexparameter handling:
Missing bounds check: Lines 385-386 directly access
config_index[0]andconfig_index[1]without verifying the array has at least 2 elements.Invalid tile_N risk: If
config_index[0]contains a tile size not inselected_tile_nums(e.g., 128 when only {8,16,32,64} were selected), line 399's*mRunners[tile_N]will dereference a non-existent map entry, causing undefined behavior.No config validation: There's no check that
config_index[1]is a valid configuration index for the selected runner.The conditional at lines 388-389 only handles the sentinel value
-1, not invalid values.Apply defensive checks:
+ TVM_FFI_ICHECK_GE(config_index.size(), 2) + << "config_index must have at least 2 elements [tile_N, config]"; + // moeConfigIndex corresponds to pair (tile_N, config) int64_t tile_N = config_index[0]; int64_t config = config_index[1]; // Autotuner has requested a default or 'fallback' config index if (tile_N == -1 || config == -1) { tile_N = *selected_tile_nums.begin(); config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + } else { + TVM_FFI_ICHECK(mRunners.find(tile_N) != mRunners.end()) + << "Requested tile_N " << tile_N << " not in selected tile sizes"; }Note: Similar issues exist in
trtllm_fp8_block_scale_moe(lines 734-747) andtrtllm_fp4_block_scale_moe(lines 1251-1266).
♻️ Duplicate comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
21-22: LGTM: Required headers added.The
<set>and<unordered_map>includes are now present and necessary forstd::setusage incomputeSelectedTileN(line 62) andstd::unordered_mapusage in runner maps (lines 377, 726, 1242).
700-751: Same config_index validation issues as fp8_per_tensor_scale_moe.Lines 734-741 have identical validation gaps as the earlier function:
- No bounds check on
config_indexarray length- No validation that
config_index[0]exists inmRunners- No validation of the config value
1188-1267: Same config_index validation issues (third occurrence).Lines 1251-1258 have the same validation gaps. The fix should be applied consistently across all three MoE functions.
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1269-1291: Inconsistent return value for default config getter.The
trtllm_get_default_moe_configsfunction returns onlyint64_t config(line 1289-1290), whereastrtllm_get_valid_moe_configsreturnsArray<Array<int64_t>>with[tile_N, config]pairs (lines 1327-1329).The single
configvalue is ambiguous without knowing whichtile_Nit corresponds to. While the implementation uses*selected_tile_nums.begin()(line 1286), callers have no way to determine this tile size.For API consistency, consider returning
Array<int64_t>containing[tile_N, config]instead of justconfig.-int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_, +Array<int64_t> trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, int64_t const num_local_experts, int64_t const gated_act_type, int64_t const num_tokens) { auto dtype_act = static_cast<btg::Dtype>(dtype_act_); auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_); std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64}; if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && dtype_act != btg::Dtype::Bfloat16) { supported_tile_nums.push_back(128); } std::set<int32_t> selected_tile_nums = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + int64_t tile_N = *selected_tile_nums.begin(); std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( - dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(), + dtype_act, dtype_weights, useDeepSeekFp8, tile_N, static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true); - return moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + int64_t config = moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); + return {tile_N, config}; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/trtllm_fused_moe_kernel_launcher.cu(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (4)
top_k(270-270)maybeGetMinTokenCount(55-60)mUseDeepSeekFp8(285-342)do_finalize(295-296)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(88-91)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
42-60: LGTM: Standard power-of-two calculation.The implementation uses a standard bit manipulation algorithm to compute the next power of two. The logic correctly handles already-power-of-two values and edge cases for small inputs.
86-88: LGTM: Launcher signature updated for per-tile autotuning.The added
moe_runnerreference andmoeConfigIndexparameters support the per-tile runner selection model. This is an internal launcher function, so the signature change doesn't affect the public API.
201-206: LGTM: Per-stage workspace padding improves precision.The separate
max_num_padded_tokens_gemm1andmax_num_padded_tokens_gemm2calculations (usingmaybeGetMinTokenCount) account for different dimension requirements in each GEMM stage. The allocations correctly use stage-specific padding, and the workspace takes the maximum of both.Also applies to: 224-234, 304-305
377-382: LGTM: Runner lifecycle management is safe.The
mRunnersmap withunique_ptrownership correctly manages runner lifetime. Since the launcher is synchronous and doesn't store the runner reference beyond the call, the stack-allocated map's destruction after the function returns is safe.
| std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums, | ||
| int64_t const num_tokens, int64_t const top_k, | ||
| int64_t const num_local_experts) { | ||
| float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts; | ||
| int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), | ||
| supported_tile_nums.front(), supported_tile_nums.back()); | ||
|
|
||
| std::set<int32_t> selected_tile_nums = { | ||
| std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, | ||
| std::min(supported_tile_nums.back(), tile_tokens_dim * 2), | ||
| std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; | ||
|
|
||
| return selected_tile_nums; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify input constraints to prevent division by zero and overflow.
Two potential issues:
-
Division by zero: If
num_local_expertsis zero, line 65 will cause division by zero. While upstream validation exists for expert counts (line 150), it doesn't explicitly guaranteelocal_num_experts > 0. -
Integer overflow: Line 65 computes
num_tokens * top_kasint64_tmultiplication before the float cast. For large token counts and top_k values, this could overflow before conversion.
Run the following script to check if callers guarantee local_num_experts > 0:
🏁 Script executed:
#!/bin/bash
# Check validation of local_num_experts parameter
rg -n -B5 -A5 'local_num_experts' --type=cpp | rg -i 'check|assert|validate|icheck'Length of output: 0
🏁 Script executed:
#!/bin/bash
# Find all calls to computeSelectedTileN with surrounding context
echo "=== CALLS TO computeSelectedTileN ==="
rg -n -B10 -A3 'computeSelectedTileN\(' --type=cpp
echo -e "\n=== FUNCTION SIGNATURES CONTAINING local_num_experts ==="
rg -n -B2 'local_num_experts' --type=cpp | grep -E 'void|bool|int|auto|std' | head -20
echo -e "\n=== CHECK FOR num_local_experts ZERO/VALIDATION ==="
rg -n -B3 -A3 'num_local_experts\s*[=!<>]' --type=cpp | head -40Length of output: 393
🏁 Script executed:
#!/bin/bash
# Direct search in the specific file mentioned in review
echo "=== FILE CONTENT AROUND computeSelectedTileN ==="
find . -name "trtllm_fused_moe_kernel_launcher.cu" -exec cat {} \; 2>/dev/null | head -100
echo -e "\n=== SEARCH FOR FUNCTION IN .cu AND .h FILES ==="
rg -l 'computeSelectedTileN' --type=cpp
echo -e "\n=== SEARCH IN ALL FUSEDMOE FILES ==="
find . -path '*/fused_moe*' -name '*.cu' -o -name '*.h' | head -20Length of output: 5488
🏁 Script executed:
#!/bin/bash
# Get the full file to search for all calls to computeSelectedTileN
echo "=== FIND ALL CALLS TO computeSelectedTileN IN THE FILE ==="
cat ./csrc/trtllm_fused_moe_kernel_launcher.cu | grep -n "computeSelectedTileN"
echo -e "\n=== CHECK FUNCTION PARAMETER VALIDATION IN trtllm_fp8_per_tensor_scale_moe_launcher ==="
cat ./csrc/trtllm_fused_moe_kernel_launcher.cu | sed -n '100,200p'
echo -e "\n=== CHECK FOR ICHECK GUARDS ON local_num_experts ==="
cat ./csrc/trtllm_fused_moe_kernel_launcher.cu | grep -n -A2 -B2 "local_num_experts"Length of output: 18497
Add guards against division by zero and integer overflow in computeSelectedTileN.
The validation at line 149-150 does not prevent local_num_experts == 0. When zero, line 65 causes division by zero:
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;Additionally, the multiplication num_tokens * top_k occurs as int64_t before the float cast. For large values (e.g., num_tokens ≈ 2^40, top_k ≈ 2^25), this overflows.
Required fixes:
- Add
TVM_FFI_ICHECK_GT(local_num_experts, 0)after line 149 - Cast to float before multiplication:
static_cast<float>(num_tokens) * top_k / num_local_experts
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_kernel_launcher.cu around lines 62-75 the computation
does integer multiplication then divides by num_local_experts which can cause
integer overflow and a division-by-zero if local_num_experts==0; add a guard
TVM_FFI_ICHECK_GT(local_num_experts, 0) after line 149 to ensure non-zero
experts, and change the averaging to perform the cast before multiplication
(i.e., cast num_tokens to float and multiply by top_k then divide by
num_local_experts) to avoid int64 overflow and then proceed with the existing
clamping and selection logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/fused_moe/core.py (1)
1707-1712: Inline deprecation warnings for tile_tokens_dim — goodWarning in the call-sites implements prior guidance to deprecate the field within functions, not via a decorator. Consider also noting the replacement (autotuning via AutoTuner) in the message.
Also applies to: 1784-1789, 1913-1918, 2052-2057
🧹 Nitpick comments (3)
flashinfer/jit/core.py (1)
64-89: Consider alternative approaches to avoid unbounded cache growth.The implementation is functional and the use of
lru_cachefor message deduplication is clever. However, the static analysis warning is worth addressing: usinglru_cacheon instance methods can prevent garbage collection. While this is mitigated here sinceloggeris a global singleton (line 92), the unbounded cache (maxsize=None) will grow indefinitely with each unique log message, which could be problematic in long-running applications with high log diversity.Consider refactoring to use a set-based approach instead:
+ def __init__(self, name): + super().__init__(name) + # ... existing initialization ... + self._logged_messages = set() + def debug_once(self, msg: str, *args: Hashable) -> None: """ As [`debug`][logging.Logger.debug], but subsequent calls with the same message are silently dropped. """ - self._print_once(self.debug, msg, *args) + key = (msg, args) + if key not in self._logged_messages: + self._logged_messages.add(key) + self.debug(msg, *args, stacklevel=2) def info_once(self, msg: str, *args: Hashable) -> None: """ As [`info`][logging.Logger.info], but subsequent calls with the same message are silently dropped. """ - self._print_once(self.info, msg, *args) + key = (msg, args) + if key not in self._logged_messages: + self._logged_messages.add(key) + self.info(msg, *args, stacklevel=2) def warning_once(self, msg: str, *args: Hashable) -> None: """ As [`warning`][logging.Logger.warning], but subsequent calls with the same message are silently dropped. """ - self._print_once(self.warning, msg, *args) - - @functools.lru_cache(maxsize=None) - def _print_once(self, log_method, msg: str, *args: Hashable) -> None: - """Helper method to log messages only once per unique (msg, args) combination.""" - # Note: stacklevel=3 to show the caller's location, not this helper method - log_method(msg, *args, stacklevel=3) + key = (msg, args) + if key not in self._logged_messages: + self._logged_messages.add(key) + self.warning(msg, *args, stacklevel=2)This approach still has unbounded growth but makes memory usage more predictable and avoids the lru_cache decorator warning. Alternatively, consider using
functools.cache(Python 3.9+) with periodic cache clearing if appropriate for your use case.flashinfer/fused_moe/core.py (2)
913-937: Prefer explicit exception over assert for layout/shuffle guardAsserts can be stripped with python -O. Raise a clear ValueError instead to enforce at runtime.
- if ( - not self.use_shuffled_weight - or self.weight_layout != WeightLayout.MajorK - ): - assert ( - self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3 - ), ( - "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale" - ) + if (not self.use_shuffled_weight or self.weight_layout != WeightLayout.MajorK) and not ( + self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3 + ): + raise ValueError( + "Non-shuffled weights or non-MajorK layout are only supported for FP8 block-scale " + "(use_deepseek_fp8=True, dtype_weights=E4m3)." + )
1017-1116: FP8 block-scale forward ignores provided hidden_states_scaleForward always builds a synthetic scale tensor and never uses the passed hidden_states_scale. If scale layout/dtype differs at runtime, profiling may diverge. Consider preferring the provided tensor when available.
- current_hidden_states_scale = torch.full( - (current_hidden_size // 128, current_num_tokens), - 2.0, - dtype=torch.float, - device=hidden_states.device, - ) + current_hidden_states_scale = ( + hidden_states_scale + if 'hidden_states_scale' in locals() and hidden_states_scale is not None + else torch.full( + (current_hidden_size // 128, current_num_tokens), + 2.0, + dtype=torch.float32, + device=hidden_states.device, + ) + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/fused_moe/core.py(20 hunks)flashinfer/jit/core.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (9)
GatedActType(141-158)top_k(270-270)intermediate_size(275-275)hidden_size(265-265)local_num_experts(277-277)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)csrc/trtllm_fused_moe_kernel_launcher.cu (6)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)trtllm_fp4_block_scale_moe(1176-1267)trtllm_fp4_block_scale_moe(1176-1189)flashinfer/utils.py (1)
device_support_pdl(568-572)flashinfer/autotuner.py (3)
AutoTuner(335-780)get(362-365)choose_one(400-525)flashinfer/jit/core.py (1)
warning_once(78-83)
🪛 Ruff (0.14.1)
flashinfer/jit/core.py
85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/fused_moe/core.py (3)
952-965: Good: cache-key now includes gating and layout knobsAdding gated_act_type, use_shuffled_weight, and weight_layout to instance_key avoids tactic cache collisions across configurations.
1168-1227: Autotune path for FP8 per-tensor looks correct; confirm output dtype assumptionIntegration with AutoTuner and passing tactic pair is sound. Output is allocated as bfloat16; please confirm the kernel always writes bf16 for this op across supported dtypes.
Also applies to: 1250-1251
1517-1519: FP4 MoERunner defaults (MajorK + shuffled) acknowledgedDefaults align with supported FP4 path; gated_act_type flows through.
| tune_max_num_tokens: int = 8192, | ||
| ) -> torch.Tensor: | ||
| if enable_pdl is None: | ||
| enable_pdl = device_support_pdl(hidden_states.device) | ||
|
|
||
| # Use AutoTuner to select the best tactic - follow FP4 pattern exactly | ||
| tuner = AutoTuner.get() | ||
| MoERunner.refine_tuning_config(tune_max_num_tokens) | ||
|
|
||
| num_tokens = hidden_states.shape[0] | ||
| hidden_size = hidden_states.shape[-1] | ||
|
|
||
| # Create workspace buffers | ||
| output = torch.empty( | ||
| num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device | ||
| ) | ||
| topk_ids = torch.empty( | ||
| num_tokens, top_k, dtype=torch.int32, device=hidden_states.device | ||
| ) | ||
| expert_weights = torch.empty( | ||
| num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device | ||
| ) | ||
|
|
||
| dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation | ||
| dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights | ||
|
|
||
| moe_runner = MoERunner( | ||
| top_k=top_k, | ||
| num_local_experts=local_num_experts, | ||
| dtype_act=dtype_act, | ||
| dtype_weights=dtype_weights, | ||
| use_deepseek_fp8=True, # block_scale mode | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| weight_layout=weight_layout, | ||
| use_shuffled_weight=use_shuffled_weight, | ||
| ) | ||
|
|
||
| inputs = [ | ||
| output, | ||
| routing_logits, | ||
| topk_ids, | ||
| expert_weights, | ||
| hidden_states, | ||
| hidden_states_scale, | ||
| ] | ||
|
|
||
| _, tactic = tuner.choose_one( | ||
| "flashinfer::trtllm_fp8_block_scale_moe", | ||
| [moe_runner], | ||
| MoERunner.tuning_config_with_hidden_states_scales, # FP8 block-scale uses hidden_states_scale | ||
| inputs, | ||
| routing_bias=routing_bias, | ||
| gemm1_weights=gemm1_weights, | ||
| gemm1_weights_scale=gemm1_weights_scale, | ||
| gemm2_weights=gemm2_weights, | ||
| gemm2_weights_scale=gemm2_weights_scale, | ||
| num_experts=num_experts, | ||
| n_group=n_group, | ||
| topk_group=topk_group, | ||
| local_expert_offset=local_expert_offset, | ||
| local_num_experts=local_num_experts, | ||
| routed_scaling_factor=routed_scaling_factor, | ||
| routing_method_type=routing_method_type, | ||
| use_shuffled_weight=use_shuffled_weight, | ||
| weight_layout=weight_layout, | ||
| enable_pdl=enable_pdl, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: re-allocating output discards caller-provided buffer
trtllm_fp8_block_scale_moe_op takes output: torch.Tensor but immediately overwrites it with a new tensor. This breaks API expectations for in-place/out-arg usage.
Apply this fix to validate and use the provided buffer instead of reallocating:
- # Create workspace buffers
- output = torch.empty(
- num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
- )
+ # Validate provided output buffer
+ check_shape_dtype_device(
+ output, (num_tokens, hidden_size), torch.bfloat16, hidden_states.device, "output"
+ )Also applies to: 1398-1399
🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 1307-1374 (and also check lines
~1398-1399), the function currently unconditionally re-allocates the `output`
tensor which discards a caller-provided buffer; instead validate and reuse the
provided `output` buffer: if `output` is None then allocate as before, otherwise
check device, dtype, shape (num_tokens x hidden_size) and raise a clear error if
mismatched; do the same for any other caller-provided buffers noted in the
comment (e.g., topk_ids/expert_weights if applicable), and remove the
unconditional reassignment so the op uses the validated buffer. Ensure
consistency for the FP8 dtype and device when performing the checks.
| auto const hidden_size = hidden_states.size(1); | ||
| bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 | ||
|
|
||
| std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TileN=128 config exists
|
[FAILED] Pipeline #37422250: 1/17 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
62-75: [Duplicate] Division by zero and overflow risks remain in computeSelectedTileN.Although marked as addressed in a previous review, the code still exhibits the same patterns:
Division by zero: Line 65 divides by
num_local_expertswithout ensuring it's non-zero. The validation at lines 149-150 only checkslocal_num_experts + local_expert_offset <= num_experts, which permitslocal_num_experts == 0.Integer overflow: The multiplication
num_tokens * top_koccurs asint64_tarithmetic before thestatic_cast<float>, risking overflow for large values.Recommended fixes:
- Add
TVM_FFI_ICHECK_GT(local_num_experts, 0)after line 149- Reorder to cast before multiplication:
static_cast<float>(num_tokens) * top_k / num_local_experts
🧹 Nitpick comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
356-366: Code duplication: dtype conversion logic repeated.The dtype-to-btg::Dtype mapping at lines 356-366 duplicates the same logic already present in the launcher function (lines 156-165). Consider extracting this into a helper function to reduce duplication and ensure consistency.
Apply this pattern:
// Add at file scope static btg::Dtype convertDLDataTypeToBtgDtype(DLDataType dtype) { if (dtype == dl_float16) { return btg::Dtype::Fp16; } else if (dtype == dl_bfloat16) { return btg::Dtype::Bfloat16; } else if (dtype == dl_float8_e4m3fn) { return btg::Dtype::E4m3; } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } }Then replace both occurrences with:
btg::Dtype mDtypeElt = convertDLDataTypeToBtgDtype(dtype);
372-382: Consider caching runners to avoid repeated construction.On each call, this code builds a
std::unique_ptr<RunnerType>for every selected tile size, but only one runner (determined byconfig_index[0]) is actually used. For non-autotuning inference workloads, this overhead could be eliminated by caching runners.However, if this function is primarily used in autotuning contexts where different configurations are benchmarked across calls, the current approach is acceptable.
1282-1289: Consider if selecting the first tile size is always appropriate for default config.The function builds a runner using
*selected_tile_nums.begin()(the smallest selected tile size) to get the default config. For some workloads, a larger tile size might yield better performance.However, choosing the smallest tile size as default is a reasonable conservative choice that should work across workloads. If this is intentional, consider adding a comment explaining the rationale.
+ // Use the smallest selected tile size for conservative default config std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(), static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/trtllm_fused_moe_kernel_launcher.cu(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (4)
top_k(270-270)maybeGetMinTokenCount(55-60)mUseDeepSeekFp8(285-342)do_finalize(295-296)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(88-91)include/flashinfer/trtllm/fused_moe/DevKernel.h (3)
mUseDeepSeekFp8(213-226)mUseDeepSeekFp8(333-344)mUseDeepSeekFp8(400-418)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
201-206: LGTM! Per-tile padding metrics properly computed.The introduction of separate
max_num_padded_tokens_gemm1andmax_num_padded_tokens_gemm2correctly appliesmaybeGetMinTokenCountto each GEMM stage with appropriate dimensions. This enables accurate per-tile workspace sizing.
224-233: LGTM! Workspace allocations use correct per-tile metrics.The workspace tensors for GEMM1 outputs and activation now use
max_num_padded_tokens_gemm1, while GEMM2 output usesmax_num_padded_tokens_gemm2. This aligns with the per-tile padding strategy.
304-305: LGTM! Correct aggregation of per-tile metrics.Using
std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2)correctly sizes the workspace to accommodate both GEMM stages.
1227-1240: LGTM! Conditional tile size support correctly handles dtype variations.The logic properly adds tile size 128 for FP4 activation types excluding Bfloat16, aligning with the similar logic in
trtllm_get_default_moe_configsandtrtllm_get_valid_moe_configs. The runner construction correctly passes all required parameters includinggated_act_type.
1291-1336: LGTM! Comprehensive config enumeration for autotuning.This function correctly enumerates all valid (tile_N, config) pairs by:
- Computing appropriate tile sizes based on workload and dtype
- Building the correct runner type (FP8 block-scale vs FP4) based on dtype combinations
- Gathering valid configs for each tile size
- Returning flattened pairs for the autotuner to benchmark
The conditional runner construction at lines 1317-1328 properly handles the different constructor signatures for FP8 vs FP4 modes.
721-731: No changes required—tile size exclusion is intentional.The codebase explicitly differentiates tile size support by quantization mode. FP8 per-tensor quantization (both weights and activation E4m3, not DeepSeek FP8) includes tile size 128, while FP4 conditionally includes 128 based on activation type, and the logic at lines 1268–1278 shows FP8 block-scale (DeepSeek FP8) does not include 128. This design choice is working as intended—no documentation or comment addition needed.
ed7ecaa to
143c296
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
62-75: ** Guard against division by zero and integer overflow.**As flagged in a previous review, line 65 still lacks protection against
num_local_experts == 0and potential integer overflow innum_tokens * top_k. The validation at lines 149-150 does not guaranteelocal_num_experts > 0.Apply the previously suggested fix:
+ // Add guard after line 149 in the calling function + TVM_FFI_ICHECK_GT(local_num_experts, 0) + << "local_num_experts must be greater than zero"; + std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums, int64_t const num_tokens, int64_t const top_k, int64_t const num_local_experts) { - float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts; + float const avg_tokens_per_expert = static_cast<float>(num_tokens) * top_k / num_local_experts;flashinfer/fused_moe/core.py (1)
1320-1322: ** Bug: Unconditional output reallocation discards caller-provided buffer.**The function accepts
output: torch.Tensoras a parameter (line 1294) but immediately overwrites it with a new allocation at lines 1320-1322, breaking the API contract for in-place operations.Apply this fix to use the provided buffer:
- # Create workspace buffers - output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) + # Validate provided output buffer + check_shape_dtype_device( + output, (num_tokens, hidden_size), torch.bfloat16, hidden_states.device, "output" + )
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
372-374: Consider making mSupportedTileN a static constant.The vector
mSupportedTileNis initialized identically on every call. Making itstatic constwould avoid repeated allocations and improve performance.Apply this diff:
- std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128}; + static const std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(2 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(12 hunks)flashinfer/fused_moe/core.py(20 hunks)flashinfer/jit/core.py(2 hunks)tests/conftest.py(1 hunks)tests/moe/test_trtllm_gen_fused_moe.py(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
flashinfer/utils.py (1)
device_support_pdl(568-572)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (2)
top_k(270-270)maybeGetMinTokenCount(55-60)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(88-91)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (8)
GatedActType(141-158)top_k(270-270)intermediate_size(275-275)local_num_experts(277-277)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)csrc/trtllm_fused_moe_kernel_launcher.cu (6)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)trtllm_fp4_block_scale_moe(1168-1259)trtllm_fp4_block_scale_moe(1168-1181)flashinfer/utils.py (1)
device_support_pdl(568-572)flashinfer/autotuner.py (3)
AutoTuner(335-780)get(362-365)choose_one(400-525)flashinfer/jit/core.py (1)
warning_once(78-83)
tests/moe/test_trtllm_gen_fused_moe.py (4)
flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/autotuner.py (1)
autotune(251-262)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)flashinfer/fused_moe/core.py (5)
trtllm_fp8_block_scale_moe(1736-1815)trtllm_fp8_per_tensor_scale_moe(1658-1733)RoutingMethodType(57-71)WeightLayout(160-167)GatedActType(172-176)
🪛 Ruff (0.14.2)
flashinfer/jit/core.py
85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
🔇 Additional comments (17)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
201-206: LGTM! Proper workspace size calculations.The per-tile max padded token calculations correctly account for minimum token requirements using
maybeGetMinTokenCount, ensuring workspace buffers are adequately sized for both GEMM operations.
1261-1336: LGTM! Config discovery functions properly support per-tile autotuning.The updated signatures for
trtllm_get_default_moe_configsandtrtllm_get_valid_moe_configscorrectly integrate with the new autotuning framework by returning tile_N and config pairs, enabling dynamic configuration selection at runtime.flashinfer/fused_moe/core.py (5)
904-936: LGTM! Well-designed parameter expansion for autotuning.The new
gated_act_type,use_shuffled_weight, andweight_layoutparameters properly extend the MoERunner to support diverse MoE configurations. The validation at lines 928-936 correctly enforces that non-default shuffle/layout options require FP8 block-scale mode.
1017-1116: LGTM! Clear routing logic for different MoE variants.The forward method properly dispatches to FP8 block-scale, FP8 per-tensor, or FP4 operations based on dtype combinations. The tactic encoding ([-1, -1] for default) aligns with the backend's config index expectations.
1168-1227: LGTM! Proper autotuning integration for FP8 per-tensor MoE.The new operation correctly:
- Creates necessary workspace buffers
- Instantiates a MoERunner with appropriate FP8 per-tensor settings
- Uses AutoTuner to select the best tactic
- Returns the computed output tensor
1508-1519: LGTM! FP4 runner correctly uses fixed layout settings.The hardcoded
weight_layout=WeightLayout.MajorKanduse_shuffled_weight=Trueat lines 1517-1518 are appropriate since FP4 operations require these specific settings, as enforced by the C++ validation logic in the kernel launcher.
1707-1712: LGTM! Appropriate deprecation warning for tile_tokens_dim.The deprecation warning properly alerts users about the planned removal of
tile_tokens_dimin v0.5.0, giving them time to adapt their code. Usingwarning_onceprevents log spam.benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
14-14: LGTM! Benchmark correctly updated for autotuning API.The removal of
calculate_tile_tokens_dimimport and passingNonefortile_tokens_dimproperly aligns with the new autotuning-driven MoE execution path where tile dimensions are selected dynamically.Also applies to: 136-136
flashinfer/jit/core.py (1)
64-83: LGTM! Well-designed once-only logging methods.The new
debug_once,info_once, andwarning_oncemethods provide a clean API for deduplication of log messages. Thestacklevel=3parameter correctly shows the original caller's location in logs.tests/conftest.py (1)
140-168: LGTM! Improved test hook follows pytest best practices.The migration from
@pytest.hookimpl(tryfirst=True)to@pytest.hookimpl(wrapper=True)with yield-based delegation is the recommended pattern. The error handling correctly manages both OOM and MissingJITCacheError scenarios while preserving test skip behavior.tests/moe/test_trtllm_gen_fused_moe.py (7)
17-17: LGTM: Import changes align with autotuning approach.The removal of
calculate_tile_tokens_dimimport is correct since tile_tokens_dim is now autotuned rather than statically calculated.Also applies to: 48-48
205-205: LGTM: Correct tile_tokens_dim autotuning in FP4 CUDA graph path.Passing
tile_tokens_dim=Noneallows the autotuner to select optimal values during graph capture (wrapped inautotune(True)at line 108).
772-796: LGTM: Proper autotuning integration for FP8 block scale MoE.The
autotune(True)context manager correctly enables autotuning, and passingtile_tokens_dim=Noneallows the autotuner to select optimal tile sizes based on runtime characteristics.
945-972: LGTM: Consistent autotuning for FP8 per-tensor MoE.The autotuning approach matches the FP8 block scale implementation and correctly delegates tile size selection to the autotuner.
1837-1897: LGTM: Good refactor to centralize skip logic.The new
skip_checkshelper function consolidates test skip conditions and addscompatible_intermediate_sizefiltering to reduce test execution time while maintaining coverage of valid configurations.
1899-2086: LGTM: Excellent test refactoring for maintainability.The
run_moe_testhelper function effectively consolidates the common test workflow (quantization → reference computation → actual computation → validation), eliminating duplication across test functions and improving maintainability.
2301-2401: Clarify intent for commented-out tests.The TopK and Llama4 routing tests are commented out. Please clarify:
- Are these temporarily disabled for testing speed optimization?
- Should they be re-enabled in a future PR?
- Or can they be removed entirely if no longer needed?
Consider adding a TODO comment explaining the rationale if these tests should be restored later.
| @functools.lru_cache(maxsize=None) | ||
| def _print_once(self, log_method, msg: str, *args: Hashable) -> None: | ||
| """Helper method to log messages only once per unique (msg, args) combination.""" | ||
| # Note: stacklevel=3 to show the caller's location, not this helper method | ||
| log_method(msg, *args, stacklevel=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential memory leak with lru_cache on instance method.
Using @functools.lru_cache on the instance method _print_once will cache self references, which can prevent garbage collection. Since FlashInferJITLogger is a singleton (line 92), this is less critical but still not ideal.
Consider using a module-level cache instead:
+_LOGGER_CACHE: Dict[Tuple[str, Hashable], bool] = {}
+
class FlashInferJITLogger(logging.Logger):
# ...
- @functools.lru_cache(maxsize=None)
- def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
+ def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
"""Helper method to log messages only once per unique (msg, args) combination."""
- # Note: stacklevel=3 to show the caller's location, not this helper method
- log_method(msg, *args, stacklevel=3)
+ cache_key = (msg, args)
+ if cache_key not in _LOGGER_CACHE:
+ _LOGGER_CACHE[cache_key] = True
+ log_method(msg, *args, stacklevel=3)Based on static analysis hint.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.2)
85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
1899-2086: Well-structured test helper reduces duplication.The
run_moe_testhelper effectively consolidates the common test workflow (quantize weights → quantize inputs → compute reference → compute production → compare). This improves maintainability and reduces code duplication across the parameterized tests.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/moe/test_trtllm_gen_fused_moe.py(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_trtllm_gen_fused_moe.py (4)
flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/fused_moe/core.py (2)
trtllm_fp8_block_scale_moe(1736-1815)trtllm_fp8_per_tensor_scale_moe(1658-1733)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/moe/test_trtllm_gen_fused_moe.py (6)
17-48: LGTM! Import changes align with autotuning approach.The removal of
calculate_tile_tokens_dimfrom imports is consistent with the PR's goal of autotuningtile_tokens_diminstead of using static calculation.
772-796: LGTM! Autotuning correctly integrated for FP8 block scale MoE.The
autotune(True)context manager correctly wraps the kernel invocation, andtile_tokens_dim=None(Line 791) properly delegates tile dimension selection to the autotuner.
945-972: LGTM! Autotuning correctly integrated for FP8 per-tensor MoE.The
autotune(True)context manager correctly wraps the kernel invocation, andtile_tokens_dim=None(Line 970) properly delegates tile dimension selection to the autotuner.
837-1886: Good refactoring to consolidate test logic.The new
skip_checkshelper consolidates skip logic and adds validation forcompatible_intermediate_size(Lines 1882-1885), which appears to be a new constraint related to autotuning tile dimensions for different intermediate sizes.
2088-2439: Comprehensive test coverage with routing-specific constraints.The refactored parameterized tests provide good coverage across different routing methods (DeepSeekV3, Renormalize, TopK, Llama4) and quantization modes. The
compatible_intermediate_sizeconstraints vary by routing method, which appears intentional for the autotuning feature.Note: The constraints limit testing to specific intermediate sizes per routing method. Ensure these constraints are documented and align with production use cases.
If the
compatible_intermediate_sizeconstraints are temporary limitations during development, consider adding a TODO comment or issue reference to track expanding support.
108-127: No issues found—autotune context pattern is correct as implemented.The original review comment is based on a misunderstanding of how the autotuner works. The pattern in lines 108–127 is correct:
Warmup (line 108):
autotune(True)context enables kernel configuration discovery and caching viaAutoTuner.profiling_cache.Graph capture (lines 119–128): The computation runs without an active
autotunecontext, which is correct. After the warmup context exits,is_tuning_modereverts to its previous state, but cached configurations persist inprofiling_cacheand are used during capture.This pattern matches the established benchmark patterns throughout the codebase (e.g.,
benchmarks/routines/moe.pylines 726–728), where warmup discovers configs and subsequent execution uses those cached results without re-tuning. Wrapping the capture inautotune(True)would trigger unnecessary re-tuning during capture, which is not the intended design.The code at lines 172–210 is also correct;
tile_tokens_dim=Noneis properly passed totrtllm_fp4_block_scale_moe.Likely an incorrect or invalid review comment.
a54cfef to
4edbda3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
146-151: Guard local_num_experts > 0.Add an explicit check before any use that depends on it (e.g., routing and tiling selection).
TVM_FFI_ICHECK_LE(local_num_experts + local_expert_offset, num_experts) << "num_experts must be greater or equal to local_num_experts + local_expert_offset"; + TVM_FFI_ICHECK_GT(local_num_experts, 0) << "local_num_experts must be > 0";
474-477: Apply the samelocal_num_experts > 0guard to FP8 block-scale entry.Ensure consistent validation across entry points.
TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k"; + TVM_FFI_ICHECK_GT(local_num_experts, 0) << "local_num_experts must be > 0";Also applies to: 700-708
♻️ Duplicate comments (3)
flashinfer/jit/core.py (1)
64-90: Replace lru_cache on instance method; use a small module-level dedupe set instead.
@functools.lru_cacheon a bound method can retainselfand unboundedly grow the cache. Also, keying on thelog_methodobject is unnecessary. Use a module-level set keyed by (method_name, msg, args) to dedupe and avoid leaks.Apply:
@@ - def warning_once(self, msg: str, *args: Hashable) -> None: + def warning_once(self, msg: str, *args: Hashable) -> None: @@ - @functools.lru_cache(maxsize=None) - def _print_once(self, log_method, msg: str, *args: Hashable) -> None: - """Helper method to log messages only once per unique (msg, args) combination.""" - # Note: stacklevel=3 to show the caller's location, not this helper method - log_method(msg, *args, stacklevel=3) + def _print_once(self, log_method, msg: str, *args: Hashable) -> None: + """Helper method to log messages only once per unique (msg, args) combination.""" + # Use method name to avoid holding references to bound methods + cache_key = (getattr(log_method, "__name__", "log"), msg, args) + if cache_key in _LOG_ONCE_KEYS: + return + _LOG_ONCE_KEYS.add(cache_key) + # Note: stacklevel=3 to show the caller's location, not this helper method + log_method(msg, *args, stacklevel=3)Add near imports (top of file):
+from typing import Hashable, Tuple +_LOG_ONCE_KEYS: set[Tuple[str, str, tuple[Hashable, ...]]] = set()Based on static analysis hint (B019).
flashinfer/fused_moe/core.py (1)
1319-1328: Don’t reallocateoutput; validate and use the provided buffer.Overwriting
outputbreaks out-arg semantics and wastes memory.- # Create workspace buffers - output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) + # Validate provided output buffer (shape/dtype/device) + check_shape_dtype_device( + output, (num_tokens, hidden_size), torch.bfloat16, hidden_states.device, "output" + )Also ensure
inputs = [...]continues to pass the validatedoutput(no changes needed).csrc/trtllm_fused_moe_kernel_launcher.cu (1)
62-75: Fix overflow and div-by-zero in computeSelectedTileN.
num_tokens * top_kcan overflow before float cast;num_local_expertscan be zero. Cast first and clamp denominator.- float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts; + int64_t denom = std::max<int64_t>(1, num_local_experts); + // compute in double to avoid overflow/precision loss, then cast + double avg = static_cast<double>(num_tokens) * static_cast<double>(top_k) + / static_cast<double>(denom); + float const avg_tokens_per_expert = static_cast<float>(avg);
🧹 Nitpick comments (5)
flashinfer/fused_moe/core.py (2)
1179-1188: Minor: avoid unused workspace tensors unless required by tuner.
topk_idsandexpert_weightsare allocated but unused by the C++ call here. If they are only for shaping the tuning profile, consider lazy-allocating them under tuning mode to reduce overhead.
1001-1045: Constructor guard reads clearer with positive condition.The current assertion triggers when (not shuffled) OR (layout != MajorK). Consider inverting for readability and to avoid accidental future regressions.
- if ( - not self.use_shuffled_weight - or self.weight_layout != WeightLayout.MajorK - ): - assert ( - self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3 - ), ( - "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale" - ) + if self.use_shuffled_weight and self.weight_layout == WeightLayout.MajorK: + pass + else: + assert self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3, ( + "Only FP8 block scale supports (not shuffled) or non-MajorK layouts" + )tests/conftest.py (1)
97-125: Use fully qualified name for compile cache key to avoid collisions.
_TORCH_COMPILE_CACHE[fn.__name__]can collide across modules. You already computefullname; use it for the cache key.- compiled = _TORCH_COMPILE_CACHE.get(fullname) + compiled = _TORCH_COMPILE_CACHE.get(fullname) @@ - _TORCH_COMPILE_CACHE[fn.__name__] = compiled + _TORCH_COMPILE_CACHE[fullname] = compiledtests/moe/test_trtllm_gen_fused_moe.py (2)
96-115: Avoid autotuning during graph capture to reduce capture-time overhead.You warm up with autotune before capture (good). During capture, ensure tuning is disabled to keep the graph minimal and deterministic.
- with torch.cuda.stream(torch_stream): + with torch.cuda.stream(torch_stream), autotune(False): self.output_tensor = self._run_moe_computation(runtime_args)
1846-1854: Clarify skip message for unsupported architectures.The message says “only guaranteed to work on SM100 and SM103,” but the condition skips when major != 10. Consider making the message explicit about the detected arch.
- if compute_capability[0] not in [10]: - pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + if compute_capability[0] != 10: + pytest.skip(f"Requires Blackwell (SM10x). Detected SM{compute_capability[0]}{compute_capability[1]}.")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(2 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(12 hunks)flashinfer/fused_moe/core.py(20 hunks)flashinfer/jit/core.py(2 hunks)tests/conftest.py(1 hunks)tests/moe/test_trtllm_gen_fused_moe.py(14 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (1)
maybeGetMinTokenCount(55-60)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(88-91)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (9)
GatedActType(141-158)top_k(270-270)intermediate_size(275-275)hidden_size(265-265)local_num_experts(277-277)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)csrc/trtllm_fused_moe_kernel_launcher.cu (6)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)trtllm_fp4_block_scale_moe(1168-1259)trtllm_fp4_block_scale_moe(1168-1181)flashinfer/utils.py (1)
device_support_pdl(568-572)flashinfer/autotuner.py (3)
AutoTuner(335-780)get(362-365)choose_one(400-525)flashinfer/jit/core.py (1)
warning_once(78-83)
tests/moe/test_trtllm_gen_fused_moe.py (5)
flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/autotuner.py (1)
autotune(251-262)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)flashinfer/fused_moe/core.py (3)
trtllm_fp8_block_scale_moe(1736-1815)trtllm_fp8_per_tensor_scale_moe(1658-1733)RoutingMethodType(57-71)include/flashinfer/trtllm/fused_moe/runner.h (5)
num_experts(263-263)top_k(270-270)intermediate_size(275-275)RoutingMethodType(37-136)hidden_size(265-265)
🪛 Ruff (0.14.2)
flashinfer/jit/core.py
85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
4edbda3 to
0316720
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/conftest.py (1)
98-121: Fix cache-key mismatch in torch.compile monkeypatch (prevents perpetual recompiles).You read from _TORCH_COMPILE_CACHE using fullname but write using fn.name. This guarantees a miss every call and re-compilation.
Apply:
- def wrapper(*args, **kwargs): - compiled = _TORCH_COMPILE_CACHE.get(fullname) + def wrapper(*args, **kwargs): + cache_key = fullname + compiled = _TORCH_COMPILE_CACHE.get(cache_key) if compiled is None: ... - _TORCH_COMPILE_CACHE[fn.__name__] = compiled + _TORCH_COMPILE_CACHE[cache_key] = compiled return compiled(*args, **kwargs)tests/moe/test_trtllm_gen_fused_moe.py (1)
1259-1282: Fix Tensor truthiness bug in check_accuracy (prevents runtime error).mismatch_percent is a 0‑dim Tensor; using it in an if raises “bool value of Tensor is ambiguous”.
Apply:
- count = torch.sum(left > right) - mismatch_percent = count / a.numel() - if mismatch_percent > 1 - percent: + count = (left > right).sum().item() + mismatch_percent = count / float(a.numel()) + if mismatch_percent > (1 - percent): print(a) print(b) - raise Exception( + raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " f"(threshold: {1 - percent:.4f})" )
♻️ Duplicate comments (2)
flashinfer/jit/core.py (1)
85-89: Memory leak concern with lru_cache already flagged.Using
@functools.lru_cacheon the instance method_print_oncecan prevent garbage collection by cachingselfreferences. While the singleton pattern (line 92) mitigates the issue, a module-level cache would be cleaner.Based on static analysis hint.
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
62-75: Division by zero concern already flagged.The division at line 65 (
/ num_local_experts) and integer overflow fromnum_tokens * top_kbefore float cast have been identified in previous reviews.
🧹 Nitpick comments (5)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
372-374: Consider extracting supported tile size logic to reduce duplication.The supported tile sizes and conditional logic for adding 128 appear in multiple places (lines 372, 721, 1227-1230, 1268-1278, 1300-1310). Consolidating this into a shared helper function would improve maintainability.
flashinfer/fused_moe/core.py (1)
1520-1521: Hardcoded FP4 MoE parameters may limit flexibility.Lines 1520-1521 hardcode
weight_layout=WeightLayout.MajorKanduse_shuffled_weight=Truefor FP4 block scale MoE. Consider whether these should be exposed as parameters to match the flexibility offered in FP8 block scale MoE (lines 1307-1308).tests/conftest.py (1)
123-125: Use logging instead of print for test infra noise.Swap print(...) with pytest’s terminalreporter or logging to keep CI output clean. Minor but helps readability.
- print("Applied torch.compile to", fullname) + import logging + logging.getLogger(__name__).debug("Applied torch.compile to %s", fullname)tests/moe/test_trtllm_gen_fused_moe.py (2)
1837-1885: Skip gating looks fine; minor guard suggestion.Consider early‑returning when CUDA is unavailable to avoid ValueError from get_compute_capability, improving skip messaging in CPU-only environments.
- compute_capability = get_compute_capability(torch.device(device="cuda")) + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for these tests.") + compute_capability = get_compute_capability(torch.device("cuda"))
2088-2444: Param spaces look heavy; consider trimming for CI stability.Given recent pipeline failure, you may want to temporarily reduce the Cartesian product (e.g., fewer intermediate_size values) behind an env flag (FAST_CI) to improve pass rate while keeping local/full runs exhaustive.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py(2 hunks)csrc/trtllm_fused_moe_kernel_launcher.cu(12 hunks)flashinfer/fused_moe/core.py(20 hunks)flashinfer/jit/core.py(2 hunks)tests/conftest.py(1 hunks)tests/moe/test_trtllm_gen_fused_moe.py(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
flashinfer/utils.py (1)
device_support_pdl(568-572)
tests/moe/test_trtllm_gen_fused_moe.py (5)
flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/autotuner.py (1)
autotune(251-262)csrc/trtllm_fused_moe_kernel_launcher.cu (4)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)flashinfer/fused_moe/core.py (3)
trtllm_fp8_block_scale_moe(1739-1818)trtllm_fp8_per_tensor_scale_moe(1661-1736)RoutingMethodType(58-72)include/flashinfer/trtllm/fused_moe/runner.h (5)
num_experts(263-263)top_k(270-270)intermediate_size(275-275)RoutingMethodType(37-136)hidden_size(265-265)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (8)
GatedActType(141-158)top_k(270-270)intermediate_size(275-275)local_num_experts(277-277)num_experts(263-263)n_group(271-271)topk_group(273-273)local_expert_offset(276-276)csrc/trtllm_fused_moe_kernel_launcher.cu (6)
trtllm_fp8_block_scale_moe(700-751)trtllm_fp8_block_scale_moe(700-707)trtllm_fp8_per_tensor_scale_moe(343-403)trtllm_fp8_per_tensor_scale_moe(343-351)trtllm_fp4_block_scale_moe(1168-1259)trtllm_fp4_block_scale_moe(1168-1181)flashinfer/utils.py (1)
device_support_pdl(568-572)flashinfer/autotuner.py (3)
AutoTuner(335-784)get(362-365)choose_one(400-529)flashinfer/jit/core.py (1)
warning_once(78-83)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (4)
top_k(270-270)maybeGetMinTokenCount(55-60)mUseDeepSeekFp8(285-342)do_finalize(295-296)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
dtypeGetNumBits(88-91)
🪛 GitHub Actions: pre-commit
tests/moe/test_trtllm_gen_fused_moe.py
[error] 1-1: pre-commit hook ruff-format reformatted 1 file and exited with code 1. Run 'pre-commit run --all-files' locally to apply formatting fixes.
🪛 Ruff (0.14.2)
flashinfer/jit/core.py
85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
flashinfer/fused_moe/core.py (3)
1183-1191: Output buffer reallocation issue already identified.The function unconditionally reallocates
outputeven though it's not a parameter here (it's created locally). However, the registered op signature doesn't include output as a parameter, so this allocation is correct. Previous concerns about output reallocation apply to other functions where output is a parameter.
1710-1715: Good deprecation messaging.The deprecation warning for
tile_tokens_dimis clear and uses the appropriatelogger.warning_onceto avoid log spam.
1323-1331: The output parameter reallocation does not break the API contract.The review comment assumes external callers provide an output buffer, but they don't. The public API function
trtllm_fp8_block_scale_moe()(line 1739) has nooutputparameter. It creates its own buffer internally and passes it to the internaltrtllm_fp8_block_scale_moe_op()function (line 1288). The internal _op function'soutputparameter is only called by this internal wrapper—not by external code. The reallocation inside _op does not violate any external API contract.Likely an incorrect or invalid review comment.
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
136-136: Correctly aligned with tile_tokens_dim deprecation.Passing
Nonefortile_tokens_dimis consistent with the autotuning changes and deprecation messaging introduced in the core MoE implementation.tests/moe/test_trtllm_gen_fused_moe.py (3)
206-210: Comment contradicts behavior; clarify tuple vs tensor.The comment says “Extract tensor from tuple” but returns output unchanged. Either return output[0] or update the comment. Given FP4Moe.call_moe indexes [0], prefer comment fix.
[suggest_minor_issue]
- return output # Extract tensor from tuple + return output # Kernel may return a tuple; caller (FP4Moe.call_moe) extracts [0]
772-796: Good: autotuner usage aligns with tile_tokens_dim deprecation.Passing None for tile_tokens_dim and using autotune(True) matches the new selection flow.
Please confirm CI has the autotuner enabled logs (e.g., “[Autotuner]: Autotuning process starts ...”) to ensure the context manager is effective.
2059-2074: Nice: enable_pdl propagated to FP8 block-scale path.This matches the new kernel signature; FP8 per‑tensor path keeps default None which is acceptable until needed.
tests/conftest.py (1)
140-167: No changes required—current exception handling is correct.MissingJITCacheError is a subclass of RuntimeError, so it will be caught by the existing
except (torch.cuda.OutOfMemoryError, RuntimeError)clause. Theisinstancecheck that follows correctly identifies and handles it. Adding it to the exception tuple would be redundant.Likely an incorrect or invalid review comment.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
[SUCCESS] Pipeline #37489603: 13/17 passed |
📌 Description
tile_tokens_dim, tune in a range of[max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2), min(128, tile_token_dim*4)]test_trtllm_gen_fused_moe.py.conftest.pyto speed up test, previously usetry_firstwhich introduce duplicate run🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Deprecations
Tests