-
Notifications
You must be signed in to change notification settings - Fork 573
update trtllm cutlass moe #2020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
β¦to feature/cutlass_moe_3xfp4
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThreads swizzled_input_sf, unpadded_hidden_size, router_scales, permuted_row_to_unpermuted_row, swap_ab and finalize-fusion flags through MOE/CUTLASS flows; adds SM90 scatter epilogue visitor; extends tile/cluster enums and SM100/SM120 candidate generation; renames many kernel namespaces to cutlass_kernels_oss; adds explicit template instantiations and launcher/signature updates. Changes
Sequence Diagram(s)sequenceDiagram
participant App
participant Runner as CutlassMoeFCRunner
participant Heuristic
participant Profiler
participant Dispatcher
Note over App,Runner: runMoe(..., swizzled_input_sf, unpadded_hidden_size, router_scales, permuted_row_to_unpermuted_row, swap_ab)
App->>Runner: runMoe(...)
Runner->>Heuristic: getTactics(gemm_id, sm, supports_finalize_fusion)
Heuristic-->>Runner: candidate CutlassGemmConfig (may include FINALIZE, swap_ab, dynamic cluster shapes)
Runner->>Profiler: profile/select (uses unpadded_hidden_size, stage-specific tactic counts)
Profiler-->>Runner: selected gemm_config
Runner->>Dispatcher: dispatch(gemm_config, router_scales, permuted_row_to_unpermuted_row, swizzled_input_sf, swap_ab)
Dispatcher-->>Runner: launches kernel (TMA warp specialized / finalize fused / scatter epilogue)
Runner-->>App: results
Estimated code review effortπ― 4 (Complex) | β±οΈ ~60 minutes Areas to focus during review:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touchesβ Failed checks (2 warnings, 1 inconclusive)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nv-yunzheq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily focuses on enhancing the TensorRT-LLM (TRTLLM) CUTLASS Mixture-of-Experts (MoE) implementation, particularly for Hopper and Blackwell architectures. The main objective is to introduce a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
|
[CANCELING] Pipeline #37982790: canceled |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ Outside diff range comments (1)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1)
221-236: Fix OOB read: s_local_experts indexed with global expert id.In min-latency map build,
s_local_expertsis sized to local experts but is indexed by the globalexpert, causing OOB when expert β [start_expert, end_expert). Guard and subtractstart_expert.Apply:
- bool is_valid_expert = - smart_routing ? s_local_experts[expert] : (expert >= start_expert && expert < end_expert); + bool const expert_in_node = (expert >= start_expert && expert < end_expert); + bool is_valid_expert = smart_routing + ? (expert_in_node && s_local_experts[expert - start_expert]) + : expert_in_node;Also consider mirroring this guard wherever
s_store_experts[expert - start_expert]is used to avoid underflow whenexpert_in_node == false.
β»οΈ Duplicate comments (5)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (5)
1724-1734: Restore defensive check for invalid permutation indices (debug-only OK).
expanded_permuted_row = unpermuted_row_to_permuted_row[...]has no validity guard. If upstream builds ever leave sentinel values, this will read OOB fromexpanded_permuted_rows.- int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; + int64_t const expanded_permuted_row = unpermuted_row_to_permuted_row[expanded_original_row]; +#ifndef NDEBUG + if (expanded_permuted_row < 0) { continue; } +#endifAlternatively add an unconditional
if (expanded_permuted_row < 0) continue;if negative is a valid sentinel in production.
1031-1050: De-duplicate swizzled vs linear SF input handling.Simplify by computing layout once and a single call to
cvt_quant_get_sf_out_offset.- if (swizzled_input_sf) { - auto const sf_in = - cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, - NumThreadsPerSF>( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, - const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), - QuantizationSFLayout::SWIZZLED_128x4); - *sf_out = *sf_in; - } else { - auto const sf_in = - cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, - NumThreadsPerSF>( - std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols / VecSize, - const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), - QuantizationSFLayout::LINEAR); - *sf_out = *sf_in; - } + auto const layout = swizzled_input_sf ? QuantizationSFLayout::SWIZZLED_128x4 + : QuantizationSFLayout::LINEAR; + auto const sf_in = + cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF, + NumThreadsPerSF>( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, + const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf), + layout); + *sf_out = *sf_in;
3937-3955: Minβlatency path currently throws; add safe fallback or gate dispatch.
computeStridesTmaWarpSpecializedLowLatencyunconditionally throws, breaking callers (setupTmaWarpSpecializedInputsminβlatency branch).Options:
- Route minβlatency to the nonβLL
computeStridesTmaWarpSpecializedwith a temporaryexpert_first_token_offsetbuilt fromnum_active_experts_per/active_expert_global_ids, or- Gate all LL dispatch sites behind
TLLM_CHECK_WITH_INFO(!min_latency_mode)to avoid calling this until LL is reintroduced. Do you want a minimal fallback drafted?
1684-1700: Align checks to element width, not constant 4.Hardcoding
% 4can break for dtypes whereFINALIZE_ELEM_PER_THREAD != 4. Use the computed constant.- assert(padded_cols % 4 == 0); - assert(unpadded_cols % 4 == 0); - assert(unpadded_cols <= padded_cols); + assert(unpadded_cols <= padded_cols); + constexpr int64_t FINALIZE_ELEM_PER_THREAD = + 128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value); + assert(padded_cols % FINALIZE_ELEM_PER_THREAD == 0); + assert(unpadded_cols % FINALIZE_ELEM_PER_THREAD == 0);As per earlier feedback.
1761-1764: Same alignment issue + simplify loop bound.Mirror the
FINALIZE_ELEM_PER_THREAD-based asserts and iterate tonum_elems_in_orig_colto avoid per-iteration branch.- assert(padded_cols % 4 == 0); - assert(unpadded_cols % 4 == 0); - assert(unpadded_cols <= padded_cols); + assert(unpadded_cols <= padded_cols); + constexpr int64_t FINALIZE_ELEM_PER_THREAD = + 128 / std::min(sizeof_bits<OutputType>::value, sizeof_bits<GemmOutputType>::value); + assert(padded_cols % FINALIZE_ELEM_PER_THREAD == 0); + assert(unpadded_cols % FINALIZE_ELEM_PER_THREAD == 0); @@ - for (int elem_index = start_offset; elem_index < num_elems_in_padded_col; - elem_index += stride) { - if (elem_index >= num_elems_in_orig_col) continue; // Skip writing beyond original columns + for (int elem_index = start_offset; elem_index < num_elems_in_orig_col; + elem_index += stride) {As per earlier feedback.
Also applies to: 1799-1806
π§Ή Nitpick comments (2)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (2)
280-286: CUDA dynamic shared memory attr check: allow == max.Guard rejects
shared_size >= max_smem_per_block; typically equality is valid. Prefer>to avoid unnecessary fallback.- if (shared_size >= static_cast<size_t>(max_smem_per_block)) { + if (shared_size > static_cast<size_t>(max_smem_per_block)) {Also applies to: 606-620
197-217: Smart routing:active_expert_global_idssemantics.In the smart-routing branch, the stored id is
i(local expert index), while in the else-branch it isi + start_expert(global id). If consumers expect global ids in both modes (per comment), convert to global viai + start_expert. Otherwise, please add a comment clarifying that LL mode uses local ids.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (1)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh(56 hunks)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (3)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (3)
1886-1894: Incorrect reference in line citation.The concern about
padded_colsandunpadded_colsconsistency applies only to lines 1886-1894 and 1901-1909 (the actual kernel launches withinfinalizeMoeRoutingKernelLauncher). Lines 3850-3858 reference a different function call (Self::gemm2()) and should not be included. All threefinalizeMoeRoutingKernelLaunchercall sites (2925-2930, 3302-3308, 3309-3315) correctly passhidden_sizeandunpadded_hidden_sizewith consistent semantic mapping to the kernel parameters.Likely an incorrect or invalid review comment.
270-286: Incorrect line references in review comment.Lines 637-645 and 686-696 are kernel device code (globalExpertPrefixSumLargeKernel and globalExpertPrefixSumKernel function implementations), not cudaLaunchKernelEx call sites. The actual launches in
cutlass_fused_moe_kernels.cuhat lines 281, 617, 738, 743, and 801 all consistently setattrs[0].id = cudaLaunchAttributeProgrammaticStreamSerializationwith.programmaticStreamSerializationAllowed = enable_pdl.Likely an incorrect or invalid review comment.
4040-4051: Aliasing and memset size are already correctly handled.The code at lines 2811-2812 allocates both
fc1_fp4_act_scale_andfc2_fp4_act_scale_from the same workspace buffer key ("fp4_act_scale"), ensuring they point to identical memory. The workspace allocation (lines 2606-2612) usesstd::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size)to reserve space, and the memset at lines 4066-4069 uses the identicalstd::max(fc1_sf_offset, fc2_sf_offset)logic to compute the fill size. The assertion at line 4049 confirms the aliasing invariant. Since both allocation and memset employ the same max-size calculation, the memset is guaranteed to fit within the workspace allocation.
|
[CANCELING] Pipeline #37985869: canceled |
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failed UT on gb300 is not relevant, LGTM on my side.
...v_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp
Show resolved
Hide resolved
|
[FAILED] Pipeline #37989907: 12/17 passed |
|
There are still some remaining cu126 compilation issues such as: Likely because we didn't add guard on the usage of We will retire cu126 at some point, but not now (considering cu126 is still one of the three supported cuda version of pytorch). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
99-102: FP4 guards insufficient for CUDA 12.6 compatibilityThe guard only checks
ENABLE_FP4, but__nv_fp4_e2m1requires CUDA 12.8+. The CI failure on cu126 (reported in PR objectives) confirms this: the identifier is undefined because CUDA 12.6 doesn't provide it. Same issue exists at lines 249-253, 742-746, and 755-759.Apply guards that also check CUDA version:
-#if defined(ENABLE_FP4) +#if defined(ENABLE_FP4) && CUDA_VERSION >= 12080 cutlass::platform::is_same<WeightType, __nv_fp4_e2m1>::value || #endifRepeat for all FP4 type references at lines 249-253, 742-746, and 755-759.
β»οΈ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
672-676: Fix zero-argument call to supportsTmaWarpSpecializedThis duplicates a past review concern:
isTmaWarpSpecializedcallssupportsTmaWarpSpecialized()without arguments on line 675, but the signature at lines 679-688 now requires anint smparameter. The same issue occurs at line 920 incalcMaxWorkspaceSize.Apply this diff to forward the member's
sm_:- return supportsTmaWarpSpecialized() && config_is_tma_warp_specialized; + return supportsTmaWarpSpecialized(sm_) && config_is_tma_warp_specialized;Also fix line 920:
- if (!supportsTmaWarpSpecialized()) { + if (!supportsTmaWarpSpecialized(sm_)) {Alternatively, add a const wrapper in the class:
bool supportsTmaWarpSpecialized() const { return supportsTmaWarpSpecialized(sm_); }Based on learnings
π§Ή Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)
953-956: Consider extending FINALIZE fusion workspace calculation beyond SM90FINALIZE fusion workspace size is currently only calculated for SM90 (line 954). If other architectures (e.g., SM100+) support finalize fusion, they should also be included in this calculation to avoid underestimating workspace requirements.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h(14 hunks)
π§° Additional context used
𧬠Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (6)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (3)
tensorrt_llm(63-112)std(81-95)calcMaxWorkspaceSizeTmaWarpSpecialized(490-502)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
tensorrt_llm(19-34)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h (1)
tensorrt_llm(60-274)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h (9)
tensorrt_llm(33-150)kernels(34-149)cutlass(114-116)cutlass(120-122)cutlass(127-129)cutlass(132-134)cutlass(140-142)cutlass_kernels(35-148)__nv_fp8_e5m2(91-93)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp (2)
get_candidate_configs(638-689)get_candidate_configs(638-640)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h (1)
EpilogueScheduleType(197-433)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (4)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (4)
530-544: LGTM: Clean signature updates for finalize fusion supportThe addition of the
supports_finalize_fusionparameter to both the const member and staticgetConfigsmethods properly threads this capability flag through the config selection pipeline.
624-629: Verify SM103 FP4 config selection strategyThe code explicitly adds SM100 configs when running on SM103 with FP4. Ensure this cross-architecture config reuse is validated and doesn't cause performance regressions or compatibility issues.
631-666: Well-structured finalize fusion and swap_ab config expansionThe logic correctly:
- Duplicates configs and marks them with FINALIZE fusion type when supported (lines 631-640)
- Removes FINALIZE configs that lack epilogue SMEM (lines 642-650)
- Adds swap_ab variants for all configs (lines 653-659) with a defensive check
- Filters to swap_ab=true only for w4_groupwise mode (lines 661-666)
978-1007: Activation type dispatch looks correctThe switch statement appropriately handles the supported activation types (Relu, Gelu, Silu, Identity, Swiglu, Geglu) and throws for invalid types. Note that
Relu2from theActivationTypeenum is not handled, which appears intentional per the AI summary noting "Relu2 path removed (no longer supported)".
|
/bot run |
nvmbreughe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Perhaps just add the additional tests for DSR1 and autotuner we discussed.
| cute::make_shape(gemm_n, gemm_k, 1)); | ||
| } | ||
| if (layout_info.stride_c) { | ||
| // TODO Enable 1xN bias matrix as C |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we don't support batch size = 1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's just the bias tensor could not be 1xN
|
[FAILED] Pipeline #38037173: 14/17 passed |
|
Per discussion offline, this PR should be ready to merge, but there are some problem shapes not covered in the backend (and the CI), and we will follow up and adding more unittests with different problem shapes in future PRs. |
<!-- .github/pull_request_template.md --> ## π Description Patch sm103 for 3xfp4 moe generation ## π Related Issues Following up of #2020 #1925 ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ``` $ ls csrc/nv_internal/tensorrt_llm/cutlass_instantiations/103/gemm_grouped 100 103 80 $ pytest tests/moe/test_trtllm_cutlass_fused_moe.py 22 passed, 3 skipped, 1 warning in 771.89s (0:12:51) ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for Blackwell (SM103) GPU architecture in MOE (Mixture of Experts) operations with specialized CUTLASS-optimized modules. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description Patch sm103 for 3xfp4 moe generation ## π Related Issues Following up of flashinfer-ai#2020 flashinfer-ai#1925 ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ``` $ ls csrc/nv_internal/tensorrt_llm/cutlass_instantiations/103/gemm_grouped 100 103 80 $ pytest tests/moe/test_trtllm_cutlass_fused_moe.py 22 passed, 3 skipped, 1 warning in 771.89s (0:12:51) ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for Blackwell (SM103) GPU architecture in MOE (Mixture of Experts) operations with specialized CUTLASS-optimized modules. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Tests