Skip to content

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Oct 23, 2025

Summary by CodeRabbit

Release Notes

  • New Features

    • Added FP8 quantization support for Mixture of Experts (MoE) operations with optimized kernel implementations
    • Introduced configurable MLP style and activation function selection for MoE layers to enhance model flexibility
  • Tests

    • Added comprehensive test suite for MoE model validation and accuracy verification

Signed-off-by: Chenghao Zhang <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 23, 2025

📝 Walkthrough

Walkthrough

This PR introduces FP8 quantized Mixture of Experts (MoE) support to TensorRT LLM's autodeploy system. Changes include enabling a new fuse_fp8_moe transform, adding FP8-specific MoE operators for both Torch and Triton backends, implementing per-expert weight stacking and kernel selection logic, and adding validation tests comparing Triton and Torch FP8 MoE implementations.

Changes

Cohort / File(s) Summary
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Enables new fuse_fp8_moe transform with stage post_load_fusion and enabled: true.
Torch FP8 MoE Operators
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Adds mlp_style and act_fn parameters to torch_quant_fp8_moe and torch_quant_fp8_moe_fake functions. Implements branching logic for "gated_mlp" and "mlp" styles with per-expert weight construction and activation function resolution.
Triton FP8 MoE Kernels
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
Adds FP8 kernel fused_mlp_moe_kernel_w8a8, helper _get_compute_type, internal _fused_moe_mlp_relu2 function, and public operator triton_quant_fp8_moe. Updates _invoke_kernel to support FP8 scales and route selection between quantized and unquantized paths.
FP8 MoE Transform
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
Adds _stack_fp8_moe_weights function and FuseFP8Moe transform class for stacking per-expert weights and scales, replacing original FP8 MoE nodes with calls to triton_quant_fp8_moe.
Quantization Logic
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
Removes quantized_moe_op_map global mapping. Introduces runtime extraction of mlp_style and act_fn parameters from original MoE nodes and propagates them as kwargs to quantized operators.
Integration Tests
tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Removes conditional skip on NemotronH.test_auto_dtype. Adds new TestNemotronMOE test class with methods get_default_kwargs, get_default_sampling_params, and test_auto_dtype.
Unit Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py
Adds test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe function to validate Triton FP8 MoE against Torch reference with FP8 weight/input scales and tolerance checks.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant Transform as FuseFP8Moe Transform
    participant WeightStack as _stack_fp8_moe_weights
    participant TritonOp as triton_quant_fp8_moe
    participant Kernel as FP8 Kernel

    User->>Transform: Process graph
    Transform->>WeightStack: Extract FP8 MoE nodes
    WeightStack->>WeightStack: Stack per-expert w1, w2, w3 weights
    WeightStack->>WeightStack: Stack input/weight scales
    WeightStack->>Transform: Register stacked parameters
    Transform->>TritonOp: Replace node with triton_quant_fp8_moe call
    
    User->>TritonOp: Forward pass
    TritonOp->>Kernel: Invoke FP8 kernel with stacked weights & scales
    Kernel->>Kernel: FP8 load & scale<br/>Per-block routing<br/>Scaled accumulate
    Kernel-->>TritonOp: Output tensor
    TritonOp-->>User: Return result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

The changes span multiple interconnected components with heterogeneous logic patterns: new FP8 kernel implementations with quantization-specific handling, per-expert weight stacking logic, runtime parameter extraction and propagation, and mlp_style branching logic. While individual files follow repetitive patterns (e.g., parameter handling), the kernel-level implementations and control-flow interactions require careful verification of numerical correctness, memory layout assumptions, and proper scale handling across quantized pathways.

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description only contains "@coderabbitai summary" and is essentially empty. According to the template, the description must include a Description section explaining the issue and solution, a Test Coverage section listing relevant tests, and a PR Checklist with validation items. The current description completely lacks these required sections and provides no substantive information about what changes were made, why they were necessary, or what tests have been added to validate the implementation. Please fill out the PR description following the provided template. Add a Description section explaining what FP8 MOE support was added for Nemotron and why this change is needed, a Test Coverage section documenting the relevant tests (such as test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe and TestNemotronMOE), and complete the PR Checklist items to confirm the changes follow coding guidelines, have appropriate test coverage, and meet other quality standards.
Docstring Coverage ⚠️ Warning Docstring coverage is 38.71% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "[None][feat] AutoDeploy: Add FP8 MOE for Nemotron" clearly and specifically describes the main change in the changeset. The title follows the required format with a ticket placeholder [None] and type [feat], and it accurately summarizes the primary objective of the PR, which is to add FP8 Mixture of Experts support for the Nemotron model in the AutoDeploy framework. The title is concise, readable, and provides sufficient context for a teammate scanning the history to understand what this PR addresses.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1)

1-1: Add NVIDIA 2025 Apache-2.0 header.

This Python source is missing the required header. Please prepend the standard 2025 NVIDIA SPDX header.

+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ # SPDX-License-Identifier: Apache-2.0
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (1)

1-1: Add NVIDIA 2025 Apache-2.0 header.

Please prepend the required header.

+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ # SPDX-License-Identifier: Apache-2.0
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py (1)

1-3: Add NVIDIA 2025 Apache-2.0 header to test file.

Tests are Python sources; add the header for compliance.

+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ # SPDX-License-Identifier: Apache-2.0
🧹 Nitpick comments (5)
tests/integration/defs/accuracy/test_llm_api_autodeploy.py (2)

154-177: Align config shape with other suites (use transforms/compile_model).

Match Llama3_1_8B/NemotronH patterns to avoid ignored keys.

-    def get_default_kwargs(self):
-        return {
-            "skip_tokenizer_init": False,
-            "trust_remote_code": True,
-            # SSMs do not support cache reuse.
-            "kv_cache_config": {
-                "enable_block_reuse": False
-            },
-            # Keep max_batch_size as in the PyTorch test to avoid OOM
-            "max_batch_size": 128,
-            # Model context length is 8K
-            "max_seq_len": 8192,
-            # Set explicitly to match default build_config behavior
-            "max_num_tokens": 8192,
-            "skip_loading_weights": False,
-            "compile_backend": "torch-cudagraph",
-            "free_mem_ratio": 0.7,
-            "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
-        }
+    def get_default_kwargs(self):
+        return {
+            "skip_tokenizer_init": False,
+            "trust_remote_code": True,
+            "kv_cache_config": {"enable_block_reuse": False},
+            "max_batch_size": 128,
+            "max_seq_len": 8192,
+            "max_num_tokens": 8192,
+            "skip_loading_weights": False,
+            "transforms": {
+                "resize_kv_cache": {"free_mem_ratio": 0.7},
+                "compile_model": {
+                    "backend": "torch-cudagraph",
+                    "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
+                },
+            },
+        }

186-197: Prefer decorator skip and drop unreachable body.

Use @pytest.mark.skip to avoid executing dead code after pytest.skip().

-    @pytest.mark.skip_less_device_memory(32000)
-    def test_auto_dtype(self):
-        pytest.skip("Nemotron-MOE is not in CI yet")
-        kwargs = self.get_default_kwargs()
-        sampling_params = self.get_default_sampling_params()
-        with AutoDeployLLM(model=self.MODEL_PATH,
-                           tokenizer=self.MODEL_PATH,
-                           **kwargs) as llm:
-            task = MMLU(self.MODEL_NAME)
-            task.evaluate(llm, sampling_params=sampling_params)
-            task = GSM8K(self.MODEL_NAME)
-            task.evaluate(llm)
+    @pytest.mark.skip_less_device_memory(32000)
+    @pytest.mark.skip(reason="Nemotron-MOE is not in CI yet")
+    def test_auto_dtype(self):
+        pass

Optionally validate that MODEL_PATH exists when running locally; the harness usually supports hub IDs. Do you want a helper that switches to hub if the local path is absent?

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (3)

51-55: Tiny simplification: avoid extra None-dim reshape.

Direct indexing is clearer and avoids an unnecessary reshape.

-        tokens_for_this_expert = x[None, top_x].reshape(-1, hidden_dim)
+        tokens_for_this_expert = x[top_x]

332-349: Silence unused args in fake kernel.

Keep API stable but avoid ARG001 warnings by explicitly discarding.

 def torch_quant_fp8_moe_fake(
@@
-    mlp_style: str = "gated_mlp",
-    act_fn: str = "silu",
+    mlp_style: str = "gated_mlp",
+    act_fn: str = "silu",
 ) -> torch.Tensor:
-    return torch.empty_like(x)
+    # Intentionally unused in fake; ensure lints stay quiet.
+    _ = (mlp_style, act_fn)
+    return torch.empty_like(x)

352-369: API parity suggestion: consider mlp_style/act_fn for NVFP4.

For a consistent surface, accept mlp_style/act_fn (even if you only support "gated_mlp" for now) and validate early. This avoids branching in callers and mirrors FP8.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e5865de and 1180437.

📒 Files selected for processing (7)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1 hunks)
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py
  • tests/integration/defs/accuracy/test_llm_api_autodeploy.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (5)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • is_op (179-202)
  • extract_op_args (407-444)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (1)
  • torch_quant_fp8_moe (218-329)
tensorrt_llm/module.py (1)
  • register_parameter (186-190)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (2)
  • triton_quant_fp8_moe (521-607)
  • triton_quant_fp8_moe (611-627)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • TransformRegistry (503-531)
  • register (509-516)
  • BaseTransform (213-500)
  • TransformInfo (121-174)
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)
  • call_function (249-276)
tests/integration/defs/accuracy/test_llm_api_autodeploy.py (2)
tests/integration/defs/accuracy/accuracy_core.py (5)
  • LlmapiAccuracyTestHarness (846-857)
  • MMLU (317-331)
  • evaluate (184-247)
  • evaluate (765-775)
  • GSM8K (334-349)
tensorrt_llm/_torch/auto_deploy/models/factory.py (2)
  • model (125-127)
  • tokenizer (130-132)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (2)
  • triton_quant_fp8_moe (521-607)
  • triton_quant_fp8_moe (611-627)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (1)
  • torch_quant_fp8_moe (218-329)
🪛 Ruff (0.14.1)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

612-613: try-except-continue detected, consider logging the exception

(S112)


612-612: Do not catch blind exception: Exception

(BLE001)


772-772: Unused method argument: cm

(ARG002)


773-773: Unused method argument: factory

(ARG002)


774-774: Unused method argument: shared_config

(ARG002)

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py

354-354: Unused function argument: num_tokens_post_padded

(ARG001)


423-423: Avoid specifying long messages outside the exception class

(TRY003)


505-505: Unused function argument: selected_experts

(ARG001)


506-506: Unused function argument: routing_weights

(ARG001)


507-507: Unused function argument: w1_stacked_weight

(ARG001)


508-508: Unused function argument: w2_stacked_weight

(ARG001)


527-527: Unused function argument: w3_weight

(ARG001)


530-530: Unused function argument: w3_input_scale

(ARG001)


533-533: Unused function argument: w3_weight_scale

(ARG001)


535-535: Unused function argument: act_fn

(ARG001)


613-613: Unused function argument: selected_experts

(ARG001)


614-614: Unused function argument: routing_weights

(ARG001)


615-615: Unused function argument: w1_weight

(ARG001)


616-616: Unused function argument: w2_weight

(ARG001)


617-617: Unused function argument: w3_weight

(ARG001)


618-618: Unused function argument: w1_input_scale

(ARG001)


619-619: Unused function argument: w2_input_scale

(ARG001)


620-620: Unused function argument: w3_input_scale

(ARG001)


621-621: Unused function argument: w1_weight_scale

(ARG001)


622-622: Unused function argument: w2_weight_scale

(ARG001)


623-623: Unused function argument: w3_weight_scale

(ARG001)


624-624: Unused function argument: mlp_style

(ARG001)


625-625: Unused function argument: act_fn

(ARG001)

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py

327-327: Avoid specifying long messages outside the exception class

(TRY003)


346-346: Unused function argument: mlp_style

(ARG001)


347-347: Unused function argument: act_fn

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (15)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py (1)

220-360: Test function verified as single definition; parity test setup is solid.

Verification confirms only one definition of test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe exists in the file (line 220). The FP8 quantization parity test is well-structured with proper controlled routing, appropriate per-expert weight scaling, and reasonable tolerances for FP8 comparisons.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

119-121: Verification confirms transform is safely registered with proper no-op behavior.

The fuse_fp8_moe transform is registered at tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py:762 and implements graceful skipping: when _stack_fp8_moe_weights() finds no FP8 MoE weights (counter == 0), it returns TransformInfo(skipped=True). Safe to enable by default.

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (6)

616-627: LGTM - Good helper for retrieving parameters or buffers.

The helper correctly handles both direct attributes and nested submodule attributes, which is necessary since scales may be registered as buffers rather than parameters.


630-665: LGTM - Correct handling of optional w3 weights and scales.

The stacking logic correctly handles both gated MLP (with w3) and standard MLP (without w3) styles by creating empty tensors when w3_list is empty. This maintains consistent tensor shapes across different MLP styles.


681-704: Verify that scales should be registered as parameters rather than buffers.

The original scales were likely registered as buffers (non-trainable tensors), but here they're being registered as parameters with requires_grad=False. While functionally similar for inference, this changes their semantic meaning and how they're handled by PyTorch's module system.

Consider whether these should be registered as buffers instead:

gm.register_buffer(new_key_w1_input_scale, w1_input_scale_stacked)

If the distinction matters for serialization, state dict handling, or other framework features, please verify the correct approach. If parameters are intentional, consider adding a comment explaining why scales are parameters rather than buffers.


707-728: LGTM - Correct graph node replacement.

The node replacement correctly:

  • Uses graph.get_attr() to reference the newly registered stacked parameters
  • Preserves the original node's kwargs (e.g., mlp_style, act_fn)
  • Replaces all uses before erasing the old node

730-737: LGTM - Proper cleanup of unused parameters.

The cleanup correctly removes dead code and unused submodules after stacking, which is essential for memory efficiency when dealing with large models.


762-785: LGTM - Clean transform implementation following established patterns.

The FuseFP8Moe transform correctly:

  • Follows the same pattern as FuseMoe for consistency
  • Uses cuda_memory_tracker to monitor memory during the transformation
  • Sets skipped=True when no FP8 MoE patterns are found
  • Implements the BaseTransform interface (unused arguments are expected)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py (7)

347-412: LGTM - Clean refactoring with proper kernel dispatch.

The refactoring correctly:

  • Consolidates common arguments between unquantized and FP8 paths
  • Uses conditional dispatch based on scale presence to select the appropriate kernel
  • Maintains backward compatibility for the unquantized path

Note: The static analysis warning about num_tokens_post_padded being unused is a false positive—it's used in common_args at line 384.


414-423: LGTM - Correct dtype mapping.

The helper correctly maps PyTorch dtypes to Triton compute types. The function is simple and handles the common FP8 use cases (BF16/FP16 output).


426-481: LGTM - Well-factored internal MoE implementation.

The refactored _fused_moe_mlp_relu2 function correctly:

  • Handles token packing and kernel configuration
  • Performs two GEMMs with ReLU² activation in between
  • Uses the new _invoke_kernel interface for both passes

513-517: LGTM - Correct FP8 quantization with clamping.

The quantization helper correctly clamps values to the FP8 E4M3 range before conversion, which is necessary to prevent overflow and matches the behavior of torch_quant_fp8_linear.


534-539: Document current limitations of FP8 MoE implementation.

The FP8 Triton MoE currently only supports:

  1. mlp_style=="mlp" (2-layer non-gated MLP)
  2. ReLU² activation (implicit, hardcoded)

However, the function signature accepts mlp_style and act_fn parameters to match the interface of torch_quant_fp8_moe. The act_fn parameter is unused.

Consider:

  1. Adding a check for act_fn to ensure it matches the hardcoded ReLU² behavior
  2. Updating the docstring to clearly document these limitations
  3. Adding a TODO comment for future gated MLP support
     """Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
     if mlp_style != "mlp":
         raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")
+    
+    # Currently only ReLU^2 activation is supported (hardcoded in kernel)
+    if act_fn not in ("relu2", "silu"):  # Accept silu for compatibility but treat as relu2
+        raise NotImplementedError(f"triton_quant_fp8_moe currently only supports relu2 activation, got {act_fn}")

549-550: Verify assumption that all experts share the same input scale.

Lines 549-550 extract only the first element of w1_input_scale and w2_input_scale, assuming all experts use the same input scale:

a1_scale = w1_input_scale[0].to(torch.float32).reshape(1).contiguous()
a2_scale = w2_input_scale[0].to(torch.float32).reshape(1).contiguous()

This differs from the weight scales (b1_scale, b2_scale) which are per-expert. Is this intentional?

If input quantization scales are truly shared across all experts (i.e., the input tensor has a single scale regardless of which expert processes it), this is correct. However, please verify this matches the quantization strategy used in the model and is consistent with how torch_quant_fp8_moe in torch_moe.py handles per-expert input scales.

If input scales should be per-expert, the kernel would need to be updated to load the appropriate scale based on off_experts.


610-627: LGTM - Fake implementation follows PyTorch custom op pattern.

The fake implementation correctly returns an output tensor with the same shape as the input, which is required for PyTorch's tracing and shape inference. The unused argument warnings from static analysis are expected for fake implementations.

Signed-off-by: Chenghao Zhang <[email protected]>
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22245 [ run ] triggered by Bot. Commit: cbf0c9a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22245 [ run ] completed with state SUCCESS. Commit: cbf0c9a
/LLM/main/L0_MergeRequest_PR pipeline #16771 completed with status: 'FAILURE'

@suyoggupta
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22266 [ run ] triggered by Bot. Commit: 17d335c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22266 [ run ] completed with state FAILURE. Commit: 17d335c
/LLM/main/L0_MergeRequest_PR pipeline #16788 completed with status: 'FAILURE'

Signed-off-by: Fridah-nv <[email protected]>
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22330 [ run ] triggered by Bot. Commit: aa9ee9b

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22330 [ run ] completed with state SUCCESS. Commit: aa9ee9b
/LLM/main/L0_MergeRequest_PR pipeline #16834 completed with status: 'FAILURE'

Signed-off-by: nvchenghaoz <[email protected]>
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22369 [ run ] triggered by Bot. Commit: 9eaafd2

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22369 [ run ] completed with state SUCCESS. Commit: 9eaafd2
/LLM/main/L0_MergeRequest_PR pipeline #16862 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22460 [ run ] triggered by Bot. Commit: 9eaafd2

@nvchenghaoz nvchenghaoz enabled auto-merge (squash) October 24, 2025 19:19
Copy link
Collaborator

@Fridah-nv Fridah-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rubber stamped per request.

@github-project-automation github-project-automation bot moved this from Backlog to In review in AutoDeploy Board Oct 24, 2025
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@nvchenghaoz
Copy link
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22478 [ run ] triggered by Bot. Commit: bab0f2e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22460 [ run ] completed with state ABORTED. Commit: 9eaafd2
LLM/main/L0_MergeRequest_PR #16927 (Blue Ocean) completed with status: ABORTED

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22479 [ kill ] triggered by Bot. Commit: bab0f2e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22478 [ run ] completed with state ABORTED. Commit: bab0f2e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22479 [ kill ] completed with state SUCCESS. Commit: bab0f2e
Successfully killed previous jobs for commit bab0f2e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22480 [ run ] triggered by Bot. Commit: bab0f2e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22480 [ run ] completed with state SUCCESS. Commit: bab0f2e
/LLM/main/L0_MergeRequest_PR pipeline #16940 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22489 [ run ] triggered by Bot. Commit: 97fe71d

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22489 [ run ] completed with state SUCCESS. Commit: 97fe71d
/LLM/main/L0_MergeRequest_PR pipeline #16947 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22508 [ run ] triggered by Bot. Commit: 97fe71d

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22508 [ run ] completed with state SUCCESS. Commit: 97fe71d
/LLM/main/L0_MergeRequest_PR pipeline #16965 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@nvchenghaoz nvchenghaoz merged commit a6d20f6 into NVIDIA:main Oct 25, 2025
5 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in AutoDeploy Board Oct 25, 2025
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 1, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: nvchenghaoz <[email protected]>
Co-authored-by: Suyog Gupta <[email protected]>
Co-authored-by: Fridah-nv <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: nvchenghaoz <[email protected]>
Co-authored-by: Suyog Gupta <[email protected]>
Co-authored-by: Fridah-nv <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: nvchenghaoz <[email protected]>
Co-authored-by: Suyog Gupta <[email protected]>
Co-authored-by: Fridah-nv <[email protected]>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: nvchenghaoz <[email protected]>
Co-authored-by: Suyog Gupta <[email protected]>
Co-authored-by: Fridah-nv <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

4 participants