Skip to content

Conversation

@symphonylyh
Copy link
Collaborator

@symphonylyh symphonylyh commented Aug 7, 2025

Summary by CodeRabbit

  • New Features

    • Enhanced support for Hopper (SM 90) GPU architecture in quantization and MoE groupwise weight handling.
    • Added architecture-specific processing for weight scale interleaving in quantization workflows.
  • Bug Fixes

    • Improved test logic to correctly handle and skip tests based on GPU architecture, ensuring accurate test coverage.
  • Tests

    • Updated and expanded test cases to differentiate behavior between Ada and Hopper GPUs, including new logic for interleaved scale tensors.

Description

MoE w4a8 (wINT4aFP8) groupwise quant status before this PR:
PyT path - support both Ada and Hopper
TRT path - kernels exist, but requires preprocessing on weights & scales. On Ada, it's supported; On Hopper, preprocessing is missing

After this PR:
TRT path supports w4a8 groupwise MoE quant on both Ada and Hopper. Hopper usage is demostrated with the provided unit test case. NOT added into E2E TRT run because PyT path is the recommended path for users.

Using PyT path to explain what preprocessing is needed:
On Ada: int4 weight preproc logic here (interleave), scale preproc logic here (FP16 w/o interleave)
On Hopper: int4 weight preproc (None, no interleave), scale preproc logic (BF16 w/ interleave. calc interleave factor here + dtype cast and do interleave here)

Solution: interleave weight_scale for Hopper path + disable weight interleave for Hopper path (but still need the subbyte transpose) + change in MoE layer definition to take the interleave factor into account.

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 7, 2025

📝 Walkthrough

Walkthrough

The changes introduce architecture-specific handling for weight scale interleaving in MoE groupwise quantization, particularly for Hopper (SM 90) GPUs. This includes updating tensor shape calculations, adding a utility function for interleave factor computation, modifying weight preprocessing logic, and refining unit tests to accommodate architecture differences and interleaving requirements.

Changes

Cohort / File(s) Change Summary
MoE Layer Interleave Handling
tensorrt_llm/layers/moe.py
Adjusted initialization of weights_scaling_factor in MOEWeightWrapper to account for interleaving based on quantization algorithm and architecture. Added use of get_weight_scale_interleave_factor to determine scaling factor tensor shape.
Quantization Functional Updates
tensorrt_llm/quantization/functional.py
Added get_weight_scale_interleave_factor function. Updated preprocess_weights_for_mixed_gemm to accept do_weight_interleave flag, gating interleaving logic.
MoE Quantization Groupwise Matmul Tests
tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
Enhanced tests to differentiate Ada and Hopper GPU architectures. Added interleave logic for scale tensors, conditional skipping, and architecture-aware shape calculations. Updated test decorators and session creation logic to reflect interleaved tensor shapes and architecture-specific requirements.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Unit Test
    participant Quant as Quantization Functional
    participant MoE as MOE Layer

    Test->>Quant: Call preprocess_weights_for_mixed_gemm(..., do_weight_interleave)
    alt do_weight_interleave is True
        Quant->>Quant: Permute and interleave weights
    else do_weight_interleave is False
        Quant->>Quant: Skip interleaving
    end
    Quant->>Test: Return processed weights

    Test->>MoE: Initialize MOEWeightWrapper(...)
    MoE->>Quant: Call get_weight_scale_interleave_factor(...)
    MoE->>MoE: Set weights_scaling_factor shape based on interleave factor
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • chzblych
  • yizhang-nv

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@symphonylyh
Copy link
Collaborator Author

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py (1)

234-245: Rename ambiguous variable I to improve readability.

The variable name I on line 238 is ambiguous and could be confused with lowercase 'l' or the number '1'. Consider using a more descriptive name.

             def interleave_scales(scales: torch.Tensor, interleave_dim: int):
                 # [num_experts, num_groups, num_cols] --> [num_experts, num_groups // interleave, num_cols * interleave]
                 # Note: num_groups = num_rows // group_size
                 E, G, C = scales.shape
-                I = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor(
+                interleave_factor = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor(
                     interleave_dim, group_size)
-                assert G % I == 0, f"Group dimension ({G}) must be divisible by interleave factor ({I})."
-                scales_interleaved = scales.reshape(E, G // I, I, C)
+                assert G % interleave_factor == 0, f"Group dimension ({G}) must be divisible by interleave factor ({interleave_factor})."
+                scales_interleaved = scales.reshape(E, G // interleave_factor, interleave_factor, C)
                 scales_interleaved = scales_interleaved.permute(0, 1, 3, 2)
                 scales_interleaved = scales_interleaved.reshape(
-                    E, G // I, C * I)
+                    E, G // interleave_factor, C * interleave_factor)
                 return scales_interleaved.contiguous()
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1b9781e and b390c4a.

📒 Files selected for processing (3)
  • tensorrt_llm/layers/moe.py (2 hunks)
  • tensorrt_llm/quantization/functional.py (3 hunks)
  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/layers/moe.py
  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
  • tensorrt_llm/quantization/functional.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/layers/moe.py
  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
  • tensorrt_llm/quantization/functional.py
🧠 Learnings (2)
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
📚 Learning: in tensorrt-llm, test files (files under tests/ directories) do not require nvidia copyright headers...
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
🪛 Ruff (0.12.2)
tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py

238-238: Ambiguous variable name: I

(E741)

🔇 Additional comments (7)
tensorrt_llm/layers/moe.py (2)

43-46: LGTM: Import addition for weight scale interleaving.

The import of get_weight_scale_interleave_factor is correctly added to support architecture-specific weight scale interleaving for W4A8 groupwise MoE quantization.


493-504: Architecture-specific weight scale interleaving implementation looks correct.

The implementation correctly handles Hopper (SM 90) architecture-specific requirements for W4A8 groupwise MoE quantization:

  1. Weight parameter: Maintains the correct shape (experts_per_node, in_features, out_features // 4) for int4 quantization
  2. Conditional interleaving: Properly checks for W4A8_ALPHA flag to determine if interleaving is needed
  3. Scale factor computation: Uses get_weight_scale_interleave_factor to compute the architecture-specific interleave factor
  4. Shape adjustment: Correctly adjusts the weights_scaling_factor shape to account for interleaving by dividing the second dimension by the interleave factor and multiplying the third dimension by it

This change aligns with the PR objective of fixing Hopper w4a8 groupwise MoE interleave and maintains backward compatibility by defaulting to scale_interleave_factor = 1 when W4A8_ALPHA is not enabled.

tensorrt_llm/quantization/functional.py (1)

953-958: LGTM! Good backward compatibility with the default parameter.

The addition of the do_weight_interleave parameter with a default value of True maintains backward compatibility while enabling architecture-specific behavior.

tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py (4)

17-17: LGTM! Proper extension of test coverage to multiple GPU architectures.

The changes correctly extend test support from Ada-only to both Ada and Hopper architectures.

Also applies to: 24-26, 31-32, 317-317


72-73: Good defensive programming to handle interleaved tensors.

Using fc2_prequant_scale.shape[-1] ensures we get the correct dimension regardless of whether weight or scale interleaving has been applied.


209-228: LGTM! Correct architecture-specific weight preprocessing logic.

The conditional logic properly handles:

  • Hopper (SM 90) with W4A8: disables weight interleaving
  • Ada (SM 89) with W4A8: converts scales to float16

229-233: LGTM! Appropriate handling of unsupported configuration.

Correctly skips the test when has_zero is enabled with Hopper W4A8, as this combination is not yet supported.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14465 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14465 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10930 completed with status: 'FAILURE'

symphonylyh and others added 2 commits August 7, 2025 10:43
Signed-off-by: Haohang Huang <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Haohang Huang <[email protected]>
@symphonylyh symphonylyh force-pushed the user/haohangh/fix-w4a8-interleave branch from 33aef0c to 446f26c Compare August 7, 2025 17:43
@symphonylyh
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py (1)

209-252: Comprehensive architecture-specific preprocessing logic implemented correctly.

The implementation properly handles the differences between Ada (SM 89) and Hopper (SM 90) architectures:

  • Hopper w4a8 correctly disables weight interleaving (do_weight_interleave=False)
  • Architecture-specific type conversions (float16 for Ada, bfloat16 for Hopper)
  • Proper scale interleaving for Hopper with the nested interleave_scales function
  • Appropriate skipping of zero tensor support for Hopper w4a8

However, address the ambiguous variable name flagged by static analysis:

-                I = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor(
+                interleave_factor = tensorrt_llm.quantization.functional.get_weight_scale_interleave_factor(
                     interleave_dim, group_size)
-                assert G % I == 0, f"Group dimension ({G}) must be divisible by interleave factor ({I})."
+                assert G % interleave_factor == 0, f"Group dimension ({G}) must be divisible by interleave factor ({interleave_factor})."
-                scales_interleaved = scales.reshape(E, G // I, I, C)
+                scales_interleaved = scales.reshape(E, G // interleave_factor, interleave_factor, C)
-                scales_interleaved = scales_interleaved.reshape(
-                    E, G // I, C * I)
+                scales_interleaved = scales_interleaved.reshape(
+                    E, G // interleave_factor, C * interleave_factor)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 33aef0c and 446f26c.

📒 Files selected for processing (3)
  • tensorrt_llm/layers/moe.py (2 hunks)
  • tensorrt_llm/quantization/functional.py (3 hunks)
  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/layers/moe.py
  • tensorrt_llm/quantization/functional.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
🧠 Learnings (2)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py
🪛 Ruff (0.12.2)
tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py

238-238: Ambiguous variable name: I

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (4)
tests/unittest/trt/quantization/test_moe_weight_only_groupwise_quant_matmul.py (4)

17-18: LGTM! Pytest import added for conditional test skipping.

The pytest import is correctly added to support the conditional test skipping functionality for Hopper GPU architecture handling.


24-26: Good architecture support expansion.

The utility imports have been appropriately updated to support both Ada and Hopper architectures, and the get_sm_version function import enables architecture-specific conditional logic.

Also applies to: 31-32


72-73: Correct parameter derivation for interleaved weights/scales.

Using fc2_prequant_scale.shape[-1] to derive the parameter n is the right approach since either weights or scales could be interleaved depending on the architecture, making the prequant scale a reliable source for the original dimension.


317-317: Test decorator correctly updated for expanded architecture support.

The change from skip_non_ada_unittest to skip_neither_ada_nor_hopper_unittest properly reflects the expanded GPU architecture support for both Ada and Hopper in the W4A8 quantization tests.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14501 [ run ] triggered by Bot

@symphonylyh symphonylyh changed the title [None][fix] Hopper w4a8 groupwise MoE interleave [https://nvbugs/5410687][fix] Hopper w4a8 groupwise MoE interleave Aug 7, 2025
@symphonylyh symphonylyh requested a review from achartier August 7, 2025 18:36
@tensorrt-cicd
Copy link
Collaborator

PR_Github #14501 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10952 completed with status: 'SUCCESS'

@achartier achartier merged commit 980929e into NVIDIA:main Aug 7, 2025
6 checks passed
Shunkangz pushed a commit to hcyezhang/TensorRT-LLM that referenced this pull request Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants