Skip to content

feat: preparing TRTLLM MoE backend to support more kernels#2794

Open
rosenrodt wants to merge 2 commits intoflashinfer-ai:mainfrom
rosenrodt:feat/prepare-trtllm-moe-tput-support-2
Open

feat: preparing TRTLLM MoE backend to support more kernels#2794
rosenrodt wants to merge 2 commits intoflashinfer-ai:mainfrom
rosenrodt:feat/prepare-trtllm-moe-tput-support-2

Conversation

@rosenrodt
Copy link
Contributor

@rosenrodt rosenrodt commented Mar 16, 2026

📌 Description

This supercedes #2741. Contrary to #2741, this PR assumes that TRTLLM MoE kernels interpret the indices of the routing table in CGA granularity. As result, routing kernel can stay unchanged.

This is for preparing to support non-SwapAb/... kernels in TRTLLM MoE. Cubins will be prepared separately.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

Unit tests passed

$ python -m pytest -v tests/moe/test_trtllm_gen_fused_moe.py
===465 passed, 9545 skipped in 6898.12s (1:54:58)===
  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added CUDA profiling support for benchmarks with NVTX context managers.
    • Introduced bandwidth (TB/s) metrics calculation alongside TFLOPS in benchmark results.
    • Added new logging macro for diagnostic output.
  • Improvements

    • Updated TFLOPS calculations to use actual routed token counts for accuracy.
    • Enhanced error diagnostics and validation messages in kernel configuration selection.
  • Bug Fixes

    • Removed hardcoded transpose behaviors for improved GEMM configuration flexibility.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 16, 2026

📝 Walkthrough

Walkthrough

This PR enhances MoE benchmarking with actual routed-token metrics and CUDA profiler support, refactors batched GEMM dimension handling with validation improvements, introduces a generic launcher selector mechanism, adjusts dtype parameter ordering in MoE runners, and adds logging infrastructure.

Changes

Cohort / File(s) Summary
Benchmarking enhancements
benchmarks/bench_moe_deepseek.py
Updated TFLOPS and bandwidth calculations to use actual routed token counts; added CUDA NVTX profiler integration with cuda_profiler_range context manager; extended BenchResult with bw_tb_s metric; modified histogram collection to compute and return local_tokens for accurate throughput metrics.
Batched GEMM validation & refactoring
csrc/trtllm_batched_gemm_runner.cu, include/flashinfer/trtllm/batched_gemm/KernelRunner.h
Introduced setProblemDimensions helper to centralize dimension initialization; added skipQuirks config filter to exclude hanging kernels; refactored validation logic in TrtllmGenBatchedGemmRunner with improved error messages and transpose-aware pointer assignments; added clarifying comments on dtype and transposeMmaOutput semantics.
MoE launcher & runner refactoring
csrc/trtllm_fused_moe_kernel_launcher.cu, csrc/trtllm_fused_moe_runner.cu, include/flashinfer/trtllm/fused_moe/runner.h
Added generic get_launcher template function for unified launcher selection; refactored launcher instantiation flow to call instantiate_moe_runner before routing preparation; swapped A/B dtype parameters (activation/weights) and removed transposeMmaOutput flag from GEMM runner options; removed ProjUpTileN field from MoEWorkspace structure; added compatibility notes in header.
Logging infrastructure
include/flashinfer/exception.h
Introduced FLASHINFER_LOG macro and new flashinfer::Log class for structured logging with function name, file, line number, and message formatting to stderr.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • cyx-6
  • IwakuraRein
  • bkryu
  • kahyunnam
  • jimmyzho
  • jiahanc
  • nv-yunzheq

Poem

🐰 A launcher selector hops into the fold,
Metrics bloom bright from routed tokens bold,
Dimensions dance in harmony at last,
As CUDA profilers capture the blast,
MoE runners sprint, now lighter and fast!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 27.59% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title directly relates to the main objective of preparing the TRTLLM MoE backend to support more kernels, as confirmed by the PR objectives and commit messages.
Description check ✅ Passed The PR description includes key sections (Description, Related Issues, Pre-commit Checks, Tests) with substantive content explaining the changes and superseding a previous PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and robustness of the TRTLLM Mixture-of-Experts (MoE) backend. By moving from hardcoded assumptions to dynamic determination of matrix transposition and introducing a mechanism to filter out problematic kernels, it lays crucial groundwork for supporting a broader array of future kernels. Additionally, the changes provide more precise performance metrics in benchmarks, offering deeper insights into MoE computation efficiency.

Highlights

  • Dynamic Kernel Transposition: The TRTLLM MoE backend now dynamically determines matrix transposition (transposeMmaOutput) based on kernel configuration, removing previous hardcoded assumptions and enabling support for a wider range of kernels.
  • Kernel Filtering and Robustness: A new skipQuirks function has been introduced to identify and filter out known problematic kernels that might cause hangs or crashes, improving the stability of kernel selection.
  • Enhanced MoE Benchmarking: The DeepSeek MoE benchmark script has been significantly updated to calculate and display both TFLOPS and achieved bandwidth (TB/s) based on actual routed token counts, providing more accurate performance insights. CUDA profiler ranges were also added for detailed analysis.
  • Refactored MoE Launcher Logic: The FusedMoeLauncher and its derived classes in the TRTLLM MoE kernel launcher have been refactored to centralize runner instantiation and simplify the prepare_moe methods, improving code organization and maintainability.
  • Improved Logging and Error Handling: New logging capabilities (FLASHINFER_LOG) and more detailed error messages have been added to the TrtllmGenBatchedGemmRunner for better debugging and problem diagnosis.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_moe_deepseek.py
    • Added cuda_profiler_range context manager for NVTX profiling.
    • Revised calc_tflops to use actual routed token counts for more accurate measurement.
    • Introduced calc_bw function to calculate achieved bandwidth in TB/s.
    • Updated BenchResult dataclass to include bw_tb_s for bandwidth reporting.
    • Modified benchmark output printing to display bandwidth and adjusted TFLOPS.
    • Updated _collect_expert_histogram to return local_tokens for accurate calculations.
    • Wrapped bench_gpu_time calls with cuda_profiler_range for profiling.
    • Adjusted header and row formatting to accommodate new performance metrics.
    • Updated footer with details on bandwidth calculation assumptions.
    • Corrected median_count formatting in the output table.
  • csrc/trtllm_batched_gemm_runner.cu
    • Introduced skipQuirks function to filter out known problematic kernels.
    • Added setProblemDimensions helper function to centralize problem dimension setup.
    • Modified TrtllmGenBatchedGemmRunner constructor to dynamically match kernel dtypes and layouts, removing hardcoded transposeMmaOutput.
    • Updated getWorkspaceSizeInBytes, run, and getValidConfigIndices to use setProblemDimensions and respect the kernel's transposeMmaOutput.
    • Improved error message in getValidConfigIndices when no valid configuration is found.
    • Added logging for kernel execution details when TRTLLM_BATCHED_GEMM_PRINT_NAME environment variable is enabled.
    • Adjusted isValidConfigIndex to utilize setProblemDimensions.
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Added get_launcher template function to centralize launcher selection logic.
    • Refactored FusedMoeLauncher to move moe_runner instantiation into instantiate_moe_runner.
    • Updated prepare_routing_common to remove ProjUpTileN from the workspace.
    • Modified run methods in various launcher classes to call instantiate_moe_runner before prepare_routing.
    • Removed redundant FusedMoeLauncher::prepare_moe_common calls from derived launcher classes.
    • Ensured mUseDeepSeekFp8 is correctly set before runner instantiation in Fp8BlockScaleLauncher.
    • Updated trtllm_bf16_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp4_block_scale_moe, and trtllm_mxint4_block_scale_moe to use the new get_launcher helper and simplified run calls.
    • Removed ProjUpTileN from MoEWorkspace in FP4BlockScaleLauncher.
  • csrc/trtllm_fused_moe_runner.cu
    • Included <algorithm> header for standard library functions.
    • Removed hardcoded transposeMmaOutput = true and dtypeA/dtypeB swaps in getOptions functions, allowing the runner to determine transposition dynamically.
  • include/flashinfer/exception.h
    • Added FLASHINFER_LOG macro for general structured logging.
    • Introduced Log class to encapsulate structured log messages.
  • include/flashinfer/trtllm/batched_gemm/KernelRunner.h
    • Added comments to TrtllmGenBatchedGemmRunnerOptions to clarify the roles of dtypeA, dtypeB, dtypeC and the behavior of transposeMmaOutput.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Added comments to getMaxNumCtasInBatchDim and MoEWorkspace to clarify that 'CTA-based' names now refer to CGA granularity.
    • Removed ProjUpTileN from the MoEWorkspace struct.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant refactoring to the TRTLLM MoE backend, preparing it to support a wider variety of kernels, including non-transposed ones. The changes are well-structured and align with the goal of making the backend more robust and extensible. Key improvements include generalizing the TrtllmGenBatchedGemmRunner to handle kernel-specific transposeMmaOutput settings, refactoring problem dimension setup into a helper function, and adding a skipQuirks function to avoid problematic kernel configurations. The benchmark scripts have also been enhanced with bandwidth calculation and more accurate TFLOPS metrics. I've identified a couple of minor issues with duplicated parameters in log messages and have provided suggestions for them. Overall, this is a solid contribution.

Comment on lines +280 to +281
FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCgasInBatchDim", maxNumCtasInBatchDim,
", MaxNumCtasInBatchDim", maxNumCtasInBatchDim, ", ShapeMNK",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a duplicated log parameter here. MaxNumCgasInBatchDim is logged twice. You might want to remove the duplicate.

    FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCgasInBatchDim", maxNumCtasInBatchDim, ", ShapeMNK",

Comment on lines +300 to +301
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim,
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous log message, maxNumCtasInBatchDim is duplicated in this error message. You might want to remove the duplicate.

                   ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
include/flashinfer/exception.h (1)

115-128: Log class is consistent with Warning class.

The implementation correctly mirrors the Warning class pattern. Both use std::cerr for output, which is thread-safe in C++11+.

Consider whether log-level filtering (controlled by environment variable) would be useful for production debugging. The existing Python-side logging uses FLASHINFER_LOGLEVEL environment variable per the learnings. A similar mechanism for C++ logs could provide parity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/exception.h` around lines 115 - 128, The Log class should
respect a FLASHINFER_LOGLEVEL environment variable like the Python logger and
the existing Warning class; modify Log (constructor or emit method) to read
FLASHINFER_LOGLEVEL (e.g., parse into levels like ERROR/WARN/INFO/DEBUG), store
or compare a numeric threshold, and only write to std::cerr when the message
level meets the configured threshold; ensure the same parsing/semantics as
Warning so both Log and Warning behave consistently.
benchmarks/bench_moe_deepseek.py (1)

43-52: Gate profiler calls behind a flag to match project patterns.

cuda_profiler_range() unconditionally invokes cudaProfilerStart/Stop() on every benchmark run. While these calls are safe no-ops when no profiling tool is attached, other benchmarks in the codebase (e.g., benchmarks/routines/moe_comm.py with nvtx_range(enabled=...) and bench_rope_quantize_fp8.py with conditional mode_ncu checks) gate profiler control behind explicit flags. Consider adding a command-line flag (e.g., --enable-profiler) to match this pattern and make profiling opt-in.

Also applies to: 375-384, 483-492, 610-619

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 43 - 52, The
cuda_profiler_range context manager currently always calls
cudaProfilerStart/Stop; make profiler control opt-in by adding a command-line
flag (e.g., --enable-profiler) and only invoke CUDA profiler/NVTX when the flag
is true: add a global or parsed argument (referenced from argparse in this file)
and update cuda_profiler_range to check that flag before calling
torch.cuda.cudart().cudaProfilerStart(), torch.cuda.nvtx.range_push(),
range_pop(), and cudaProfilerStop(); likewise wrap or gate the other
NVTX/profiler usages noted around lines 375-384, 483-492, and 610-619 to consult
the same flag so profiling becomes explicitly enabled rather than unconditional.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 106-108: Rename the ambiguous single-letter variable I to a
descriptive name (e.g., intermediate_size) wherever it's used: replace I with
intermediate_size in the block that assigns H = CFG.hidden_size and I =
CFG.intermediate_size and in the subsequent FLOPS computation flops =
local_tokens * (2 * H * 2 * I + 2 * I * H), and also update the same rename in
the later helper usages around the second occurrence (lines 119-123) so all
references (intermediate_size, H, local_tokens, flops) remain consistent.
- Around line 99-109: The calc_tflops function (and the other throughput helper
in the same file that computes TFLOPS/throughput) must guard against
non-positive latency: clamp ms to a small positive epsilon (e.g. ms = max(ms,
1e-6)) or return 0.0 immediately when ms <= 0 to avoid divide-by-zero or
negative throughput, and apply the same fix to the corresponding helper at lines
112-133 so both functions consistently handle ms <= 0.

In `@csrc/trtllm_batched_gemm_runner.cu`:
- Around line 299-301: The error message concatenation prints
maxNumCtasInBatchDim twice; update the string built around transposeMmaOutput,
configIndex, and maxNumCtasInBatchDim so the duplicate is removed and the
intended field is logged (either remove the second ", maxNumCtasInBatchDim: "
segment or replace it with the correct variable name if another field was
meant); look for the concatenation that references transposeMmaOutput,
configIndex, and maxNumCtasInBatchDim in trtllm_batched_gemm_runner.cu and
correct the duplicated token accordingly.
- Around line 279-286: The debug log call to FLASHINFER_LOG prints two fields
"MaxNumCgasInBatchDim" and "MaxNumCtasInBatchDim" but passes
maxNumCtasInBatchDim for both; update the arguments so the
"MaxNumCgasInBatchDim" label is paired with the correct variable
(maxNumCgasInBatchDim) and "MaxNumCtasInBatchDim" remains paired with
maxNumCtasInBatchDim in the FLASHINFER_LOG call.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 109-128: In get_launcher, the if (it == launchers_map.end()) block
contains a redundant FLASHINFER_CHECK(it != launchers_map.end(), ...); remove
the surrounding if and replace it with a single direct check: call
FLASHINFER_CHECK(it != launchers_map.end(), op_name, "...", tile_N,
"(selected_tile_count=", selected_tile_nums.size(), ")") after computing it, so
the error path is clear; reference symbols: get_launcher, launchers_map,
selected_tile_nums, tile_N, and FLASHINFER_CHECK.

---

Nitpick comments:
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 43-52: The cuda_profiler_range context manager currently always
calls cudaProfilerStart/Stop; make profiler control opt-in by adding a
command-line flag (e.g., --enable-profiler) and only invoke CUDA profiler/NVTX
when the flag is true: add a global or parsed argument (referenced from argparse
in this file) and update cuda_profiler_range to check that flag before calling
torch.cuda.cudart().cudaProfilerStart(), torch.cuda.nvtx.range_push(),
range_pop(), and cudaProfilerStop(); likewise wrap or gate the other
NVTX/profiler usages noted around lines 375-384, 483-492, and 610-619 to consult
the same flag so profiling becomes explicitly enabled rather than unconditional.

In `@include/flashinfer/exception.h`:
- Around line 115-128: The Log class should respect a FLASHINFER_LOGLEVEL
environment variable like the Python logger and the existing Warning class;
modify Log (constructor or emit method) to read FLASHINFER_LOGLEVEL (e.g., parse
into levels like ERROR/WARN/INFO/DEBUG), store or compare a numeric threshold,
and only write to std::cerr when the message level meets the configured
threshold; ensure the same parsing/semantics as Warning so both Log and Warning
behave consistently.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 42ea45b7-3c89-4025-bf3e-5475757042bc

📥 Commits

Reviewing files that changed from the base of the PR and between 043bc43 and 08ee7b9.

📒 Files selected for processing (7)
  • benchmarks/bench_moe_deepseek.py
  • csrc/trtllm_batched_gemm_runner.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu
  • include/flashinfer/exception.h
  • include/flashinfer/trtllm/batched_gemm/KernelRunner.h
  • include/flashinfer/trtllm/fused_moe/runner.h

Comment on lines +99 to +109
def calc_tflops(local_tokens, ms):
"""Calculate TFLOPS using actual routed token count.

With EP, only tokens routed to local experts are computed.
Assumes uniform routing distribution across experts.
FC1: [M, H] x [H, 2I]
FC2: [M, I] x [I, H]
FLOPs = 2 * local_tokens * (H*2I + I*H) = 6 * local_tokens * H * I
"""
if num_local_experts is None:
num_local_experts = CFG.num_experts
H = CFG.hidden_size
I = CFG.intermediate_size
flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
return flops / (ms * 1e-3) / 1e12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard throughput computations against non-positive latency.

If ms <= 0, both helpers can raise/diverge during division. Add a small guard to keep benchmarking output resilient.

Proposed fix
 def calc_tflops(local_tokens, ms):
@@
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
-    flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
+    if ms <= 0:
+        return float("nan")
+    H = CFG.hidden_size
+    I = CFG.intermediate_size
+    flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
     return flops / (ms * 1e-3) / 1e12
 
 
 def calc_bw(local_tokens, active_experts, ms):
@@
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
+    if ms <= 0:
+        return float("nan")
+    H = CFG.hidden_size
+    I = CFG.intermediate_size

Also applies to: 112-133

🧰 Tools
🪛 Ruff (0.15.5)

[error] 107-107: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 99 - 109, The calc_tflops
function (and the other throughput helper in the same file that computes
TFLOPS/throughput) must guard against non-positive latency: clamp ms to a small
positive epsilon (e.g. ms = max(ms, 1e-6)) or return 0.0 immediately when ms <=
0 to avoid divide-by-zero or negative throughput, and apply the same fix to the
corresponding helper at lines 112-133 so both functions consistently handle ms
<= 0.

Comment on lines +106 to +108
H = CFG.hidden_size
I = CFG.intermediate_size
flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename single-letter I to avoid Ruff E741 lint errors.

I is flagged as ambiguous (E741). Rename to a descriptive identifier (e.g., intermediate_size) in both helpers.

Proposed fix
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
-    flops = local_tokens * (2 * H * 2 * I + 2 * I * H)
+    H = CFG.hidden_size
+    intermediate_size = CFG.intermediate_size
+    flops = local_tokens * (
+        2 * H * 2 * intermediate_size + 2 * intermediate_size * H
+    )
@@
-    H = CFG.hidden_size
-    I = CFG.intermediate_size
+    H = CFG.hidden_size
+    intermediate_size = CFG.intermediate_size
 
-    weight_bytes = active_experts * (H * 2 * I + I * H) * NVFP4_BYTES
+    weight_bytes = (
+        active_experts * (H * 2 * intermediate_size + intermediate_size * H) * NVFP4_BYTES
+    )

Also applies to: 119-123

🧰 Tools
🪛 Ruff (0.15.5)

[error] 107-107: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 106 - 108, Rename the
ambiguous single-letter variable I to a descriptive name (e.g.,
intermediate_size) wherever it's used: replace I with intermediate_size in the
block that assigns H = CFG.hidden_size and I = CFG.intermediate_size and in the
subsequent FLOPS computation flops = local_tokens * (2 * H * 2 * I + 2 * I * H),
and also update the same rename in the later helper usages around the second
occurrence (lines 119-123) so all references (intermediate_size, H,
local_tokens, flops) remain consistent.

Comment on lines +279 to +286
if (getBoolEnv("TRTLLM_BATCHED_GEMM_PRINT_NAME")) {
FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCgasInBatchDim", maxNumCtasInBatchDim,
", MaxNumCtasInBatchDim", maxNumCtasInBatchDim, ", ShapeMNK",
gemmData.mProblemDimensions.mM, gemmData.mProblemDimensions.mN,
gemmData.mProblemDimensions.mK, ", ValidShapeMNK",
gemmData.mProblemDimensions.mValidM, gemmData.mProblemDimensions.mValidN,
gemmData.mProblemDimensions.mValidK, ", Kernel", config.mFunctionName);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate field in debug logging.

Line 280-281 logs MaxNumCgasInBatchDim followed by MaxNumCtasInBatchDim with the same value (maxNumCtasInBatchDim). This appears to be a copy-paste error.

🔧 Proposed fix
   if (getBoolEnv("TRTLLM_BATCHED_GEMM_PRINT_NAME")) {
-    FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCgasInBatchDim", maxNumCtasInBatchDim,
-                   ", MaxNumCtasInBatchDim", maxNumCtasInBatchDim, ", ShapeMNK",
+    FLASHINFER_LOG("NumBatches", numBatches, ", MaxNumCtasInBatchDim", maxNumCtasInBatchDim,
+                   ", ShapeMNK",
                    gemmData.mProblemDimensions.mM, gemmData.mProblemDimensions.mN,
                    gemmData.mProblemDimensions.mK, ", ValidShapeMNK",
                    gemmData.mProblemDimensions.mValidM, gemmData.mProblemDimensions.mValidN,
                    gemmData.mProblemDimensions.mValidK, ", Kernel", config.mFunctionName);
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_batched_gemm_runner.cu` around lines 279 - 286, The debug log
call to FLASHINFER_LOG prints two fields "MaxNumCgasInBatchDim" and
"MaxNumCtasInBatchDim" but passes maxNumCtasInBatchDim for both; update the
arguments so the "MaxNumCgasInBatchDim" label is paired with the correct
variable (maxNumCgasInBatchDim) and "MaxNumCtasInBatchDim" remains paired with
maxNumCtasInBatchDim in the FLASHINFER_LOG call.

Comment on lines +299 to +301
", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex,
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim,
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate field in error message.

maxNumCtasInBatchDim is printed twice in the error message.

🔧 Proposed fix
   FLASHINFER_CHECK(err == 0,
                    "Error occurred when running GEMM!"
                    " (numBatches: ",
                    numBatches, ", GemmMNK: ", m, " ", n, " ", k, ", Kernel: ", config.mFunctionName,
-                   ", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex,
-                   ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim,
-                   ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");
+                   ", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex,
+                   ", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex,
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim,
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");
", transposeMmaOutput: ", transposeMmaOutput, ", configIndex: ", configIndex,
", maxNumCtasInBatchDim: ", maxNumCtasInBatchDim, ")");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_batched_gemm_runner.cu` around lines 299 - 301, The error message
concatenation prints maxNumCtasInBatchDim twice; update the string built around
transposeMmaOutput, configIndex, and maxNumCtasInBatchDim so the duplicate is
removed and the intended field is logged (either remove the second ",
maxNumCtasInBatchDim: " segment or replace it with the correct variable name if
another field was meant); look for the concatenation that references
transposeMmaOutput, configIndex, and maxNumCtasInBatchDim in
trtllm_batched_gemm_runner.cu and correct the duplicated token accordingly.

Comment on lines +109 to +128
template <typename LauncherType>
// Select a launcher
LauncherType& get_launcher(
std::unordered_map<int32_t, std::unique_ptr<LauncherType>>& launchers_map,
std::set<int32_t> const& selected_tile_nums, int64_t& tile_N, int64_t& config,
char const* op_name) {
FLASHINFER_CHECK(!selected_tile_nums.empty(), op_name, ": no available tile_N candidates");

if (tile_N == -1) {
tile_N = *selected_tile_nums.begin();
}

auto it = launchers_map.find(static_cast<int32_t>(tile_N));
if (it == launchers_map.end()) {
FLASHINFER_CHECK(it != launchers_map.end(), op_name, ": failed to select launcher for tile_N ",
tile_N, " (selected_tile_count=", selected_tile_nums.size(), ")");
}

return *(it->second);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Redundant condition check in get_launcher.

The FLASHINFER_CHECK on Line 123-124 is inside an if (it == launchers_map.end()) block, so the check it != launchers_map.end() will always fail. The outer if makes the inner check redundant.

🔧 Proposed fix
   auto it = launchers_map.find(static_cast<int32_t>(tile_N));
-  if (it == launchers_map.end()) {
-    FLASHINFER_CHECK(it != launchers_map.end(), op_name, ": failed to select launcher for tile_N ",
-                     tile_N, " (selected_tile_count=", selected_tile_nums.size(), ")");
-  }
+  FLASHINFER_CHECK(it != launchers_map.end(), op_name, ": failed to select launcher for tile_N ",
+                   tile_N, " (selected_tile_count=", selected_tile_nums.size(), ")");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <typename LauncherType>
// Select a launcher
LauncherType& get_launcher(
std::unordered_map<int32_t, std::unique_ptr<LauncherType>>& launchers_map,
std::set<int32_t> const& selected_tile_nums, int64_t& tile_N, int64_t& config,
char const* op_name) {
FLASHINFER_CHECK(!selected_tile_nums.empty(), op_name, ": no available tile_N candidates");
if (tile_N == -1) {
tile_N = *selected_tile_nums.begin();
}
auto it = launchers_map.find(static_cast<int32_t>(tile_N));
if (it == launchers_map.end()) {
FLASHINFER_CHECK(it != launchers_map.end(), op_name, ": failed to select launcher for tile_N ",
tile_N, " (selected_tile_count=", selected_tile_nums.size(), ")");
}
return *(it->second);
}
template <typename LauncherType>
// Select a launcher
LauncherType& get_launcher(
std::unordered_map<int32_t, std::unique_ptr<LauncherType>>& launchers_map,
std::set<int32_t> const& selected_tile_nums, int64_t& tile_N, int64_t& config,
char const* op_name) {
FLASHINFER_CHECK(!selected_tile_nums.empty(), op_name, ": no available tile_N candidates");
if (tile_N == -1) {
tile_N = *selected_tile_nums.begin();
}
auto it = launchers_map.find(static_cast<int32_t>(tile_N));
FLASHINFER_CHECK(it != launchers_map.end(), op_name, ": failed to select launcher for tile_N ",
tile_N, " (selected_tile_count=", selected_tile_nums.size(), ")");
return *(it->second);
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 109 - 128, In
get_launcher, the if (it == launchers_map.end()) block contains a redundant
FLASHINFER_CHECK(it != launchers_map.end(), ...); remove the surrounding if and
replace it with a single direct check: call FLASHINFER_CHECK(it !=
launchers_map.end(), op_name, "...", tile_N, "(selected_tile_count=",
selected_tile_nums.size(), ")") after computing it, so the error path is clear;
reference symbols: get_launcher, launchers_map, selected_tile_nums, tile_N, and
FLASHINFER_CHECK.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants