Skip to content

Conversation

@jiahanc
Copy link
Collaborator

@jiahanc jiahanc commented Oct 25, 2025

📌 Description

  • Update the autotune logic in trtllm-gen moe. Instead of using a fixed tile_tokens_dim, tune in a range of
    [max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2), min(128, tile_token_dim*4)]
  • Add FP8 MOE autotune logic, initial PR Trtllm-gen Fp8 MoE Autotunner #1494 from @aleozlx, update logic to sync with new autotuner.
  • Update logic in test_trtllm_gen_fused_moe.py.
  • Update the conftest.py to speed up test, previously use try_first which introduce duplicate run
  • Add log_once in logger

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Runtime autotuning with per-tile dynamic routing and selectable MoE runner options (gated activation, shuffled-weight, weight-layout).
    • One-time (deduplicated) logging helpers added to JIT logger.
  • Deprecations

    • tile_tokens_dim removed from new paths and marked deprecated in legacy entry points; new tuning parameters introduced for autotuning.
  • Tests

    • Tests refactored for autotuning/routing with new helpers and improved handling/reporting for missing JIT cache.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 25, 2025

Walkthrough

The PR removes tile_tokens_dim propagation and replaces it with autotuning-driven MoE flows: per‑tile runner/config selection, tactic/config indices, new public MoERunner options (gated activation, shuffled‑weight flag, weight‑layout); tests and benchmarks pass None for tile_tokens_dim and use autotuning paths.

Changes

Cohort / File(s) Summary
Benchmark update
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Removed calculate_tile_tokens_dim import and computation; benchmark now passes None for tile_tokens_dim into autotuner flows.
C++ kernel launcher refactor
csrc/trtllm_fused_moe_kernel_launcher.cu
Added helpers (nextPowerOfTwo, computeSelectedTileN); introduced per‑tile MoE::Runner construction and dynamic per‑tile config selection; launcher APIs updated to accept a moe_runner reference and per‑tile config_index arrays; config discovery APIs changed to accept dtype/activation params and return per‑tile config indices; added <set> and <unordered_map>.
Python MoE core & API
flashinfer/fused_moe/core.py
Removed tile_tokens_dim from MoERunner constructor; added public options gated_act_type, use_shuffled_weight, weight_layout; integrated AutoTuner and tactic/config selection (tune_max_num_tokens); updated op signatures to drop or deprecate tile_tokens_dim and route tactics/config indices to backend.
Tests & fixtures
tests/moe/test_trtllm_gen_fused_moe.py, tests/conftest.py
Tests no longer compute/consume tile_tokens_dim; kernel calls pass None for tile_tokens_dim; added skip_checks and run_moe_test helpers; routing configs include compatible_intermediate_size; pytest hook changed to @pytest.hookimpl(wrapper=True) and missing JIT cache handling/reporting added.
JIT logger enhancement
flashinfer/jit/core.py
Added deduplicated logging helpers debug_once, info_once, warning_once backed by an LRU‑cached _print_once helper to suppress duplicate messages.

Sequence Diagram(s)

sequenceDiagram
    actor Caller
    participant Py as Python: MoE core (AutoTuner)
    participant Auto as AutoTuner
    participant CPP as C++ Launcher
    participant Runners as Per‑tile MoE Runners

    Note over Caller,Py: Old flow (static tile_tokens_dim)
    Caller->>Py: call_moe(..., tile_tokens_dim=N)
    Py->>CPP: launcher(..., tile_tokens_dim=N)
    CPP->>Runners: select single runner/config
    Runners-->>CPP: run fixed config

    Note over Caller,Py: New flow (autotune + per‑tile)
    Caller->>Py: call_moe(..., tune_max_num_tokens)
    Py->>Auto: select_tactic(instance_key, max_tokens)
    Auto-->>Py: tactic / Array<int64_t> config_index
    Py->>CPP: launcher(..., Array<int64_t> config_index, moe_runner)
    CPP->>CPP: computeSelectedTileN(...) -> tile groups
    CPP->>Runners: build/select per‑tile runner(s) using config_index
    Runners-->>CPP: chosen runner per tile
    CPP->>CPP: allocate per‑tile workspace & execute per‑tile runner/config
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Pay extra attention to:
    • C++ ↔ Python ABI/signature changes and binding updates for launcher functions.
    • Per‑tile workspace sizing, padding, and memory-safety in the launcher.
    • Correct propagation, ordering, and validation of tactic/config indices from AutoTuner to C++.
    • Tests/benchmarks now passing None for tile_tokens_dim and exercising autotuner flows.
    • Pytest wrapper change and JIT‑cache report file locking and reporting logic.

Possibly related PRs

Suggested reviewers

  • aleozlx
  • yongwww
  • cyx-6
  • wenscarl
  • kahyunnam

Poem

🐰
I hopped through tiles both big and small,
Dropped the static dim and let tuners call.
Per‑tile runners jiggle, configs fall in line,
Kernels hum tuned rhythms—neat and fine.
A cheerful hop—autotune's mine!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 51.06% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The PR title "feat: autotune tile_tokens_dim in trtllm-gen MOE" is directly related to the main objective of the changeset. The changes across multiple files consistently implement autotuning logic for tile_tokens_dim over a specified range instead of using a fixed value, and this is clearly captured in the title. The title is concise, specific, and accurately conveys the primary feature being introduced to someone scanning the commit history.
Description Check ✅ Passed The pull request description is largely complete and well-structured. The 📌 Description section provides clear, specific details about what the PR does, including the autotune logic changes for tile_tokens_dim ranges, FP8 MOE autotune additions, test and conftest updates, and logger enhancements. The ✅ Pre-commit Checks section has all three items properly checked, and the 🧪 Tests section confirms both that tests have been updated and are passing (both boxes checked). The only notable omission is the 🔍 Related Issues section, which contains only the template placeholder comment without any actual issue links, though the author does reference PR #1494 in the description itself. The Reviewer Notes section, which is optional, is appropriately left empty.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@jiahanc jiahanc force-pushed the trtllm_gen_moe_autotune branch from 7e0ed1b to 0cd4848 Compare October 28, 2025 03:37
@jiahanc jiahanc marked this pull request as ready for review October 28, 2025 03:38
@jiahanc
Copy link
Collaborator Author

jiahanc commented Oct 28, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !92 has been created, and the CI pipeline #37415173 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

99-101: Fix FC2 bias shape.

gemm2 (FC2) bias must be sized by hidden_size, not 2*intermediate_size. Current shape will fail C++ checks and/or broadcast incorrectly.

Suggested diff:

-    bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
-    bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
+    bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10
+    bias2 = torch.randn(num_experts, hidden_size, device=device) * 10

Also applies to: 120-126

csrc/trtllm_fused_moe_kernel_launcher.cu (2)

195-206: args.mDtypeOut may be uninitialized. Initialize explicitly to Bfloat16.

maybeGetMinTokenCount reads args.mDtypeOut bits; in FP8/FP4 block-scale launchers it’s not set. Set to BF16 to match output dtype.

 // In trtllm_fp8_block_scale_moe_launcher
   auto dtype = hidden_states.dtype();
   if (dtype == dl_float16) {
     args.mDtypeElt = btg::Dtype::Fp16;
   } else if (dtype == dl_bfloat16) {
     args.mDtypeElt = btg::Dtype::Bfloat16;
   } else if (dtype == dl_float8_e4m3fn) {
     args.mDtypeElt = btg::Dtype::E4m3;
   } else {
     TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE.";
   }
+  args.mDtypeOut = btg::Dtype::Bfloat16;

 // In trtllm_fp4_block_scale_moe_launcher (after args setup)
+  // Output dtype is BF16 for block-scale path
+  args.mDtypeOut = btg::Dtype::Bfloat16;

Also applies to: 522-531, 919-925


548-556: gemm1_output_scale uses wrong padded token count.

Should use max_num_padded_tokens_gemm1 (like activation_output_scale) to match gemm1_output. Using max_num_padded_tokens can under-allocate.

-  Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens},
+  Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens_gemm1},
                                            dl_float32, hidden_states.device());
🧹 Nitpick comments (5)
tests/moe/test_trtllm_gen_fused_moe.py (1)

205-209: Deprecation path exercised as intended.

Explicitly passing tile_tokens_dim=None to trtllm_fp4_block_scale_moe keeps tests aligned with autotune. Consider adding a brief comment noting autotuner selection is used when None is given.

flashinfer/fused_moe/core.py (4)

21-21: typing_extensions.deprecated is type-checker only; add runtime warning.

@deprecated won’t emit runtime warnings. Emit warnings.warn on first call to improve UX.

Apply in each deprecated wrapper, e.g.:

+import warnings
...
 @deprecated("tile_tokens_dim is deprecated and will be removed in trtllm_fp8_per_tensor_scale_moe after v0.5.0")
 def trtllm_fp8_per_tensor_scale_moe(...):
+    warnings.warn(
+        "tile_tokens_dim is deprecated and ignored; autotuner selects tactics. "
+        "Support will be removed after v0.5.0.",
+        DeprecationWarning,
+        stacklevel=2,
+    )
     return get_trtllm_moe_sm100_module()...

126-134: Duplicate entry in trtllm_gen_dtype_has_scale.

MxE4m3 appears twice. Remove the duplicate for clarity.

-    if dtype in [
-        DtypeTrtllmGen.MxE4m3,
-        DtypeTrtllmGen.E2m1,
-        DtypeTrtllmGen.MxE2m1,
-        DtypeTrtllmGen.MxE4m3,
-    ]:
+    if dtype in {DtypeTrtllmGen.MxE4m3, DtypeTrtllmGen.E2m1, DtypeTrtllmGen.MxE2m1}:
         return True

1643-1716: Silence ARG001 for unused tile_tokens_dim in deprecated wrapper.

Ruff flags tile_tokens_dim as unused. Delete the name or add a no-op to satisfy the linter without changing the API.

 def trtllm_fp8_per_tensor_scale_moe(..., tile_tokens_dim: int = 8, ...):
-    """FP8 per tensor scale MoE operation.
+    """FP8 per tensor scale MoE operation.
     ...
     """
+    # tile_tokens_dim kept for backward-compat; ignored by autotuner.
+    del tile_tokens_dim  # ruff: ARG001
     return get_trtllm_moe_sm100_module()...

1797-1930: Silence ARG001 and add consistent deprecation note.

Do the same for FP8 block scale and FP4 block scale wrappers. Also emit a warning (as above) for consistency; routed_moe already logs.

 def trtllm_fp8_block_scale_moe(..., tile_tokens_dim: Optional[int] = None, ...):
+    del tile_tokens_dim  # ruff: ARG001
+    warnings.warn("tile_tokens_dim is deprecated and ignored...", DeprecationWarning, stacklevel=2)
     ...
 def trtllm_fp4_block_scale_moe(..., tile_tokens_dim: Optional[int], ...):
+    del tile_tokens_dim  # ruff: ARG001
+    warnings.warn("tile_tokens_dim is deprecated and ignored...", DeprecationWarning, stacklevel=2)
     ...
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ef687e9 and 0cd4848.

📒 Files selected for processing (5)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (12 hunks)
  • flashinfer/fused_moe/core.py (22 hunks)
  • tests/conftest.py (1 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • intermediate_size (275-275)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (11)
  • top_k (270-270)
  • maybeGetMinTokenCount (55-60)
  • intermediate_size (275-275)
  • hidden_size (265-265)
  • mUseDeepSeekFp8 (285-342)
  • local_num_experts (277-277)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
  • do_finalize (295-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (88-91)
flashinfer/fused_moe/core.py (4)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • GatedActType (141-158)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
  • trtllm_fp8_block_scale_moe (699-750)
  • trtllm_fp8_block_scale_moe (699-706)
  • trtllm_fp8_per_tensor_scale_moe (342-402)
  • trtllm_fp8_per_tensor_scale_moe (342-350)
  • trtllm_fp4_block_scale_moe (1175-1266)
  • trtllm_fp4_block_scale_moe (1175-1188)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
flashinfer/autotuner.py (3)
  • AutoTuner (335-780)
  • get (362-365)
  • choose_one (400-525)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py

1738-1738: Unused function argument: tile_tokens_dim

(ARG001)


1825-1825: Unused function argument: tile_tokens_dim

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/conftest.py (2)

140-144: Verify PR scope: Changes don't match PR objectives.

The PR title indicates this should add autotuning for tile_tokens_dim in TRT-LLM MoE kernels, but the only change is a pytest hook refactoring in test infrastructure. Please confirm:

  1. Are there additional files/changes not included in this review?
  2. Is this an infrastructure prerequisite for the main feature?
  3. Should this be split into a separate PR?

140-144: Change is correct and follows pytest best practices.

The migration from tryfirst=True to wrapper=True with yield is the recommended approach for wrapping test execution. Verification confirms:

  • Only one pytest_runtest_call hook exists (no conflicts)
  • Exception handling for OutOfMemoryError and MissingJITCacheError remains intact
  • The wrapper pattern properly delegates to pytest's execution chain while preserving error handling behavior
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

136-143: tile_tokens_dim=None usage is correct.

Passing None here aligns with the deprecated API and delegates selection to the autotuner; no action needed.

tests/moe/test_trtllm_gen_fused_moe.py (1)

1866-1884: Good guards for compatibility and speed.

The added compatible_intermediate_size filters and targeted skips keep CI time reasonable while respecting kernel constraints. LGTM.

Also applies to: 1905-1917, 1930-1933, 1962-1981

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

371-399: Per‑tile autotune path looks correct.

  • Runners constructed for selected tiles.
  • Fallback when config_index contains -1 is well‑handled.
  • computeSelectedTileN centers tile_N on avg tokens per expert with neighbors.
    LGTM.

Also applies to: 720-747, 1234-1266, 1268-1331

flashinfer/fused_moe/core.py (1)

905-1072: Verified: Autotuner plumbing and tactics caching are correctly implemented.

Static analysis confirms all review observations:

  • Caching consistency: valid_tactics_dict properly caches per instance_key tuple (lines 945-967); cache key is comprehensive and includes all relevant parameters
  • refine_tuning_config usage: Correctly invoked before each tuner.choose_one() call to customize token buckets (lines 458, 1164, 1301, 1487)
  • choose_one pattern: Consistently used across all 5 autotuner invocations (lines 481, 495, 1195, 1339, 1522)
  • [-1, -1] fallback: Correctly applied throughout for default tactic selection (lines 1046, 1070, 1105, 1237, 1382, 1584)

The smoke test cannot run in the sandbox environment (missing torch), but code structure verification confirms the implementation is sound.

@jiahanc
Copy link
Collaborator Author

jiahanc commented Oct 28, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !92 has been updated with latest changes, and the CI pipeline #37422250 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
flashinfer/fused_moe/core.py (2)

1026-1033: Consider avoiding repeated tensor allocation during profiling.

The current_hidden_states_scale tensor is allocated fresh on every forward call. During autotuning, this function may be invoked many times. Consider whether the scale tensor could be allocated once and reused, or if this allocation is intentional for isolation between runs.


1659-1731: Consider consistent deprecation warnings across all deprecated functions.

While all four deprecated functions correctly accept but ignore tile_tokens_dim, only trtllm_fp4_block_scale_routed_moe (lines 2048-2053) logs an informational message when the parameter is provided. For consistency and better user guidance, consider adding similar logging to the other three deprecated functions: trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, and trtllm_fp4_block_scale_moe.

Note: The static analysis warnings about unused tile_tokens_dim arguments are expected and correct for deprecated parameters.

Also applies to: 1734-1810, 1813-1945, 1948-2087

tests/moe/test_trtllm_gen_fused_moe.py (1)

1953-1956: Clarify skip reason for RenormNaive test.

The skip reason states "similar to RenormalizeNaive" but this test IS for RoutingMethodType.RenormalizeNaive. Did you mean "similar to Renormalize" (referring to RoutingMethodType.Renormalize)? If so, please update the message for clarity.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0cd4848 and 55f95f0.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/core.py (23 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/fused_moe/core.py (4)
include/flashinfer/trtllm/fused_moe/runner.h (9)
  • GatedActType (141-158)
  • top_k (270-270)
  • intermediate_size (275-275)
  • hidden_size (265-265)
  • local_num_experts (277-277)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
  • trtllm_fp8_block_scale_moe (699-750)
  • trtllm_fp8_block_scale_moe (699-706)
  • trtllm_fp8_per_tensor_scale_moe (342-402)
  • trtllm_fp8_per_tensor_scale_moe (342-350)
  • trtllm_fp4_block_scale_moe (1175-1266)
  • trtllm_fp4_block_scale_moe (1175-1188)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
flashinfer/autotuner.py (3)
  • AutoTuner (335-780)
  • get (362-365)
  • choose_one (400-525)
tests/moe/test_trtllm_gen_fused_moe.py (4)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/fused_moe/core.py (2)
  • trtllm_fp8_block_scale_moe (1737-1810)
  • trtllm_fp8_per_tensor_scale_moe (1662-1731)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp8_block_scale_moe (699-750)
  • trtllm_fp8_block_scale_moe (699-706)
  • trtllm_fp8_per_tensor_scale_moe (342-402)
  • trtllm_fp8_per_tensor_scale_moe (342-350)
🪛 Ruff (0.14.1)
flashinfer/fused_moe/core.py

1754-1754: Unused function argument: tile_tokens_dim

(ARG001)


1841-1841: Unused function argument: tile_tokens_dim

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
flashinfer/fused_moe/core.py (5)

914-937: LGTM: Clean parameter addition with proper validation.

The new parameters (gated_act_type, use_shuffled_weight, weight_layout) are well-defined with sensible defaults, and the validation logic correctly restricts non-shuffled weights or non-MajorK layouts to FP8 block-scale mode only.


954-976: LGTM: Proper cache key extension for autotuning.

The instance_key correctly includes use_shuffled_weight and weight_layout to distinguish different MoE configurations, and the error handling gracefully returns an empty list when tactics retrieval fails.


1194-1204: Verify hardcoded weight_layout assumption.

Line 1202 hardcodes weight_layout=WeightLayout.MajorK for FP8 per-tensor MoE. According to the validation logic in MoERunner.__init__ (lines 933-937), non-MajorK layouts are only supported for FP8 block-scale. Confirm this is the intended behavior and document if this restriction is permanent or planned for future extension.


1313-1375: LGTM: Autotuning integration for FP8 block-scale.

The autotuning logic correctly instantiates the MoERunner with use_shuffled_weight and weight_layout parameters, selects the appropriate tuning config variant, and passes the selected tactic to the underlying C++ operation.


1518-1519: LGTM: Consistent with FP4 MoE layout requirements.

Hardcoding weight_layout=WeightLayout.MajorK and use_shuffled_weight=True aligns with the validation logic that restricts alternative layouts to FP8 block-scale only.

tests/moe/test_trtllm_gen_fused_moe.py (5)

48-48: LGTM: Test updated for autotuning workflow.

The removal of calculate_tile_tokens_dim import and passing tile_tokens_dim=None correctly reflects the transition to autotuning-based tile dimension selection.

Also applies to: 205-205


772-796: LGTM: FP8 block-scale test properly uses autotuning.

The autotune(True) context manager correctly enables autotuning for the FP8 block-scale MoE call, and tile_tokens_dim=None appropriately delegates tile dimension selection to the autotuner.


945-972: LGTM: FP8 per-tensor test properly uses autotuning.

Consistent with the FP8 block-scale changes, the per-tensor test correctly enables autotuning and passes None for tile_tokens_dim.


1869-1869: Verify that compatible_intermediate_size restrictions reflect actual kernel limitations.

The new compatible_intermediate_size constraints reduce test coverage for specific routing configurations. Confirm whether these restrictions are due to:

  1. Known kernel limitations or performance characteristics that make certain combinations unsupported
  2. Test optimization to reduce CI time

If (2), consider documenting this as a test-only constraint to avoid confusion with actual runtime restrictions.

Also applies to: 1887-1887, 1905-1905, 1920-1920, 1935-1935, 1950-1950, 1968-1968, 1983-1983


2080-2083: LGTM: Clear skip logic for intermediate size compatibility.

The skip logic properly enforces the compatible_intermediate_size constraint with a descriptive message.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

350-403: Add bounds checking and validation for config_index array.

Multiple safety concerns with the config_index parameter handling:

  1. Missing bounds check: Lines 385-386 directly access config_index[0] and config_index[1] without verifying the array has at least 2 elements.

  2. Invalid tile_N risk: If config_index[0] contains a tile size not in selected_tile_nums (e.g., 128 when only {8,16,32,64} were selected), line 399's *mRunners[tile_N] will dereference a non-existent map entry, causing undefined behavior.

  3. No config validation: There's no check that config_index[1] is a valid configuration index for the selected runner.

The conditional at lines 388-389 only handles the sentinel value -1, not invalid values.

Apply defensive checks:

+  TVM_FFI_ICHECK_GE(config_index.size(), 2) 
+      << "config_index must have at least 2 elements [tile_N, config]";
+
   // moeConfigIndex corresponds to pair (tile_N, config)
   int64_t tile_N = config_index[0];
   int64_t config = config_index[1];
   // Autotuner has requested a default or 'fallback' config index
   if (tile_N == -1 || config == -1) {
     tile_N = *selected_tile_nums.begin();
     config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
                                                           local_num_experts, num_tokens);
+  } else {
+    TVM_FFI_ICHECK(mRunners.find(tile_N) != mRunners.end())
+        << "Requested tile_N " << tile_N << " not in selected tile sizes";
   }

Note: Similar issues exist in trtllm_fp8_block_scale_moe (lines 734-747) and trtllm_fp4_block_scale_moe (lines 1251-1266).

♻️ Duplicate comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)

21-22: LGTM: Required headers added.

The <set> and <unordered_map> includes are now present and necessary for std::set usage in computeSelectedTileN (line 62) and std::unordered_map usage in runner maps (lines 377, 726, 1242).


700-751: Same config_index validation issues as fp8_per_tensor_scale_moe.

Lines 734-741 have identical validation gaps as the earlier function:

  • No bounds check on config_index array length
  • No validation that config_index[0] exists in mRunners
  • No validation of the config value

1188-1267: Same config_index validation issues (third occurrence).

Lines 1251-1258 have the same validation gaps. The fix should be applied consistently across all three MoE functions.

🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1269-1291: Inconsistent return value for default config getter.

The trtllm_get_default_moe_configs function returns only int64_t config (line 1289-1290), whereas trtllm_get_valid_moe_configs returns Array<Array<int64_t>> with [tile_N, config] pairs (lines 1327-1329).

The single config value is ambiguous without knowing which tile_N it corresponds to. While the implementation uses *selected_tile_nums.begin() (line 1286), callers have no way to determine this tile size.

For API consistency, consider returning Array<int64_t> containing [tile_N, config] instead of just config.

-int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_,
+Array<int64_t> trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_,
                                        bool const useDeepSeekFp8, int64_t const top_k,
                                        int64_t const hidden_size, int64_t const intermediate_size,
                                        int64_t const num_local_experts,
                                        int64_t const gated_act_type, int64_t const num_tokens) {
   auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
   auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
   std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
   if ((dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) &&
       dtype_act != btg::Dtype::Bfloat16) {
     supported_tile_nums.push_back(128);
   }
   std::set<int32_t> selected_tile_nums =
       computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);

+  int64_t tile_N = *selected_tile_nums.begin();
   std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner =
       std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
-          dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(),
+          dtype_act, dtype_weights, useDeepSeekFp8, tile_N,
           static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);

-  return moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
+  int64_t config = moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
                                                 num_local_experts, num_tokens);
+  return {tile_N, config};
 }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 55f95f0 and 64298f3.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (4)
  • top_k (270-270)
  • maybeGetMinTokenCount (55-60)
  • mUseDeepSeekFp8 (285-342)
  • do_finalize (295-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (88-91)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)

42-60: LGTM: Standard power-of-two calculation.

The implementation uses a standard bit manipulation algorithm to compute the next power of two. The logic correctly handles already-power-of-two values and edge cases for small inputs.


86-88: LGTM: Launcher signature updated for per-tile autotuning.

The added moe_runner reference and moeConfigIndex parameters support the per-tile runner selection model. This is an internal launcher function, so the signature change doesn't affect the public API.


201-206: LGTM: Per-stage workspace padding improves precision.

The separate max_num_padded_tokens_gemm1 and max_num_padded_tokens_gemm2 calculations (using maybeGetMinTokenCount) account for different dimension requirements in each GEMM stage. The allocations correctly use stage-specific padding, and the workspace takes the maximum of both.

Also applies to: 224-234, 304-305


377-382: LGTM: Runner lifecycle management is safe.

The mRunners map with unique_ptr ownership correctly manages runner lifetime. Since the launcher is synchronous and doesn't store the runner reference beyond the call, the stack-allocated map's destruction after the function returns is safe.

Comment on lines +62 to +75
std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums,
int64_t const num_tokens, int64_t const top_k,
int64_t const num_local_experts) {
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());

std::set<int32_t> selected_tile_nums = {
std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim,
std::min(supported_tile_nums.back(), tile_tokens_dim * 2),
std::min(supported_tile_nums.back(), tile_tokens_dim * 4)};

return selected_tile_nums;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify input constraints to prevent division by zero and overflow.

Two potential issues:

  1. Division by zero: If num_local_experts is zero, line 65 will cause division by zero. While upstream validation exists for expert counts (line 150), it doesn't explicitly guarantee local_num_experts > 0.

  2. Integer overflow: Line 65 computes num_tokens * top_k as int64_t multiplication before the float cast. For large token counts and top_k values, this could overflow before conversion.

Run the following script to check if callers guarantee local_num_experts > 0:


🏁 Script executed:

#!/bin/bash
# Check validation of local_num_experts parameter
rg -n -B5 -A5 'local_num_experts' --type=cpp | rg -i 'check|assert|validate|icheck'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find all calls to computeSelectedTileN with surrounding context
echo "=== CALLS TO computeSelectedTileN ==="
rg -n -B10 -A3 'computeSelectedTileN\(' --type=cpp

echo -e "\n=== FUNCTION SIGNATURES CONTAINING local_num_experts ==="
rg -n -B2 'local_num_experts' --type=cpp | grep -E 'void|bool|int|auto|std' | head -20

echo -e "\n=== CHECK FOR num_local_experts ZERO/VALIDATION ==="
rg -n -B3 -A3 'num_local_experts\s*[=!<>]' --type=cpp | head -40

Length of output: 393



🏁 Script executed:

#!/bin/bash
# Direct search in the specific file mentioned in review
echo "=== FILE CONTENT AROUND computeSelectedTileN ==="
find . -name "trtllm_fused_moe_kernel_launcher.cu" -exec cat {} \; 2>/dev/null | head -100

echo -e "\n=== SEARCH FOR FUNCTION IN .cu AND .h FILES ==="
rg -l 'computeSelectedTileN' --type=cpp

echo -e "\n=== SEARCH IN ALL FUSEDMOE FILES ==="
find . -path '*/fused_moe*' -name '*.cu' -o -name '*.h' | head -20

Length of output: 5488



🏁 Script executed:

#!/bin/bash
# Get the full file to search for all calls to computeSelectedTileN
echo "=== FIND ALL CALLS TO computeSelectedTileN IN THE FILE ==="
cat ./csrc/trtllm_fused_moe_kernel_launcher.cu | grep -n "computeSelectedTileN"

echo -e "\n=== CHECK FUNCTION PARAMETER VALIDATION IN trtllm_fp8_per_tensor_scale_moe_launcher ==="
cat ./csrc/trtllm_fused_moe_kernel_launcher.cu | sed -n '100,200p'

echo -e "\n=== CHECK FOR ICHECK GUARDS ON local_num_experts ==="
cat ./csrc/trtllm_fused_moe_kernel_launcher.cu | grep -n -A2 -B2 "local_num_experts"

Length of output: 18497


Add guards against division by zero and integer overflow in computeSelectedTileN.

The validation at line 149-150 does not prevent local_num_experts == 0. When zero, line 65 causes division by zero:

float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;

Additionally, the multiplication num_tokens * top_k occurs as int64_t before the float cast. For large values (e.g., num_tokens ≈ 2^40, top_k ≈ 2^25), this overflows.

Required fixes:

  1. Add TVM_FFI_ICHECK_GT(local_num_experts, 0) after line 149
  2. Cast to float before multiplication: static_cast<float>(num_tokens) * top_k / num_local_experts
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_kernel_launcher.cu around lines 62-75 the computation
does integer multiplication then divides by num_local_experts which can cause
integer overflow and a division-by-zero if local_num_experts==0; add a guard
TVM_FFI_ICHECK_GT(local_num_experts, 0) after line 149 to ensure non-zero
experts, and change the averaging to perform the cast before multiplication
(i.e., cast num_tokens to float and multiply by top_k then divide by
num_local_experts) to avoid int64 overflow and then proceed with the existing
clamping and selection logic.

@jiahanc jiahanc requested a review from aleozlx as a code owner October 28, 2025 06:16
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
flashinfer/fused_moe/core.py (1)

1707-1712: Inline deprecation warnings for tile_tokens_dim — good

Warning in the call-sites implements prior guidance to deprecate the field within functions, not via a decorator. Consider also noting the replacement (autotuning via AutoTuner) in the message.

Also applies to: 1784-1789, 1913-1918, 2052-2057

🧹 Nitpick comments (3)
flashinfer/jit/core.py (1)

64-89: Consider alternative approaches to avoid unbounded cache growth.

The implementation is functional and the use of lru_cache for message deduplication is clever. However, the static analysis warning is worth addressing: using lru_cache on instance methods can prevent garbage collection. While this is mitigated here since logger is a global singleton (line 92), the unbounded cache (maxsize=None) will grow indefinitely with each unique log message, which could be problematic in long-running applications with high log diversity.

Consider refactoring to use a set-based approach instead:

+    def __init__(self, name):
+        super().__init__(name)
+        # ... existing initialization ...
+        self._logged_messages = set()
+
     def debug_once(self, msg: str, *args: Hashable) -> None:
         """
         As [`debug`][logging.Logger.debug], but subsequent calls with
         the same message are silently dropped.
         """
-        self._print_once(self.debug, msg, *args)
+        key = (msg, args)
+        if key not in self._logged_messages:
+            self._logged_messages.add(key)
+            self.debug(msg, *args, stacklevel=2)
 
     def info_once(self, msg: str, *args: Hashable) -> None:
         """
         As [`info`][logging.Logger.info], but subsequent calls with
         the same message are silently dropped.
         """
-        self._print_once(self.info, msg, *args)
+        key = (msg, args)
+        if key not in self._logged_messages:
+            self._logged_messages.add(key)
+            self.info(msg, *args, stacklevel=2)
 
     def warning_once(self, msg: str, *args: Hashable) -> None:
         """
         As [`warning`][logging.Logger.warning], but subsequent calls with
         the same message are silently dropped.
         """
-        self._print_once(self.warning, msg, *args)
-
-    @functools.lru_cache(maxsize=None)
-    def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
-        """Helper method to log messages only once per unique (msg, args) combination."""
-        # Note: stacklevel=3 to show the caller's location, not this helper method
-        log_method(msg, *args, stacklevel=3)
+        key = (msg, args)
+        if key not in self._logged_messages:
+            self._logged_messages.add(key)
+            self.warning(msg, *args, stacklevel=2)

This approach still has unbounded growth but makes memory usage more predictable and avoids the lru_cache decorator warning. Alternatively, consider using functools.cache (Python 3.9+) with periodic cache clearing if appropriate for your use case.

flashinfer/fused_moe/core.py (2)

913-937: Prefer explicit exception over assert for layout/shuffle guard

Asserts can be stripped with python -O. Raise a clear ValueError instead to enforce at runtime.

-            if (
-                not self.use_shuffled_weight
-                or self.weight_layout != WeightLayout.MajorK
-            ):
-                assert (
-                    self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3
-                ), (
-                    "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale"
-                )
+            if (not self.use_shuffled_weight or self.weight_layout != WeightLayout.MajorK) and not (
+                self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3
+            ):
+                raise ValueError(
+                    "Non-shuffled weights or non-MajorK layout are only supported for FP8 block-scale "
+                    "(use_deepseek_fp8=True, dtype_weights=E4m3)."
+                )

1017-1116: FP8 block-scale forward ignores provided hidden_states_scale

Forward always builds a synthetic scale tensor and never uses the passed hidden_states_scale. If scale layout/dtype differs at runtime, profiling may diverge. Consider preferring the provided tensor when available.

-                    current_hidden_states_scale = torch.full(
-                        (current_hidden_size // 128, current_num_tokens),
-                        2.0,
-                        dtype=torch.float,
-                        device=hidden_states.device,
-                    )
+                    current_hidden_states_scale = (
+                        hidden_states_scale
+                        if 'hidden_states_scale' in locals() and hidden_states_scale is not None
+                        else torch.full(
+                            (current_hidden_size // 128, current_num_tokens),
+                            2.0,
+                            dtype=torch.float32,
+                            device=hidden_states.device,
+                        )
+                    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 64298f3 and c4be849.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/core.py (20 hunks)
  • flashinfer/jit/core.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (9)
  • GatedActType (141-158)
  • top_k (270-270)
  • intermediate_size (275-275)
  • hidden_size (265-265)
  • local_num_experts (277-277)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
  • trtllm_fp4_block_scale_moe (1176-1267)
  • trtllm_fp4_block_scale_moe (1176-1189)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
flashinfer/autotuner.py (3)
  • AutoTuner (335-780)
  • get (362-365)
  • choose_one (400-525)
flashinfer/jit/core.py (1)
  • warning_once (78-83)
🪛 Ruff (0.14.1)
flashinfer/jit/core.py

85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/fused_moe/core.py (3)

952-965: Good: cache-key now includes gating and layout knobs

Adding gated_act_type, use_shuffled_weight, and weight_layout to instance_key avoids tactic cache collisions across configurations.


1168-1227: Autotune path for FP8 per-tensor looks correct; confirm output dtype assumption

Integration with AutoTuner and passing tactic pair is sound. Output is allocated as bfloat16; please confirm the kernel always writes bf16 for this op across supported dtypes.

Also applies to: 1250-1251


1517-1519: FP4 MoERunner defaults (MajorK + shuffled) acknowledged

Defaults align with supported FP4 path; gated_act_type flows through.

Comment on lines +1307 to +1377
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
if enable_pdl is None:
enable_pdl = device_support_pdl(hidden_states.device)

# Use AutoTuner to select the best tactic - follow FP4 pattern exactly
tuner = AutoTuner.get()
MoERunner.refine_tuning_config(tune_max_num_tokens)

num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[-1]

# Create workspace buffers
output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
)
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
)
expert_weights = torch.empty(
num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device
)

dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation
dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights

moe_runner = MoERunner(
top_k=top_k,
num_local_experts=local_num_experts,
dtype_act=dtype_act,
dtype_weights=dtype_weights,
use_deepseek_fp8=True, # block_scale mode
hidden_size=hidden_size,
intermediate_size=intermediate_size,
weight_layout=weight_layout,
use_shuffled_weight=use_shuffled_weight,
)

inputs = [
output,
routing_logits,
topk_ids,
expert_weights,
hidden_states,
hidden_states_scale,
]

_, tactic = tuner.choose_one(
"flashinfer::trtllm_fp8_block_scale_moe",
[moe_runner],
MoERunner.tuning_config_with_hidden_states_scales, # FP8 block-scale uses hidden_states_scale
inputs,
routing_bias=routing_bias,
gemm1_weights=gemm1_weights,
gemm1_weights_scale=gemm1_weights_scale,
gemm2_weights=gemm2_weights,
gemm2_weights_scale=gemm2_weights_scale,
num_experts=num_experts,
n_group=n_group,
topk_group=topk_group,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type,
use_shuffled_weight=use_shuffled_weight,
weight_layout=weight_layout,
enable_pdl=enable_pdl,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: re-allocating output discards caller-provided buffer

trtllm_fp8_block_scale_moe_op takes output: torch.Tensor but immediately overwrites it with a new tensor. This breaks API expectations for in-place/out-arg usage.

Apply this fix to validate and use the provided buffer instead of reallocating:

-        # Create workspace buffers
-        output = torch.empty(
-            num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
-        )
+        # Validate provided output buffer
+        check_shape_dtype_device(
+            output, (num_tokens, hidden_size), torch.bfloat16, hidden_states.device, "output"
+        )

Also applies to: 1398-1399

🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 1307-1374 (and also check lines
~1398-1399), the function currently unconditionally re-allocates the `output`
tensor which discards a caller-provided buffer; instead validate and reuse the
provided `output` buffer: if `output` is None then allocate as before, otherwise
check device, dtype, shape (num_tokens x hidden_size) and raise a clear error if
mismatched; do the same for any other caller-provided buffers noted in the
comment (e.g., topk_ids/expert_weights if applicable), and remove the
unconditional reassignment so the op uses the validated buffer. Ensure
consistency for the FP8 dtype and device when performing the checks.

auto const hidden_size = hidden_states.size(1);
bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8

std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TileN=128 config exists

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37422250: 1/17 passed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

62-75: [Duplicate] Division by zero and overflow risks remain in computeSelectedTileN.

Although marked as addressed in a previous review, the code still exhibits the same patterns:

  1. Division by zero: Line 65 divides by num_local_experts without ensuring it's non-zero. The validation at lines 149-150 only checks local_num_experts + local_expert_offset <= num_experts, which permits local_num_experts == 0.

  2. Integer overflow: The multiplication num_tokens * top_k occurs as int64_t arithmetic before the static_cast<float>, risking overflow for large values.

Recommended fixes:

  • Add TVM_FFI_ICHECK_GT(local_num_experts, 0) after line 149
  • Reorder to cast before multiplication: static_cast<float>(num_tokens) * top_k / num_local_experts
🧹 Nitpick comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)

356-366: Code duplication: dtype conversion logic repeated.

The dtype-to-btg::Dtype mapping at lines 356-366 duplicates the same logic already present in the launcher function (lines 156-165). Consider extracting this into a helper function to reduce duplication and ensure consistency.

Apply this pattern:

// Add at file scope
static btg::Dtype convertDLDataTypeToBtgDtype(DLDataType dtype) {
  if (dtype == dl_float16) {
    return btg::Dtype::Fp16;
  } else if (dtype == dl_bfloat16) {
    return btg::Dtype::Bfloat16;
  } else if (dtype == dl_float8_e4m3fn) {
    return btg::Dtype::E4m3;
  } else {
    TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE.";
  }
}

Then replace both occurrences with:

btg::Dtype mDtypeElt = convertDLDataTypeToBtgDtype(dtype);

372-382: Consider caching runners to avoid repeated construction.

On each call, this code builds a std::unique_ptr<RunnerType> for every selected tile size, but only one runner (determined by config_index[0]) is actually used. For non-autotuning inference workloads, this overhead could be eliminated by caching runners.

However, if this function is primarily used in autotuning contexts where different configurations are benchmarked across calls, the current approach is acceptable.


1282-1289: Consider if selecting the first tile size is always appropriate for default config.

The function builds a runner using *selected_tile_nums.begin() (the smallest selected tile size) to get the default config. For some workloads, a larger tile size might yield better performance.

However, choosing the smallest tile size as default is a reasonable conservative choice that should work across workloads. If this is intentional, consider adding a comment explaining the rationale.

+ // Use the smallest selected tile size for conservative default config
  std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner> moe_runner =
      std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
          dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(),
          static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c4be849 and 0e63887.

📒 Files selected for processing (1)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (3)
include/flashinfer/trtllm/fused_moe/runner.h (4)
  • top_k (270-270)
  • maybeGetMinTokenCount (55-60)
  • mUseDeepSeekFp8 (285-342)
  • do_finalize (295-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (88-91)
include/flashinfer/trtllm/fused_moe/DevKernel.h (3)
  • mUseDeepSeekFp8 (213-226)
  • mUseDeepSeekFp8 (333-344)
  • mUseDeepSeekFp8 (400-418)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)

201-206: LGTM! Per-tile padding metrics properly computed.

The introduction of separate max_num_padded_tokens_gemm1 and max_num_padded_tokens_gemm2 correctly applies maybeGetMinTokenCount to each GEMM stage with appropriate dimensions. This enables accurate per-tile workspace sizing.


224-233: LGTM! Workspace allocations use correct per-tile metrics.

The workspace tensors for GEMM1 outputs and activation now use max_num_padded_tokens_gemm1, while GEMM2 output uses max_num_padded_tokens_gemm2. This aligns with the per-tile padding strategy.


304-305: LGTM! Correct aggregation of per-tile metrics.

Using std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2) correctly sizes the workspace to accommodate both GEMM stages.


1227-1240: LGTM! Conditional tile size support correctly handles dtype variations.

The logic properly adds tile size 128 for FP4 activation types excluding Bfloat16, aligning with the similar logic in trtllm_get_default_moe_configs and trtllm_get_valid_moe_configs. The runner construction correctly passes all required parameters including gated_act_type.


1291-1336: LGTM! Comprehensive config enumeration for autotuning.

This function correctly enumerates all valid (tile_N, config) pairs by:

  1. Computing appropriate tile sizes based on workload and dtype
  2. Building the correct runner type (FP8 block-scale vs FP4) based on dtype combinations
  3. Gathering valid configs for each tile size
  4. Returning flattened pairs for the autotuner to benchmark

The conditional runner construction at lines 1317-1328 properly handles the different constructor signatures for FP8 vs FP4 modes.


721-731: No changes required—tile size exclusion is intentional.

The codebase explicitly differentiates tile size support by quantization mode. FP8 per-tensor quantization (both weights and activation E4m3, not DeepSeek FP8) includes tile size 128, while FP4 conditionally includes 128 based on activation type, and the logic at lines 1268–1278 shows FP8 block-scale (DeepSeek FP8) does not include 128. This design choice is working as intended—no documentation or comment addition needed.

@jiahanc jiahanc force-pushed the trtllm_gen_moe_autotune branch from ed7ecaa to 143c296 Compare October 28, 2025 20:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

62-75: ** Guard against division by zero and integer overflow.**

As flagged in a previous review, line 65 still lacks protection against num_local_experts == 0 and potential integer overflow in num_tokens * top_k. The validation at lines 149-150 does not guarantee local_num_experts > 0.

Apply the previously suggested fix:

+  // Add guard after line 149 in the calling function
+  TVM_FFI_ICHECK_GT(local_num_experts, 0)
+      << "local_num_experts must be greater than zero";
+
 std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums,
                                        int64_t const num_tokens, int64_t const top_k,
                                        int64_t const num_local_experts) {
-  float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
+  float const avg_tokens_per_expert = static_cast<float>(num_tokens) * top_k / num_local_experts;
flashinfer/fused_moe/core.py (1)

1320-1322: ** Bug: Unconditional output reallocation discards caller-provided buffer.**

The function accepts output: torch.Tensor as a parameter (line 1294) but immediately overwrites it with a new allocation at lines 1320-1322, breaking the API contract for in-place operations.

Apply this fix to use the provided buffer:

-        # Create workspace buffers
-        output = torch.empty(
-            num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
-        )
+        # Validate provided output buffer
+        check_shape_dtype_device(
+            output, (num_tokens, hidden_size), torch.bfloat16, hidden_states.device, "output"
+        )
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

372-374: Consider making mSupportedTileN a static constant.

The vector mSupportedTileN is initialized identically on every call. Making it static const would avoid repeated allocations and improve performance.

Apply this diff:

-    std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
+    static const std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e63887 and 143c296.

📒 Files selected for processing (6)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (12 hunks)
  • flashinfer/fused_moe/core.py (20 hunks)
  • flashinfer/jit/core.py (2 hunks)
  • tests/conftest.py (1 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (2)
  • top_k (270-270)
  • maybeGetMinTokenCount (55-60)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (88-91)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (8)
  • GatedActType (141-158)
  • top_k (270-270)
  • intermediate_size (275-275)
  • local_num_experts (277-277)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
  • trtllm_fp4_block_scale_moe (1168-1259)
  • trtllm_fp4_block_scale_moe (1168-1181)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
flashinfer/autotuner.py (3)
  • AutoTuner (335-780)
  • get (362-365)
  • choose_one (400-525)
flashinfer/jit/core.py (1)
  • warning_once (78-83)
tests/moe/test_trtllm_gen_fused_moe.py (4)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
flashinfer/fused_moe/core.py (5)
  • trtllm_fp8_block_scale_moe (1736-1815)
  • trtllm_fp8_per_tensor_scale_moe (1658-1733)
  • RoutingMethodType (57-71)
  • WeightLayout (160-167)
  • GatedActType (172-176)
🪛 Ruff (0.14.2)
flashinfer/jit/core.py

85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)

🔇 Additional comments (17)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

201-206: LGTM! Proper workspace size calculations.

The per-tile max padded token calculations correctly account for minimum token requirements using maybeGetMinTokenCount, ensuring workspace buffers are adequately sized for both GEMM operations.


1261-1336: LGTM! Config discovery functions properly support per-tile autotuning.

The updated signatures for trtllm_get_default_moe_configs and trtllm_get_valid_moe_configs correctly integrate with the new autotuning framework by returning tile_N and config pairs, enabling dynamic configuration selection at runtime.

flashinfer/fused_moe/core.py (5)

904-936: LGTM! Well-designed parameter expansion for autotuning.

The new gated_act_type, use_shuffled_weight, and weight_layout parameters properly extend the MoERunner to support diverse MoE configurations. The validation at lines 928-936 correctly enforces that non-default shuffle/layout options require FP8 block-scale mode.


1017-1116: LGTM! Clear routing logic for different MoE variants.

The forward method properly dispatches to FP8 block-scale, FP8 per-tensor, or FP4 operations based on dtype combinations. The tactic encoding ([-1, -1] for default) aligns with the backend's config index expectations.


1168-1227: LGTM! Proper autotuning integration for FP8 per-tensor MoE.

The new operation correctly:

  • Creates necessary workspace buffers
  • Instantiates a MoERunner with appropriate FP8 per-tensor settings
  • Uses AutoTuner to select the best tactic
  • Returns the computed output tensor

1508-1519: LGTM! FP4 runner correctly uses fixed layout settings.

The hardcoded weight_layout=WeightLayout.MajorK and use_shuffled_weight=True at lines 1517-1518 are appropriate since FP4 operations require these specific settings, as enforced by the C++ validation logic in the kernel launcher.


1707-1712: LGTM! Appropriate deprecation warning for tile_tokens_dim.

The deprecation warning properly alerts users about the planned removal of tile_tokens_dim in v0.5.0, giving them time to adapt their code. Using warning_once prevents log spam.

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

14-14: LGTM! Benchmark correctly updated for autotuning API.

The removal of calculate_tile_tokens_dim import and passing None for tile_tokens_dim properly aligns with the new autotuning-driven MoE execution path where tile dimensions are selected dynamically.

Also applies to: 136-136

flashinfer/jit/core.py (1)

64-83: LGTM! Well-designed once-only logging methods.

The new debug_once, info_once, and warning_once methods provide a clean API for deduplication of log messages. The stacklevel=3 parameter correctly shows the original caller's location in logs.

tests/conftest.py (1)

140-168: LGTM! Improved test hook follows pytest best practices.

The migration from @pytest.hookimpl(tryfirst=True) to @pytest.hookimpl(wrapper=True) with yield-based delegation is the recommended pattern. The error handling correctly manages both OOM and MissingJITCacheError scenarios while preserving test skip behavior.

tests/moe/test_trtllm_gen_fused_moe.py (7)

17-17: LGTM: Import changes align with autotuning approach.

The removal of calculate_tile_tokens_dim import is correct since tile_tokens_dim is now autotuned rather than statically calculated.

Also applies to: 48-48


205-205: LGTM: Correct tile_tokens_dim autotuning in FP4 CUDA graph path.

Passing tile_tokens_dim=None allows the autotuner to select optimal values during graph capture (wrapped in autotune(True) at line 108).


772-796: LGTM: Proper autotuning integration for FP8 block scale MoE.

The autotune(True) context manager correctly enables autotuning, and passing tile_tokens_dim=None allows the autotuner to select optimal tile sizes based on runtime characteristics.


945-972: LGTM: Consistent autotuning for FP8 per-tensor MoE.

The autotuning approach matches the FP8 block scale implementation and correctly delegates tile size selection to the autotuner.


1837-1897: LGTM: Good refactor to centralize skip logic.

The new skip_checks helper function consolidates test skip conditions and adds compatible_intermediate_size filtering to reduce test execution time while maintaining coverage of valid configurations.


1899-2086: LGTM: Excellent test refactoring for maintainability.

The run_moe_test helper function effectively consolidates the common test workflow (quantization → reference computation → actual computation → validation), eliminating duplication across test functions and improving maintainability.


2301-2401: Clarify intent for commented-out tests.

The TopK and Llama4 routing tests are commented out. Please clarify:

  • Are these temporarily disabled for testing speed optimization?
  • Should they be re-enabled in a future PR?
  • Or can they be removed entirely if no longer needed?

Consider adding a TODO comment explaining the rationale if these tests should be restored later.

Comment on lines +85 to +89
@functools.lru_cache(maxsize=None)
def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
"""Helper method to log messages only once per unique (msg, args) combination."""
# Note: stacklevel=3 to show the caller's location, not this helper method
log_method(msg, *args, stacklevel=3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential memory leak with lru_cache on instance method.

Using @functools.lru_cache on the instance method _print_once will cache self references, which can prevent garbage collection. Since FlashInferJITLogger is a singleton (line 92), this is less critical but still not ideal.

Consider using a module-level cache instead:

+_LOGGER_CACHE: Dict[Tuple[str, Hashable], bool] = {}
+
 class FlashInferJITLogger(logging.Logger):
     # ...
-    @functools.lru_cache(maxsize=None)
-    def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
+    def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
         """Helper method to log messages only once per unique (msg, args) combination."""
-        # Note: stacklevel=3 to show the caller's location, not this helper method
-        log_method(msg, *args, stacklevel=3)
+        cache_key = (msg, args)
+        if cache_key not in _LOGGER_CACHE:
+            _LOGGER_CACHE[cache_key] = True
+            log_method(msg, *args, stacklevel=3)

Based on static analysis hint.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.2)

85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)

1899-2086: Well-structured test helper reduces duplication.

The run_moe_test helper effectively consolidates the common test workflow (quantize weights → quantize inputs → compute reference → compute production → compare). This improves maintainability and reduces code duplication across the parameterized tests.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 143c296 and 20bd60b.

📒 Files selected for processing (1)
  • tests/moe/test_trtllm_gen_fused_moe.py (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_trtllm_gen_fused_moe.py (4)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/fused_moe/core.py (2)
  • trtllm_fp8_block_scale_moe (1736-1815)
  • trtllm_fp8_per_tensor_scale_moe (1658-1733)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/moe/test_trtllm_gen_fused_moe.py (6)

17-48: LGTM! Import changes align with autotuning approach.

The removal of calculate_tile_tokens_dim from imports is consistent with the PR's goal of autotuning tile_tokens_dim instead of using static calculation.


772-796: LGTM! Autotuning correctly integrated for FP8 block scale MoE.

The autotune(True) context manager correctly wraps the kernel invocation, and tile_tokens_dim=None (Line 791) properly delegates tile dimension selection to the autotuner.


945-972: LGTM! Autotuning correctly integrated for FP8 per-tensor MoE.

The autotune(True) context manager correctly wraps the kernel invocation, and tile_tokens_dim=None (Line 970) properly delegates tile dimension selection to the autotuner.


837-1886: Good refactoring to consolidate test logic.

The new skip_checks helper consolidates skip logic and adds validation for compatible_intermediate_size (Lines 1882-1885), which appears to be a new constraint related to autotuning tile dimensions for different intermediate sizes.


2088-2439: Comprehensive test coverage with routing-specific constraints.

The refactored parameterized tests provide good coverage across different routing methods (DeepSeekV3, Renormalize, TopK, Llama4) and quantization modes. The compatible_intermediate_size constraints vary by routing method, which appears intentional for the autotuning feature.

Note: The constraints limit testing to specific intermediate sizes per routing method. Ensure these constraints are documented and align with production use cases.

If the compatible_intermediate_size constraints are temporary limitations during development, consider adding a TODO comment or issue reference to track expanding support.


108-127: No issues found—autotune context pattern is correct as implemented.

The original review comment is based on a misunderstanding of how the autotuner works. The pattern in lines 108–127 is correct:

  1. Warmup (line 108): autotune(True) context enables kernel configuration discovery and caching via AutoTuner.profiling_cache.

  2. Graph capture (lines 119–128): The computation runs without an active autotune context, which is correct. After the warmup context exits, is_tuning_mode reverts to its previous state, but cached configurations persist in profiling_cache and are used during capture.

This pattern matches the established benchmark patterns throughout the codebase (e.g., benchmarks/routines/moe.py lines 726–728), where warmup discovers configs and subsequent execution uses those cached results without re-tuning. Wrapping the capture in autotune(True) would trigger unnecessary re-tuning during capture, which is not the intended design.

The code at lines 172–210 is also correct; tile_tokens_dim=None is properly passed to trtllm_fp4_block_scale_moe.

Likely an incorrect or invalid review comment.

@jiahanc jiahanc force-pushed the trtllm_gen_moe_autotune branch from a54cfef to 4edbda3 Compare October 28, 2025 21:40
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

146-151: Guard local_num_experts > 0.

Add an explicit check before any use that depends on it (e.g., routing and tiling selection).

   TVM_FFI_ICHECK_LE(local_num_experts + local_expert_offset, num_experts)
       << "num_experts must be greater or equal to local_num_experts + local_expert_offset";
+  TVM_FFI_ICHECK_GT(local_num_experts, 0) << "local_num_experts must be > 0";

474-477: Apply the same local_num_experts > 0 guard to FP8 block-scale entry.

Ensure consistent validation across entry points.

   TVM_FFI_ICHECK_GT(num_experts, top_k) << "num_experts must be greater than top_k";
+  TVM_FFI_ICHECK_GT(local_num_experts, 0) << "local_num_experts must be > 0";

Also applies to: 700-708

♻️ Duplicate comments (3)
flashinfer/jit/core.py (1)

64-90: Replace lru_cache on instance method; use a small module-level dedupe set instead.

@functools.lru_cache on a bound method can retain self and unboundedly grow the cache. Also, keying on the log_method object is unnecessary. Use a module-level set keyed by (method_name, msg, args) to dedupe and avoid leaks.

Apply:

@@
-    def warning_once(self, msg: str, *args: Hashable) -> None:
+    def warning_once(self, msg: str, *args: Hashable) -> None:
@@
-    @functools.lru_cache(maxsize=None)
-    def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
-        """Helper method to log messages only once per unique (msg, args) combination."""
-        # Note: stacklevel=3 to show the caller's location, not this helper method
-        log_method(msg, *args, stacklevel=3)
+    def _print_once(self, log_method, msg: str, *args: Hashable) -> None:
+        """Helper method to log messages only once per unique (msg, args) combination."""
+        # Use method name to avoid holding references to bound methods
+        cache_key = (getattr(log_method, "__name__", "log"), msg, args)
+        if cache_key in _LOG_ONCE_KEYS:
+            return
+        _LOG_ONCE_KEYS.add(cache_key)
+        # Note: stacklevel=3 to show the caller's location, not this helper method
+        log_method(msg, *args, stacklevel=3)

Add near imports (top of file):

+from typing import Hashable, Tuple
+_LOG_ONCE_KEYS: set[Tuple[str, str, tuple[Hashable, ...]]] = set()

Based on static analysis hint (B019).

flashinfer/fused_moe/core.py (1)

1319-1328: Don’t reallocate output; validate and use the provided buffer.

Overwriting output breaks out-arg semantics and wastes memory.

-        # Create workspace buffers
-        output = torch.empty(
-            num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
-        )
+        # Validate provided output buffer (shape/dtype/device)
+        check_shape_dtype_device(
+            output, (num_tokens, hidden_size), torch.bfloat16, hidden_states.device, "output"
+        )

Also ensure inputs = [...] continues to pass the validated output (no changes needed).

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

62-75: Fix overflow and div-by-zero in computeSelectedTileN.

num_tokens * top_k can overflow before float cast; num_local_experts can be zero. Cast first and clamp denominator.

-  float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
+  int64_t denom = std::max<int64_t>(1, num_local_experts);
+  // compute in double to avoid overflow/precision loss, then cast
+  double avg = static_cast<double>(num_tokens) * static_cast<double>(top_k)
+               / static_cast<double>(denom);
+  float const avg_tokens_per_expert = static_cast<float>(avg);
🧹 Nitpick comments (5)
flashinfer/fused_moe/core.py (2)

1179-1188: Minor: avoid unused workspace tensors unless required by tuner.

topk_ids and expert_weights are allocated but unused by the C++ call here. If they are only for shaping the tuning profile, consider lazy-allocating them under tuning mode to reduce overhead.


1001-1045: Constructor guard reads clearer with positive condition.

The current assertion triggers when (not shuffled) OR (layout != MajorK). Consider inverting for readability and to avoid accidental future regressions.

-            if (
-                not self.use_shuffled_weight
-                or self.weight_layout != WeightLayout.MajorK
-            ):
-                assert (
-                    self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3
-                ), (
-                    "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale"
-                )
+            if self.use_shuffled_weight and self.weight_layout == WeightLayout.MajorK:
+                pass
+            else:
+                assert self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3, (
+                    "Only FP8 block scale supports (not shuffled) or non-MajorK layouts"
+                )
tests/conftest.py (1)

97-125: Use fully qualified name for compile cache key to avoid collisions.

_TORCH_COMPILE_CACHE[fn.__name__] can collide across modules. You already compute fullname; use it for the cache key.

-    compiled = _TORCH_COMPILE_CACHE.get(fullname)
+    compiled = _TORCH_COMPILE_CACHE.get(fullname)
@@
-            _TORCH_COMPILE_CACHE[fn.__name__] = compiled
+            _TORCH_COMPILE_CACHE[fullname] = compiled
tests/moe/test_trtllm_gen_fused_moe.py (2)

96-115: Avoid autotuning during graph capture to reduce capture-time overhead.

You warm up with autotune before capture (good). During capture, ensure tuning is disabled to keep the graph minimal and deterministic.

-        with torch.cuda.stream(torch_stream):
+        with torch.cuda.stream(torch_stream), autotune(False):
             self.output_tensor = self._run_moe_computation(runtime_args)

1846-1854: Clarify skip message for unsupported architectures.

The message says “only guaranteed to work on SM100 and SM103,” but the condition skips when major != 10. Consider making the message explicit about the detected arch.

-    if compute_capability[0] not in [10]:
-        pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
+    if compute_capability[0] != 10:
+        pytest.skip(f"Requires Blackwell (SM10x). Detected SM{compute_capability[0]}{compute_capability[1]}.")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a54cfef and 4edbda3.

📒 Files selected for processing (6)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (12 hunks)
  • flashinfer/fused_moe/core.py (20 hunks)
  • flashinfer/jit/core.py (2 hunks)
  • tests/conftest.py (1 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (14 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • maybeGetMinTokenCount (55-60)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (88-91)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (9)
  • GatedActType (141-158)
  • top_k (270-270)
  • intermediate_size (275-275)
  • hidden_size (265-265)
  • local_num_experts (277-277)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
  • trtllm_fp4_block_scale_moe (1168-1259)
  • trtllm_fp4_block_scale_moe (1168-1181)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
flashinfer/autotuner.py (3)
  • AutoTuner (335-780)
  • get (362-365)
  • choose_one (400-525)
flashinfer/jit/core.py (1)
  • warning_once (78-83)
tests/moe/test_trtllm_gen_fused_moe.py (5)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
flashinfer/fused_moe/core.py (3)
  • trtllm_fp8_block_scale_moe (1736-1815)
  • trtllm_fp8_per_tensor_scale_moe (1658-1733)
  • RoutingMethodType (57-71)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • num_experts (263-263)
  • top_k (270-270)
  • intermediate_size (275-275)
  • RoutingMethodType (37-136)
  • hidden_size (265-265)
🪛 Ruff (0.14.2)
flashinfer/jit/core.py

85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
Signed-off-by: jiahanc <[email protected]>
@jiahanc jiahanc force-pushed the trtllm_gen_moe_autotune branch from 4edbda3 to 0316720 Compare October 29, 2025 01:56
@jiahanc
Copy link
Collaborator Author

jiahanc commented Oct 29, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !92 has been updated with latest changes, and the CI pipeline #37489603 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/conftest.py (1)

98-121: Fix cache-key mismatch in torch.compile monkeypatch (prevents perpetual recompiles).

You read from _TORCH_COMPILE_CACHE using fullname but write using fn.name. This guarantees a miss every call and re-compilation.

Apply:

-    def wrapper(*args, **kwargs):
-        compiled = _TORCH_COMPILE_CACHE.get(fullname)
+    def wrapper(*args, **kwargs):
+        cache_key = fullname
+        compiled = _TORCH_COMPILE_CACHE.get(cache_key)
         if compiled is None:
             ...
-            _TORCH_COMPILE_CACHE[fn.__name__] = compiled
+            _TORCH_COMPILE_CACHE[cache_key] = compiled
         return compiled(*args, **kwargs)
tests/moe/test_trtllm_gen_fused_moe.py (1)

1259-1282: Fix Tensor truthiness bug in check_accuracy (prevents runtime error).

mismatch_percent is a 0‑dim Tensor; using it in an if raises “bool value of Tensor is ambiguous”.

Apply:

-    count = torch.sum(left > right)
-    mismatch_percent = count / a.numel()
-    if mismatch_percent > 1 - percent:
+    count = (left > right).sum().item()
+    mismatch_percent = count / float(a.numel())
+    if mismatch_percent > (1 - percent):
         print(a)
         print(b)
-        raise Exception(
+        raise Exception(
             f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
             f"(threshold: {1 - percent:.4f})"
         )
♻️ Duplicate comments (2)
flashinfer/jit/core.py (1)

85-89: Memory leak concern with lru_cache already flagged.

Using @functools.lru_cache on the instance method _print_once can prevent garbage collection by caching self references. While the singleton pattern (line 92) mitigates the issue, a module-level cache would be cleaner.

Based on static analysis hint.

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

62-75: Division by zero concern already flagged.

The division at line 65 (/ num_local_experts) and integer overflow from num_tokens * top_k before float cast have been identified in previous reviews.

🧹 Nitpick comments (5)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

372-374: Consider extracting supported tile size logic to reduce duplication.

The supported tile sizes and conditional logic for adding 128 appear in multiple places (lines 372, 721, 1227-1230, 1268-1278, 1300-1310). Consolidating this into a shared helper function would improve maintainability.

flashinfer/fused_moe/core.py (1)

1520-1521: Hardcoded FP4 MoE parameters may limit flexibility.

Lines 1520-1521 hardcode weight_layout=WeightLayout.MajorK and use_shuffled_weight=True for FP4 block scale MoE. Consider whether these should be exposed as parameters to match the flexibility offered in FP8 block scale MoE (lines 1307-1308).

tests/conftest.py (1)

123-125: Use logging instead of print for test infra noise.

Swap print(...) with pytest’s terminalreporter or logging to keep CI output clean. Minor but helps readability.

-    print("Applied torch.compile to", fullname)
+    import logging
+    logging.getLogger(__name__).debug("Applied torch.compile to %s", fullname)
tests/moe/test_trtllm_gen_fused_moe.py (2)

1837-1885: Skip gating looks fine; minor guard suggestion.

Consider early‑returning when CUDA is unavailable to avoid ValueError from get_compute_capability, improving skip messaging in CPU-only environments.

-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is required for these tests.")
+    compute_capability = get_compute_capability(torch.device("cuda"))

2088-2444: Param spaces look heavy; consider trimming for CI stability.

Given recent pipeline failure, you may want to temporarily reduce the Cartesian product (e.g., fewer intermediate_size values) behind an env flag (FAST_CI) to improve pass rate while keeping local/full runs exhaustive.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4edbda3 and 0316720.

📒 Files selected for processing (6)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (2 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (12 hunks)
  • flashinfer/fused_moe/core.py (20 hunks)
  • flashinfer/jit/core.py (2 hunks)
  • tests/conftest.py (1 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (14 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
tests/moe/test_trtllm_gen_fused_moe.py (5)
flashinfer/utils.py (1)
  • get_compute_capability (251-254)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
flashinfer/fused_moe/core.py (3)
  • trtllm_fp8_block_scale_moe (1739-1818)
  • trtllm_fp8_per_tensor_scale_moe (1661-1736)
  • RoutingMethodType (58-72)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • num_experts (263-263)
  • top_k (270-270)
  • intermediate_size (275-275)
  • RoutingMethodType (37-136)
  • hidden_size (265-265)
flashinfer/fused_moe/core.py (5)
include/flashinfer/trtllm/fused_moe/runner.h (8)
  • GatedActType (141-158)
  • top_k (270-270)
  • intermediate_size (275-275)
  • local_num_experts (277-277)
  • num_experts (263-263)
  • n_group (271-271)
  • topk_group (273-273)
  • local_expert_offset (276-276)
csrc/trtllm_fused_moe_kernel_launcher.cu (6)
  • trtllm_fp8_block_scale_moe (700-751)
  • trtllm_fp8_block_scale_moe (700-707)
  • trtllm_fp8_per_tensor_scale_moe (343-403)
  • trtllm_fp8_per_tensor_scale_moe (343-351)
  • trtllm_fp4_block_scale_moe (1168-1259)
  • trtllm_fp4_block_scale_moe (1168-1181)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
flashinfer/autotuner.py (3)
  • AutoTuner (335-784)
  • get (362-365)
  • choose_one (400-529)
flashinfer/jit/core.py (1)
  • warning_once (78-83)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
include/flashinfer/trtllm/fused_moe/runner.h (4)
  • top_k (270-270)
  • maybeGetMinTokenCount (55-60)
  • mUseDeepSeekFp8 (285-342)
  • do_finalize (295-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • dtypeGetNumBits (88-91)
🪛 GitHub Actions: pre-commit
tests/moe/test_trtllm_gen_fused_moe.py

[error] 1-1: pre-commit hook ruff-format reformatted 1 file and exited with code 1. Run 'pre-commit run --all-files' locally to apply formatting fixes.

🪛 Ruff (0.14.2)
flashinfer/jit/core.py

85-85: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
flashinfer/fused_moe/core.py (3)

1183-1191: Output buffer reallocation issue already identified.

The function unconditionally reallocates output even though it's not a parameter here (it's created locally). However, the registered op signature doesn't include output as a parameter, so this allocation is correct. Previous concerns about output reallocation apply to other functions where output is a parameter.


1710-1715: Good deprecation messaging.

The deprecation warning for tile_tokens_dim is clear and uses the appropriate logger.warning_once to avoid log spam.


1323-1331: The output parameter reallocation does not break the API contract.

The review comment assumes external callers provide an output buffer, but they don't. The public API function trtllm_fp8_block_scale_moe() (line 1739) has no output parameter. It creates its own buffer internally and passes it to the internal trtllm_fp8_block_scale_moe_op() function (line 1288). The internal _op function's output parameter is only called by this internal wrapper—not by external code. The reallocation inside _op does not violate any external API contract.

Likely an incorrect or invalid review comment.

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

136-136: Correctly aligned with tile_tokens_dim deprecation.

Passing None for tile_tokens_dim is consistent with the autotuning changes and deprecation messaging introduced in the core MoE implementation.

tests/moe/test_trtllm_gen_fused_moe.py (3)

206-210: Comment contradicts behavior; clarify tuple vs tensor.

The comment says “Extract tensor from tuple” but returns output unchanged. Either return output[0] or update the comment. Given FP4Moe.call_moe indexes [0], prefer comment fix.

[suggest_minor_issue]

-        return output  # Extract tensor from tuple
+        return output  # Kernel may return a tuple; caller (FP4Moe.call_moe) extracts [0]

772-796: Good: autotuner usage aligns with tile_tokens_dim deprecation.

Passing None for tile_tokens_dim and using autotune(True) matches the new selection flow.

Please confirm CI has the autotuner enabled logs (e.g., “[Autotuner]: Autotuning process starts ...”) to ensure the context manager is effective.


2059-2074: Nice: enable_pdl propagated to FP8 block-scale path.

This matches the new kernel signature; FP8 per‑tensor path keeps default None which is acceptable until needed.

tests/conftest.py (1)

140-167: No changes required—current exception handling is correct.

MissingJITCacheError is a subclass of RuntimeError, so it will be caught by the existing except (torch.cuda.OutOfMemoryError, RuntimeError) clause. The isinstance check that follows correctly identifies and handles it. Adding it to the exception tuple would be redundant.

Likely an incorrect or invalid review comment.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 enabled auto-merge (squash) October 29, 2025 04:25
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37489603: 13/17 passed

@yzh119 yzh119 disabled auto-merge October 29, 2025 06:47
@yzh119 yzh119 merged commit bb6b620 into flashinfer-ai:main Oct 29, 2025
4 checks passed
@aleozlx aleozlx mentioned this pull request Nov 11, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants