Skip to content

feat: preparing TRTLLM MoE backend to support more kernels#2741

Draft
rosenrodt wants to merge 3 commits intoflashinfer-ai:mainfrom
rosenrodt:feat/prepare-trtllm-moe-tput-support
Draft

feat: preparing TRTLLM MoE backend to support more kernels#2741
rosenrodt wants to merge 3 commits intoflashinfer-ai:mainfrom
rosenrodt:feat/prepare-trtllm-moe-tput-support

Conversation

@rosenrodt
Copy link
Contributor

@rosenrodt rosenrodt commented Mar 10, 2026

📌 Description

  • Preparing to support non-SwapAb kernels in TRTLLM MoE. cubin will be prepared separately. With or without locally updated cubins python -m pytest -v tests/moe/test_trtllm_gen_fused_moe.py -k NvFP4 passes
  • Revise routing info to take cluster size into account.
    • From MoE caller’s view, the tile size is in CGA granularity to make sure token-to-expert mapping is correct. But, inside MoE the routing info ctaIdxXyToBatchIdx, ctaIdxXyToMnLimit, numNonExitingCtas are expected to be of CTA granularity. The fix is to expand the routing info to CTA granularity so routing kernels must take CGA size into account.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • CUDA profiler integration for scoped kernel profiling and NVTX ranges.
    • Bandwidth (TB/s) reporting added alongside TFLOPS and latency.
    • Lightweight logging macro for runtime diagnostics.
  • Improvements

    • More accurate TFLOPS and bandwidth calculations using actual routed token counts.
    • Cluster-size–aware routing and tile sizing for improved MoE performance and config validation.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

This PR adds cluster-size (CGA/CTA) awareness across fused MoE routing and batched-GEMM code, exposes per-config cluster-size accessors, updates routing and launcher flows to propagate clusterSize, and enhances benchmarking with CUDA profiling, bandwidth (TB/s) metrics, and reporting.

Changes

Cohort / File(s) Summary
Benchmarking Infrastructure
benchmarks/bench_moe_deepseek.py
Added cuda_profiler_range context manager, NVFP4_BYTES/BF16_BYTES constants, refactored calc_tflops(local_tokens, ms), added calc_bw(local_tokens, active_experts, ms), extended BenchResult with bw_tb_s, updated histogram collection to return local_tokens, wrapped runners with CUDA profiling, and widened output to include TB/s columns.
Batched GEMM Runner
csrc/trtllm_batched_gemm_runner.cu, include/flashinfer/trtllm/batched_gemm/KernelRunner.h
Introduced cluster-aware tile sizing (CGA), renamed max-CTAs API params to max-CGAs, added dtypeA/B/C options, updated kernel config selection to account for cluster sizing, and added getConfigClusterSizeInBatchDim() accessor.
MoE Launcher & Orchestration
csrc/trtllm_fused_moe_kernel_launcher.cu, csrc/trtllm_fused_moe_runner.cu, include/flashinfer/trtllm/fused_moe/runner.h
Added get_launcher() helper and instantiate_moe_runner, made prepare_routing() cluster-aware (accepts clusterSize), propagated clusterSize through launcher/runner lifecycles, updated Runner constructors to accept clusterSize, and exposed per-config cluster-size query methods.
Routing Kernels & Logic
csrc/trtllm_fused_moe_routing_deepseek.cu, csrc/trtllm_fused_moe_routing_llama4.cu, csrc/trtllm_fused_moe_routing_renormalize.cu, include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh, include/flashinfer/trtllm/fused_moe/RoutingKernel.h
Converted CGA→CTA granularity by scaling counts with clusterSize, adjusted MnLimit, padded-offset and permuted-index calculations for pow2 and non-pow2 paths, and added cluster metadata fields (mClusterSizeInBatchDim, mClusterSizeLog2, mTileTokensDim) in routing structures.
Logging Infrastructure
include/flashinfer/exception.h
Added a flashinfer::Log class and FLASHINFER_LOG(...) macro for structured variadic logging.
Public Headers / API
include/flashinfer/trtllm/fused_moe/RoutingKernel.h, include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh, include/flashinfer/trtllm/batched_gemm/KernelRunner.h
Extended public data structures and KernelParamsBase to carry cluster-size fields and padding/token metadata; adjusted public APIs to expose cluster-size accessors and renamed some members to CGA terminology.

Sequence Diagram(s)

sequenceDiagram
  participant Client as Client/Launcher
  participant Launcher as FusedMoeLauncher
  participant Routing as Routing::Runner
  participant MoE as MoE::Runner
  participant GEMM as BatchedGemmRunner
  participant Device as CUDA Kernel

  Client->>Launcher: select tile/config (may fallback)
  Launcher->>Routing: instantiate/prepare_routing(clusterSize)
  Routing->>MoE: provide routingData (mClusterSizeInBatchDim, ... )
  MoE->>GEMM: choose gemm configs, ensure cluster-size consistency
  GEMM->>Device: launch kernels with cluster-aware dims
  Device-->>Client: return results/metrics
  Note over Client,Launcher: Benchmark wraps operations with cuda_profiler_range
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~70 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • joker-eph
  • cyx-6
  • djmmoss
  • yzh119
  • wenscarl
  • bkryu
  • nvmbreughe
  • jiahanc

Poem

🐇 I hopped through code with cluster-sized cheer,
CGAs aligned and benchmarks now clear,
TB/s and TFLOPS counted from routed tokens,
Launchers and runners carry the tokens' tokens,
A tiny rabbit clap—deploy without fear!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.69% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main purpose: preparing TRTLLM MoE backend to support additional kernels, which aligns with the primary objective of the changeset.
Description check ✅ Passed The description adequately covers what the PR does (preparing non-SwapAb kernel support, revising routing info for cluster size), explains the technical rationale (CGA vs CTA granularity), and notes test validation. Most template sections are acknowledged or marked complete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility and robustness of the TRTLLM MoE backend. The primary focus is on adapting the system to support a wider variety of kernels by making the routing logic cluster-aware and ensuring that the internal GEMM operations correctly handle these new dimensional considerations. This involves a fundamental shift in how routing information is processed, moving from a cluster-level view to a more granular CTA-level view for internal kernel execution. Additionally, the changes include improvements to configuration selection for batched GEMM, making it more adaptive and less reliant on hardcoded assumptions, and adding better debugging capabilities.

Highlights

  • TRTLLM MoE Backend Preparation: Prepared the TRTLLM Mixture-of-Experts (MoE) backend to support a broader range of kernels, specifically non-SwapAb kernels, by refining the underlying architecture and routing mechanisms.
  • Cluster Size Integration in Routing: Revised the routing information to explicitly account for cluster sizes. This involved expanding routing data from CGA (Cooperative Group Array) granularity to CTA (Cooperative Thread Array) granularity, ensuring correct token-to-expert mapping within the MoE kernels.
  • Refined Batched GEMM Configuration Matching: Improved the configuration matching logic for batched GEMM operations, making it more robust by separating data type and layout checks, and dynamically determining transposeMmaOutput rather than hardcoding it.
  • Enhanced Debugging and Profiling: Introduced a new FLASHINFER_LOG macro for structured logging and integrated CUDA profiler ranges into MoE benchmarks for more detailed performance analysis.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_moe_deepseek.py
    • Introduced cuda_profiler_range context manager for NVTX profiling.
    • Applied cuda_profiler_range to various bench_gpu_time calls for detailed profiling.
  • csrc/trtllm_batched_gemm_runner.cu
    • Introduced getClusterSizeInBatchDim and setProblemDimensions helper functions.
    • Refined config matching logic in the constructor, separating dtype and layout checks.
    • Added getConfigClusterSizeInBatchDim method.
    • Updated getWorkspaceSizeInBytes, run, getValidConfigIndices, getDefaultValidConfigIndex, and isValidConfigIndex to use new helper functions and cluster size.
    • Enhanced error messages and added logging for kernel execution details.
  • csrc/trtllm_fused_moe_kernel_launcher.cu
    • Added get_launcher template function for robust launcher selection and fallback.
    • Modified prepare_routing_common to accept clusterSize and adjust CTA buffer sizing.
    • Removed ProjUpTileN from workspace.
    • Refactored prepare_moe_common into instantiate_moe_runner and updated its usage.
    • Integrated clusterSize into prepare_routing and Routing::Runner initialization.
    • Simplified launcher selection in various trtllm_moe functions using the new get_launcher helper.
  • csrc/trtllm_fused_moe_routing_deepseek.cu
    • Adjusted numCta calculation to incorporate mClusterSizeInBatchDim.
    • Modified mnLimit and padded offset calculations to align with CTA-level granularity considering cluster size.
  • csrc/trtllm_fused_moe_routing_llama4.cu
    • Adjusted numCta calculation to incorporate mClusterSizeInBatchDim.
    • Modified mnLimit and padded offset calculations to align with CTA-level granularity considering cluster size.
  • csrc/trtllm_fused_moe_routing_renormalize.cu
    • Adjusted numCta calculation to incorporate mClusterSizeInBatchDim.
    • Modified mnLimit and padded offset calculations to align with CTA-level granularity considering cluster size.
  • csrc/trtllm_fused_moe_runner.cu
    • Included <algorithm> header.
    • Updated Routing::Runner constructor and data structures to include clusterSizeInBatchDim.
    • Removed hardcoded transposeMmaOutput and associated dtype swaps in GEMM options.
    • Replaced maxNumCtasInBatchDim with maxNumCgasInBatchDim in various runner methods.
    • Added getConfigClusterSizeInBatchDim methods to PermuteGemm1, Gemm2, and MoE runners.
  • include/flashinfer/exception.h
    • Introduced FLASHINFER_LOG macro for consistent logging.
    • Added Log class to handle structured log messages.
  • include/flashinfer/trtllm/batched_gemm/KernelRunner.h
    • Added descriptive comments for dtypeA, dtypeB, dtypeC, and transposeMmaOutput in TrtllmGenBatchedGemmRunnerOptions.
    • Declared getConfigClusterSizeInBatchDim in TrtllmGenBatchedGemmRunner.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
    • Adjusted numCta calculation to incorporate mClusterSizeInBatchDim.
    • Modified mnLimit and padded offset calculations to align with CTA-level granularity considering cluster size.
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
    • Added mClusterSizeInBatchDim and mClusterSizeLog2 to DataBase and KernelParamsBase.
    • Provided clarifying comments for routing kernel parameters.
  • include/flashinfer/trtllm/fused_moe/runner.h
    • Renamed getMaxNumCtasInBatchDim to getMaxNumCgasInBatchDim and updated its implementation.
    • Added getCgaSizeInBatchDim helper function.
    • Updated getMaxPermutedPaddedCount to use CGA-based calculations.
    • Modified Routing::Runner constructor to accept clusterSizeInBatchDim.
    • Added mClusterSizeInBatchDim member to Routing::Runner.
    • Declared getConfigClusterSizeInBatchDim in PermuteGemm1, Gemm2, and MoE runners.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces CUDA profiling capabilities to MoE benchmarks and refactors the TrtllmGenBatchedGemmRunner in C++ to better handle problem dimensions and cluster sizes, including the introduction of helper functions getClusterSizeInBatchDim and setProblemDimensions. The routing kernels for MoE were updated to correctly account for mClusterSizeInBatchDim in CTA calculations and token-space padded offsets. Several potential integer overflow vulnerabilities were identified in calculations involving numTokens * topK, maxCgas * padding, and numCta * mClusterSizeInBatchDim, which could lead to incorrect buffer sizes or out-of-bounds memory access. Additionally, a code duplication issue was noted regarding the getClusterSizeInBatchDim logic.

int32_t maxCtas = getMaxNumCtasInBatchDim(numTokens, expertsPerToken, numExperts, padding);
return maxCtas * padding;
int32_t maxCgas = getMaxNumCgasInBatchDim(numTokens, expertsPerToken, numExperts, padding);
return maxCgas * padding;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Potential integer overflow in maxCgas * padding. If the product exceeds the range of int32_t, it will wrap around, potentially resulting in a negative or truncated value. This value is used for memory allocation in csrc/trtllm_fused_moe_kernel_launcher.cu, which could lead to a buffer that is too small, causing out-of-bounds writes in the routing kernels.

Suggested change
return maxCgas * padding;
return static_cast<int64_t>(maxCgas) * padding;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will ignore int overflow warnings as they predate the change in PR

Comment on lines +39 to +41
static inline int32_t getClusterSizeInBatchDim(BatchedGemmOptions const& options) {
return options.mTransposeMmaOutput ? options.mClusterDimY : options.mClusterDimX;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function getClusterSizeInBatchDim has the same logic as getCgaSizeInBatchDim introduced in include/flashinfer/trtllm/fused_moe/runner.h. To avoid code duplication and improve maintainability, consider creating a single utility function in a common header. The version in runner.h is more generic as it doesn't depend on BatchedGemmOptions. You could move that function to a shared location, and then call it here like getClusterSizeInBatchDim(options.mTransposeMmaOutput, options.mClusterDimX, options.mClusterDimY).

@rosenrodt rosenrodt marked this pull request as ready for review March 10, 2026 13:47
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
include/flashinfer/exception.h (2)

115-128: The Log class duplicates Warning class structure.

The new Log class is nearly identical to the existing Warning class—both have the same structure with only the message prefix differing ("Log" vs "Warning"). If more such classes are anticipated, consider extracting a common base or template to reduce duplication.

Additionally, Log::emit() writes unconditionally to stderr. Based on learnings, the project supports log level filtering via FLASHINFER_LOGLEVEL environment variable. The current implementation bypasses this mechanism.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/exception.h` around lines 115 - 128, The new Log class
duplicates Warning; refactor by extracting a common base (e.g., class BaseLog or
template LogBase) that takes the prefix ("Log"/"Warning") and constructs
message_ so both Log and Warning reuse the formatting logic, and replace
duplicated constructors; also change Log::emit (and Warning::emit) to honor the
FLASHINFER_LOGLEVEL environment variable (read FLASHINFER_LOGLEVEL and skip
emitting if current level is lower than the message level) instead of
unconditionally writing to stderr so logging follows the project's log-level
filtering.

73-82: Consider using or integrating with the existing logging infrastructure.

The new FLASHINFER_LOG macro writes directly to stderr via Log::emit(), but the project already has spdlog-based logging in include/flashinfer/logging.h with level-filtered macros (FLASHINFER_LOG_INFO, FLASHINFER_LOG_DEBUG, etc.). This creates two parallel logging systems:

  1. FLASHINFER_LOG(...) → unconditional stderr output
  2. FLASHINFER_LOG_INFO(...) → spdlog with level filtering via FLASHINFER_LOGLEVEL

The naming similarity (FLASHINFER_LOG vs FLASHINFER_LOG_*) may cause confusion. Consider either:

  • Renaming this macro (e.g., FLASHINFER_LOG_STDERR) to clarify its behavior
  • Integrating with spdlog to respect log level configuration
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/exception.h` around lines 73 - 82, The new macro
FLASHINFER_LOG currently bypasses the project's spdlog-based logging (e.g.,
FLASHINFER_LOG_INFO / FLASHINFER_LOG_DEBUG) by always calling
flashinfer::Log(...).emit(), which creates two parallel logging systems and
confusing names; either rename FLASHINFER_LOG to something explicit (e.g.,
FLASHINFER_LOG_STDERR) or change its implementation to forward to the existing
logging infrastructure in include/flashinfer/logging.h (map to the appropriate
spdlog macro or use the same level filtering), ensuring you update references to
FLASHINFER_LOG and keep Log::emit() usage consistent or removed so that log
level configuration via FLASHINFER_LOGLEVEL is respected.
include/flashinfer/trtllm/batched_gemm/KernelRunner.h (1)

128-129: Header declares method but parameter names are inconsistent with implementation.

The implementation in csrc/trtllm_batched_gemm_runner.cu uses maxNumCgasInBatchDim for several methods, but this header still declares parameters as maxNumCtasInBatchDim (e.g., lines 93, 133). While the parameter names don't affect binary compatibility, this inconsistency can cause confusion for API consumers.

Consider updating the header parameter names to maxNumCgasInBatchDim to match the implementation and the CGA terminology used throughout this PR.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/batched_gemm/KernelRunner.h` around lines 128 -
129, Change the inconsistent parameter name(s) in the KernelRunner.h
declarations from maxNumCtasInBatchDim to maxNumCgasInBatchDim so they match the
implementation in csrc/trtllm_batched_gemm_runner.cu; specifically update the
signature for getConfigClusterSizeInBatchDim and any other methods in
KernelRunner that currently declare maxNumCtasInBatchDim to use
maxNumCgasInBatchDim to keep the API terminology consistent with the CGA naming
used in the implementation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 95-96: The bandwidth calc in calc_bw currently assumes NVFP4
activations (using NVFP4_BYTES) which underestimates traffic when hidden_states
can be BF16 or other dtypes per trtllm_fp4_block_scale_moe; update calc_bw to
accept an input dtype parameter (e.g., input_dtype or act_bytes) or a flag and
use the corresponding byte-size constant (BF16_BYTES, NVFP4_BYTES, etc.) when
computing local_tokens * H * <bytes>, and update any callers (or default
behavior) so FC1 bandwidth uses the selected dtype; alternatively, add a clear
docstring on calc_bw and top-level comment stating the NVFP4-only assumption if
you choose not to change behavior.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 117-120: The current branch unconditionally overwrites a
caller-specified tile_N when config == -1; change the logic so we only pick a
default tile when the caller did not specify one. Replace the combined check
with a conditional that sets tile_N = *selected_tile_nums.begin() only when
tile_N == -1 (leave tile_N untouched if caller provided it) and do not overwrite
config; update the code around instantiate_moe_runner()/selected_tile_nums to
use this new conditional so explicit {tile_N, -1} stays {tile_N, -1}.
- Around line 1102-1104: The code instantiates the MoE runner before setting the
DeepSeek mode, causing the generic runner to be chosen incorrectly; change the
call order so prepare_routing(...) (which sets args->mUseDeepSeekFp8 / DeepSeek
mode) runs before instantiate_moe_runner(...), then call
moe_runner->getConfigClusterSizeInBatchDim(moe_tactic) afterwards so tactic
validation and workspace sizing match getValidConfigs() and the actual DeepSeek
runtime path.

In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 584-591: The function Runner::getConfigClusterSizeInBatchDim
currently returns max(gemm1ClusterSize, gemm2ClusterSize) which allows mixed
GEMM cluster sizes; instead, detect mismatches and fail or filter earlier: add a
guard in getConfigClusterSizeInBatchDim that retrieves gemm1ClusterSize via
mPermuteGemm1.getConfigClusterSizeInBatchDim(config.gemm1Config) and
gemm2ClusterSize via mGemm2.getConfigClusterSizeInBatchDim(config.gemm2Config),
and if gemm1ClusterSize != gemm2ClusterSize either assert/throw a clear error
referencing the config index and mPassingConfigs or remove such entries from
mPassingConfigs when building it so only same-cluster pairs remain; ensure the
error message names Runner::getConfigClusterSizeInBatchDim, mPassingConfigs,
gemm1ClusterSize and gemm2ClusterSize to make the guard obvious.

---

Nitpick comments:
In `@include/flashinfer/exception.h`:
- Around line 115-128: The new Log class duplicates Warning; refactor by
extracting a common base (e.g., class BaseLog or template LogBase) that takes
the prefix ("Log"/"Warning") and constructs message_ so both Log and Warning
reuse the formatting logic, and replace duplicated constructors; also change
Log::emit (and Warning::emit) to honor the FLASHINFER_LOGLEVEL environment
variable (read FLASHINFER_LOGLEVEL and skip emitting if current level is lower
than the message level) instead of unconditionally writing to stderr so logging
follows the project's log-level filtering.
- Around line 73-82: The new macro FLASHINFER_LOG currently bypasses the
project's spdlog-based logging (e.g., FLASHINFER_LOG_INFO /
FLASHINFER_LOG_DEBUG) by always calling flashinfer::Log(...).emit(), which
creates two parallel logging systems and confusing names; either rename
FLASHINFER_LOG to something explicit (e.g., FLASHINFER_LOG_STDERR) or change its
implementation to forward to the existing logging infrastructure in
include/flashinfer/logging.h (map to the appropriate spdlog macro or use the
same level filtering), ensuring you update references to FLASHINFER_LOG and keep
Log::emit() usage consistent or removed so that log level configuration via
FLASHINFER_LOGLEVEL is respected.

In `@include/flashinfer/trtllm/batched_gemm/KernelRunner.h`:
- Around line 128-129: Change the inconsistent parameter name(s) in the
KernelRunner.h declarations from maxNumCtasInBatchDim to maxNumCgasInBatchDim so
they match the implementation in csrc/trtllm_batched_gemm_runner.cu;
specifically update the signature for getConfigClusterSizeInBatchDim and any
other methods in KernelRunner that currently declare maxNumCtasInBatchDim to use
maxNumCgasInBatchDim to keep the API terminology consistent with the CGA naming
used in the implementation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e809d3fc-fd22-473e-8761-061ba59d0163

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and c3c9217.

📒 Files selected for processing (12)
  • benchmarks/bench_moe_deepseek.py
  • csrc/trtllm_batched_gemm_runner.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_routing_deepseek.cu
  • csrc/trtllm_fused_moe_routing_llama4.cu
  • csrc/trtllm_fused_moe_routing_renormalize.cu
  • csrc/trtllm_fused_moe_runner.cu
  • include/flashinfer/exception.h
  • include/flashinfer/trtllm/batched_gemm/KernelRunner.h
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh
  • include/flashinfer/trtllm/fused_moe/RoutingKernel.h
  • include/flashinfer/trtllm/fused_moe/runner.h

Comment on lines +95 to +96
NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale
BF16_BYTES = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Bandwidth calculation assumes NVFP4 activations, but BF16 is also supported.

The calc_bw function assumes FC1 input is NVFP4 (local_tokens * H * NVFP4_BYTES), but according to trtllm_fp4_block_scale_moe documentation, hidden_states can be bfloat16, mxfp8, or nvfp4. For BF16 inputs, the bandwidth calculation would underestimate actual memory traffic.

Consider either:

  1. Adding a parameter to specify input dtype
  2. Documenting that this calculation assumes NVFP4 inputs

Also applies to: 112-132

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_moe_deepseek.py` around lines 95 - 96, The bandwidth calc in
calc_bw currently assumes NVFP4 activations (using NVFP4_BYTES) which
underestimates traffic when hidden_states can be BF16 or other dtypes per
trtllm_fp4_block_scale_moe; update calc_bw to accept an input dtype parameter
(e.g., input_dtype or act_bytes) or a flag and use the corresponding byte-size
constant (BF16_BYTES, NVFP4_BYTES, etc.) when computing local_tokens * H *
<bytes>, and update any callers (or default behavior) so FC1 bandwidth uses the
selected dtype; alternatively, add a clear docstring on calc_bw and top-level
comment stating the NVFP4-only assumption if you choose not to change behavior.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)

1399-1438: Consider extracting shared logic to avoid code duplication.

FP4BlockScaleLauncher::prepare_routing duplicates most of prepare_routing_common instead of calling it. This risks divergence if the common logic is updated. The FP4-specific parts (e.g., direct assignment to workspace.routing_expert_indexes and workspace.expert_weights) could be applied after calling the base method.

♻️ Suggested refactor
 void prepare_routing(int32_t clusterSize) override {
+  FusedMoeLauncher::prepare_routing_common(clusterSize);
+
+  // FP4-specific: use pre-computed indices and weights directly
+  workspace.routing_expert_indexes =
+      static_cast<int*>(const_cast<void*>(expert_indices.data_ptr()));
+  workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
+
   args->mDtypeElt = mDtypeAct;
   auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
   mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
-  // ... rest of duplicated code removed
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1399 - 1438,
FP4BlockScaleLauncher::prepare_routing duplicates the logic in
prepare_routing_common; refactor by replacing the duplicated initialization with
a call to prepare_routing_common(...) (or the base implementation) and then
apply only the FP4-specific adjustments: set workspace.routing_expert_indexes
and workspace.expert_weights (using expert_indices/expert_weights), set any
FP4-only workspace fields such as workspace.permuted_idx_size if needed, and
preserve setting args->mDtypeElt and mRoutingBiasDtype; ensure prepare_routing
calls prepare_routing_common before performing these FP4-specific assignments so
shared logic isn’t duplicated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1399-1438: FP4BlockScaleLauncher::prepare_routing duplicates the
logic in prepare_routing_common; refactor by replacing the duplicated
initialization with a call to prepare_routing_common(...) (or the base
implementation) and then apply only the FP4-specific adjustments: set
workspace.routing_expert_indexes and workspace.expert_weights (using
expert_indices/expert_weights), set any FP4-only workspace fields such as
workspace.permuted_idx_size if needed, and preserve setting args->mDtypeElt and
mRoutingBiasDtype; ensure prepare_routing calls prepare_routing_common before
performing these FP4-specific assignments so shared logic isn’t duplicated.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d2169147-5208-4d6f-b351-b9fe05156544

📥 Commits

Reviewing files that changed from the base of the PR and between c3c9217 and 50750da.

📒 Files selected for processing (2)
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu

@rosenrodt rosenrodt marked this pull request as draft March 12, 2026 09:46
@rosenrodt
Copy link
Contributor Author

I will hold off this PR for the moment because I am evaluating revising the cubins. So we do not need to revise the routing kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants