Skip to content

[None][feat] Optimize GDN of Qwen3-Next/3.5; adds BF16 TRTLLM MoE#12557

Open
rosenrodt wants to merge 10 commits intoNVIDIA:mainfrom
rosenrodt:qwen3next-3_5-pyt-perf
Open

[None][feat] Optimize GDN of Qwen3-Next/3.5; adds BF16 TRTLLM MoE#12557
rosenrodt wants to merge 10 commits intoNVIDIA:mainfrom
rosenrodt:qwen3next-3_5-pyt-perf

Conversation

@rosenrodt
Copy link
Copy Markdown
Collaborator

@rosenrodt rosenrodt commented Mar 26, 2026

Summary by CodeRabbit

  • New Features

    • Added BF16 support for Mixture of Experts models with FlashInfer backend
    • Enhanced tensor parallelism support for Qwen3 models
  • Performance

    • Optimized causal convolution kernel selection based on input characteristics
    • Improved state management and memory handling in model execution
    • Refined sequence processing efficiency for variable-length inputs
  • Tests

    • Extended test coverage for tensor parallelism configurations and MoE backends
    • Added parametrized test variants for BF16 and quantization modes

Description

  • Enable BF16 TRTLLM MoE through FlashInfer in the PyTorch backend.
  • Fix the Mamba2 metadata prefill bubble in chunked prefill serving (by @Wong4j)
  • Improve Gated Delta Net kernels perf
    • Tune causal-conv launches for varlen / short-sequence workloads.
    • In-place indexed state updates in kernel
    • Keep decode q/k/v tensors as views instead of instantiating new packed tensor
    • Change raster order of fused_sigmoid_gating_delta_rule_update_kernel

Perf

Qwen3.5-35B-A3B BF16 TP1
ISL/OSL=4k/1k synthetic (ignore_eos=True)
Tested on B200

Concurrency CUTLASS MoE (baseline) CUTLASS MoE (this PR) TRTLLM MoE (this PR) Speedup
1 180.42 178.74 210.98 1.17
8 1039.22 1077.51 1190.87 1.15
64 3786.6 3995.66 4307.72 1.14
128 4840.73 5272.36 5631.7 1.16
256 5232.83 6179.34 6485.89 1.24

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40455 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: --disable-fast-fail

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40456 [ run ] triggered by Bot. Commit: 85ec854 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40456 [ run ] completed with state SUCCESS. Commit: 85ec854
/LLM/main/L0_MergeRequest_PR pipeline #31545 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40481 [ run ] triggered by Bot. Commit: 252269f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40481 [ run ] completed with state SUCCESS. Commit: 252269f
/LLM/main/L0_MergeRequest_PR pipeline #31569 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40500 [ run ] triggered by Bot. Commit: 162777e Link to invocation

@rosenrodt rosenrodt force-pushed the qwen3next-3_5-pyt-perf branch from 162777e to 6b67c8e Compare March 27, 2026 13:55
@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40502 [ run ] triggered by Bot. Commit: 6b67c8e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40502 [ run ] completed with state SUCCESS. Commit: 6b67c8e
/LLM/main/L0_MergeRequest_PR pipeline #31590 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40530 [ run ] triggered by Bot. Commit: 5b0a3fb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40530 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 9 PM PST on 3/28.

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

cc @VALLIS-NERIA @nv-guomingz as this PR modifies some of the GDN, mamba state kernels

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40575 [ run ] triggered by Bot. Commit: 5b0a3fb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40575 [ run ] completed with state SUCCESS. Commit: 5b0a3fb
/LLM/main/L0_MergeRequest_PR pipeline #31616 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40620 [ run ] triggered by Bot. Commit: 19840ed Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40620 [ run ] completed with state SUCCESS. Commit: 19840ed
/LLM/main/L0_MergeRequest_PR pipeline #31660 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@rosenrodt rosenrodt force-pushed the qwen3next-3_5-pyt-perf branch from 19840ed to 8a2c936 Compare March 31, 2026 05:57
rosenrodt and others added 3 commits March 31, 2026 14:02
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Three optimizations to eliminate GPU idle bubbles during prefill in
Mamba2Metadata.prepare() for hybrid GDN models (e.g. Qwen3.5):

1. Remove tl.constexpr from num_seqs and N in _cu_seqlens_triton_kernel.
   Triton JIT recompiles for each unique constexpr value (~120ms each).
   In serving, num_seqs varies every prefill step, causing repeated
   recompilation. With dynamic parameters, only one compilation occurs.

2. Accept total_seqlens from caller to skip first GPU->CPU sync.
   cu_seqlens[-1].item() blocked on all pending GPU work. The caller
   (Mamba2Metadata.prepare) already has num_ctx_tokens on CPU.

3. Compute extra_chunks with pure Python arithmetic on CPU seq_lens
   to eliminate the second GPU->CPU sync (cumsum + p[-1].item()).

Before: _prepare_inputs ~120-460ms per prefill step (Triton recompile +
        GPU sync bubbles)
After:  _prepare_inputs ~1-2ms steady state

Verified: 9200+ random equivalence tests + e2e serving assertion with
1000 requests (0 mismatches). GSM8K accuracy unchanged (90.07% on full
1319 samples).

Signed-off-by: Shijie Wang <jaywan@nvidia.com>
- update chunked Gated Delta Rule prefill to use indexed in-kernel state updates
- remove explicit Qwen3Next prefill state gather/scatter in forward_extend
- retune causalConv1d forward launch selection for varlen and short sequences

Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
- keep decode qkv views and make the fused recurrent kernel stride-aware
- restore the decode tile choice that wins on the representative bs256 pure-decode benchmark

Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
@rosenrodt rosenrodt force-pushed the qwen3next-3_5-pyt-perf branch from 8a2c936 to 30759a5 Compare March 31, 2026 06:09
@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run —disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40876 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: —disable-fail-fast

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40933 [ run ] triggered by Bot. Commit: 30759a5 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40933 [ run ] completed with state FAILURE. Commit: 30759a5
/LLM/main/L0_MergeRequest_PR pipeline #31925 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@rosenrodt
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41255 [ run ] triggered by Bot. Commit: 30759a5 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41255 [ run ] completed with state SUCCESS. Commit: 30759a5
/LLM/main/L0_MergeRequest_PR pipeline #32213 completed with status: 'SUCCESS'

CI Report

Link to invocation

@rosenrodt rosenrodt marked this pull request as ready for review April 2, 2026 04:53
@rosenrodt rosenrodt requested review from a team as code owners April 2, 2026 04:53
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

📝 Walkthrough

Walkthrough

This PR introduces BF16 unquantized fused MoE execution via FlashInfer backend, adds indexed state update support to FLA kernels with parameter threading through multiple layers, refactors Triton/CUDA kernels for improved memory efficiency, updates model forward passes to leverage new kernel features, and parametrizes integration tests across tensor parallelism configurations and backends.

Changes

Cohort / File(s) Summary
Fused MoE — BF16 Infrastructure
tensorrt_llm/_torch/modules/fused_moe/create_moe.py, tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py, tensorrt_llm/_torch/modules/fused_moe/quantization.py
Introduces BF16 unquantized execution path with resolve_moe_cls() function that performs fallback to CUTLASS when appropriate; adds FlashInfer backend detection and routing-method-aware selection; implements BF16TRTLLMGenFusedMoEMethod with BlockMajorK weight layout and post-load processing.
Fused MoE — Backend Selection & Configuration
tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Switches from get_moe_cls() to resolve_moe_cls() for backend determination; introduces config deep-copying for quantization overrides to avoid mutating original model config.
Fused MoE — Operation Execution
tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py
Adds pack_topk_ids() Triton kernel and run_bf16_moe() method to MoE backend interface; implements BF16 execution path for FlashInfer backend with routing-method conversion; consolidates topk packing logic across quantization modes.
FLA Kernels — Indexed State Threading
tensorrt_llm/_torch/modules/fla/chunk.py, tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
Adds initial_state_indices and inplace_indexed_state_update parameters throughout function signatures and kernel invocations; updates validation logic to enforce indexed update constraints; modifies final-state handling based on indexed vs. non-indexed modes.
Triton Kernel — Sigmoid Gating
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
Replaces input_guard with custom_device_ctx; refactors kernel to use per-tensor token strides instead of packed layouts; introduces grid-striding loop over N*HV tiles and adjusts launch grid shape from (N*HV, NV, NK) to (NK, NV, min(N*HV, 65535)); adds stride parameters to kernel signature.
Mamba Modules — Helper Extraction & Metadata
tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py, tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py
Introduces shared extract_transpose_prefill_slice() helper; updates Triton kernel parameters from constexpr to runtime values; extends cu_seqlens_to_chunk_indices_offsets_triton() to accept optional total_seqlens and extra_chunks parameters for cached computations.
Model — Qwen3Next Forward Passes
tensorrt_llm/_torch/models/modeling_qwen3_next.py
Uses extract_transpose_prefill_slice() for transposed prefill handling; introduces TP-aware derived attributes (*_per_tp dimensions); refactors forward_decode to use per-TP dimensions; changes forward_extend to perform in-place state updates via chunk_gated_delta_rule with indexed state parameters instead of manual indexing.
CUDA Kernel — Causal Conv1D
cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
Adds causal_conv1d_fwd_dispatch() helper to select between 64 or 128 thread configurations based on varlen status and sequence length threshold; updates copyright year to 2022–2026.
Integration Tests — Qwen3.5 35B Parametrization
tests/integration/defs/accuracy/test_llm_api_pytorch.py
Parametrizes test_bf16 over MoE backends (CUTLASS, TRTLLM) and tensor-parallel sizes (1, 2); parametrizes test_fp8 over tensor-parallel sizes (1, 2); adds backend-specific skip logic and reduces batch size to 32.
Integration Test Lists — Test Selection
tests/integration/test_lists/qa/llm_function_core.txt, tests/integration/test_lists/qa/llm_function_core_sanity.txt, tests/integration/test_lists/test-db/l0_b200.yml
Updates test selection entries to reflect parameterized test_bf16 and test_fp8 variants across MoE backends and tensor-parallel configurations.
Integration Tests — Waiver Cleanup
tests/integration/test_lists/waives.txt
Removes waiver for TestQwen3_5_35B_A3B::test_fp8.
Unit Tests — MoE Skip Logic
tests/unittest/_torch/modules/moe/moe_test_utils.py
Updates should_skip_trtllm to allow unquantized (BF16) path with 128-multiple size constraints; refines should_skip_to_accelerate_ci to preserve TRTLLM BF16 coverage while skipping gated unquantized for other backends.

Sequence Diagram(s)

sequenceDiagram
    participant Model as Model Init
    participant CreateMoE as create_moe()
    participant ResolveCls as resolve_moe_cls()
    participant GetCls as get_moe_cls()
    participant Fallback as Fallback Logic
    participant Backend as MoE Backend
    participant Router as run_bf16_moe()

    Model->>CreateMoE: routing_method, dtype
    CreateMoE->>ResolveCls: model_config, routing_method, dtype, override_quant_config
    ResolveCls->>GetCls: initial class selection
    GetCls-->>ResolveCls: TRTLLMGenFusedMoE or other
    ResolveCls->>Fallback: check if BF16 path + FlashInfer required
    alt BF16 path unsupported by routing method
        Fallback-->>ResolveCls: CutlassFusedMoE (fallback)
    else BF16 path supported
        Fallback-->>ResolveCls: TRTLLMGenFusedMoE
    end
    ResolveCls-->>CreateMoE: resolved_moe_cls
    CreateMoE->>Backend: create_moe_backend(resolved_moe_cls)
    Backend-->>CreateMoE: backend instance
    CreateMoE-->>Model: MoE module ready
    
    Note over Router: Runtime execution
    Router->>Backend: run_bf16_moe(x, router_logits)
    alt router_logits provided
        Backend->>Router: trtllm_bf16_moe() + routing conversion
    else router_logits absent
        Backend->>Router: pack_topk_ids() + trtllm_bf16_routed_moe()
    end
    Router-->>Backend: BF16 MoE output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • nv-guomingz
  • Tom-Zheng
  • litaotju
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Linked Issues check ❓ Inconclusive The PR title uses '[None]' format indicating no linked JIRA ticket, NVBugs ID, or GitHub issue. While the PR description and commit messages reference specific problems (Mamba2 prefill bubble, GDN performance), there is no explicit issue tracker reference in the title. If this work addresses a tracked issue or bug (e.g., NVBugs ID or JIRA ticket), consider adding it to the PR title. If intentionally unlinked, this status is acceptable.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title '[None][feat] Optimize GDN of Qwen3-Next/3.5; adds BF16 TRTLLM MoE' clearly summarizes the main changes: optimization of Gated Delta Net, addition of BF16 TRTLLM MoE support, and improvements to Qwen3/3.5 models. The title is specific and directly related to the primary objectives.
Description check ✅ Passed The PR description partially follows the template with some key sections present (Description with bullet points explaining changes, a performance table, and a PR Checklist) but lacks proper completion of the 'Test Coverage' section which only contains a comment placeholder with no actual test names listed. The Description explains features and perf metrics clearly, but Test Coverage is empty.
Out of Scope Changes check ✅ Passed The PR scope focuses on Qwen3/3.5 optimization, Mamba2 metadata fixes, GDN kernel improvements, and BF16 TRTLLM MoE enabling. All changes appear directly related to these stated objectives. The copyright year update (2022-2025 to 2022-2026) is a minor in-scope housekeeping change.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py (1)

179-242: ⚠️ Potential issue | 🟠 Major

Restore the temporary ModelConfig mutation in a finally block.

Line 182 starts mutating backend_model_config, and Lines 190-193 temporarily disable weight creation on the same object. When override_quant_config is None, that object is the caller-owned model_config. If create_moe_backend(), validate_backend(), or the backend sync fails before Line 240, the shared config exits this constructor with a different skip_create_weights_in_init/freeze state, so a retry can silently skip weight creation. Please save the original values and restore them in finally, and keep the quant-config override inside that guarded mutation block.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py` around lines 179 -
242, The constructor mutates backend_model_config (which may alias caller-owned
model_config) and flips _frozen and skip_create_weights_in_init without
guaranteeing restoration on exceptions; capture the original values (e.g.,
orig_frozen = backend_model_config._frozen and tmp_skip_create_weights_in_init
already present) and perform the mutations (including making a deep copy only
when override_quant_config is set) inside a try block, then restore
backend_model_config._frozen and
backend_model_config.skip_create_weights_in_init in a finally block so failures
in create_moe_backend, validate_backend, or the sync won’t leave the shared
ModelConfig in a bad state; keep the later conditional create_weights() logic
unchanged (references: backend_model_config, model_config,
override_quant_config, tmp_skip_create_weights_in_init, create_moe_backend,
validate_backend, backend.create_weights).
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)

276-299: ⚠️ Potential issue | 🔴 Critical

Validate indexed-state inputs before launching Triton.

Line 298 forwards h0_i unconditionally, but the heuristic on Line 22 only checks h0_i is not None. If a caller passes initial_state_indices without initial_state, Lines 97-106 still enter indexed-state mode and do pointer arithmetic on a null h0. This wrapper also never checks fixed-length callers for N indices or validates slot IDs, so a malformed index tensor can read or write past initial_state.

🛡️ Suggested wrapper guards
     h = k.new_empty(B, NT, H, K, V)
     use_indexed_state = initial_state is not None and initial_state_indices is not None
+    if initial_state_indices is not None:
+        if initial_state is None:
+            raise ValueError("initial_state_indices requires initial_state")
+        if initial_state_indices.shape[0] != N:
+            raise ValueError(
+                f"Expected {N} initial_state_indices entries, got {initial_state_indices.shape[0]}."
+            )
+        # Also reject out-of-range slot ids here, or thread slot_num into the
+        # kernel and use tl.device_assert before doing pointer arithmetic.
+    if inplace_indexed_state_update and not use_indexed_state:
+        raise ValueError(
+            "inplace_indexed_state_update requires initial_state and initial_state_indices."
+        )
     if use_indexed_state and not inplace_indexed_state_update:
         raise ValueError(
             "Indexed chunk state updates require inplace_indexed_state_update=True."
         )
@@
-        h0_i=initial_state_indices,
+        h0_i=initial_state_indices if use_indexed_state else None,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py` around lines 276 - 299, The
kernel is currently launched with initial_state_indices (h0_i) unchecked which
can lead to null-pointer arithmetic or OOB access; before calling
chunk_gated_delta_rule_fwd_kernel_h_blockdim64 validate that if
initial_state_indices is not None then initial_state is not None, that
initial_state_indices.dtype is a long/torch.int64, that
initial_state_indices.dim() == 1 and initial_state_indices.numel() == N (or the
expected per-batch count), and that all index values satisfy 0 <= idx < K (raise
ValueError with a clear message if any check fails); only forward h0_i and h0 to
the kernel after these validations and keep the existing
inplace_indexed_state_update check.
🧹 Nitpick comments (7)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)

700-701: Make the shared cache intent explicit with ClassVar.

Line 701 uses a mutable class attribute cache; this is likely intentional, but it currently trips RUF012 and hides intent. Mark it as ClassVar[...] to document shared state explicitly.

🔧 Suggested fix
-from typing import Dict, List, NamedTuple, Optional, Tuple, Union
+from typing import ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union
@@
-    _cache_permute_indices: Dict[tuple[tuple[int, ...], str, int],
-                                 torch.Tensor] = {}
+    _cache_permute_indices: ClassVar[
+        Dict[tuple[tuple[int, ...], str, int], torch.Tensor]
+    ] = {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 700 -
701, The class-level mutable cache _cache_permute_indices is intended to be
shared but not marked explicitly; import ClassVar from typing and change the
annotation of _cache_permute_indices to ClassVar[Dict[tuple[tuple[int, ...],
str, int], torch.Tensor]] = {} in the class where it is declared (symbol:
_cache_permute_indices) so the intent of shared mutable state is explicit to
linters and readers while keeping the same default empty dict behavior. Ensure
you only change the type annotation (and add the ClassVar import if missing) and
do not move the variable into __init__ or make it an instance attribute.
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)

5859-5883: Trim the CUDA graph capture list to the new batch cap.

max_batch_size is now 32, but this test still asks CudaGraphConfig to warm up/capture 64 and 128 batches. If those sizes are not clamped internally, the warmup footprint stays larger than this test can ever use and undercuts the memory reduction from this change.

♻️ Proposed fix
-        cuda_graph_config = CudaGraphConfig(
-            enable_padding=True, batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128])
+        cuda_graph_config = CudaGraphConfig(
+            enable_padding=True, batch_sizes=[1, 2, 4, 8, 16, 32])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 5859 -
5883, The CudaGraph capture batch list in test_bf16 exceeds the test's
max_batch_size (32) causing unnecessary warmup memory; update the
cuda_graph_config used in test_bf16 (CudaGraphConfig instance) to only include
batch sizes <= the LLM's max_batch_size (e.g., trim or filter the batch_sizes
list [1,2,4,8,16,32] or programmatically filter against max_batch_size) so
captures never request 64 or 128 when max_batch_size=32.
tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py (2)

78-121: CPU fallback path has a potential device mismatch.

The CPU fallback at Lines 97-100 computes packed_topk_ids and copies to output. If output is None, it's created on topk_ids.device (Line 93), but the fallback computation happens on whatever device the inputs are on (could be CPU). This is fine since the fallback only triggers when inputs are not CUDA, but consider adding a comment for clarity.

Minor: The fallback comment "Fallback to CPU just in case" could be more precise:

-    # Fallback to CPU just in case
+    # Fallback to PyTorch path when Triton unavailable or inputs not on CUDA
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py` around lines 78 -
121, The CPU fallback in pack_topk_ids can create a subtle device mismatch:
output is created with device=topk_ids.device but the fallback path triggers
when inputs may be non-CUDA; make the behavior explicit by ensuring output is
allocated on the same device as the fallback computation (e.g., use
device=topk_ids.device or torch.device('cpu') when triton is None or inputs are
not CUDA) before computing packed_topk_ids, and update the comment "Fallback to
CPU just in case" to something precise like "CPU fallback path (used when triton
is unavailable or inputs are non-CUDA)"; reference pack_topk_ids, output,
packed_topk_ids, and triton when locating the change.

920-922: Conditional copy may produce unexpected return type.

When output is not None and do_finalize is True, the method copies result into output and returns output. However, when do_finalize=False, the raw result is returned which might have different semantics (tuple vs tensor based on FlashInfer's return). Consider documenting the return type behavior or verifying consistency with other run_*_moe methods.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py` around lines 920 -
922, The code path returns output (a tensor) when output is provided and
do_finalize is True, but returns raw result (which may be a tuple or different
type from FlashInfer) when do_finalize is False; make the return type
consistent: either always return the same structure (e.g., always return output
tensor) or always return result but copy into output only for in-place
semantics. Modify the method in moe_op_backend.py to normalize the return value
based on output/result/do_finalize (use output.clone() or wrap result into the
same tensor/tuple shape as other run_*_moe methods), and add a short comment
documenting the chosen semantic; ensure you update the code paths that reference
output, result, and do_finalize and match behavior of other run_*_moe
implementations.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (2)

327-330: FIXME comment indicates known accuracy bug.

The comment "FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug" suggests a workaround is in place. Consider tracking this with an issue if not already done.

Would you like me to help create an issue to track this FlashInfer BF16 routing accuracy bug?


321-325: Fix indentation for static analysis compliance.

Static analysis reports E125: continuation line with same indent as next logical line.

Fix indentation
     `@staticmethod`
     def _supports_flashinfer_bf16_routing_method(
-        routing_method: BaseMoeRoutingMethod, ) -> bool:
+            routing_method: BaseMoeRoutingMethod) -> bool:
         # FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
         return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py` around lines
321 - 325, The def of _supports_flashinfer_bf16_routing_method has misaligned
continuation indentation causing E125; fix by aligning the parameter list and
closing parenthesis with the def line (e.g., put the closing parenthesis and
return type on the same line as the parameters or place the closing parenthesis
on its own line indented to the same level as the def), leaving the body
unchanged (the check using isinstance(routing_method,
DeepSeekV3MoeRoutingMethod) must remain). Ensure references to
BaseMoeRoutingMethod and DeepSeekV3MoeRoutingMethod in the signature are
preserved exactly.
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)

68-72: Switch from typing.Tuple to builtin tuple[...] syntax.

Python >=3.10 fully supports builtin generic tuple[...]. Update both return annotations at lines 72 and 125, plus the docstring at line 132. After making these changes, remove the from typing import Tuple import.

♻️ Changes needed
  • Line 17: Remove from typing import Tuple
  • Line 72: Change -> Tuple[torch.Tensor, torch.Tensor]: to -> tuple[torch.Tensor, torch.Tensor]:
  • Line 125: Change -> Tuple[torch.Tensor, torch.Tensor]: to -> tuple[torch.Tensor, torch.Tensor]:
  • Line 132: Change Tuple[torch.Tensor, torch.Tensor] to tuple[torch.Tensor, torch.Tensor] in docstring
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py` around lines 68 - 72,
The return type annotations and docstring use typing.Tuple; update to the native
generic syntax and remove the unused import: replace the return annotation in
cu_seqlens_to_chunk_indices_offsets_triton (currently -> Tuple[torch.Tensor,
torch.Tensor]) with -> tuple[torch.Tensor, torch.Tensor], do the same for the
other function whose annotation is currently -> Tuple[torch.Tensor,
torch.Tensor], update the docstring occurrence of "Tuple[torch.Tensor,
torch.Tensor]" to "tuple[torch.Tensor, torch.Tensor]", and remove the now-unused
"from typing import Tuple" import.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py`:
- Around line 67-77: The variable name all shadows the builtin (causing Ruff
A001); rename it to a non-conflicting name like total_tokens in the block that
handles IS_VARLEN so both branches use total_tokens instead of all (update the
assignments where all = T and all = B * T and any subsequent uses), touching the
code around IS_VARLEN, bos/eos calculation (cu_seqlens, i_n), and seq_T/T/B
references to ensure consistency.

In `@tensorrt_llm/_torch/modules/fused_moe/create_moe.py`:
- Around line 80-96: In resolve_moe_cls, replace the inconsistent access to
effective_quant_config.layer_quant_mode with effective_quant_config.quant_mode
so it matches get_moe_cls; specifically update the has_quant computation to call
effective_quant_config.quant_mode.has_any_quant(exclude_kv_cache=True). Keep the
same logic that checks TRTLLMGenFusedMoE, routing_method and returns
CutlassFusedMoE when appropriate, and ensure references to
ModelConfig.quant_config, QuantConfig.quant_mode, layer_quant_mode,
has_any_quant, resolve_moe_cls, get_moe_cls, TRTLLMGenFusedMoE and
CutlassFusedMoE are adjusted accordingly.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 723-738: The current list-comprehension + torch.stack for
processed_w3_w1 and processed_w2 creates all expert tensors at once (memory
spike) and can trigger device-mismatch because cached permute indices may be on
CPU; instead, pre-allocate the output tensors with the correct
shape/dtype/device and fill them per-expert in a for-loop. Specifically, for
w3_w1: call self._get_w3_w1_permute_indices(...) once, then ensure the returned
w3_w1_permute_indices is moved to the expert device (e.g., .to(expert.device))
before using it; allocate processed_w3_w1 = torch.empty((num_experts, ...),
device=module.w3_w1_weight.device, dtype=...) and fill each slot by calling
_prepare_bf16_weight_for_trtllm_gen(expert, permute_indices_on_device,
self.block_k) inside a loop. Do the same for w2 using _get_w2_permute_indices,
processed_w2 allocation, and per-expert filling to avoid peak memory doubling
and device mismatch.

In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py`:
- Around line 54-81: The store-side address calculation in the transposition
kernel needs 64-bit arithmetic to avoid wrapping when computing conv_offsets *
num_prefill_tokens; update the _extract_transpose_prefill_kernel implementation
so any computation that multiplies conv_offsets by num_prefill_tokens (and any
derived store offsets) is performed in 64-bit (e.g., cast conv_offsets or
num_prefill_tokens to int64/torch.int64 or triton.int64 before multiplication)
and ensure the resulting offset used to write into out (the transposed buffer)
is a 64-bit value; adjust kernel argument types if necessary to accept the
widened offset and keep the same call from extract_transpose_prefill_slice.

In `@tests/integration/test_lists/qa/llm_function_core_sanity.txt`:
- Around line 178-181: Sanity list omits BF16 TP2 variants for
TestQwen3_5_35B_A3B while FP8 TP2 is included and BF16 TP2 exists in
l0_b200.yml; add the missing BF16 TP2 entries (e.g.,
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-CUTLASS]
and
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-TRTLLM]) to
the sanity file to mirror the FP8 TP2 coverage, or if the omission is
intentional, add a comment in the sanity list explaining the deliberate
exclusion referencing TestQwen3_5_35B_A3B BF16 TP2 so reviewers know it was
intentional.

---

Outside diff comments:
In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py`:
- Around line 276-299: The kernel is currently launched with
initial_state_indices (h0_i) unchecked which can lead to null-pointer arithmetic
or OOB access; before calling chunk_gated_delta_rule_fwd_kernel_h_blockdim64
validate that if initial_state_indices is not None then initial_state is not
None, that initial_state_indices.dtype is a long/torch.int64, that
initial_state_indices.dim() == 1 and initial_state_indices.numel() == N (or the
expected per-batch count), and that all index values satisfy 0 <= idx < K (raise
ValueError with a clear message if any check fails); only forward h0_i and h0 to
the kernel after these validations and keep the existing
inplace_indexed_state_update check.

In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py`:
- Around line 179-242: The constructor mutates backend_model_config (which may
alias caller-owned model_config) and flips _frozen and
skip_create_weights_in_init without guaranteeing restoration on exceptions;
capture the original values (e.g., orig_frozen = backend_model_config._frozen
and tmp_skip_create_weights_in_init already present) and perform the mutations
(including making a deep copy only when override_quant_config is set) inside a
try block, then restore backend_model_config._frozen and
backend_model_config.skip_create_weights_in_init in a finally block so failures
in create_moe_backend, validate_backend, or the sync won’t leave the shared
ModelConfig in a bad state; keep the later conditional create_weights() logic
unchanged (references: backend_model_config, model_config,
override_quant_config, tmp_skip_create_weights_in_init, create_moe_backend,
validate_backend, backend.create_weights).

---

Nitpick comments:
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py`:
- Around line 321-325: The def of _supports_flashinfer_bf16_routing_method has
misaligned continuation indentation causing E125; fix by aligning the parameter
list and closing parenthesis with the def line (e.g., put the closing
parenthesis and return type on the same line as the parameters or place the
closing parenthesis on its own line indented to the same level as the def),
leaving the body unchanged (the check using isinstance(routing_method,
DeepSeekV3MoeRoutingMethod) must remain). Ensure references to
BaseMoeRoutingMethod and DeepSeekV3MoeRoutingMethod in the signature are
preserved exactly.

In `@tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py`:
- Around line 78-121: The CPU fallback in pack_topk_ids can create a subtle
device mismatch: output is created with device=topk_ids.device but the fallback
path triggers when inputs may be non-CUDA; make the behavior explicit by
ensuring output is allocated on the same device as the fallback computation
(e.g., use device=topk_ids.device or torch.device('cpu') when triton is None or
inputs are not CUDA) before computing packed_topk_ids, and update the comment
"Fallback to CPU just in case" to something precise like "CPU fallback path
(used when triton is unavailable or inputs are non-CUDA)"; reference
pack_topk_ids, output, packed_topk_ids, and triton when locating the change.
- Around line 920-922: The code path returns output (a tensor) when output is
provided and do_finalize is True, but returns raw result (which may be a tuple
or different type from FlashInfer) when do_finalize is False; make the return
type consistent: either always return the same structure (e.g., always return
output tensor) or always return result but copy into output only for in-place
semantics. Modify the method in moe_op_backend.py to normalize the return value
based on output/result/do_finalize (use output.clone() or wrap result into the
same tensor/tuple shape as other run_*_moe methods), and add a short comment
documenting the chosen semantic; ensure you update the code paths that reference
output, result, and do_finalize and match behavior of other run_*_moe
implementations.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 700-701: The class-level mutable cache _cache_permute_indices is
intended to be shared but not marked explicitly; import ClassVar from typing and
change the annotation of _cache_permute_indices to
ClassVar[Dict[tuple[tuple[int, ...], str, int], torch.Tensor]] = {} in the class
where it is declared (symbol: _cache_permute_indices) so the intent of shared
mutable state is explicit to linters and readers while keeping the same default
empty dict behavior. Ensure you only change the type annotation (and add the
ClassVar import if missing) and do not move the variable into __init__ or make
it an instance attribute.

In `@tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py`:
- Around line 68-72: The return type annotations and docstring use typing.Tuple;
update to the native generic syntax and remove the unused import: replace the
return annotation in cu_seqlens_to_chunk_indices_offsets_triton (currently ->
Tuple[torch.Tensor, torch.Tensor]) with -> tuple[torch.Tensor, torch.Tensor], do
the same for the other function whose annotation is currently ->
Tuple[torch.Tensor, torch.Tensor], update the docstring occurrence of
"Tuple[torch.Tensor, torch.Tensor]" to "tuple[torch.Tensor, torch.Tensor]", and
remove the now-unused "from typing import Tuple" import.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 5859-5883: The CudaGraph capture batch list in test_bf16 exceeds
the test's max_batch_size (32) causing unnecessary warmup memory; update the
cuda_graph_config used in test_bf16 (CudaGraphConfig instance) to only include
batch sizes <= the LLM's max_batch_size (e.g., trim or filter the batch_sizes
list [1,2,4,8,16,32] or programmatically filter against max_batch_size) so
captures never request 64 or 128 when max_batch_size=32.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ee2a3c42-1ad1-400e-b2b0-898bbb826d10

📥 Commits

Reviewing files that changed from the base of the PR and between f6db7e3 and 30759a5.

📒 Files selected for processing (18)
  • cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py
  • tensorrt_llm/_torch/modules/fla/chunk.py
  • tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
  • tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
  • tensorrt_llm/_torch/modules/fused_moe/moe_op_backend.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py
  • tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/test_lists/test-db/l0_b200.yml
  • tests/integration/test_lists/waives.txt
  • tests/unittest/_torch/modules/moe/moe_test_utils.py
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

Comment on lines +67 to +77
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
seq_T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
seq_T = T
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename all to keep Ruff green.

Ruff already flags this block with A001 because all shadows the builtin. Renaming it to something like total_tokens avoids the lint error without changing the kernel logic.

✏️ Suggested rename
-            all = T
+            total_tokens = T
             seq_T = eos - bos
         else:
             bos, eos = i_n * T, i_n * T + T
-            all = B * T
+            total_tokens = B * T
             seq_T = T
@@
-        p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
+        p_o = o + ((i_k * total_tokens + bos) * HV + i_hv) * V + o_v
🧰 Tools
🪛 Ruff (0.15.7)

[error] 72-72: Variable all is shadowing a Python builtin

(A001)


[error] 76-76: Variable all is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around
lines 67 - 77, The variable name all shadows the builtin (causing Ruff A001);
rename it to a non-conflicting name like total_tokens in the block that handles
IS_VARLEN so both branches use total_tokens instead of all (update the
assignments where all = T and all = B * T and any subsequent uses), touching the
code around IS_VARLEN, bos/eos calculation (cu_seqlens, i_n), and seq_T/T/B
references to ensure consistency.

Comment on lines +80 to +96
def resolve_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype],
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
moe_cls = get_moe_cls(model_config, override_quant_config)

effective_quant_config = override_quant_config or model_config.quant_config
has_quant = (effective_quant_config is not None
and effective_quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True))
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
routing_method)):
return CutlassFusedMoE

return moe_cls
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the difference between quant_mode and layer_quant_mode in QuantConfig
ast-grep --pattern 'class QuantConfig {
  $$$
}'

rg -n "layer_quant_mode|quant_mode" --type py -C2 tensorrt_llm/models/modeling_utils.py | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 1855


🏁 Script executed:

#!/bin/bash
# Read the create_moe.py file to see both get_moe_cls and resolve_moe_cls functions
cat -n tensorrt_llm/_torch/modules/fused_moe/create_moe.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 5038


🏁 Script executed:

#!/bin/bash
# Search for has_any_quant method to understand what it's defined on
rg -n "def has_any_quant" --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 317


🏁 Script executed:

#!/bin/bash
# Check the QuantConfig class definition to understand the structure
rg -n "class QuantConfig" --type py -A 20 | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 6020


🏁 Script executed:

#!/bin/bash
# Check QuantModeWrapper definition and its methods
rg -n "class QuantModeWrapper" --type py -A 30

Repository: NVIDIA/TensorRT-LLM

Length of output: 1627


🏁 Script executed:

#!/bin/bash
# Check the full QuantConfig class definition, especially properties
rg -n "class QuantConfig" --type py -A 40 tensorrt_llm/models/modeling_utils.py | head -80

Repository: NVIDIA/TensorRT-LLM

Length of output: 2019


🏁 Script executed:

#!/bin/bash
# Check QuantMode class definition and methods
rg -n "class QuantMode" --type py -A 20 tensorrt_llm/quantization/mode.py | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 956


🏁 Script executed:

#!/bin/bash
# Get the full QuantConfig class definition focusing on the properties
sed -n '132,200p' tensorrt_llm/models/modeling_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2696


🏁 Script executed:

#!/bin/bash
# Check what quant_mode_list contains in the QuantConfig.quant_mode property
sed -n '174,182p' tensorrt_llm/models/modeling_utils.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 344


🏁 Script executed:

#!/bin/bash
# Check the QuantMode.has_any_quant method implementation
sed -n '206,215p' tensorrt_llm/quantization/mode.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 585


🏁 Script executed:

#!/bin/bash
# Get more context on the QuantMode.from_quant_algo to understand what it returns
rg -n "def from_quant_algo" --type py -A 15 tensorrt_llm/quantization/mode.py | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 1051


🏁 Script executed:

#!/bin/bash
# Verify how layer_quant_mode is used elsewhere in the codebase
rg -n "layer_quant_mode.has_any_quant" --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2431


🏁 Script executed:

#!/bin/bash
# Check how quant_mode is used elsewhere in the codebase
rg -n "quant_mode.has_any_quant|quant_mode.has_fp8|quant_mode.has_nvfp4" --type py | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 2827


🏁 Script executed:

#!/bin/bash
# Check if there are other uses of quant_mode.has_any_quant() beyond line 49 in get_moe_cls
rg -n "\.quant_mode\.has_any_quant" --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2198


Fix inconsistent attribute access: Use quant_mode instead of layer_quant_mode in resolve_moe_cls.

get_moe_cls (line 49) uses quant_config.quant_mode.has_any_quant(...) while resolve_moe_cls (line 89) uses effective_quant_config.layer_quant_mode.has_any_quant(...). Although both properties return equivalent results (both call QuantMode.from_quant_algo with identical arguments), the inconsistency creates unnecessary confusion. Align resolve_moe_cls with get_moe_cls:

Proposed fix
     effective_quant_config = override_quant_config or model_config.quant_config
     has_quant = (effective_quant_config is not None
-                 and effective_quant_config.layer_quant_mode.has_any_quant(
+                 and effective_quant_config.quant_mode.has_any_quant(
                      exclude_kv_cache=True))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def resolve_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype],
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
moe_cls = get_moe_cls(model_config, override_quant_config)
effective_quant_config = override_quant_config or model_config.quant_config
has_quant = (effective_quant_config is not None
and effective_quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True))
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
routing_method)):
return CutlassFusedMoE
return moe_cls
def resolve_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype],
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
moe_cls = get_moe_cls(model_config, override_quant_config)
effective_quant_config = override_quant_config or model_config.quant_config
has_quant = (effective_quant_config is not None
and effective_quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True))
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
routing_method)):
return CutlassFusedMoE
return moe_cls
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/create_moe.py` around lines 80 - 96, In
resolve_moe_cls, replace the inconsistent access to
effective_quant_config.layer_quant_mode with effective_quant_config.quant_mode
so it matches get_moe_cls; specifically update the has_quant computation to call
effective_quant_config.quant_mode.has_any_quant(exclude_kv_cache=True). Keep the
same logic that checks TRTLLMGenFusedMoE, routing_method and returns
CutlassFusedMoE when appropriate, and ensure references to
ModelConfig.quant_config, QuantConfig.quant_mode, layer_quant_mode,
has_any_quant, resolve_moe_cls, get_moe_cls, TRTLLMGenFusedMoE and
CutlassFusedMoE are adjusted accordingly.

Comment on lines +723 to +738
w3_w1_permute_indices = self._get_w3_w1_permute_indices(
module.w3_w1_weight.data[0],
is_gated_act_gemm=getattr(module, "is_gated_activation", True))
w2_permute_indices = self._get_w2_permute_indices(
module.w2_weight.data[0])

processed_w3_w1 = torch.stack([
_prepare_bf16_weight_for_trtllm_gen(expert, w3_w1_permute_indices,
self.block_k)
for expert in module.w3_w1_weight.data
])
processed_w2 = torch.stack([
_prepare_bf16_weight_for_trtllm_gen(expert, w2_permute_indices,
self.block_k)
for expert in module.w2_weight.data
])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Final verification: check if current code would actually fail with GPU weights
sed -n '165,180p' tensorrt_llm/_torch/modules/fused_moe/quantization.py && echo "---" && sed -n '723,738p' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1505


Use pre-allocation instead of torch.stack to prevent memory spikes during expert weight processing.

At lines 729 and 734, the list comprehension with torch.stack creates all intermediate expert tensors in memory simultaneously before stacking, which temporarily doubles peak memory usage. For large MoE models with many experts, this can trigger OOM during weight loading.

Additionally, the code has a device placement issue: because _get_permute_indices caches results globally, indices may remain on CPU even when weights are on GPU, causing expert[permute_indices] to fail with a device mismatch error during GPU weight loading.

Use the pre-allocation approach with explicit device movement:

Recommended implementation
+        w3_w1_permute_indices = w3_w1_permute_indices.to(
+            module.w3_w1_weight.device)
+        w2_permute_indices = w2_permute_indices.to(module.w2_weight.device)
+
+        first_w3_w1 = _prepare_bf16_weight_for_trtllm_gen(
+            module.w3_w1_weight.data[0], w3_w1_permute_indices, self.block_k)
+        processed_w3_w1 = torch.empty(
+            (module.w3_w1_weight.shape[0], *first_w3_w1.shape),
+            dtype=first_w3_w1.dtype,
+            device=first_w3_w1.device,
+        )
+        processed_w3_w1[0].copy_(first_w3_w1)
+        for i in range(1, module.w3_w1_weight.shape[0]):
+            processed_w3_w1[i].copy_(
+                _prepare_bf16_weight_for_trtllm_gen(
+                    module.w3_w1_weight.data[i], w3_w1_permute_indices,
+                    self.block_k))
+
+        first_w2 = _prepare_bf16_weight_for_trtllm_gen(
+            module.w2_weight.data[0], w2_permute_indices, self.block_k)
+        processed_w2 = torch.empty(
+            (module.w2_weight.shape[0], *first_w2.shape),
+            dtype=first_w2.dtype,
+            device=first_w2.device,
+        )
+        processed_w2[0].copy_(first_w2)
+        for i in range(1, module.w2_weight.shape[0]):
+            processed_w2[i].copy_(
+                _prepare_bf16_weight_for_trtllm_gen(
+                    module.w2_weight.data[i], w2_permute_indices,
+                    self.block_k))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 723 -
738, The current list-comprehension + torch.stack for processed_w3_w1 and
processed_w2 creates all expert tensors at once (memory spike) and can trigger
device-mismatch because cached permute indices may be on CPU; instead,
pre-allocate the output tensors with the correct shape/dtype/device and fill
them per-expert in a for-loop. Specifically, for w3_w1: call
self._get_w3_w1_permute_indices(...) once, then ensure the returned
w3_w1_permute_indices is moved to the expert device (e.g., .to(expert.device))
before using it; allocate processed_w3_w1 = torch.empty((num_experts, ...),
device=module.w3_w1_weight.device, dtype=...) and fill each slot by calling
_prepare_bf16_weight_for_trtllm_gen(expert, permute_indices_on_device,
self.block_k) inside a loop. Do the same for w2 using _get_w2_permute_indices,
processed_w2 allocation, and per-expert filling to avoid peak memory doubling
and device mismatch.

Comment on lines +54 to +81
def extract_transpose_prefill_slice(
src: torch.Tensor,
num_prefill_tokens: int,
start_col: int,
width: int,
) -> torch.Tensor:
"""
Extract and transpose a contiguous prefill slice for causal_conv1d_fn.

Input: src[num_tokens, num_cols]
Output: [width, num_prefill_tokens]
"""
out = torch.empty(width, num_prefill_tokens, dtype=src.dtype, device=src.device)

BLOCK_SEQ, BLOCK_CONV = 32, 128
grid = (triton.cdiv(num_prefill_tokens, BLOCK_SEQ), triton.cdiv(width, BLOCK_CONV))

_extract_transpose_prefill_kernel[grid](
src,
out,
num_prefill_tokens,
src.shape[1],
start_col,
width,
BLOCK_SEQ,
BLOCK_CONV,
)
return out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Widen the store-side offsets too.

Lines 45-47 fixed the read-side overflow, but Lines 50-51 still compute conv_offsets * num_prefill_tokens in 32-bit. This new helper now reuses that kernel for arbitrary width, so long prefills can wrap the destination address and corrupt the transposed output.

🛠️ Suggested kernel fix
-    dst_offsets = conv_offsets[:, None] * num_prefill_tokens + seq_offsets[None, :]
+    dst_offsets = (
+        conv_offsets[:, None].to(tl.int64) * num_prefill_tokens
+        + seq_offsets[None, :].to(tl.int64)
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/mamba/fuse_elementwise_ops.py` around lines 54 -
81, The store-side address calculation in the transposition kernel needs 64-bit
arithmetic to avoid wrapping when computing conv_offsets * num_prefill_tokens;
update the _extract_transpose_prefill_kernel implementation so any computation
that multiplies conv_offsets by num_prefill_tokens (and any derived store
offsets) is performed in 64-bit (e.g., cast conv_offsets or num_prefill_tokens
to int64/torch.int64 or triton.int64 before multiplication) and ensure the
resulting offset used to write into out (the transposed buffer) is a 64-bit
value; adjust kernel argument types if necessary to accept the widened offset
and keep the same call from extract_transpose_prefill_slice.

Comment on lines +178 to +181
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-CUTLASS]
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp1-TRTLLM]
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp1]
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_fp8[tp2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and read the lines in question
if [ -f "tests/integration/test_lists/qa/llm_function_core_sanity.txt" ]; then
  echo "=== Lines 178-181 from sanity test list ==="
  sed -n '175,185p' "tests/integration/test_lists/qa/llm_function_core_sanity.txt" | cat -n
else
  echo "File not found"
fi

Repository: NVIDIA/TensorRT-LLM

Length of output: 975


🏁 Script executed:

# Check the overall structure of the sanity file to understand patterns
echo "=== Overall structure of sanity test list ==="
wc -l "tests/integration/test_lists/qa/llm_function_core_sanity.txt"
echo ""
echo "=== Check for BF16 and FP8 patterns ==="
grep -n "bf16\|fp8\|BF16\|FP8" "tests/integration/test_lists/qa/llm_function_core_sanity.txt" | tail -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 3230


🏁 Script executed:

# Check if l0_b200.yml exists and look for BF16 TP2 entries
if [ -f "tests/integration/test_lists/test-db/l0_b200.yml" ]; then
  echo "=== Checking l0_b200.yml for BF16 TP2 references ==="
  grep -n "bf16\|BF16\|tp2\|TP2" "tests/integration/test_lists/test-db/l0_b200.yml" | head -20
else
  echo "l0_b200.yml not found"
fi

Repository: NVIDIA/TensorRT-LLM

Length of output: 731


🏁 Script executed:

# Search for TestQwen3_5_35B_A3B entries in the sanity file to see all variants
echo "=== All TestQwen3_5_35B_A3B entries in sanity list ==="
grep -n "TestQwen3_5_35B_A3B" "tests/integration/test_lists/qa/llm_function_core_sanity.txt"

Repository: NVIDIA/TensorRT-LLM

Length of output: 426


Confirm whether BF16 TP2 omission from sanity is intentional.

The sanity test list includes FP8 TP2 for TestQwen3_5_35B_A3B but only BF16 TP1 variants (CUTLASS and TRTLLM). However, BF16 TP2 variants are present in the B200 pre-merge list (l0_b200.yml), creating an asymmetry that may reduce early BF16 TP2 regression detection in the sanity lane. Please clarify if this exclusion is deliberate.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/test_lists/qa/llm_function_core_sanity.txt` around lines
178 - 181, Sanity list omits BF16 TP2 variants for TestQwen3_5_35B_A3B while FP8
TP2 is included and BF16 TP2 exists in l0_b200.yml; add the missing BF16 TP2
entries (e.g.,
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-CUTLASS]
and
accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-TRTLLM]) to
the sanity file to mirror the FP8 TP2 coverage, or if the omission is
intentional, add a comment in the sanity list explaining the deliberate
exclusion referencing TestQwen3_5_35B_A3B BF16 TP2 so reviewers know it was
intentional.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants