Skip to content

Conversation

@jiahanc
Copy link
Collaborator

@jiahanc jiahanc commented Oct 20, 2025

📌 Description

co-work with @IwakuraRein

  • update the trtllm-gen fused moe headers
  • add new kernels for trtllm-gen fused moe
    • for NvFp4, add tile 256
    • for MxFp8 x MxFp4, add 128, 256
    • for FP8 per-tensor, add 192, 256
    • for FP8 block scale, add 128
  • update the logics of computeSelectedTileN
  • add tune_max_num_tokens to FP8 per-tensor and FP8 block scale
  • rename TLLM_GEN_BMM_CUBIN_PATH to TLLM_GEN_GEMM_CUBIN_PATH
  • add TLLM_GEN_EXPORT_FLASHINFER

NOTE: split-k kernels are temporarily disabled as they cause failure in renormalize + expert 256 tests.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Expanded MoE tiling (adds 128/192/256), FP8 per‑tensor MoE path, FP8/FP4 autotuner benchmark, and new tune_max_num_tokens tuning parameter.
  • Improvements

    • Router now supports tile‑based (non‑power‑of‑two) layouts and propagates explicit valid M/N/K for safer sizing; autotuner logs include exception details; added export/compile flags and clearer kernel error messages.
  • Bug Fixes

    • Relaxed strict padding/power‑of‑two checks and made log2 handling safer.
  • Tests

    • Extended MoE tests to cover new FP8 block‑scale and routing scenarios.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Walkthrough

Add tile-based (non-power-of-two) tiling support to fused MoE routing via a compile-time isPow2 switch and mTileTokensDim propagation; expand MOE tile-size sets and autotuning; thread valid M/N/K through BatchedGemm/GEMM interfaces and runners; adjust build flags, artifact paths, and logging.

Changes

Cohort / File(s) Summary
Fused MOE routing cores
csrc/trtllm_fused_moe_routing_deepseek.cu, csrc/trtllm_fused_moe_routing_llama4.cu, csrc/trtllm_fused_moe_routing_renormalize.cu
Replace log2-based tiling math with conditional KernelParams::isPow2 paths for CTA counts, mnLimits, offsets, and permuted-index sizing; remove a runtime paddingLog2 constraint; introduce local variables for conditional results.
Routing headers & device helpers
include/flashinfer/trtllm/fused_moe/RoutingKernel.h, .../RoutingKernel.cuh, include/flashinfer/trtllm/fused_moe/DevKernel.h
Add isPow2_ template parameter to KernelParams, add mTileTokensDim and change padding default; add device/host helpers mulTileN/divUpTileN/divUpMulTileN; add LAUNCH_TILEN macro and switch routing launches to it.
MOE launcher & runner
csrc/trtllm_fused_moe_kernel_launcher.cu, csrc/trtllm_fused_moe_runner.cu
Make tile selection neighborhood-based; expand supported tile_N (128/192/256 and conditional 256 for specific dtype combos); propagate mTileTokensDim into routingData; change computeLog2 to return -1 for non-pow2 instead of asserting.
Batched GEMM runner & interface changes
csrc/trtllm_batched_gemm_runner.cu, include/flashinfer/trtllm/batched_gemm/.../BatchedGemmInterface.h, .../KernelParams.h
Thread new valid-dimension fields (mValidM/mValidN/mValidK) into config validation, workspace sizing, and TMA shape/stride builders; add constructor flags for cubin export and adjust run signature; makeTmaShapeStrideAbc now accepts size + optional valid dims.
GEMM options, traits & enums
include/flashinfer/trtllm/batched_gemm/.../GemmOptions.h, .../BatchedGemmOptions.h, .../GemmGatedActOptions.h, .../KernelTraits.h, .../BatchedGemmEnums.h, .../TmaDescriptor.h
Add runtime/fusion/valid-dimension parameters and forward declarations (trtllm::gen::CudaRunner/GenCfg); extend constructors/members (mValidM/N/K, fuse/overlap flags, epilogue regs); strengthen validations; add RouteImpl::LdgPlusSts and helper; clarify TMA error messages.
Routing kernel arithmetic & call sites
include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh, include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Add tile-N arithmetic helpers and conditional branching so routingPermutation and index-offset kernels support both pow2 and tile-based arithmetic; propagate isPow2_ template parameter across KernelParams specializations and add mTileTokensDim.
Kernel decl cleanup
include/flashinfer/trtllm/batched_gemm/.../KernelParamsDecl.h
Remove a large explanatory comment block (no behavioral change).
Build flags & artifacts
flashinfer/artifacts.py, flashinfer/jit/fused_moe.py, flashinfer/jit/gemm/core.py
Update artifact path/checksum for TRTLLM GEMM BMM; add -DTLLM_GEN_EXPORT_FLASHINFER compile flag to GEMM and fused-MoE builds; rename BMM→GEMM cubin macro usages.
Python API & autotuner
flashinfer/fused_moe/core.py, flashinfer/autotuner.py, flashinfer/jit/fused_moe.py, flashinfer/jit/gemm/core.py
Make tile_tokens_dim optional (deprecated) and add tune_max_num_tokens param passed through MOE APIs; autotuner logs now include exception details; add export macro to fused MOE compile flags.
Benchmarks & tests
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py, tests/moe/test_trtllm_gen_fused_moe.py
Add FP8 autotuner benchmark path and constants; update FP4/FP8 dispatch; expand tests to include FP8 block-scale variants and routing permutations (DeepSeek/Renormalize) and broader parameter matrices.
Misc edits
flashinfer/artifacts.py, include/flashinfer/trtllm/batched_gemm/.../BatchedGemmInterface.h, .../TmaDescriptor.h, .../GemmGatedActOptions.h
Update artifact constants, adjust error messages, update dumpOptions signatures and public struct fields to include new runtime/export members and conditional dumping flags.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant KernelParams
    participant RoutingKernel
    participant TileHelpers

    Note over Caller,KernelParams: Initialize kernel params (isPow2, mTileTokensDim)
    Caller->>KernelParams: setBaseParams(data)

    alt isPow2 == true
        RoutingKernel->>TileHelpers: divUpLog2(expert_count)
        TileHelpers-->>RoutingKernel: numCta
        RoutingKernel->>TileHelpers: mulLog2(idx)
        TileHelpers-->>RoutingKernel: offset/permutedSize
    else isPow2 == false
        RoutingKernel->>TileHelpers: divUpTileN(expert_count, tileN)
        TileHelpers-->>RoutingKernel: numCta
        RoutingKernel->>TileHelpers: mulTileN(idx, tileN)
        TileHelpers-->>RoutingKernel: offset/permutedSize
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Areas needing extra attention:
    • Template propagation: verify isPow2_ across KernelParamsBase, specializations, and all call sites.
    • Arithmetic parity: ensure log2-based and tile-N helpers produce identical indexing/limits where required.
    • BatchedGemm/GEMM validation and workspace sizing: confirm mValidM/mValidN/mValidK are consistently applied (DeepSeek FP8 constraints, TMA shapes).
    • Build/export flags and cubin loading: validate behavior when TLLM_GEN_EXPORT_FLASHINFER is defined and GEMM cubin macro resolution.

Possibly related PRs

Suggested reviewers

  • aleozlx
  • djmmoss
  • yongwww
  • joker-eph
  • cyx-6
  • nvmbreughe
  • yzh119

Poem

🐇
I hopped through tiles both small and grand,
two paths — pow2 and tiled land.
I threaded valid M, N, K through the night,
tuned tiles and flags until sunrise bright.
Hop, compile, run — the rabbit's delight.

Pre-merge checks and finishing touches

✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: updating trtllm-gen fused MoE routing kernels and adding more kernels with expanded tile support.
Description check ✅ Passed The PR description covers all key changes requested, including kernel additions, tile expansions, parameter additions, macro renames, and includes a note about split-k kernel disablement.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e88417 and 67a5ffb.

📒 Files selected for processing (1)
  • flashinfer/artifacts.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/artifacts.py (1)

92-93: Verify artifact path, hash, and checksum match published artifacts.

The coordinated update of ArtifactPath.TRTLLM_GEN_BMM, MetaInfoHash.TRTLLM_GEN_BMM, and CheckSumHash.TRTLLM_GEN_BMM (lines 92-93, 108-109, 126-127) follows the correct pattern. However, please confirm:

  1. The artifact path "23daeee32b60bde7947ce1ee7a58d4ab701f134b/batched_gemm-0d28130-add42d1" corresponds to the actual published location in the artifactory.
  2. The MetaInfoHash "6cfade1395f9648aba5dcf2c329114619e175c0f238882555178f98c8f5c1968" matches the meta info header file at that location.
  3. The CheckSumHash "46ccf0492e3ed10135c2861a4f4ef9bb45846610f9a9d2ccaf2d5bf01d2006fd" matches the checksums.txt file at that location.

Incorrect values will cause verify_cubin failures (line 237-238) or download failures. Given the CI failures noted in the PR description, ensure these artifact references align with what the published build produced.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jiahanc jiahanc force-pushed the updateTileNCalc branch 2 times, most recently from 9d9ad95 to 7cd156d Compare October 22, 2025 22:21
@IwakuraRein IwakuraRein force-pushed the updateTileNCalc branch 2 times, most recently from f060ab9 to e7ac015 Compare October 31, 2025 18:13
@IwakuraRein IwakuraRein changed the title Revise the calculation related to TileN in routing of MOE TRTLLM backend Update trtllm-gen fused moe routing kernel and add more kernels Nov 1, 2025
@IwakuraRein IwakuraRein marked this pull request as ready for review November 3, 2025 17:50
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/trtllm_batched_gemm_runner.cu (1)

423-450: Make isValidConfigIndex honor the WAR override

getValidConfigIndices now copies a config and patches mValidK/mValidN/mValidM before calling isValidConfig, but isValidConfigIndex still forwards the unmodified config. That means a config reported as valid by getValidConfigIndices (or returned by getDefaultValidConfigIndex) can immediately fail isValidConfigIndex, breaking callers that double-check the chosen index. Please mirror the same WAR here so the validation helpers stay consistent with run.

-  auto const& config = configs[configIndex];
-
-  return bmm.isValidConfig(config, gemmData);
+  auto myConfig = configs[configIndex];
+  myConfig.mOptions.mValidK = k;
+  myConfig.mOptions.mValidN = gemmData.mProblemDimensions.mN;
+  myConfig.mOptions.mValidM = gemmData.mProblemDimensions.mM;
+
+  return bmm.isValidConfig(myConfig, gemmData);
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f9cd034 and 9b13c3b.

⛔ Files ignored due to path filters (2)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h is excluded by !**/gen/**
📒 Files selected for processing (23)
  • csrc/trtllm_batched_gemm_runner.cu (2 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (5 hunks)
  • csrc/trtllm_fused_moe_routing_deepseek.cu (2 hunks)
  • csrc/trtllm_fused_moe_routing_llama4.cu (3 hunks)
  • csrc/trtllm_fused_moe_routing_renormalize.cu (1 hunks)
  • csrc/trtllm_fused_moe_runner.cu (4 hunks)
  • flashinfer/artifacts.py (1 hunks)
  • flashinfer/autotuner.py (1 hunks)
  • flashinfer/fused_moe/core.py (6 hunks)
  • flashinfer/jit/fused_moe.py (1 hunks)
  • flashinfer/jit/gemm/core.py (2 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (2 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (10 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (4 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (25 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (11 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (7 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (6 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h (7 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
🧰 Additional context used
🧬 Code graph analysis (9)
flashinfer/jit/fused_moe.py (1)
flashinfer/artifacts.py (1)
  • ArtifactPath (83-98)
csrc/trtllm_batched_gemm_runner.cu (2)
csrc/trtllm_gemm_runner.cu (8)
  • m (111-126)
  • m (111-111)
  • m (128-179)
  • m (128-130)
  • m (181-236)
  • m (181-181)
  • m (238-250)
  • m (238-238)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
  • isValidConfig (710-720)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
  • std (229-583)
  • std (239-244)
  • std (279-300)
  • std (290-296)
include/flashinfer/trtllm/fmha/kernelParams.h (4)
  • std (215-230)
  • std (273-278)
  • std (361-366)
  • std (398-402)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (1)
  • getTmemColStridePerGroup (99-103)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (4)
  • trtllm (58-63)
  • gen (59-62)
  • mExecPath (377-435)
  • mInstanceIdx (380-380)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (6)
  • trtllm (82-87)
  • gen (83-86)
  • getShuffleBlockSize (602-608)
  • string (438-440)
  • string (445-447)
  • mInstanceIdx (421-421)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (6)
  • trtllm (48-53)
  • gen (49-52)
  • gemm (147-152)
  • gemmGatedAct (55-191)
  • ActType (62-180)
  • mOptions (213-213)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
  • RouteImpl (28-57)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (2)
csrc/trtllm_batched_gemm_runner.cu (10)
  • run (156-259)
  • run (156-164)
  • run (261-275)
  • run (261-265)
  • run (277-293)
  • run (277-283)
  • run (295-310)
  • run (295-299)
  • getWorkspaceSizeInBytes (129-154)
  • getWorkspaceSizeInBytes (129-131)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
  • std (229-583)
  • std (239-244)
  • std (279-300)
  • std (290-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
  • gemm (32-297)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (3)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h (3)
  • flashinfer (36-38)
  • gemm (42-347)
  • gemm (468-488)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (6)
  • string (133-181)
  • trtllm (38-271)
  • gen (39-269)
  • Dtype (43-268)
  • dtypeIsBlockFmt (96-99)
  • dtypeNumEltsPerSf (198-209)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
  • gemm (30-417)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
flashinfer/jit/gemm/core.py (2)

384-384: LGTM!

The addition of the -DTLLM_GEN_EXPORT_FLASHINFER flag is consistent with the existing preprocessor define pattern and is appropriately placed alongside related TRTLLM flags.


535-535: LGTM!

The flag addition is consistent with the change at Line 384, ensuring both TRTLLM GEMM module variants have the same FlashInfer export configuration.

flashinfer/artifacts.py (1)

91-93: Artifact path update looks good

The refreshed hash keeps the TRTLLM fused artifacts aligned with the latest cubin drop; no further action from my side.

flashinfer/jit/fused_moe.py (1)

235-242: Build flag alignment looks consistent

Adding -DTLLM_GEN_EXPORT_FLASHINFER and switching to TLLM_GEN_GEMM_CUBIN_PATH cleanly mirror the artifact rename; everything lines up with the updated loader flow.

csrc/trtllm_fused_moe_runner.cu (1)

35-177: Nice touch on the routing metadata

Letting computeLog2 fall back to -1 and threading mTileTokensDim into every routing path keeps the pow2 and non-pow2 kernels in sync. Looks solid.

csrc/trtllm_fused_moe_routing_renormalize.cu (1)

168-214: Tile-aware CTA math looks correct

Splitting the CTA/count arithmetic between divUpLog2/mulLog2 and the new divUpTileN/mulTileN helpers is exactly what the non-power-of-two path needs. No issues spotted.

Comment on lines 67 to 82
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());

std::set<int32_t> selected_tile_nums = {
std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim,
std::min(supported_tile_nums.back(), tile_tokens_dim * 2),
std::min(supported_tile_nums.back(), tile_tokens_dim * 4)};
auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);

std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}

Copy link
Contributor

@coderabbitai coderabbitai bot Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against missing tile entry before iterating neighbors

tile_tokens_dim is clamped to the numeric range of supported_tile_nums, but the clamped value is not guaranteed to be present in the container. If the closest supported value is different (e.g., the list is {16, 24, 40, 64} and nextPowerOfTwo returns 32), std::find returns end(). The very next statement calls std::next(it), invoking undefined behaviour and potentially crashing the process. Please snap tile_tokens_dim to an actual supported entry (e.g., via std::lower_bound) before walking neighbours.

-  auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
-
-  std::set<int32_t> selected_tile_nums;
-  selected_tile_nums.insert(tile_tokens_dim);
+  auto it =
+      std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
+  if (it == supported_tile_nums.end()) {
+    it = std::prev(supported_tile_nums.end());
+  }
+  tile_tokens_dim = *it;
+
+  std::set<int32_t> selected_tile_nums;
+  selected_tile_nums.insert(tile_tokens_dim);
   if (std::next(it) != supported_tile_nums.end()) {
     selected_tile_nums.insert(*std::next(it));
     if (std::next(std::next(it)) != supported_tile_nums.end()) {
       selected_tile_nums.insert(*std::next(std::next(it)));
     }
   }
   if (it != supported_tile_nums.begin()) {
     selected_tile_nums.insert(*std::prev(it));
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());
std::set<int32_t> selected_tile_nums = {
std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim,
std::min(supported_tile_nums.back(), tile_tokens_dim * 2),
std::min(supported_tile_nums.back(), tile_tokens_dim * 4)};
auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}
int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
supported_tile_nums.front(), supported_tile_nums.back());
auto it =
std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
if (it == supported_tile_nums.end()) {
it = std::prev(supported_tile_nums.end());
}
tile_tokens_dim = *it;
std::set<int32_t> selected_tile_nums;
selected_tile_nums.insert(tile_tokens_dim);
if (std::next(it) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(it));
if (std::next(std::next(it)) != supported_tile_nums.end()) {
selected_tile_nums.insert(*std::next(std::next(it)));
}
}
if (it != supported_tile_nums.begin()) {
selected_tile_nums.insert(*std::prev(it));
}
🤖 Prompt for AI Agents
In csrc/trtllm_fused_moe_kernel_launcher.cu around lines 67-82, the code uses
std::find on supported_tile_nums with a clamped tile_tokens_dim but that value
may not exist in the container and iterating neighbors from an end() iterator is
undefined; fix by snapping tile_tokens_dim to an actual supported entry via
std::lower_bound: use lower_bound to get the first element >= tile_tokens_dim,
if lower_bound == supported_tile_nums.end() set the iterator to std::prev(end())
(the largest supported value), otherwise keep the lower_bound result (or if you
want the closest neighbor compare the lower_bound and its previous element and
pick the nearer value), then set tile_tokens_dim = *it and proceed to insert *it
and its neighbors into selected_tile_nums while guarding with iterator boundary
checks (only call std::next/std::prev after verifying they are not
end()/begin()).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. We should make sure supported_tile_nums contains consecutive pow of 2 numbers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IwakuraRein, I appreciate you acknowledging the issue! However, I notice that the current implementation already includes non-power-of-2 values in supported_tile_nums (e.g., 192 at line 256), so strictly enforcing "consecutive powers of 2" would require broader changes.

More importantly, even if we constrain supported_tile_nums to powers of 2, the current code still has a vulnerability: nextPowerOfTwo(avg_tokens_per_expert) can produce a power-of-2 value that's clamped to the numeric range but doesn't exist in the container. For example:

  • If supported_tile_nums = {8, 16, 64, 128} (missing 32)
  • And nextPowerOfTwo(...) returns 32
  • std::clamp(32, 8, 128) → 32
  • std::find returns end() → ❌ UB when calling std::next(it)

Recommendation:

I suggest implementing CodeRabbit's defensive fix using std::lower_bound, which gracefully handles any input:

auto it = std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
if (it == supported_tile_nums.end()) {
  it = std::prev(it);  // Use largest supported value
}
tile_tokens_dim = *it;  // Snap to actual supported value

This ensures robustness regardless of the supported tile set. If you want to enforce a design constraint (consecutive powers of 2), consider adding a debug assertion in the callers that construct supported_tile_nums, but the runtime logic should still be defensive.

Would you like me to open an issue to track adding this fix?

Comment on lines 389 to 415
std::stringstream ss;
ss << gemmGatedAct::dumpOptions(options) << ", ";
ss << "mBatchedM={}," << std::endl;
ss << "mBatchedN={}," << std::endl;
ss << gemmGatedAct::dumpOptions(options, dumpRuntimeParams) << ", ";
if (dumpRuntimeParams) {
ss << "mBatchedM={}," << std::endl;
ss << "mBatchedN={}," << std::endl;
}
ss << "mBatchMode=batchedGemm::BatchedGemmOptions::BatchMode("
<< static_cast<int32_t>(options.mBatchMode) << ")," << std::endl;
ss << "mNumBatches=" << options.mNumBatches << "," << std::endl;
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl;
ss << "mIsStaticBatch=" << options.mIsStaticBatch << "," << std::endl;
ss << "mNumTokens=" << options.mNumTokens << "," << std::endl;
if (dumpRuntimeParams) {
ss << "mNumBatches=" << options.mNumBatches << "," << std::endl;
}
ss << "mNumRegsPerThreadLoadB=" << options.mNumRegsPerThreadLoadB << "," << std::endl;
ss << "mNumRegsPerThreadLoadSfB=" << options.mNumRegsPerThreadLoadSfB << "," << std::endl;
if (dumpRuntimeParams) {
ss << "mNumTokens=" << options.mNumTokens << "," << std::endl;
}
ss << "mNumWarpsLoadB=" << options.mNumWarpsLoadB << "," << std::endl;
ss << "mNumWarpsLoadSfB=" << options.mNumWarpsLoadSfB << "," << std::endl;
ss << "mRouteImpl=batchedGemm::RouteImpl(" << static_cast<int32_t>(options.mRouteImpl) << "),"
<< std::endl;
ss << "mRouteSfsImpl={batchedGemm::RouteImpl("
<< static_cast<int32_t>(options.mRouteSfsImpl.value()) << ")}," << std::endl;
ss << "mGridWaitForPrimaryRouting=" << options.mGridWaitForPrimaryRouting << "," << std::endl;
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << ","
<< std::endl;
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << ","
<< std::endl;
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
return ss.str();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Restore batched-dimension values in dumpOptions

dumpOptions now emits mBatchedM={} / mBatchedN={} regardless of the actual contents, which removes the batch-dimension information we rely on when debugging tuning logs. Please keep printing the real vectors.

-    ss << "mBatchedM={}," << std::endl;
-    ss << "mBatchedN={}," << std::endl;
+    ss << "mBatchedM={";
+    for (size_t i = 0; i < options.mBatchedM.size(); ++i) {
+      ss << (i ? ", " : "") << options.mBatchedM[i];
+    }
+    ss << "}," << std::endl;
+    ss << "mBatchedN={";
+    for (size_t i = 0; i < options.mBatchedN.size(); ++i) {
+      ss << (i ? ", " : "") << options.mBatchedN[i];
+    }
+    ss << "}," << std::endl;
🤖 Prompt for AI Agents
In
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
around lines 389 to 415, the dumpOptions implementation emits literal
placeholders "mBatchedM={}" and "mBatchedN={}" when dumpRuntimeParams is true,
removing the actual batched-dimension data; replace those placeholders with the
real contents of options.mBatchedM and options.mBatchedN so the vectors are
printed (format them the same way other vector fields are emitted, including
separators/newlines and trailing commas) — e.g. serialize each vector into the
stringstream instead of hard-coded braces, preserving the existing comma and
std::endl layout.

Comment on lines +48 to +54
namespace trtllm {
namespace gen {
class CudaRunner;
class GenCfg;
} // namespace gen
} // namespace trtllm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove the nested trtllm namespace block.

By adding namespace trtllm { … } inside namespace batchedGemm, every unqualified trtllm::gen in this file now resolves to batchedGemm::trtllm::gen, not ::trtllm::gen. As a result, the alias namespace tg = trtllm::gen; and members like trtllm::gen::CudaRunner* point at the new nested namespace, which only has the forward declarations and none of the real definitions (dtypeGetNumBits, Dtype, etc.). This breaks compilation immediately. Please drop this block (or move the forward declarations out to global scope / prefix uses with ::). One minimal fix is:

-namespace trtllm {
-namespace gen {
-class CudaRunner;
-class GenCfg;
-}  // namespace gen
-}  // namespace trtllm

That restores lookup to the existing ::trtllm::gen symbols provided via GemmOptions.h.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
namespace trtllm {
namespace gen {
class CudaRunner;
class GenCfg;
} // namespace gen
} // namespace trtllm
🤖 Prompt for AI Agents
In
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
around lines 48-54, the nested "namespace trtllm { namespace gen { ... } }"
block creates a batchedGemm::trtllm::gen shadow that breaks lookup of the real
::trtllm::gen symbols; remove that nested namespace block (or move the forward
declarations to the global scope or use ::trtllm::gen prefixes) so all
unqualified trtllm::gen references resolve to the existing ::trtllm::gen
definitions provided by GemmOptions.h.

Comment on lines +116 to +123
#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \
if (data.mPaddingLog2 > 0) { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, \
stream); \
} else { \
LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, \
smemSize, stream); \
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep mPaddingLog2 == 0 on the pow2 path

The new LAUNCH_TILEN macro only treats mPaddingLog2 > 0 as a power-of-two configuration. Existing callers legitimately pass mPaddingLog2 == 0 (tile size 1) and rely on the previous pow2 code path. After this change those cases fall into the “tile-N” branch (KernelParams::isPow2 == false), but the host code never populates mTileTokensDim for pow2 flows, so the CUDA kernels dereference it as 0 and hit divUpTileN(count, 0) / mulTileN(..., 0) in RoutingKernel.cuh, yielding a divide-by-zero/device fault at runtime. Please keep zero in the pow2 branch (or initialise mTileTokensDim to 1 before launching).

-  if (data.mPaddingLog2 > 0) {
+  if (data.mPaddingLog2 >= 0) {
🤖 Prompt for AI Agents
In include/flashinfer/trtllm/fused_moe/DevKernel.h around lines 116-123, the
macro treats only mPaddingLog2 > 0 as the pow2 path causing mPaddingLog2 == 0
callers to take the tile-N branch and crash; change the pow2 check to include
zero (e.g., use data.mPaddingLog2 >= 0) so mPaddingLog2 == 0 follows the
pow2/LAUNCH_ESC(..., true) path, or alternatively ensure host code sets
mTileTokensDim = 1 for pow2 launches before invoking the macro.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9b13c3b and 7dc3a59.

📒 Files selected for processing (2)
  • csrc/trtllm_batched_gemm_runner.cu (4 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/trtllm_batched_gemm_runner.cu (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
  • isValidConfig (718-728)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (4)
csrc/trtllm_batched_gemm_runner.cu (10)
  • run (160-261)
  • run (160-168)
  • run (263-277)
  • run (263-267)
  • run (279-295)
  • run (279-285)
  • run (297-312)
  • run (297-301)
  • getWorkspaceSizeInBytes (129-158)
  • getWorkspaceSizeInBytes (129-131)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
  • std (229-583)
  • std (239-244)
  • std (279-300)
  • std (290-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (6)
  • if (1071-1079)
  • string (438-440)
  • string (445-447)
  • loadCubinData (1494-1514)
  • trtllm (82-87)
  • gen (83-86)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
  • trtllm (28-90)
  • gen (29-89)
  • launchKernel (34-84)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
csrc/trtllm_batched_gemm_runner.cu (4)

147-149: LGTM! Valid dimension initialization.

The initialization of valid dimensions to match the full problem dimensions is correct and aligns with the documented default behavior.


246-248: LGTM! Consistent valid dimension initialization.

The valid dimensions are correctly initialized before passing gemmData to the run method.


338-340: LGTM! Valid dimension initialization consistent across all methods.

The initialization pattern is consistent with getWorkspaceSizeInBytes and run methods.


402-402: LGTM! Minor refactoring.

Passing configs[configIndex] directly instead of using an intermediate reference is a valid simplification.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (7)

27-32: LGTM! Conditional compilation for export modes.

The conditional includes properly handle different export targets (FlashInfer vs standard).


76-87: LGTM! New valid dimension fields with clear documentation.

The addition of mValidM/N/K fields with documentation explaining their purpose is well done. Default initialization to 0 is appropriate given these fields must be explicitly set by callers.


459-469: LGTM! New constructor and compile-time method.

The constructor properly accepts cubin export and rotation parameters, and the generateAndCompileKernel method is appropriately guarded for non-export builds.


475-597: LGTM! Enhanced run method with proper implementation.

The rewritten run method includes several improvements:

  • Const reference parameter instead of copy (better performance)
  • Complete implementation with module caching using context-aware keys
  • Proper error handling and cleanup (unloading modules when no cache provided)
  • Conditional compilation support for different build modes

696-713: LGTM! Proper propagation of valid dimensions.

The method correctly propagates the valid dimension fields from BatchedGemmData to BatchedGemmOptions, ensuring they're available for validation.


718-728: LGTM! Config validation implementation.

The isValidConfig method properly validates configurations by extracting options and checking them without modification (updateOptions=false).


800-802: LGTM! Private member variables.

The mExportsCubin and mNumRotations members properly store the constructor parameters for later use.

Comment on lines +745 to 787
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config,
BatchedGemmData const& data) const {
std::vector<size_t> workspaceSizes;

// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, data);

if (options.mUseDeepSeekFp8 && options.mFusedAct) {
int32_t totalNumPaddedTokens = 0;
auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
if (!options.mEnablesEarlyExit || options.mNumTokens == 0) {
for (int32_t bi = 0; bi < options.mNumBatches; ++bi) {
totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM)
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
}
} else {
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM : options.mTileN;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}
} else {
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM : options.mTileN;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}

// Get options from config.
auto& options = config.mOptions;

int const tokenTile = batchM ? options.mTileM : options.mTileN;
// Get options from config.
auto& options = config.mOptions;

auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
auto const intermediateTile = batchM ? options.mTileN : options.mTileM;
int const tokenTile = batchM ? options.mTileM : options.mTileN;

auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float);
auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
auto const intermediateTile = batchM ? options.mTileN : options.mTileM;

auto const numTilesToken = numTokens / tokenTile;
auto const numTilesInt = intermediateDim / intermediateTile;
auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t);

// TODO: do we need to pad to 1024?
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024));
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024));
}
auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float);

return workspaceSizes;
}
auto const numTilesToken = numTokens / tokenTile;
auto const numTilesInt = intermediateDim / intermediateTile;
auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t);

////////////////////////////////////////////////////////////////////////////////////////////////////
int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
BatchedGemmData const& batchedGemmData, void* cudaStream,
int32_t /* multiProcessorCount */, bool usePdl,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache) {
// Might be used.
(void)usePdl;
(void)moduleCache;
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, batchedGemmData);

bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 &&
options.mDtypeB == tg::Dtype::E4m3;

auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData);
float* dPtrRowMax{nullptr};
uint32_t* dPtrRowMaxBars{nullptr};

// Set the completion barriers to 0 if needed.
if (useDeepSeekFp8 && options.mFusedAct) {
dPtrRowMax = reinterpret_cast<float*>(alignPtr(reinterpret_cast<char*>(workspace), 1024));
dPtrRowMaxBars = reinterpret_cast<uint32_t*>(
alignPtr(reinterpret_cast<char*>(dPtrRowMax) + workspaceSizes[0], 1024));
auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1],
reinterpret_cast<cudaStream_t>(cudaStream));
if (err != cudaSuccess) {
return 1;
// TODO: do we need to pad to 1024?
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024));
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024));
}
}

auto [numCtaBatch, numCtaTile, numCtaInner] =
getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim);
auto kernelParams = KernelParamsSetup::setKernelParams(
options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB,
batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA,
batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA,
batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias,
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit,
batchedGemmData.mInputBuffers.mPtrGatedActAlpha,
batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap,
dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch);

// The size of the grid.
std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner}
: std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner};

#ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule;
CUfunction cuFunction;

auto fiModuleLoadData = [&](CUmodule* module) {
const std::string sha256 = config.mHash ? config.mHash : "";
std::string fname_cubin = config.mFunctionName;
if (!fname_cubin.empty()) {
fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0])));
}
fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin";
std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256);
cuModuleLoadData(&cuModule, cubin.c_str());
};

if (moduleCache.has_value()) {
ModuleCache& moduleCacheRef = moduleCache.value().get();

// Modules are associated with a specific context, so the context is included in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);

// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a
// string in decimal representation.
std::string const ctxName =
std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);

// Use cache if module is found, otherwise load and insert into cache
if (module != moduleCacheRef.end()) {
cuFunction = std::get<1>(module->second);
} else {
fiModuleLoadData(&cuModule);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
} else {
fiModuleLoadData(&cuModule);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
return workspaceSizes;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Variable shadowing issue in workspace size calculation.

Line 767 creates a new reference auto& options = config.mOptions that shadows the options variable created on line 750 via getOptionsFromConfigAndData(). This means:

  • Lines 752-764 use the complete options (with M/N/K from data)
  • Lines 767-779 use only config.mOptions (without runtime dimensions)

This shadowing likely loses the populated mM, mN, mK, mValidM, mValidN, mValidK, and other runtime fields from the data, which could lead to incorrect workspace size calculations.

Apply this diff to remove the shadowing:

-    // Get options from config.
-    auto& options = config.mOptions;
+    // Use the same options variable created earlier (no shadowing)

Then use the existing options variable for the calculations that follow.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config,
BatchedGemmData const& data) const {
std::vector<size_t> workspaceSizes;
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, data);
if (options.mUseDeepSeekFp8 && options.mFusedAct) {
int32_t totalNumPaddedTokens = 0;
auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
if (!options.mEnablesEarlyExit || options.mNumTokens == 0) {
for (int32_t bi = 0; bi < options.mNumBatches; ++bi) {
totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM)
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
}
} else {
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM : options.mTileN;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}
} else {
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM : options.mTileN;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}
// Get options from config.
auto& options = config.mOptions;
int const tokenTile = batchM ? options.mTileM : options.mTileN;
// Get options from config.
auto& options = config.mOptions;
auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
auto const intermediateTile = batchM ? options.mTileN : options.mTileM;
int const tokenTile = batchM ? options.mTileM : options.mTileN;
auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float);
auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
auto const intermediateTile = batchM ? options.mTileN : options.mTileM;
auto const numTilesToken = numTokens / tokenTile;
auto const numTilesInt = intermediateDim / intermediateTile;
auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t);
// TODO: do we need to pad to 1024?
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024));
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024));
}
auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float);
return workspaceSizes;
}
auto const numTilesToken = numTokens / tokenTile;
auto const numTilesInt = intermediateDim / intermediateTile;
auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t);
////////////////////////////////////////////////////////////////////////////////////////////////////
int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
BatchedGemmData const& batchedGemmData, void* cudaStream,
int32_t /* multiProcessorCount */, bool usePdl,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache) {
// Might be used.
(void)usePdl;
(void)moduleCache;
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, batchedGemmData);
bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 &&
options.mDtypeB == tg::Dtype::E4m3;
auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData);
float* dPtrRowMax{nullptr};
uint32_t* dPtrRowMaxBars{nullptr};
// Set the completion barriers to 0 if needed.
if (useDeepSeekFp8 && options.mFusedAct) {
dPtrRowMax = reinterpret_cast<float*>(alignPtr(reinterpret_cast<char*>(workspace), 1024));
dPtrRowMaxBars = reinterpret_cast<uint32_t*>(
alignPtr(reinterpret_cast<char*>(dPtrRowMax) + workspaceSizes[0], 1024));
auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1],
reinterpret_cast<cudaStream_t>(cudaStream));
if (err != cudaSuccess) {
return 1;
// TODO: do we need to pad to 1024?
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024));
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024));
}
}
auto [numCtaBatch, numCtaTile, numCtaInner] =
getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim);
auto kernelParams = KernelParamsSetup::setKernelParams(
options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB,
batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA,
batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA,
batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias,
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit,
batchedGemmData.mInputBuffers.mPtrGatedActAlpha,
batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap,
dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch);
// The size of the grid.
std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner}
: std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner};
#ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule;
CUfunction cuFunction;
auto fiModuleLoadData = [&](CUmodule* module) {
const std::string sha256 = config.mHash ? config.mHash : "";
std::string fname_cubin = config.mFunctionName;
if (!fname_cubin.empty()) {
fname_cubin[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0])));
}
fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin";
std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256);
cuModuleLoadData(&cuModule, cubin.c_str());
};
if (moduleCache.has_value()) {
ModuleCache& moduleCacheRef = moduleCache.value().get();
// Modules are associated with a specific context, so the context is included in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a
// string in decimal representation.
std::string const ctxName =
std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);
// Use cache if module is found, otherwise load and insert into cache
if (module != moduleCacheRef.end()) {
cuFunction = std::get<1>(module->second);
} else {
fiModuleLoadData(&cuModule);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
} else {
fiModuleLoadData(&cuModule);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
return workspaceSizes;
}
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config,
BatchedGemmData const& data) const {
std::vector<size_t> workspaceSizes;
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, data);
if (options.mUseDeepSeekFp8 && options.mFusedAct) {
int32_t totalNumPaddedTokens = 0;
auto const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
if (!options.mEnablesEarlyExit || options.mNumTokens == 0) {
for (int32_t bi = 0; bi < options.mNumBatches; ++bi) {
totalNumPaddedTokens += batchM ? gemm::divUpMul(options.mBatchedM[bi], options.mTileM)
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
}
} else {
// Get tile in token dim.
auto tileTokensDim = batchM ? options.mTileM : options.mTileN;
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
}
// Use the same options variable created earlier (no shadowing)
int const tokenTile = batchM ? options.mTileM : options.mTileN;
auto const numTokens = totalNumPaddedTokens;
auto const intermediateDim = batchM ? options.mN : options.mM;
auto const intermediateTile = batchM ? options.mTileN : options.mTileM;
auto const numBytesRowMax = intermediateDim * totalNumPaddedTokens / 128 * sizeof(float);
auto const numTilesToken = numTokens / tokenTile;
auto const numTilesInt = intermediateDim / intermediateTile;
auto const numBytesRowMaxBars = numTilesToken * numTilesInt / 2 * sizeof(uint32_t);
// TODO: do we need to pad to 1024?
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMax, 1024));
workspaceSizes.push_back(getSizePaddedToAlignment(numBytesRowMaxBars, 1024));
}
return workspaceSizes;
}
🤖 Prompt for AI Agents
In
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
around lines 745-787, a new reference `auto& options = config.mOptions` (line
~767) shadows the earlier `options` obtained from getOptionsFromConfigAndData(),
causing loss of runtime-populated fields; remove that shadowing declaration and
reuse the original `options` variable for the subsequent calculations
(tokenTile, intermediateDim, intermediateTile, numBytesRowMax,
numBytesRowMaxBars, and workspace push_backs) so the runtime M/N/K/valid fields
from data are used when computing workspace sizes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

287-297: Restore the CLI entrypoint
Splitting the FP8/FP4 helpers removed the global bench_trtllm_gen_fused_moe_autotuner, so executing the script now raises a NameError. Reintroduce a dispatcher that forwards to the specialised implementations.

+def bench_trtllm_gen_fused_moe_autotuner(
+    tune_max_num_tokens: Optional[int],
+    quant_mode: str,
+    num_tokens: int,
+    num_experts: int,
+    hidden_size: int,
+    intermediate_size: int,
+    top_k: int,
+    warmups: int,
+    iterations: int,
+):
+    if quant_mode == "Fp8-Per-Tensor":
+        return bench_trtllm_gen_fused_moe_autotuner_fp8(
+            tune_max_num_tokens,
+            quant_mode,
+            num_tokens,
+            num_experts,
+            hidden_size,
+            intermediate_size,
+            top_k,
+            warmups,
+            iterations,
+        )
+    return bench_trtllm_gen_fused_moe_autotuner_fp4(
+        tune_max_num_tokens,
+        quant_mode,
+        num_tokens,
+        num_experts,
+        hidden_size,
+        intermediate_size,
+        top_k,
+        warmups,
+        iterations,
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7dc3a59 and 9c2ec07.

📒 Files selected for processing (1)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4)
flashinfer/fused_moe/core.py (3)
  • RoutingMethodType (58-72)
  • trtllm_fp4_block_scale_moe (1827-1961)
  • trtllm_fp8_per_tensor_scale_moe (1661-1739)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

29-29: Unused function argument: quant_mode

(ARG001)


67-89: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)

Comment on lines +21 to +27
def fp8_quantize(x):
max = x.float().abs().nan_to_num().max()
scale = FLOAT8_E4M3_MAX / max
x = (x * scale).to(torch.float8_e4m3fn)
return x, 1.0 / scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard FP8 quantization against all-zero inputs
All-zero inputs make max zero, scale infinite, and the quantized tensor NaN (0 × ∞), which breaks the benchmark when buffers start cleared. Please handle the zero case before inverting the scale.

 def fp8_quantize(x):
-    max = x.float().abs().nan_to_num().max()
-    scale = FLOAT8_E4M3_MAX / max
-    x = (x * scale).to(torch.float8_e4m3fn)
-    return x, 1.0 / scale
+    max_val = x.float().abs().nan_to_num().max()
+    if max_val == 0:
+        return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0
+    scale = FLOAT8_E4M3_MAX / max_val
+    quantized = (x * scale).to(torch.float8_e4m3fn)
+    return quantized, 1.0 / scale
🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around lines 21 to 25, the
fp8_quantize function divides by max which can be zero for all-zero inputs;
guard against that by computing max = x.float().abs().nan_to_num().max(), then
check if max == 0 (or torch.isclose(max, torch.tensor(0., device=max.device)))
before inverting it; if it is zero, return the input cast to torch.float8_e4m3fn
(or an all-zero tensor of the same shape) and a safe inverse scale (e.g. 1.0),
otherwise compute scale = FLOAT8_E4M3_MAX / max and proceed with quantization
and return x.to(torch.float8_e4m3fn) and 1.0/scale.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

21-25: Guard FP8 quantization against all-zero inputs.

All-zero inputs make max zero, scale infinite, and the quantized tensor NaN (0 × ∞), which breaks the benchmark when buffers start cleared. Please handle the zero case before inverting the scale.

Apply this diff to fix the issue:

 def fp8_quantize(x):
-    max = x.float().abs().nan_to_num().max()
-    scale = FLOAT8_E4M3_MAX / max
-    x = (x * scale).to(torch.float8_e4m3fn)
-    return x, 1.0 / scale
+    max_val = x.float().abs().nan_to_num().max()
+    if max_val == 0:
+        return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0
+    scale = FLOAT8_E4M3_MAX / max_val
+    quantized = (x * scale).to(torch.float8_e4m3fn)
+    return quantized, 1.0 / scale
🧹 Nitpick comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

67-89: Consider refactoring lambda to a named function.

The lambda expression is quite long (22 lines). While functional, a named function would improve readability and follow PEP 8 style guidelines.

Apply this diff to refactor:

+    def run_moe():
+        return trtllm_fp8_per_tensor_scale_moe(
+            routing_logits,
+            None,  # routing_bias
+            hidden_states,
+            w13,
+            output1_scale_scalar,
+            output1_scales_gate_scalar,
+            w2,
+            output2_scale_scalar,
+            num_experts,
+            top_k,
+            None,  # n_group
+            None,  # topk_group
+            intermediate_size,
+            0,  # local_expert_offset
+            num_experts,
+            1.0,  # routed_scaling_factor
+            False,  # use_routing_scales_on_input
+            None,
+            RoutingMethodType.TopK.value,
+            enable_pdl,
+            tune_max_num_tokens
+        )
+
-    fn = lambda: trtllm_fp8_per_tensor_scale_moe(
-        routing_logits,
-        None,  # routing_bias
-        hidden_states,
-        w13,
-        output1_scale_scalar,
-        output1_scales_gate_scalar,
-        w2,
-        output2_scale_scalar,
-        num_experts,
-        top_k,
-        None,  # n_group
-        None,  # topk_group
-        intermediate_size,
-        0,  # local_expert_offset
-        num_experts,
-        1.0,  # routed_scaling_factor
-        False,  # use_routing_scales_on_input
-        None,
-        RoutingMethodType.TopK.value,
-        enable_pdl,
-        tune_max_num_tokens
-    )
+    fn = run_moe
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9c2ec07 and 99eb4ec.

📒 Files selected for processing (1)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (5)
flashinfer/fused_moe/core.py (3)
  • GatedActType (173-177)
  • trtllm_fp4_block_scale_moe (1827-1961)
  • trtllm_fp8_per_tensor_scale_moe (1661-1739)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp4_block_scale_moe (1177-1273)
  • trtllm_fp4_block_scale_moe (1177-1190)
  • trtllm_fp8_per_tensor_scale_moe (352-412)
  • trtllm_fp8_per_tensor_scale_moe (352-360)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
🪛 GitHub Actions: pre-commit
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

[error] 4-4: Unused import: torch.nn.functional. The import statement is present but not used in the code.


[error] 9-9: Removed import statement from flashinfer.fused_moe. The old import is replaced with a formatted multi-line import block.

🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

29-29: Unused function argument: quant_mode

(ARG001)


67-89: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (5)

12-12: LGTM!

The new import for trtllm_fp8_per_tensor_scale_moe is properly used in the FP8 benchmark function.


18-19: LGTM!

The quantization constants are properly defined and will be used in the FP8 quantization logic.


237-246: LGTM!

The benchmarking logic correctly runs with and without autotuning and collects timing measurements.


262-262: LGTM!

The CLI correctly extends quantization mode choices to include the new FP8 per-tensor option.


287-310: LGTM!

The routing logic correctly dispatches to the appropriate benchmark function based on the quantization mode, and all parameters are properly forwarded.

import argparse
from typing import Optional, Literal
import torch
import torch.nn.functional as F
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused import.

The import torch.nn.functional as F is not used anywhere in the code.

Apply this diff to remove the unused import:

-import torch.nn.functional as F
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import torch.nn.functional as F
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 4-4: Unused import: torch.nn.functional. The import statement is present but not used in the code.

🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around line 4, the import
"torch.nn.functional as F" is unused; remove that import line from the file to
eliminate the unused dependency and clean up imports.


def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
quant_mode: Literal["Fp8-Per-Tensor"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused parameter.

The quant_mode parameter is not used within the function body. If it's intended for future use or external validation, consider adding a comment explaining its purpose.

Apply this diff if the parameter is not needed:

 def bench_trtllm_gen_fused_moe_autotuner_fp8(
     tune_max_num_tokens: Optional[int],
-    quant_mode: Literal["Fp8-Per-Tensor"],
     num_tokens: int,
     num_experts: int,
     hidden_size: int,
     intermediate_size: int,
     top_k: int,
     warmups: int,
     iterations: int,
 ):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
quant_mode: Literal["Fp8-Per-Tensor"],
def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
num_tokens: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
warmups: int,
iterations: int,
):
🧰 Tools
🪛 Ruff (0.14.3)

29-29: Unused function argument: quant_mode

(ARG001)

🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around line 29, the
function signature includes an unused parameter quant_mode:
Literal["Fp8-Per-Tensor"]; remove this parameter from the signature and any
references to it, or if it is intentionally reserved for future use, keep it but
add a clear comment above the parameter explaining its purpose and why it is
unused (e.g., "reserved for future quantization modes"); update any call sites
if you remove it to avoid breaking callers.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

23-27: Guard FP8 quantization against zero max.

Line 23 lets scale = FLOAT8_E4M3_MAX / 0 when the tensor is all zeros, so x * scale turns into NaNs and the benchmark fails immediately. Bail out with zeros and a unit inverse-scale before dividing.

Apply this diff to harden the quantizer:

 def fp8_quantize(x):
-    max = x.float().abs().nan_to_num().max()
-    scale = FLOAT8_E4M3_MAX / max
-    x = (x * scale).to(torch.float8_e4m3fn)
-    return x, 1.0 / scale
+    max_val = x.float().abs().nan_to_num().amax()
+    if max_val.item() == 0:
+        zeros = torch.zeros_like(x, dtype=torch.float8_e4m3fn)
+        return zeros, torch.tensor(1.0, device=x.device, dtype=torch.float32)
+    scale = FLOAT8_E4M3_MAX / max_val
+    quantized = (x * scale).to(torch.float8_e4m3fn)
+    return quantized, scale.reciprocal()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 99eb4ec and 8768aad.

📒 Files selected for processing (1)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4)
flashinfer/fused_moe/core.py (3)
  • trtllm_fp4_block_scale_moe (1827-1961)
  • trtllm_fp8_per_tensor_scale_moe (1661-1739)
  • RoutingMethodType (58-72)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
flashinfer/utils.py (1)
  • device_support_pdl (568-572)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

32-32: Unused function argument: quant_mode

(ARG001)


70-92: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines +60 to +68
output1_scale_scalar = torch.tensor(
[hidden_states_scale * w13_scale] * num_experts, device=device
)
output1_scales_gate_scalar = torch.ones(
num_experts, device=device, dtype=torch.float32
)
output2_scale_scalar = torch.tensor(
[hidden_states_scale * w2_scale] * num_experts, device=device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Construct FP8 scale vectors without CPU conversion errors.

Lines 60-68 call torch.tensor([...], device=device) on CUDA scalars, which raises TypeError: can't convert CUDA tensor to numpy(). That stops the FP8 path before benchmarking. Build the vectors on device without Python lists.

Apply this diff to keep the scales on CUDA:

-    output1_scale_scalar = torch.tensor(
-        [hidden_states_scale * w13_scale] * num_experts, device=device
-    )
+    scale_prod_1 = (hidden_states_scale * w13_scale).item()
+    output1_scale_scalar = torch.full(
+        (num_experts,),
+        scale_prod_1,
+        device=device,
+        dtype=torch.float32,
+    )
@@
-    output2_scale_scalar = torch.tensor(
-        [hidden_states_scale * w2_scale] * num_experts, device=device
-    )
+    scale_prod_2 = (hidden_states_scale * w2_scale).item()
+    output2_scale_scalar = torch.full(
+        (num_experts,),
+        scale_prod_2,
+        device=device,
+        dtype=torch.float32,
+    )
🤖 Prompt for AI Agents
In benchmarks/bench_trtllm_gen_fused_moe_autotuner.py around lines 60 to 68, the
code creates FP8 scale vectors using Python lists of CUDA scalars which triggers
"can't convert CUDA tensor to numpy()" on CUDA; replace those list constructions
with device-native tensor factories (e.g., use torch.full or torch.ones with
shape (num_experts,) and the desired dtype/device) to produce
output1_scale_scalar and output2_scale_scalar directly on the CUDA device (and
keep output1_scales_gate_scalar as torch.ones on device with correct dtype).

@yzh119
Copy link
Collaborator

yzh119 commented Nov 4, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !107 has been created, and the CI pipeline #37845025 is currently running. I'll report back once the pipeline job completes.

@IwakuraRein IwakuraRein marked this pull request as draft November 4, 2025 03:16
@jiahanc jiahanc marked this pull request as ready for review November 5, 2025 02:12
@jiahanc
Copy link
Collaborator Author

jiahanc commented Nov 5, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !107 has been updated with latest changes, and the CI pipeline #37954416 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (10)
include/flashinfer/trtllm/fused_moe/DevKernel.h (1)

116-123: Critical: Fix the pow2 condition to include zero

This issue was previously flagged: the condition data.mPaddingLog2 > 0 excludes mPaddingLog2 == 0, causing it to fall into the tile-N branch (false). However, mPaddingLog2 == 0 represents tile size 1 (2^0 = 1), which is a power of 2 and should follow the pow2 path. When zero takes the false branch, mTileTokensDim is not populated for pow2 flows, leading to divide-by-zero runtime faults in the CUDA kernels.

Apply this diff:

-  if (data.mPaddingLog2 > 0) {
+  if (data.mPaddingLog2 >= 0) {
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

66-81: Fix UB when tile not present; snap to nearest supported entry before walking neighbors.

std::find may return end() (e.g., nextPowerOfTwo=32 but 32 not in list). Using std::next(it) on end() is UB. Replace with lower_bound, clamp iterator, and snap tile_tokens_dim to an actual element; also guard neighbors.

-  // assume supported_tile_nums is sorted
-  int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
-                                       supported_tile_nums.front(), supported_tile_nums.back());
-  auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
+  // assume supported_tile_nums is sorted
+  int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert),
+                                       supported_tile_nums.front(), supported_tile_nums.back());
+  auto it = std::lower_bound(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
+  if (it == supported_tile_nums.end()) { it = std::prev(supported_tile_nums.end()); }
+  // Optionally pick the closer neighbor if lower_bound overshoots
+  if (it != supported_tile_nums.begin()) {
+    auto prev = std::prev(it);
+    if (std::abs(*prev - tile_tokens_dim) <= std::abs(*it - tile_tokens_dim)) it = prev;
+  }
+  tile_tokens_dim = *it;
 
   std::set<int32_t> selected_tile_nums;
   selected_tile_nums.insert(tile_tokens_dim);
-  if (std::next(it) != supported_tile_nums.end()) {
+  if (std::next(it) != supported_tile_nums.end()) {
     selected_tile_nums.insert(*std::next(it));
-    if (std::next(std::next(it)) != supported_tile_nums.end()) {
+    if (std::next(std::next(it)) != supported_tile_nums.end()) {
       selected_tile_nums.insert(*std::next(std::next(it)));
     }
   }
-  if (it != supported_tile_nums.begin()) {
+  if (it != supported_tile_nums.begin()) {
     selected_tile_nums.insert(*std::prev(it));
   }
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)

48-54: Remove nested trtllm::gen inside batchedGemm; it shadows real ::trtllm::gen.

This creates batchedGemm::trtllm::gen and breaks lookups for the actual ::trtllm::gen (dtype helpers, enums). Drop the nested block and reference global with ::.

-namespace trtllm {
-namespace gen {
-class CudaRunner;
-class GenCfg;
-}  // namespace gen
-}  // namespace trtllm
+// Forward declarations, when needed, should reference global namespace or use fully qualified ::trtllm::gen.
+// (Remove nested shadow to keep tg = ::trtllm::gen valid.)

And make the alias explicit:

-namespace tg = trtllm::gen;
+namespace tg = ::trtllm::gen;

Also applies to: 59-60

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)

58-64: Avoid nested trtllm::gen in batchedGemm; use global to prevent shadowing.

Same shadowing issue as in GemmGatedActOptions.h. Remove the nested block and use ::trtllm::gen in aliases.

-namespace trtllm {
-namespace gen {
-class CudaRunner;
-class GenCfg;
-}  // namespace gen
-}  // namespace trtllm
+// Use ::trtllm::gen forward decls if needed; avoid introducing batchedGemm::trtllm::gen.
-namespace tg = trtllm::gen;
+namespace tg = ::trtllm::gen;

Also applies to: 69-71


388-415: Restore printing of mBatchedM/mBatchedN in dumpOptions.

Placeholders {} drop essential tuning logs. Serialize the actual vectors when dumpRuntimeParams is true.

-  if (dumpRuntimeParams) {
-    ss << "mBatchedM={}," << std::endl;
-    ss << "mBatchedN={}," << std::endl;
-  }
+  if (dumpRuntimeParams) {
+    ss << "mBatchedM={";
+    for (size_t i = 0; i < options.mBatchedM.size(); ++i) {
+      ss << (i ? ", " : "") << options.mBatchedM[i];
+    }
+    ss << "}," << std::endl;
+    ss << "mBatchedN={";
+    for (size_t i = 0; i < options.mBatchedN.size(); ++i) {
+      ss << (i ? ", " : "") << options.mBatchedN[i];
+    }
+    ss << "}," << std::endl;
+  }
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (3)

30-40: Remove or document the unused quant_mode parameter.

The quant_mode parameter at line 32 is not used within the function body.

If this parameter is reserved for future use, add a comment explaining its purpose. Otherwise, remove it:

 def bench_trtllm_gen_fused_moe_autotuner_fp8(
     tune_max_num_tokens: Optional[int],
-    quant_mode: Literal["Fp8-Per-Tensor"],
     num_tokens: int,
     num_experts: int,
     hidden_size: int,
     intermediate_size: int,
     top_k: int,
     warmups: int,
     iterations: int,
 ):

23-27: Guard FP8 quantization against all-zero inputs.

When the input tensor is all zeros, max becomes zero, scale becomes infinite, and the quantized result is NaN (0 × ∞). This will break benchmarking when buffers start cleared.

Apply this diff to handle the zero case:

 def fp8_quantize(x):
-    max = x.float().abs().nan_to_num().max()
-    scale = FLOAT8_E4M3_MAX / max
-    x = (x * scale).to(torch.float8_e4m3fn)
-    return x, 1.0 / scale
+    max_val = x.float().abs().nan_to_num().max()
+    if max_val == 0:
+        return torch.zeros_like(x, dtype=torch.float8_e4m3fn), 1.0
+    scale = FLOAT8_E4M3_MAX / max_val
+    quantized = (x * scale).to(torch.float8_e4m3fn)
+    return quantized, 1.0 / scale

60-68: Construct FP8 scale vectors without CPU conversion errors.

Lines 60-68 call torch.tensor([...], device=device) on CUDA scalars, which raises TypeError: can't convert CUDA tensor to numpy(). Build the vectors on device without Python lists.

Apply this diff to keep scales on CUDA:

-    output1_scale_scalar = torch.tensor(
-        [hidden_states_scale * w13_scale] * num_experts, device=device
-    )
+    scale_prod_1 = (hidden_states_scale * w13_scale).item()
+    output1_scale_scalar = torch.full(
+        (num_experts,),
+        scale_prod_1,
+        device=device,
+        dtype=torch.float32,
+    )
@@
-    output2_scale_scalar = torch.tensor(
-        [hidden_states_scale * w2_scale] * num_experts, device=device
-    )
+    scale_prod_2 = (hidden_states_scale * w2_scale).item()
+    output2_scale_scalar = torch.full(
+        (num_experts,),
+        scale_prod_2,
+        device=device,
+        dtype=torch.float32,
+    )
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)

745-787: Fix variable shadowing that loses runtime dimensions in workspace calculation.

Line 767 creates a new reference auto& options = config.mOptions that shadows the options variable from line 750. This causes:

  • Lines 752-764 to use complete options (with runtime M/N/K/valid fields from data)
  • Lines 769-779 to use only config.mOptions (missing runtime dimensions)

The shadowing loses the populated mM, mN, mK, mValidM, mValidN, mValidK fields, potentially leading to incorrect workspace size calculations.

Apply this diff to remove the shadowing:

-    // Get options from config.
-    auto& options = config.mOptions;
+    // Use the same options variable created earlier (no shadowing)

Then use the existing options variable for all subsequent calculations.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (1)

652-665: Fix warning macro that prevents valid dimension clamping.

When options.mValidK > options.mK, the TLLM_LOG_WARNING at line 655 fires. Under TLLM_GEN_EXPORT_INTERFACE, this macro immediately returns false, so the subsequent clamp operation at lines 659-661 never executes. This breaks callers relying on updateOptions == true to sanitize oversized valid dimensions.

Replace the warning with a conditional log that doesn't trigger early return:

-  if (options.mValidM > options.mM || options.mValidN > options.mN ||
-      options.mValidK > options.mK) {
-    TLLM_LOG_WARNING(
-        options.mValidK <= options.mK,
-        "ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively.");
-    if (updateOptions) {
-      options.mValidM = std::min(options.mValidM, options.mM);
-      options.mValidN = std::min(options.mValidN, options.mN);
-      options.mValidK = std::min(options.mValidK, options.mK);
-    } else {
-      return false;
-    }
-  }
+  if (options.mValidM > options.mM || options.mValidN > options.mN ||
+      options.mValidK > options.mK) {
+#ifdef TLLM_GEN_DEBUG
+    printArgs("WARNING: ValidM, ValidN, and ValidK must be less than or equal to M, N, and K respectively.\n");
+#endif
+    if (updateOptions) {
+      options.mValidM = std::min(options.mValidM, options.mM);
+      options.mValidN = std::min(options.mValidN, options.mN);
+      options.mValidK = std::min(options.mValidK, options.mK);
+    } else {
+      return false;
+    }
+  }
🧹 Nitpick comments (5)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

381-392: Tile sets look good; ensure sorted+unique before selection to avoid duplicates.

When conditionally pushing 128/192/256, duplicates can appear across branches; keep the vector sorted and unique before passing to computeSelectedTileN.

-    std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256};
+    std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128, 192, 256};
+    mSupportedTileN.erase(std::unique(mSupportedTileN.begin(), mSupportedTileN.end()), mSupportedTileN.end());
-    std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
+    std::vector<int32_t> mSupportedTileN = {8, 16, 32, 64, 128};
+    mSupportedTileN.erase(std::unique(mSupportedTileN.begin(), mSupportedTileN.end()), mSupportedTileN.end());
-  std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
+  std::vector<int32_t> supported_tile_nums = {8, 16, 32, 64};
   ...
   // after conditional push_backs
+  std::sort(supported_tile_nums.begin(), supported_tile_nums.end());
+  supported_tile_nums.erase(std::unique(supported_tile_nums.begin(), supported_tile_nums.end()),
+                            supported_tile_nums.end());

Also applies to: 730-740, 1322-1336

flashinfer/fused_moe/core.py (1)

125-133: Minor nit: duplicate enum in list.

MxE4m3 appears twice in trtllm_gen_dtype_has_scale. Harmless; remove duplicate for clarity.

-        DtypeTrtllmGen.MxE4m3,
         DtypeTrtllmGen.E2m1,
         DtypeTrtllmGen.MxE2m1,
-        DtypeTrtllmGen.MxE4m3,
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)

154-169: New validation is good; message tweaks optional.

DeepSeek FP8 and shuffled-MatrixA constraints on hidden/valid sizes are correct. Consider clarifying error text to include hiddenSizeStr for both fields.

include/flashinfer/trtllm/fused_moe/RoutingKernel.h (1)

53-57: Template split (isPow2_, UsePdl_) and new mTileTokensDim are wired correctly.

Defaults (mPaddingLog2=-1) are safe under non-pow2 path; Data->Params copy includes mTileTokensDim. Comment on mPtrPermutedIdxToTokenIdx shape still mentions a derived formula and may confuse readers.

-  // dim: [mTileTokensDim * mTopK + (mNumExperts × mTileTokensDim) - mNumExperts]
+  // dim: [total_permuted_tokens]; actual size equals padded tokens count reported via mPtrPermutedIdxSize[0]

Also applies to: 95-103, 104-157, 232-254, 274-301

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

70-92: Prefer def over lambda assignment for clarity.

Assigning a lambda expression to a variable reduces readability and makes debugging harder (stack traces show <lambda> instead of a meaningful function name).

Refactor to use a proper function definition:

-    fn = lambda: trtllm_fp8_per_tensor_scale_moe(
-        routing_logits,
-        None,  # routing_bias
-        hidden_states,
-        w13,
-        output1_scale_scalar,
-        output1_scales_gate_scalar,
-        w2,
-        output2_scale_scalar,
-        num_experts,
-        top_k,
-        None,  # n_group
-        None,  # topk_group
-        intermediate_size,
-        0,  # local_expert_offset
-        num_experts,
-        1.0,  # routed_scaling_factor
-        False,  # use_routing_scales_on_input
-        None,
-        RoutingMethodType.TopK.value,
-        enable_pdl,
-        num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
-    )
+    def fn():
+        return trtllm_fp8_per_tensor_scale_moe(
+            routing_logits,
+            None,  # routing_bias
+            hidden_states,
+            w13,
+            output1_scale_scalar,
+            output1_scales_gate_scalar,
+            w2,
+            output2_scale_scalar,
+            num_experts,
+            top_k,
+            None,  # n_group
+            None,  # topk_group
+            intermediate_size,
+            0,  # local_expert_offset
+            num_experts,
+            1.0,  # routed_scaling_factor
+            False,  # use_routing_scales_on_input
+            None,
+            RoutingMethodType.TopK.value,
+            enable_pdl,
+            num_tokens if tune_max_num_tokens is None else tune_max_num_tokens,
+        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 703ed28 and 0e88417.

⛔ Files ignored due to path filters (2)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h is excluded by !**/gen/**
📒 Files selected for processing (25)
  • benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (4 hunks)
  • csrc/trtllm_batched_gemm_runner.cu (5 hunks)
  • csrc/trtllm_fused_moe_kernel_launcher.cu (5 hunks)
  • csrc/trtllm_fused_moe_routing_deepseek.cu (2 hunks)
  • csrc/trtllm_fused_moe_routing_llama4.cu (3 hunks)
  • csrc/trtllm_fused_moe_routing_renormalize.cu (1 hunks)
  • csrc/trtllm_fused_moe_runner.cu (4 hunks)
  • flashinfer/artifacts.py (2 hunks)
  • flashinfer/autotuner.py (1 hunks)
  • flashinfer/fused_moe/core.py (6 hunks)
  • flashinfer/jit/fused_moe.py (1 hunks)
  • flashinfer/jit/gemm/core.py (2 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (2 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (4 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (10 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (4 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (25 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (11 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h (0 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (7 hunks)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/DevKernel.h (2 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (6 hunks)
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h (7 hunks)
  • tests/moe/test_trtllm_gen_fused_moe.py (7 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
🚧 Files skipped from review as they are similar to previous changes (3)
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
  • csrc/trtllm_fused_moe_runner.cu
🧰 Additional context used
🧬 Code graph analysis (10)
flashinfer/jit/fused_moe.py (1)
flashinfer/artifacts.py (1)
  • ArtifactPath (83-98)
csrc/trtllm_batched_gemm_runner.cu (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
  • isValidConfig (718-728)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h (1)
  • getTmemColStridePerGroup (99-103)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)
csrc/trtllm_batched_gemm_runner.cu (10)
  • run (160-261)
  • run (160-168)
  • run (263-277)
  • run (263-267)
  • run (279-295)
  • run (279-285)
  • run (297-312)
  • run (297-301)
  • getWorkspaceSizeInBytes (129-158)
  • getWorkspaceSizeInBytes (129-131)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
  • std (229-583)
  • std (239-244)
  • std (279-300)
  • std (290-296)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
  • trtllm (28-90)
  • gen (29-89)
  • launchKernel (34-84)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (3)
  • trtllm (48-53)
  • gen (49-52)
  • gemm (147-152)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
  • RouteImpl (28-57)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (4)
  • trtllm (58-63)
  • gen (59-62)
  • mExecPath (377-435)
  • mInstanceIdx (380-380)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (6)
  • trtllm (82-87)
  • gen (83-86)
  • getShuffleBlockSize (602-608)
  • string (438-440)
  • string (445-447)
  • mInstanceIdx (421-421)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h (4)
  • getShuffleBlockSize (539-545)
  • string (407-409)
  • string (414-416)
  • string (420-521)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (5)
csrc/trtllm_fused_moe_kernel_launcher.cu (4)
  • trtllm_fp4_block_scale_moe (1177-1273)
  • trtllm_fp4_block_scale_moe (1177-1190)
  • trtllm_fp8_per_tensor_scale_moe (352-412)
  • trtllm_fp8_per_tensor_scale_moe (352-360)
flashinfer/fused_moe/core.py (3)
  • trtllm_fp4_block_scale_moe (1827-1961)
  • trtllm_fp8_per_tensor_scale_moe (1661-1739)
  • RoutingMethodType (58-72)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
flashinfer/utils.py (1)
  • device_support_pdl (569-573)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
  • gemm (32-297)
tests/moe/test_trtllm_gen_fused_moe.py (2)
flashinfer/fused_moe/core.py (2)
  • RoutingMethodType (58-72)
  • WeightLayout (161-168)
include/flashinfer/trtllm/fused_moe/runner.h (1)
  • RoutingMethodType (37-136)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (3)
include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h (3)
  • flashinfer (36-38)
  • gemm (42-347)
  • gemm (468-488)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (6)
  • string (133-181)
  • trtllm (38-271)
  • gen (39-269)
  • Dtype (43-268)
  • dtypeIsBlockFmt (96-99)
  • dtypeNumEltsPerSf (198-209)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
  • gemm (30-417)
🪛 Ruff (0.14.3)
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py

32-32: Unused function argument: quant_mode

(ARG001)


70-92: Do not assign a lambda expression, use a def

Rewrite fn as a def

(E731)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (29)
flashinfer/autotuner.py (1)

479-506: LGTM! Improved error visibility.

Including the exception details in the warning log enhances observability during profiling failures. The debug log on line 491 continues to provide full context with shapes and error details, while the warning now gives a quick overview of what went wrong.

include/flashinfer/trtllm/fused_moe/DevKernel.h (3)

127-132: LGTM: Clean replacement with preserved signatures

The replacement of LAUNCH_PDL with LAUNCH_TILEN correctly preserves type signatures and parameters for both FP32 and BF16 routing paths.


141-174: LGTM: Consistent replacement across all dtype combinations

The replacement of LAUNCH_PDL with LAUNCH_TILEN is correctly applied across all nine dtype combinations (score/bias/expW), preserving type signatures, parameters, and the extraFlag/numExperts propagation.


194-204: LGTM: Correct replacement for all expert-count paths

The replacement of LAUNCH_PDL with LAUNCH_TILEN is correctly applied across all four branches (FP32/BF16 × extraFlag1), preserving type signatures and the numExperts parameter.

flashinfer/jit/gemm/core.py (2)

533-539: LGTM! Consistent flag addition across TRTLLM modules.

Good to see the -DTLLM_GEN_EXPORT_FLASHINFER flag added consistently to both the regular and low-latency TRTLLM GEMM modules. This ensures uniform behavior across both compilation paths.


382-387: The flag addition is not used by the compiled sources and should be reviewed for correctness.

The TLLM_GEN_EXPORT_FLASHINFER preprocessor flag exists in the codebase but is only used within the batched_gemm/trtllmGen_bmm_export/ module, not in the gemm/trtllmGen_gemm_export/ path that the compiled runners include. The gemm runners compile against GemmInterface.h which uses TLLM_GEN_EXPORT_INTERFACE instead, not TLLM_GEN_EXPORT_FLASHINFER. Adding this flag to the gemm runners' compilation will have no effect on their behavior.

Verify whether:

  • The flag should be added to a different compilation target (batched_gemm runners)
  • The gemm runners should be using TLLM_GEN_EXPORT_FLASHINFER instead of TLLM_GEN_EXPORT_INTERFACE
  • This flag addition was unintended

Likely an incorrect or invalid review comment.

flashinfer/jit/fused_moe.py (1)

236-236: LGTM: New export flag added.

The addition of -DTLLM_GEN_EXPORT_FLASHINFER is straightforward and aligns with the PR's objective to add this export flag for the TRTLLM fused MOE SM100 module.

csrc/trtllm_fused_moe_kernel_launcher.cu (1)

560-566: Verify gemm1_output_scale shape (tokens dimension).

Block-scale path uses max_num_padded_tokens for gemm1_output_scale while gemm1_output uses max_num_padded_tokens_gemm1. Mismatch might over-allocate or later index past produced rows. Confirm intended dimension and align both if necessary.

include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh (1)

70-87: isPow2/tileN dual-path arithmetic looks consistent.

mulTileN/divUpTileN/divUpMulTileN are correct and used consistently in CTA count, mnLimit, and permuted size computations. No issues.

Also applies to: 321-347, 361-369, 554-564, 586-596, 603-613

flashinfer/fused_moe/core.py (1)

1319-1347: Parameter passthrough of tune_max_num_tokens looks correct.

Autotuner reconfiguration and op wrappers correctly forward tune_max_num_tokens. No blocking issues.

Also applies to: 1540-1567, 1712-1739, 1763-1824, 1922-1961, 2061-2100

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)

184-191: dumpOptions(runtime flag) and new config fields LGTM.

Extended dump with dumpRuntimeParams and runtime wiring fields looks fine.

Also applies to: 199-215

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (1)

237-255: DeepSeek FP8 valid-dimension checks are correct; LdgPlusSts guards are helpful.

Constraint checks on (M/N/K)%128 and route Sfs impl combinations look consistent.

Also applies to: 276-284

csrc/trtllm_batched_gemm_runner.cu (1)

147-158: Good: propagate validM/N/K and explicitly gate split‑K.

Setting valid dims before queries and requiring mClusterDimZ==1 avoids known failures; config sorting and fallback look fine.

Also applies to: 246-255, 338-341, 449-452

csrc/trtllm_fused_moe_routing_renormalize.cu (1)

168-176: Tile/pow2 dual-path applied consistently in renormalize kernels.

divUp{Log2,TileN}, mul{Log2,TileN} are used coherently for CTA count, mnLimit, and permuted size. Looks correct.

Also applies to: 179-186, 193-203, 206-216

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)

34-36: Enum update aligns with new routing path.

The extra variant plus helper keep the routing helpers consistent and makes the downstream checks straightforward. Looks good.

csrc/trtllm_fused_moe_routing_deepseek.cu (1)

396-421: Nice split between pow2 and tile flows.

The constexpr branch cleanly swaps in the tile arithmetic without touching the existing log2 path. That should really help with the new non‑power-of-two kernels.

csrc/trtllm_fused_moe_routing_llama4.cu (1)

188-244: Thanks for mirroring the tile-aware logic here.

The consistent min(mnLimit1, mnLimit2) handling across both branches is reassuring, especially with the new tile sizes.

tests/moe/test_trtllm_gen_fused_moe.py (1)

2088-2175: Great to see FP8 block-scale in the renorm matrix.

The added parameter coverage (tiles, routing configs, GeGlu guard) should catch regressions once the new kernels land.

benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)

291-314: LGTM!

The dispatch logic cleanly separates FP8 and FP4 benchmark paths based on the quantization mode. The implementation correctly routes to the appropriate function with all necessary parameters.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (2)

85-96: LGTM!

The addition of validM/N/K parameters with default values of -1 provides a clean backward-compatible interface. The initialization logic correctly defaults valid dimensions to their corresponding size dimensions when not explicitly provided.


109-209: LGTM!

The shape/stride computation correctly distinguishes between padded dimensions (for strides) and valid dimensions (for shapes). This optimization reduces unnecessary memory traffic by clamping TMA shapes to the valid data range while maintaining correct stride calculations for the full allocated memory.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)

76-87: LGTM!

The addition of mValidM/N/K fields to ProblemDimensions is well-documented and properly initialized. The comment clearly explains their purpose: tracking the valid range of dimensions separately from padded dimensions due to alignment constraints.


461-469: LGTM!

The new constructor properly initializes the mExportsCubin and mNumRotations member variables. The generateAndCompileKernel method is appropriately guarded with #ifndef TLLM_GEN_EXPORT_INTERFACE to keep compilation-related functionality separate from the export interface.


696-713: LGTM!

The getOptionsFromConfigAndData method correctly populates all problem dimension fields, including the newly added mValidM/N/K fields from the problem dimensions data.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (5)

641-650: LGTM!

The default initialization logic for mValidM/N/K is clean and correct. When these fields are not explicitly set (-1), they default to the full dimension sizes, which is the expected behavior for non-padded cases.


667-678: LGTM!

The validation correctly prevents use of validM/N/K parameters with BlockMajorK layout, as the swizzled memory layout is incompatible with the valid-dimension optimization. The check properly detects when any valid parameter differs from its corresponding size parameter.


1199-1201: LGTM!

The validation correctly ensures both mK and mValidK are multiples of 128 when using DeepSeek FP8. This is essential for the per-128-channel scaling mechanism to work correctly.


1014-1019: LGTM!

The validation correctly requires both mM and mValidM to be multiples of the shuffle block size when using shuffled matrix A. This ensures the memory reordering is consistent across both the allocated and valid data ranges.


1493-1514: LGTM!

The loadCubinData function cleanly abstracts the two cubin loading paths:

  1. FlashInfer mode: loads from storage using the cubin loader with proper path construction
  2. TensorRT-LLM mode: uses embedded cubin data from config.mData

The static_assert at line 1508 appropriately catches configuration errors at compile time when the required macro is missing.

"-DENABLE_FP8",
"-DENABLE_FP4",
f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify artifact path consistency after macro rename.

The macro has been renamed to TLLM_GEN_GEMM_CUBIN_PATH but still points to ArtifactPath.TRTLLM_GEN_BMM. From artifacts.py, there are two distinct artifact paths:

  • TRTLLM_GEN_BMM: "d60ea310477e47941ad9527c6d61c0531d82dd26/batched_gemm-ac83afb-96dfa16"
  • TRTLLM_GEN_GEMM: "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"

Please confirm whether TLLM_GEN_GEMM_CUBIN_PATH should point to the BMM artifact path (batched GEMM) or if it should be updated to use ArtifactPath.TRTLLM_GEN_GEMM.

Run the following script to verify which artifact path the C++ code expects:


🏁 Script executed:

#!/bin/bash
# Description: Check how TLLM_GEN_GEMM_CUBIN_PATH and TLLM_GEN_BMM_CUBIN_PATH 
# are used in the C++ source files to determine the correct artifact path.

# Search for usage of the macro in C++ files
echo "=== Searching for TLLM_GEN_GEMM_CUBIN_PATH usage ==="
rg -n -C3 --type=cpp --type=cu 'TLLM_GEN_GEMM_CUBIN_PATH' 

echo -e "\n=== Searching for TLLM_GEN_BMM_CUBIN_PATH usage ==="
rg -n -C3 --type=cpp --type=cu 'TLLM_GEN_BMM_CUBIN_PATH'

echo -e "\n=== Checking trtllm_batched_gemm_runner.cu (from compilation list) ==="
fd -t f 'trtllm_batched_gemm_runner.cu' --exec rg -n -C5 'CUBIN_PATH|batched.*gemm|TLLM_GEN'

Length of output: 2345


🏁 Script executed:

#!/bin/bash
# Search for the macro names without type restrictions
echo "=== Searching for TLLM_GEN_GEMM_CUBIN_PATH in all files ==="
rg -n -C2 'TLLM_GEN_GEMM_CUBIN_PATH' --max-count=20

echo -e "\n=== Searching for TLLM_GEN_BMM_CUBIN_PATH in all files ==="
rg -n -C2 'TLLM_GEN_BMM_CUBIN_PATH' --max-count=20

echo -e "\n=== Checking fused_moe.py around line 241 ==="
fd -t f 'fused_moe.py' --exec cat {} \; | head -250 | tail -50

echo -e "\n=== Looking for artifact path TRTLLM_GEN_BMM definition ==="
rg -n 'TRTLLM_GEN_BMM|TRTLLM_GEN_GEMM' artifacts.py -A1 -B1

Length of output: 5352


🏁 Script executed:

# Script to verify: search in codebase without skipping

Length of output: 59


Update fused_moe.py line 241 to use the correct artifact path.

The macro TLLM_GEN_GEMM_CUBIN_PATH at line 241 is defined with the wrong artifact path. Evidence shows:

  • flashinfer/jit/gemm/core.py (lines 386, 537) correctly uses ArtifactPath.TRTLLM_GEN_GEMM for the same macro
  • Both C++ header files in trtllm/gemm/trtllmGen_gemm_export/ and trtllm/batched_gemm/trtllmGen_bmm_export/ expect TLLM_GEN_GEMM_CUBIN_PATH (lines 30–33 and 1497–1509 respectively)
  • TLLM_GEN_BMM_CUBIN_PATH does not exist anywhere in the codebase

Line 241 should be:

f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
🤖 Prompt for AI Agents
In flashinfer/jit/fused_moe.py around line 241, the macro definition uses the
wrong artifact path (ArtifactPath.TRTLLM_GEN_BMM) for TLLM_GEN_GEMM_CUBIN_PATH;
change it to use ArtifactPath.TRTLLM_GEN_GEMM instead, preserving the existing
string escaping/formatting so the line becomes the same f-string but referencing
ArtifactPath.TRTLLM_GEN_GEMM.

Comment on lines +278 to 283
if (numEpilogueWarps) {
extraGmemCMultiplier *= numEpilogueWarps / 4;
}

// Number of bytes to store the output in smem.
auto const numBytesSmemStoreC = usesSmemForGmemC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix shared-memory sizing for epilogue warps.

Multiplying by numEpilogueWarps / 4 truncates toward zero. For example, tileM=192 → numEpilogueWarps = 6, so the multiplier stays at 1 and we under-allocate the DSMEM buffers. Even worse, 1–3 warps become 0 and allocate nothing. We need a ceil-div against the 4‑warp baseline.

Please replace the scaling with something like:

-        if (numEpilogueWarps) {
-          extraGmemCMultiplier *= numEpilogueWarps / 4;
-        }
+        if (numEpilogueWarps) {
+          auto const warpGroups = (numEpilogueWarps + 3) / 4;
+          extraGmemCMultiplier *= warpGroups;
+        }

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37954416: 12/17 passed

Signed-off-by: Siyuan Fu <[email protected]>
@IwakuraRein
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119
Copy link
Collaborator

yzh119 commented Nov 6, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !107 has been updated with latest changes, and the CI pipeline #37976772 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37976772: 12/17 passed

@yzh119 yzh119 merged commit b211926 into flashinfer-ai:main Nov 6, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants