feat: preparing TRTLLM MoE backend to support more kernels#2741
feat: preparing TRTLLM MoE backend to support more kernels#2741rosenrodt wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds cluster-size (CGA/CTA) awareness across fused MoE routing and batched-GEMM code, exposes per-config cluster-size accessors, updates routing and launcher flows to propagate clusterSize, and enhances benchmarking with CUDA profiling, bandwidth (TB/s) metrics, and reporting. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Client/Launcher
participant Launcher as FusedMoeLauncher
participant Routing as Routing::Runner
participant MoE as MoE::Runner
participant GEMM as BatchedGemmRunner
participant Device as CUDA Kernel
Client->>Launcher: select tile/config (may fallback)
Launcher->>Routing: instantiate/prepare_routing(clusterSize)
Routing->>MoE: provide routingData (mClusterSizeInBatchDim, ... )
MoE->>GEMM: choose gemm configs, ensure cluster-size consistency
GEMM->>Device: launch kernels with cluster-aware dims
Device-->>Client: return results/metrics
Note over Client,Launcher: Benchmark wraps operations with cuda_profiler_range
Estimated code review effort🎯 4 (Complex) | ⏱️ ~70 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and robustness of the TRTLLM MoE backend. The primary focus is on adapting the system to support a wider variety of kernels by making the routing logic cluster-aware and ensuring that the internal GEMM operations correctly handle these new dimensional considerations. This involves a fundamental shift in how routing information is processed, moving from a cluster-level view to a more granular CTA-level view for internal kernel execution. Additionally, the changes include improvements to configuration selection for batched GEMM, making it more adaptive and less reliant on hardcoded assumptions, and adding better debugging capabilities. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces CUDA profiling capabilities to MoE benchmarks and refactors the TrtllmGenBatchedGemmRunner in C++ to better handle problem dimensions and cluster sizes, including the introduction of helper functions getClusterSizeInBatchDim and setProblemDimensions. The routing kernels for MoE were updated to correctly account for mClusterSizeInBatchDim in CTA calculations and token-space padded offsets. Several potential integer overflow vulnerabilities were identified in calculations involving numTokens * topK, maxCgas * padding, and numCta * mClusterSizeInBatchDim, which could lead to incorrect buffer sizes or out-of-bounds memory access. Additionally, a code duplication issue was noted regarding the getClusterSizeInBatchDim logic.
| int32_t maxCtas = getMaxNumCtasInBatchDim(numTokens, expertsPerToken, numExperts, padding); | ||
| return maxCtas * padding; | ||
| int32_t maxCgas = getMaxNumCgasInBatchDim(numTokens, expertsPerToken, numExperts, padding); | ||
| return maxCgas * padding; |
There was a problem hiding this comment.
Potential integer overflow in maxCgas * padding. If the product exceeds the range of int32_t, it will wrap around, potentially resulting in a negative or truncated value. This value is used for memory allocation in csrc/trtllm_fused_moe_kernel_launcher.cu, which could lead to a buffer that is too small, causing out-of-bounds writes in the routing kernels.
| return maxCgas * padding; | |
| return static_cast<int64_t>(maxCgas) * padding; |
There was a problem hiding this comment.
I will ignore int overflow warnings as they predate the change in PR
| static inline int32_t getClusterSizeInBatchDim(BatchedGemmOptions const& options) { | ||
| return options.mTransposeMmaOutput ? options.mClusterDimY : options.mClusterDimX; | ||
| } |
There was a problem hiding this comment.
This function getClusterSizeInBatchDim has the same logic as getCgaSizeInBatchDim introduced in include/flashinfer/trtllm/fused_moe/runner.h. To avoid code duplication and improve maintainability, consider creating a single utility function in a common header. The version in runner.h is more generic as it doesn't depend on BatchedGemmOptions. You could move that function to a shared location, and then call it here like getClusterSizeInBatchDim(options.mTransposeMmaOutput, options.mClusterDimX, options.mClusterDimY).
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (3)
include/flashinfer/exception.h (2)
115-128: TheLogclass duplicatesWarningclass structure.The new
Logclass is nearly identical to the existingWarningclass—both have the same structure with only the message prefix differing ("Log" vs "Warning"). If more such classes are anticipated, consider extracting a common base or template to reduce duplication.Additionally,
Log::emit()writes unconditionally to stderr. Based on learnings, the project supports log level filtering viaFLASHINFER_LOGLEVELenvironment variable. The current implementation bypasses this mechanism.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/exception.h` around lines 115 - 128, The new Log class duplicates Warning; refactor by extracting a common base (e.g., class BaseLog or template LogBase) that takes the prefix ("Log"/"Warning") and constructs message_ so both Log and Warning reuse the formatting logic, and replace duplicated constructors; also change Log::emit (and Warning::emit) to honor the FLASHINFER_LOGLEVEL environment variable (read FLASHINFER_LOGLEVEL and skip emitting if current level is lower than the message level) instead of unconditionally writing to stderr so logging follows the project's log-level filtering.
73-82: Consider using or integrating with the existing logging infrastructure.The new
FLASHINFER_LOGmacro writes directly to stderr viaLog::emit(), but the project already has spdlog-based logging ininclude/flashinfer/logging.hwith level-filtered macros (FLASHINFER_LOG_INFO,FLASHINFER_LOG_DEBUG, etc.). This creates two parallel logging systems:
FLASHINFER_LOG(...)→ unconditional stderr outputFLASHINFER_LOG_INFO(...)→ spdlog with level filtering viaFLASHINFER_LOGLEVELThe naming similarity (
FLASHINFER_LOGvsFLASHINFER_LOG_*) may cause confusion. Consider either:
- Renaming this macro (e.g.,
FLASHINFER_LOG_STDERR) to clarify its behavior- Integrating with spdlog to respect log level configuration
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/exception.h` around lines 73 - 82, The new macro FLASHINFER_LOG currently bypasses the project's spdlog-based logging (e.g., FLASHINFER_LOG_INFO / FLASHINFER_LOG_DEBUG) by always calling flashinfer::Log(...).emit(), which creates two parallel logging systems and confusing names; either rename FLASHINFER_LOG to something explicit (e.g., FLASHINFER_LOG_STDERR) or change its implementation to forward to the existing logging infrastructure in include/flashinfer/logging.h (map to the appropriate spdlog macro or use the same level filtering), ensuring you update references to FLASHINFER_LOG and keep Log::emit() usage consistent or removed so that log level configuration via FLASHINFER_LOGLEVEL is respected.include/flashinfer/trtllm/batched_gemm/KernelRunner.h (1)
128-129: Header declares method but parameter names are inconsistent with implementation.The implementation in
csrc/trtllm_batched_gemm_runner.cuusesmaxNumCgasInBatchDimfor several methods, but this header still declares parameters asmaxNumCtasInBatchDim(e.g., lines 93, 133). While the parameter names don't affect binary compatibility, this inconsistency can cause confusion for API consumers.Consider updating the header parameter names to
maxNumCgasInBatchDimto match the implementation and the CGA terminology used throughout this PR.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/batched_gemm/KernelRunner.h` around lines 128 - 129, Change the inconsistent parameter name(s) in the KernelRunner.h declarations from maxNumCtasInBatchDim to maxNumCgasInBatchDim so they match the implementation in csrc/trtllm_batched_gemm_runner.cu; specifically update the signature for getConfigClusterSizeInBatchDim and any other methods in KernelRunner that currently declare maxNumCtasInBatchDim to use maxNumCgasInBatchDim to keep the API terminology consistent with the CGA naming used in the implementation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 95-96: The bandwidth calc in calc_bw currently assumes NVFP4
activations (using NVFP4_BYTES) which underestimates traffic when hidden_states
can be BF16 or other dtypes per trtllm_fp4_block_scale_moe; update calc_bw to
accept an input dtype parameter (e.g., input_dtype or act_bytes) or a flag and
use the corresponding byte-size constant (BF16_BYTES, NVFP4_BYTES, etc.) when
computing local_tokens * H * <bytes>, and update any callers (or default
behavior) so FC1 bandwidth uses the selected dtype; alternatively, add a clear
docstring on calc_bw and top-level comment stating the NVFP4-only assumption if
you choose not to change behavior.
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 117-120: The current branch unconditionally overwrites a
caller-specified tile_N when config == -1; change the logic so we only pick a
default tile when the caller did not specify one. Replace the combined check
with a conditional that sets tile_N = *selected_tile_nums.begin() only when
tile_N == -1 (leave tile_N untouched if caller provided it) and do not overwrite
config; update the code around instantiate_moe_runner()/selected_tile_nums to
use this new conditional so explicit {tile_N, -1} stays {tile_N, -1}.
- Around line 1102-1104: The code instantiates the MoE runner before setting the
DeepSeek mode, causing the generic runner to be chosen incorrectly; change the
call order so prepare_routing(...) (which sets args->mUseDeepSeekFp8 / DeepSeek
mode) runs before instantiate_moe_runner(...), then call
moe_runner->getConfigClusterSizeInBatchDim(moe_tactic) afterwards so tactic
validation and workspace sizing match getValidConfigs() and the actual DeepSeek
runtime path.
In `@csrc/trtllm_fused_moe_runner.cu`:
- Around line 584-591: The function Runner::getConfigClusterSizeInBatchDim
currently returns max(gemm1ClusterSize, gemm2ClusterSize) which allows mixed
GEMM cluster sizes; instead, detect mismatches and fail or filter earlier: add a
guard in getConfigClusterSizeInBatchDim that retrieves gemm1ClusterSize via
mPermuteGemm1.getConfigClusterSizeInBatchDim(config.gemm1Config) and
gemm2ClusterSize via mGemm2.getConfigClusterSizeInBatchDim(config.gemm2Config),
and if gemm1ClusterSize != gemm2ClusterSize either assert/throw a clear error
referencing the config index and mPassingConfigs or remove such entries from
mPassingConfigs when building it so only same-cluster pairs remain; ensure the
error message names Runner::getConfigClusterSizeInBatchDim, mPassingConfigs,
gemm1ClusterSize and gemm2ClusterSize to make the guard obvious.
---
Nitpick comments:
In `@include/flashinfer/exception.h`:
- Around line 115-128: The new Log class duplicates Warning; refactor by
extracting a common base (e.g., class BaseLog or template LogBase) that takes
the prefix ("Log"/"Warning") and constructs message_ so both Log and Warning
reuse the formatting logic, and replace duplicated constructors; also change
Log::emit (and Warning::emit) to honor the FLASHINFER_LOGLEVEL environment
variable (read FLASHINFER_LOGLEVEL and skip emitting if current level is lower
than the message level) instead of unconditionally writing to stderr so logging
follows the project's log-level filtering.
- Around line 73-82: The new macro FLASHINFER_LOG currently bypasses the
project's spdlog-based logging (e.g., FLASHINFER_LOG_INFO /
FLASHINFER_LOG_DEBUG) by always calling flashinfer::Log(...).emit(), which
creates two parallel logging systems and confusing names; either rename
FLASHINFER_LOG to something explicit (e.g., FLASHINFER_LOG_STDERR) or change its
implementation to forward to the existing logging infrastructure in
include/flashinfer/logging.h (map to the appropriate spdlog macro or use the
same level filtering), ensuring you update references to FLASHINFER_LOG and keep
Log::emit() usage consistent or removed so that log level configuration via
FLASHINFER_LOGLEVEL is respected.
In `@include/flashinfer/trtllm/batched_gemm/KernelRunner.h`:
- Around line 128-129: Change the inconsistent parameter name(s) in the
KernelRunner.h declarations from maxNumCtasInBatchDim to maxNumCgasInBatchDim so
they match the implementation in csrc/trtllm_batched_gemm_runner.cu;
specifically update the signature for getConfigClusterSizeInBatchDim and any
other methods in KernelRunner that currently declare maxNumCtasInBatchDim to use
maxNumCgasInBatchDim to keep the API terminology consistent with the CGA naming
used in the implementation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e809d3fc-fd22-473e-8761-061ba59d0163
📒 Files selected for processing (12)
benchmarks/bench_moe_deepseek.pycsrc/trtllm_batched_gemm_runner.cucsrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_routing_deepseek.cucsrc/trtllm_fused_moe_routing_llama4.cucsrc/trtllm_fused_moe_routing_renormalize.cucsrc/trtllm_fused_moe_runner.cuinclude/flashinfer/exception.hinclude/flashinfer/trtllm/batched_gemm/KernelRunner.hinclude/flashinfer/trtllm/fused_moe/RoutingKernel.cuhinclude/flashinfer/trtllm/fused_moe/RoutingKernel.hinclude/flashinfer/trtllm/fused_moe/runner.h
| NVFP4_BYTES = 9 / 16 # 0.5 bytes value + 1/16 byte block scale | ||
| BF16_BYTES = 2 |
There was a problem hiding this comment.
Bandwidth calculation assumes NVFP4 activations, but BF16 is also supported.
The calc_bw function assumes FC1 input is NVFP4 (local_tokens * H * NVFP4_BYTES), but according to trtllm_fp4_block_scale_moe documentation, hidden_states can be bfloat16, mxfp8, or nvfp4. For BF16 inputs, the bandwidth calculation would underestimate actual memory traffic.
Consider either:
- Adding a parameter to specify input dtype
- Documenting that this calculation assumes NVFP4 inputs
Also applies to: 112-132
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_moe_deepseek.py` around lines 95 - 96, The bandwidth calc in
calc_bw currently assumes NVFP4 activations (using NVFP4_BYTES) which
underestimates traffic when hidden_states can be BF16 or other dtypes per
trtllm_fp4_block_scale_moe; update calc_bw to accept an input dtype parameter
(e.g., input_dtype or act_bytes) or a flag and use the corresponding byte-size
constant (BF16_BYTES, NVFP4_BYTES, etc.) when computing local_tokens * H *
<bytes>, and update any callers (or default behavior) so FC1 bandwidth uses the
selected dtype; alternatively, add a clear docstring on calc_bw and top-level
comment stating the NVFP4-only assumption if you choose not to change behavior.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1399-1438: Consider extracting shared logic to avoid code duplication.
FP4BlockScaleLauncher::prepare_routingduplicates most ofprepare_routing_commoninstead of calling it. This risks divergence if the common logic is updated. The FP4-specific parts (e.g., direct assignment toworkspace.routing_expert_indexesandworkspace.expert_weights) could be applied after calling the base method.♻️ Suggested refactor
void prepare_routing(int32_t clusterSize) override { + FusedMoeLauncher::prepare_routing_common(clusterSize); + + // FP4-specific: use pre-computed indices and weights directly + workspace.routing_expert_indexes = + static_cast<int*>(const_cast<void*>(expert_indices.data_ptr())); + workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr()); + args->mDtypeElt = mDtypeAct; auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - // ... rest of duplicated code removed }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1399 - 1438, FP4BlockScaleLauncher::prepare_routing duplicates the logic in prepare_routing_common; refactor by replacing the duplicated initialization with a call to prepare_routing_common(...) (or the base implementation) and then apply only the FP4-specific adjustments: set workspace.routing_expert_indexes and workspace.expert_weights (using expert_indices/expert_weights), set any FP4-only workspace fields such as workspace.permuted_idx_size if needed, and preserve setting args->mDtypeElt and mRoutingBiasDtype; ensure prepare_routing calls prepare_routing_common before performing these FP4-specific assignments so shared logic isn’t duplicated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1399-1438: FP4BlockScaleLauncher::prepare_routing duplicates the
logic in prepare_routing_common; refactor by replacing the duplicated
initialization with a call to prepare_routing_common(...) (or the base
implementation) and then apply only the FP4-specific adjustments: set
workspace.routing_expert_indexes and workspace.expert_weights (using
expert_indices/expert_weights), set any FP4-only workspace fields such as
workspace.permuted_idx_size if needed, and preserve setting args->mDtypeElt and
mRoutingBiasDtype; ensure prepare_routing calls prepare_routing_common before
performing these FP4-specific assignments so shared logic isn’t duplicated.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d2169147-5208-4d6f-b351-b9fe05156544
📒 Files selected for processing (2)
csrc/trtllm_fused_moe_kernel_launcher.cucsrc/trtllm_fused_moe_runner.cu
|
I will hold off this PR for the moment because I am evaluating revising the cubins. So we do not need to revise the routing kernel. |
📌 Description
cubinwill be prepared separately. With or without locally updated cubinspython -m pytest -v tests/moe/test_trtllm_gen_fused_moe.py -k NvFP4passesctaIdxXyToBatchIdx,ctaIdxXyToMnLimit,numNonExitingCtasare expected to be of CTA granularity. The fix is to expand the routing info to CTA granularity so routing kernels must take CGA size into account.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements