Skip to content

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Nov 14, 2025

Summary by CodeRabbit

Release Notes

  • New Features
    • Added automatic CUDA hardware capability detection to enable device-specific acceleration on compatible GPUs
    • Implemented optimized fused CUDA kernel-based routing for improved inference performance

For nemotron MOE:

1k/1k/8:
Baseline: Without FP8 kv cache 0.5945
With FP8 kv cache 0.6075
With FP8 kv cache + cublas kernel: 0.6417
With FP8 kv cache + cuda(no cublas) kernel + patch the NemotronHTopkRouter: 0.6733
With FP8 kv cache + cublas kernel + patch the NemotronHTopkRouter: 0.7288

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 14, 2025

📝 Walkthrough

Walkthrough

Two files in the auto_deploy module are modified: quant.py adds runtime CUDA capability detection and conditional enable_cuda_core flag based on SM 8.9 or 12.0 support, narrowing the CUDA-core path to specific hardware; nemotron_h.py introduces a new optimized forward method for NemotronHTopkRouter using fused CUDA kernel-based top-k routing, registered via module patching.

Changes

Cohort / File(s) Summary
CUDA Hardware Detection & Optimization
tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Adds runtime CUDA device capability detection. Determines if device is SM 8.9 or 12.0 and sets enable_cuda_core flag accordingly. Updates CUDA-core path selection logic to require both input size ≤ 8 AND enable_cuda_core true; otherwise falls back to cuBLAS.
Model Routing Optimization
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Introduces new _nemotron_h_topk_router_forward method implementing fused CUDA kernel-based top-k routing. Workflow: reshapes inputs, computes router logits via linear transformation, calls noaux_tc_op kernel for top-k weights and indices, returns results. Registers method via CUSTOM_MODULE_PATCHES for runtime patching of NemotronHTopkRouter.

Sequence Diagram

sequenceDiagram
    participant caller as Caller
    participant forward as _nemotron_h_topk_router_forward
    participant reshape as Reshape Input
    participant linear as Router Linear
    participant kernel as noaux_tc_op Kernel
    participant return as Return Results
    
    caller->>forward: hidden_states
    forward->>reshape: reshape(hidden_states)
    reshape->>linear: reshaped_input
    linear->>linear: compute router logits
    linear->>kernel: logits
    rect rgba(100, 150, 200, 0.2)
        note over kernel: CUDA kernel execution<br/>(top-k selection)
    end
    kernel->>kernel: extract top-k weights & indices
    kernel->>return: weights, indices
    return->>caller: (indices, weights)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • quant.py: Review CUDA capability detection logic; verify SM 8.9 and 12.0 are correct target architectures and that input size threshold of ≤ 8 is appropriate
  • nemotron_h.py: Verify noaux_tc_op kernel call correctness; confirm module patching mechanism via CUSTOM_MODULE_PATCHES properly replaces the original forward method at runtime; check shape transformations and tensor operations

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is incomplete and does not follow the template. It lacks a proper Description section explaining what changes were made and why, and provides only benchmark results without context for the code modifications. Add a comprehensive Description section explaining the rationale for changes to quant.py and nemotron_h.py patches. Include Test Coverage section listing relevant tests that validate the modifications.
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: performance improvement for small batch size in AutoDeploy, and follows the required format with [None][feat] prefix.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)

91-117: Fused router forward looks consistent with existing MOE interface; consider reshape instead of view.

The new _nemotron_h_topk_router_forward keeps the contract that self.gate(hidden_states) returns (topk_indices, topk_weights) in the shape expected by torch_ops.auto_deploy.torch_moe, without extra reshaping, which is aligned with the NemotronH MOE usage pattern. Based on learnings.

One minor robustness tweak: hidden_states = hidden_states.view(-1, self.config.hidden_size) assumes that hidden_states is contiguous. To avoid surprises if a non‑contiguous tensor ever reaches this router, using reshape (or .contiguous().view(...)) would be safer:

-    hidden_states = hidden_states.view(-1, self.config.hidden_size)
+    hidden_states = hidden_states.reshape(-1, self.config.hidden_size)

Behavior otherwise looks correct: logits are computed in fp32, and noaux_tc_op receives logits along with e_score_correction_bias, n_group, topk_group, top_k, and routed_scaling_factor in a sensible order.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cc4c980 and 2761887.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (2 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.
📚 Learning: 2025-09-23T15:13:48.819Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/multimem.h:20-30
Timestamp: 2025-09-23T15:13:48.819Z
Learning: TRT-LLM targets modern CUDA toolkits that support FP8 datatypes, so cuda_fp8.h can be included unconditionally without version guards in TRT-LLM code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
📚 Learning: 2025-09-23T15:12:38.312Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/thop/allreduceOp.cpp:352-446
Timestamp: 2025-09-23T15:12:38.312Z
Learning: In TensorRT-LLM NCCL device implementation, NCCL version 2.28+ requirements are handled at runtime in the nccl_device/config layer rather than with compile-time guards. This allows the allreduceOp to remain version-agnostic and delegates version compatibility validation to the appropriate lower-level components that can gracefully handle unsupported configurations.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
📚 Learning: 2025-10-20T17:07:18.745Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py:98-116
Timestamp: 2025-10-20T17:07:18.745Z
Learning: In NemotronH models (tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py), the gate (self.gate) returns topk_indices and topk_weights that are already in the correct shape to be passed directly to torch_ops.auto_deploy.torch_moe without needing to reshape them when hidden_states is flattened.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)
cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h (3)
  • n_group (241-241)
  • topk_group (243-243)
  • top_k (240-240)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py (1)

169-169: Patch registration for NemotronHTopkRouter forward looks correct.

Binding _nemotron_h_topk_router_forward via CUSTOM_MODULE_PATCHES["NemotronHTopkRouter"] is consistent with the other NemotronH patches and should seamlessly swap in the fused router implementation at load time.

@suyoggupta
Copy link
Collaborator

Could you please also post accuracy numbers for tp1, tp2

@nvchenghaoz
Copy link
Collaborator Author

Had to use the updated model to get a good accuracy number -

The accuracy I got is: MMLU: 73.879, gsm8k: 86.884

@nvchenghaoz nvchenghaoz enabled auto-merge (squash) November 17, 2025 19:07
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24797 [ run ] triggered by Bot. Commit: 5f121cb

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24797 [ run ] completed with state SUCCESS. Commit: 5f121cb
/LLM/main/L0_MergeRequest_PR pipeline #18711 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24808 [ run ] triggered by Bot. Commit: 5f121cb

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24808 [ run ] completed with state SUCCESS. Commit: 5f121cb
/LLM/main/L0_MergeRequest_PR pipeline #18721 completed with status: 'FAILURE'

@galagam
Copy link
Collaborator

galagam commented Nov 18, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24902 [ run ] triggered by Bot. Commit: 5f121cb

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24902 [ run ] completed with state SUCCESS. Commit: 5f121cb
/LLM/main/L0_MergeRequest_PR pipeline #18803 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24934 [ run ] triggered by Bot. Commit: 5f121cb

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24934 [ run ] completed with state SUCCESS. Commit: 5f121cb
/LLM/main/L0_MergeRequest_PR pipeline #18834 completed with status: 'SUCCESS'

@nvchenghaoz nvchenghaoz merged commit f0b68e4 into NVIDIA:main Nov 18, 2025
5 checks passed
@github-project-automation github-project-automation bot moved this from Backlog to Done in AutoDeploy Board Nov 18, 2025
lkomali pushed a commit to lkomali/TensorRT-LLM that referenced this pull request Nov 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants