Skip to content

Conversation

@tcherckez-nvidia
Copy link
Collaborator

@tcherckez-nvidia tcherckez-nvidia commented Dec 8, 2025

  • Add enums module with MLPStyle, ActivationFunction, WeightsFormat, and WeightsFusion enums
  • Refactor torch_moe, triton_moe, and trtllm_moe to use enum-based configuration instead of strings
  • Update parameter names from w1_weight/w2_weight/w3_weight to weights_1/weights_2/weights_3
  • Add WeightsFusion enum to support different weight ordering formats:
    • GATE_UP_DOWN: w1, w2, w3 stored separately
    • GATEUP_DOWN: [w1, w3] concatenated, w2 separate (Llama4 native format)
    • UPGATE_DOWN: [w3, w1] concatenated, w2 separate (TRT-LLM format)
  • Update sharding logic to handle weight fusion format conversion (GATEUP_DOWN -> UPGATE_DOWN)
  • Add comprehensive tests for MoE operations including BMM pattern matching
  • Add single-GPU and multi-GPU tests for BMM MoE fusion with reference validation
  • Improve type safety and maintainability of MoE code

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced MoE configuration with enum-based parameters for activation functions, MLP styles, weight formats, and fusion strategies
    • Extended support for stacked and per-expert weight formats with multiple fusion options
    • Improved validation and error handling for unsupported configurations
  • Bug Fixes

    • Added expert index masking for distributed routing scenarios
  • Tests

    • Comprehensive test coverage for MoE fusion patterns and multi-GPU configurations

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

- Add enums module with MLPStyle, ActivationFunction, WeightsFormat, and WeightsFusion enums
- Refactor torch_moe, triton_moe, and trtllm_moe to use enum-based configuration instead of strings
- Update parameter names from w1_weight/w2_weight/w3_weight to weights_1/weights_2/weights_3
- Add WeightsFusion enum to support different weight ordering formats:
  * GATE_UP_DOWN: w1, w2, w3 stored separately
  * GATEUP_DOWN: [w1, w3] concatenated, w2 separate (Llama4 native format)
  * UPGATE_DOWN: [w3, w1] concatenated, w2 separate (TRT-LLM format)
- Update sharding logic to handle weight fusion format conversion (GATEUP_DOWN -> UPGATE_DOWN)
- Add comprehensive tests for MoE operations including BMM pattern matching
- Add single-GPU and multi-GPU tests for BMM MoE fusion with reference validation
- Improve type safety and maintainability of MoE code

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
@tcherckez-nvidia tcherckez-nvidia requested a review from a team as a code owner December 8, 2025 16:00
@tcherckez-nvidia tcherckez-nvidia changed the title [#9717][feat]: Refactor MoE code to use enums for configuration [#9717][chore] Refactor MoE code to use enums for configuration Dec 8, 2025
@tcherckez-nvidia
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 8, 2025

📝 Walkthrough

Walkthrough

This PR refactors MoE configuration handling by introducing enum-based parameters (MLPStyle, ActivationFunction, WeightsFormat, WeightsFusion) in place of string-based parameters. API signatures are updated across torch_moe and quantized variants, with expanded support for multiple weight formats and fusion strategies. Validation logic and transform patterns are updated to handle the new configuration metadata and weight ordering semantics.

Changes

Cohort / File(s) Summary
Enum Definitions and Helpers
tensorrt_llm/_torch/auto_deploy/enums.py
New module defining MLPStyle, ActivationFunction, WeightsFormat, WeightsFusion enums and corresponding string-to-enum parsers (mlp_style_from_str, act_fn_from_str, weights_format_from_str, weights_fusion_from_str) with normalization and validation.
MoE Operator APIs
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Updated signatures for torch_moe, torch_moe_fake, torch_fused_moe, torch_quant_fp8_moe, torch_quant_nvfp4_moe to use weights_1/2/3 and enum-backed parameters (weights_format, weights_fusion, mlp_style, act_fn). Expanded logic to support per_expert and stacked weight formats, multiple MLP styles, and fusion strategies with comprehensive validation.
Triton MoE Handler
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
Introduced enum-based validation for mlp_style and act_fn; added expert masking for EP sharding to clamp selected_experts to valid range.
TRT-LLM MoE Handler
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
Updated _validate_mlp_style_and_act_fn to accept enum types; replaced string comparisons with enum-based validation and updated error messaging.
Model Patches
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Updated to use enum values (ActivationFunction.RELU2.value, MLPStyle.MLP.value) instead of string literals in MoE forward path.
Transform Fusion Logic
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
Extended BMM MoE pattern matching to detect and propagate weights_fusion metadata; updated _find_gate_up_bmm to return fusion type; added support for GATEUP_DOWN and UPGATE_DOWN weight ordering in Llama4 stacked paths; introduced weights_1/2/3 parameter extraction and fusion-aware weight stacking.
Sharding Utilities
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
Enhanced to determine weights_fusion format at runtime via weights_fusion_from_str; clarified gate-up weight swapping semantics between GATEUP_DOWN and UPGATE_DOWN formats; updated weight transformation logic.
Unit Tests (Single GPU)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
Added enum-based configuration parameters (weights_format, weights_fusion, mlp_style, act_fn) to MoE test invocations; introduced negative tests for string-to-enum conversions and comprehensive parameter-validation tests for torch_moe and TRT-LLM combinations.
Fusion Tests (Single GPU)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py
New test module introducing ReferenceMoeModel (ground-truth per-token routing) and BmmMoeModel (BMM pattern implementation); comprehensive tests comparing reference, unfused, and fused outputs across dtype variations with pattern matching validation.
Fusion Tests (Multi GPU)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py
New distributed test module with ReferenceMoeModel and BmmMoeModel for multi-GPU validation; includes test_bmm_moe_fusion_distributed that verifies BMM MoE fusion across multiple GPUs and dtypes.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

  • torch_moe.py requires careful review of the new weight format/fusion/mlp_style logic branches, parameter signature changes, and shape validation across all combinations (per_expert vs. stacked, multiple fusion orders, gated vs. non-gated MLPs).
  • fused_moe.py (transform library) has intricate weight fusion type detection and ordering logic that affects downstream transformations; the _find_gate_up_bmm return type change and fusion_type propagation paths need tracing.
  • Public API changes across multiple functions (torch_moe, torch_moe_fake, torch_quant_*_moe) impact compatibility; enum parameter defaults should be verified against expected behaviors.
  • Test coverage is comprehensive but the new distributed test setup and reference model logic should be validated for correctness.

Possibly related PRs

Suggested reviewers

  • galagam
  • nvchenghaoz
  • nzmora-nvidia

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description provides a bulleted list of changes but lacks narrative explanation of why this refactoring is necessary and what problems it solves. Add a brief narrative explanation of the rationale behind the refactoring, such as improved type safety, maintainability, or prevention of invalid configurations.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: refactoring MoE code to use enums for configuration, which aligns with all modifications described in the raw summary.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)

640-648: Signature mismatch in register_fake function.

The fake registration function is missing the mlp_style and act_fn parameters that are present in the real function signature (lines 611-613). This will cause tracing/export issues when PyTorch tries to match the fake signature.

 @triton_fused_moe.register_fake
 def triton_fused_moe(
     x: torch.Tensor,
     selected_experts: torch.Tensor,
     routing_weights: torch.Tensor,
     w1_stacked_weight: torch.Tensor,
     w2_stacked_weight: torch.Tensor,
+    mlp_style: str = "mlp",
+    act_fn: str = "relu2",
 ) -> torch.Tensor:
     return torch.empty_like(x)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py (1)

257-258: Type mismatch: passing strings to function expecting enums.

_validate_mlp_style_and_act_fn now expects MLPStyle and ActivationFunction enum types (see lines 103-115), but this call passes raw string parameters. This will cause a KeyError when the function tries to look up the string in the supported_combinations dict.

 @trtllm_quant_fp8_moe_fused.register_fake
 def trtllm_quant_fp8_moe_fused_fake(
     x: torch.Tensor,
     selected_experts: torch.Tensor,
     routing_weights: torch.Tensor,
     w1_weight: torch.Tensor,
     w2_weight: torch.Tensor,
     w3_weight: torch.Tensor,
     w1_input_scale: torch.Tensor,
     w2_input_scale: torch.Tensor,
     w3_input_scale: torch.Tensor,
     w1_weight_scale: torch.Tensor,
     w2_weight_scale: torch.Tensor,
     w3_weight_scale: torch.Tensor,
     gemm1_dequant: torch.Tensor,
     gemm2_act_quant: torch.Tensor,
     gemm2_dequant: torch.Tensor,
     mlp_style: str,
     act_fn: str,
 ) -> torch.Tensor:
-    _validate_mlp_style_and_act_fn(mlp_style, act_fn)
+    _validate_mlp_style_and_act_fn(mlp_style_from_str(mlp_style), act_fn_from_str(act_fn))
     return torch.empty_like(x)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

182-189: Function signature is missing weights_fusion_enum parameter.

The function is called with weights_fusion_enum at line 65, but the signature here doesn't include this parameter. This would cause a TypeError at runtime.

 def _process_llama4_stacked_moe_node(
     gm: GraphModule,
     graph: torch.fx.Graph,
     node: Node,
     replacement_op,
     act_fn_val: str,
     fused_key_counter: int,
+    weights_fusion_enum: WeightsFusion,
 ) -> None:

Also verify that the function body uses weights_fusion_enum appropriately for the weight ordering logic.

🧹 Nitpick comments (5)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (1)

399-424: Consider using raw strings for regex patterns.

The static analysis tool flagged that these match patterns contain regex metacharacters (*) but aren't raw strings. While the patterns work correctly, using raw strings (e.g., r"Unknown mlp_style.*invalid_style") is more explicit and prevents potential issues with escape sequences.

     def test_invalid_mlp_style(self):
         from tensorrt_llm._torch.auto_deploy.enums import mlp_style_from_str

-        with pytest.raises(ValueError, match="Unknown mlp_style.*invalid_style"):
+        with pytest.raises(ValueError, match=r"Unknown mlp_style.*invalid_style"):
             mlp_style_from_str("invalid_style")

     def test_invalid_activation_function(self):
         from tensorrt_llm._torch.auto_deploy.enums import act_fn_from_str

-        with pytest.raises(ValueError, match="Unknown act_fn.*invalid_act"):
+        with pytest.raises(ValueError, match=r"Unknown act_fn.*invalid_act"):
             act_fn_from_str("invalid_act")

     def test_invalid_weights_format(self):
         from tensorrt_llm._torch.auto_deploy.enums import weights_format_from_str

-        with pytest.raises(ValueError, match="Unknown weights_format.*invalid_format"):
+        with pytest.raises(ValueError, match=r"Unknown weights_format.*invalid_format"):
             weights_format_from_str("invalid_format")

     def test_invalid_weights_fusion(self):
         from tensorrt_llm._torch.auto_deploy.enums import weights_fusion_from_str

-        with pytest.raises(ValueError, match="Unknown weights_fusion.*invalid_fusion"):
+        with pytest.raises(ValueError, match=r"Unknown weights_fusion.*invalid_fusion"):
             weights_fusion_from_str("invalid_fusion")
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py (1)

17-241: Code duplication with single-GPU test file.

ReferenceMoeModel and BmmMoeModel classes are nearly identical to those in tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py. Consider extracting these to a shared test utilities module to reduce maintenance burden and ensure consistency.

This is a nice-to-have improvement that could be deferred to a follow-up PR.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py (1)

414-631: Consider reducing debug output verbosity or making it conditional.

This section contains ~220 lines of detailed graph tracing output. While useful for debugging pattern matching failures, it may clutter test output during normal CI runs.

Consider:

  1. Using logging.debug() instead of print() to allow log-level control
  2. Guarding with an environment variable like DEBUG_BMM_MOE_PATTERN
  3. Moving the tracing logic to a helper function that can be called only when debugging is needed

This is a nice-to-have improvement that can be addressed in a follow-up.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (1)

133-142: Add strict=True to zip() for safety.

If w1_list and w3_list have different lengths due to a bug, the current zip() would silently produce fewer items than expected, potentially causing subtle correctness issues.

             fused_w_up_experts = torch.stack(
                 [
                     torch.cat(
                         [gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)],
                         dim=-2,
                     )
-                    for w1_node, w3_node in zip(w1_list, w3_list)
+                    for w1_node, w3_node in zip(w1_list, w3_list, strict=True)
                 ],
                 dim=0,
             )
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (1)

506-509: weights_fusion parameter is unused in quantized MoE ops.

The weights_fusion parameter is added for API consistency but is not used in the implementation. Consider either:

  1. Adding a note in the docstring that it's ignored for quantized ops
  2. Raising an error if a non-default value is passed
+    # Note: weights_fusion is accepted for API consistency but currently ignored.
+    # Quantized MoE always uses GATE_UP_DOWN (separate w1/w2/w3 weights).
+    if weights_fusion != WeightsFusion.GATE_UP_DOWN.value:
+        raise ValueError(
+            f"Quantized MoE only supports weights_fusion='w1_w2_w3_separate'. Got: {weights_fusion}"
+        )
+
     # Convert string parameters to enums
     mlp_style_enum = mlp_style_from_str(mlp_style)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1c7b7cd and 15c47f4.

📒 Files selected for processing (10)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (20 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py (6 hunks)
  • tensorrt_llm/_torch/auto_deploy/enums.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (12 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., use from package.subpackage import foo and then foo.SomeClass() instead of from package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g., some_file.py)
Python class names should use PascalCase (e.g., class SomeClass)
Python function and method names should use snake_case (e.g., def my_awesome_function():)
Python local variable names should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile = ...)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g., self.x = 5 followed by """<type>: Description of 'x'""" )
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic

Files:

  • tensorrt_llm/_torch/auto_deploy/enums.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
**/*.{cpp,h,cu,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top

Files:

  • tensorrt_llm/_torch/auto_deploy/enums.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
🧠 Learnings (14)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
📚 Learning: 2025-08-09T02:04:49.623Z
Learnt from: Fridah-nv
Repo: NVIDIA/TensorRT-LLM PR: 6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.623Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/enums.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
📚 Learning: 2025-08-14T23:23:27.449Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
📚 Learning: 2025-10-20T17:07:18.745Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
📚 Learning: 2025-10-13T19:45:03.518Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: tests/unittest/_torch/multi_gpu/test_nccl_device.py:138-149
Timestamp: 2025-10-13T19:45:03.518Z
Learning: In test_nccl_device.py, the NCCL device AllReduce implementation compares the entire residual tensor on each rank, unlike the UB implementation which compares per-rank chunks. The residual chunking calculations in the test are intentionally overridden to reflect this design difference.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py
📚 Learning: 2025-08-21T02:39:12.009Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1475-1480
Timestamp: 2025-08-21T02:39:12.009Z
Learning: The min latency mode functionality in TensorRT-LLM MOE kernels (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu) is deprecated and no longer being maintained/updated, as confirmed by djns99. Bug reports and optimization suggestions for the computeStridesTmaWarpSpecializedLowLatencyKernel and related min latency code paths should be deprioritized.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
📚 Learning: 2025-09-19T21:28:13.751Z
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device allreduce implementation (cpp/tensorrt_llm/thop/allreduceOp.cpp), the goto pattern in runNCCLAllReduceDeviceFusion is intentionally used for future extensibility, allowing multiple switch cases to fallback to the default handler. While not aesthetically ideal, this pattern supports adding more fusion cases later that can reuse the same fallback logic.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
🧬 Code graph analysis (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (1)
tensorrt_llm/_torch/auto_deploy/enums.py (4)
  • ActivationFunction (15-19)
  • MLPStyle (8-12)
  • act_fn_from_str (52-59)
  • mlp_style_from_str (42-49)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/auto_deploy/enums.py (2)
  • WeightsFusion (29-39)
  • weights_fusion_from_str (72-79)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py (4)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (276-344)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (23-78)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (198-221)
tensorrt_llm/_torch/auto_deploy/distributed/common.py (1)
  • spawn_multiprocess_job (248-252)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (2)
tensorrt_llm/_torch/auto_deploy/enums.py (8)
  • ActivationFunction (15-19)
  • MLPStyle (8-12)
  • WeightsFormat (22-26)
  • WeightsFusion (29-39)
  • mlp_style_from_str (42-49)
  • act_fn_from_str (52-59)
  • weights_format_from_str (62-69)
  • weights_fusion_from_str (72-79)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py (1)
  • _validate_mlp_style_and_act_fn (103-115)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
tensorrt_llm/_torch/auto_deploy/enums.py (2)
  • ActivationFunction (15-19)
  • MLPStyle (8-12)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py (3)
cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h (2)
  • tensorrt_llm (19-37)
  • ActivationType (24-37)
tensorrt_llm/_torch/auto_deploy/enums.py (4)
  • ActivationFunction (15-19)
  • MLPStyle (8-12)
  • act_fn_from_str (52-59)
  • mlp_style_from_str (42-49)
tensorrt_llm/_torch/utils.py (1)
  • ActivationType (38-47)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (1)
tensorrt_llm/_torch/auto_deploy/enums.py (8)
  • ActivationFunction (15-19)
  • MLPStyle (8-12)
  • WeightsFormat (22-26)
  • WeightsFusion (29-39)
  • act_fn_from_str (52-59)
  • mlp_style_from_str (42-49)
  • weights_format_from_str (62-69)
  • weights_fusion_from_str (72-79)
🪛 Ruff (0.14.7)
tensorrt_llm/_torch/auto_deploy/enums.py

49-49: Avoid specifying long messages outside the exception class

(TRY003)


59-59: Avoid specifying long messages outside the exception class

(TRY003)


69-69: Avoid specifying long messages outside the exception class

(TRY003)


79-79: Avoid specifying long messages outside the exception class

(TRY003)

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py

405-405: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


411-411: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


417-417: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


423-423: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


459-459: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


485-485: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


515-515: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


591-591: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


616-616: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)


626-626: Pattern passed to match= contains metacharacters but is neither escaped nor raw

(RUF043)

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

56-56: Avoid specifying long messages outside the exception class

(TRY003)


62-62: Avoid specifying long messages outside the exception class

(TRY003)


64-64: Avoid specifying long messages outside the exception class

(TRY003)


112-115: Avoid specifying long messages outside the exception class

(TRY003)


208-208: Avoid specifying long messages outside the exception class

(TRY003)


215-215: Avoid specifying long messages outside the exception class

(TRY003)


217-217: Avoid specifying long messages outside the exception class

(TRY003)


301-301: Avoid specifying long messages outside the exception class

(TRY003)


306-306: Avoid specifying long messages outside the exception class

(TRY003)


308-308: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py

31-31: Avoid specifying long messages outside the exception class

(TRY003)


207-210: Avoid specifying long messages outside the exception class

(TRY003)


219-222: Avoid specifying long messages outside the exception class

(TRY003)


224-227: Avoid specifying long messages outside the exception class

(TRY003)


229-231: Avoid specifying long messages outside the exception class

(TRY003)


237-240: Avoid specifying long messages outside the exception class

(TRY003)


262-265: Avoid specifying long messages outside the exception class

(TRY003)


267-270: Avoid specifying long messages outside the exception class

(TRY003)


272-275: Avoid specifying long messages outside the exception class

(TRY003)


283-286: Avoid specifying long messages outside the exception class

(TRY003)


303-306: Avoid specifying long messages outside the exception class

(TRY003)


308-311: Avoid specifying long messages outside the exception class

(TRY003)


313-315: Avoid specifying long messages outside the exception class

(TRY003)


321-324: Avoid specifying long messages outside the exception class

(TRY003)


339-339: Avoid specifying long messages outside the exception class

(TRY003)


342-345: Avoid specifying long messages outside the exception class

(TRY003)


351-353: Avoid specifying long messages outside the exception class

(TRY003)


371-374: Avoid specifying long messages outside the exception class

(TRY003)


390-392: Avoid specifying long messages outside the exception class

(TRY003)


403-403: Avoid specifying long messages outside the exception class

(TRY003)


413-413: Unused function argument: weights_1

(ARG001)


414-414: Unused function argument: weights_2

(ARG001)


415-415: Unused function argument: weights_3

(ARG001)


416-416: Unused function argument: weights_format

(ARG001)


417-417: Unused function argument: weights_fusion

(ARG001)


418-418: Unused function argument: mlp_style

(ARG001)


419-419: Unused function argument: act_fn

(ARG001)


506-506: Unused function argument: weights_fusion

(ARG001)


604-604: Avoid specifying long messages outside the exception class

(TRY003)


623-623: Unused function argument: weights_fusion

(ARG001)


624-624: Unused function argument: mlp_style

(ARG001)


625-625: Unused function argument: act_fn

(ARG001)


647-647: Unused function argument: weights_fusion

(ARG001)


756-756: Avoid specifying long messages outside the exception class

(TRY003)


778-778: Unused function argument: weights_fusion

(ARG001)


779-779: Unused function argument: mlp_style

(ARG001)


780-780: Unused function argument: act_fn

(ARG001)

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

139-139: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


145-145: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (33)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (2)

12-12: LGTM on enum import.

The import correctly maintains namespace separation per coding guidelines.


145-154: Missing required parameters and inconsistent parameter naming.

Based on the test file changes, torch_moe now requires weights_format and weights_fusion parameters. Additionally, the PR objective mentions renaming w1_weight/w2_weight/w3_weight to weights_1/weights_2/weights_3, but this call still uses the old names.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

28-28: LGTM on enum imports.

Import correctly added for WeightsFusion handling.


1302-1302: Documentation accurately describes the weight format conversion.

The comments clearly explain the GATEUP_DOWN to UPGATE_DOWN conversion for TRT-LLM compatibility.

Also applies to: 1316-1316

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py (5)

8-13: LGTM on enum imports.

Correctly imports all new enum types needed for the test updates.


136-139: LGTM on enum-based parameter usage.

The test correctly passes enum .value strings for the new configuration parameters.


178-190: LGTM on BMM MoE test updates.

Good use of comments to clarify parameter semantics and proper enum usage with apply_routing_on_input.


427-604: Comprehensive negative test coverage for parameter validation.

The TestTorchMoeConfigValidation class thoroughly tests validation scenarios including fusion applicability, missing weights, expert count mismatches, and empty weights. Good test coverage for the new enum-based API.


606-627: LGTM on TRT-LLM enum validation tests.

Good coverage of unsupported mlp_style + act_fn combinations.

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (4)

17-22: LGTM on enum imports.

Correctly imports the required enum types and conversion functions for the refactored API.


615-620: LGTM on enum-based validation.

Clean pattern for converting string parameters to enums and asserting supported values. The assertions provide clear error messages for unsupported configurations.


628-631: Good defensive handling for EP sharding.

Clamping expert IDs to valid range prevents out-of-bounds access when expert IDs may be negative after EP sharding. The comment clearly explains the rationale.


676-684: LGTM on FP8 MoE enum validation.

Correctly converts string parameters to enums and raises NotImplementedError for unsupported combinations, which is appropriate for a currently limited implementation.

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py (5)

19-24: LGTM on enum imports.

Correctly imports the required enum types and conversion functions.


103-115: Well-structured validation helper.

The _validate_mlp_style_and_act_fn function provides clear, centralized validation with informative error messages. Good use of a mapping to define supported combinations.


45-64: LGTM on enum-based activation type resolution.

Clean pattern for converting string parameters to enums and mapping to the appropriate ActivationType for the TRT-LLM kernel.


165-170: LGTM on FP8 MOE enum validation.

Correctly converts and validates parameters using the centralized helper function.


291-308: LGTM on NVFP4 MOE enum handling.

Consistent pattern with the other MOE functions for enum conversion and activation type resolution.

tensorrt_llm/_torch/auto_deploy/enums.py (2)

8-40: LGTM - Well-structured enum definitions.

The enums are clearly documented with inline comments explaining each value's purpose and usage context (e.g., Llama4 native format vs TRT-LLM format). The hierarchical organization (MLPStyle → ActivationFunction → WeightsFormat → WeightsFusion) follows a logical structure.


42-79: Parser functions are correctly implemented.

The parser functions follow a consistent pattern with clear error messages including valid values, which aids debugging. The linear search is appropriate given the small number of enum values.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_bmm_moe_fusion.py (2)

243-301: Test implementation looks correct.

The distributed test properly validates that the BMM MoE pattern matching works across multiple GPUs. The comment on lines 296-297 appropriately notes that the fused graph may not execute correctly without additional sharding transforms, which clarifies the test's scope.


1-2: Missing NVIDIA copyright header.

Per the coding guidelines, this test file should contain an NVIDIA copyright header at the top.

⛔ Skipped due to learnings
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_bmm_moe_fusion.py (3)

726-734: Tolerances are appropriately set for numerical comparison.

The 5e-2 relative and absolute tolerance is reasonable for MoE computations with bfloat16/float16, where floating-point accumulation order differences can cause minor numerical variations.


821-826: Main guard is useful for debugging.

The if __name__ == "__main__" block allows direct execution during development, which is helpful for debugging pattern matching issues.


1-13: Missing NVIDIA copyright header.

Per the coding guidelines, this test file should contain an NVIDIA copyright header at the top.

⛔ Skipped due to learnings
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (4)

1-30: Imports and module structure are well-organized.

The enum imports from the new enums module follow the coding guidelines for maintaining namespace when importing. The import organization is clean and groups related items together.


208-226: Helper function correctly handles FX graph list representations.

The unwrap_list function properly handles the various ways lists can be represented in FX graphs, including list() call nodes, tuples, and direct lists. The fallback to empty list for falsy inputs is appropriate.


897-913: Fusion type detection logic is correct.

The logic correctly determines weight ordering based on which chunk index feeds into the silu activation:

  • chunk[0] → silu means [gate, up] = [w1, w3] (Llama4 native format)
  • chunk[1] → silu means [up, gate] = [w3, w1] (TRT-LLM format)

The fallback to None when detection fails is appropriately handled by the caller.


1303-1318: Fused MoE node construction correctly uses enum values.

The call to torch.ops.auto_deploy.torch_moe correctly passes enum .value properties as strings, which matches the expected signature and allows serialization in the FX graph.

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (4)

18-31: Activation resolution is correctly implemented.

The function cleanly maps enum values to their corresponding activation callables, with a clear error message for unsupported activations.


108-197: Excellent documentation of weight interpretation.

The comprehensive docstring clearly explains how weights_1, weights_2, and weights_3 are interpreted based on the weights_format, weights_fusion, and mlp_style parameters. The examples for different model architectures (Mixtral, Llama4, NemotronH) are particularly helpful.


213-257: Validation logic is comprehensive.

The validation correctly checks for:

  • List lengths matching expected format
  • Tensor dimensionality (ndim == 3 for stacked)
  • Expert count consistency across weight tensors

The detailed error messages include actual values received, which aids debugging.


790-806: Interleaved gate/up weight format is correctly documented.

The updated docstrings clearly indicate that gate_up_w and gate_up_b use interleaved format (even indices = gate, odd indices = up), which matches the slicing on line 806.

Comment on lines +1438 to 1449
weights_fusion_enum = weights_fusion_from_str(args[7])
# Transform gate_up_stacked: slice experts, swap [W1,W3]->[W3,W1] if GATEUP_DOWN, transpose (E,H,2I)->(E,2I,H)
# GATEUP_DOWN means [w1, w3] order -> swap to TRT-LLM [w3, w1]
# UPGATE_DOWN means [w3, w1] order -> already in TRT-LLM format, no swap needed
if isinstance(w3_w1_tensor_node, Node):
_transform_bmm_moe_weight_param(
gm, w3_w1_tensor_node, local_lo, local_hi, swap_gate_up=True
gm,
w3_w1_tensor_node,
local_lo,
local_hi,
swap_gate_up=weights_fusion_enum == WeightsFusion.GATEUP_DOWN,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Potential IndexError when accessing args[7].

The code assumes args[7] (weights_fusion) is always present. If the node's args have fewer than 8 elements, this will raise an IndexError. Consider adding bounds validation or a defensive check.

+    if len(args) <= 7:
+        raise ValueError(f"Expected at least 8 args for stacked MoE sharding, got {len(args)}")
     weights_fusion_enum = weights_fusion_from_str(args[7])

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27329 [ run ] triggered by Bot. Commit: 15c47f4

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27329 [ run ] completed with state SUCCESS. Commit: 15c47f4
/LLM/main/L0_MergeRequest_PR pipeline #20878 completed with status: 'FAILURE'

- Update MOE custom ops to use centralized enums
- Refactor fused_moe and quantize_moe transformations
- Update model patches (deepseek, mixtral, nemotron_h, qwen3)
- Update tests to use new enum structure
- Remove obsolete test file

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
@tcherckez-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27450 [ run ] triggered by Bot. Commit: 4c76238

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27450 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User tcherckez-nvidia is not in the prioritized list. L0 testing cannot be triggered.

Add check to ensure at least 8 arguments are provided for stacked MoE
sharding to prevent index errors when accessing args[7].

Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com>
@tcherckez-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27473 [ run ] triggered by Bot. Commit: bbcd61a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27473 [ run ] completed with state DISABLED
L0 testing is limited to prioritized users. User tcherckez-nvidia is not in the prioritized list. L0 testing cannot be triggered.

@lucaslie lucaslie requested a review from Fridah-nv December 9, 2025 16:13
@tcherckez-nvidia
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27701 [ run ] triggered by Bot. Commit: bbcd61a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27701 [ run ] completed with state SUCCESS. Commit: bbcd61a
/LLM/main/L0_MergeRequest_PR pipeline #21147 completed with status: 'FAILURE'

@tcherckez-nvidia
Copy link
Collaborator Author

@Fridah-nv I'm going to make major changes, you can skip the review for now

@tcherckez-nvidia
Copy link
Collaborator Author

Closing this because it became a too complex and heavy change.
Will divide into lighter commits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants