Skip to content

Conversation

@nvkgoyal
Copy link

@nvkgoyal nvkgoyal commented Aug 5, 2025

Summary by CodeRabbit

  • New Features

    • Added optional split-batch attention/MoE overlap mode controlled by runtime flags to enable parallel processing of batch halves.
    • Introduced configuration options to toggle the feature and set split batch size.
    • Provided a new benchmarking script to generate datasets, configure runs, and profile multiple overlap modes.
  • Chores

    • Added extensive runtime debug logging across attention, decoder, MoE, and execution paths to aid troubleshooting.
    • Relaxed an internal shape assertion under the new mode to support half-batch paths without affecting default behavior.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 5, 2025

Walkthrough

Adds a new benchmarking script and introduces an optional split-batch overlap execution path across model engine, attention backend, and DeepSeek-V3 model. When enabled via flags/env, attention metadata is split into two halves, processed via auxiliary CUDA streams, with alternating metadata references and adjusted collective ops; extensive debug logging is added.

Changes

Cohort / File(s) Summary of Changes
Benchmark Orchestration Script
run_deepseek_batch_split.sh
New script to configure and run DeepSeek-R1 batch-splitting and inter-layer overlap benchmarks; argument parsing, env var exports, dataset generation/verification, attention DP config generation, optional Nsight profiling, and multiple modes (no split, inter-layer only, split only, full overlap).
Attention Backend (TRT-LLM)
tensorrt_llm/_torch/attention_backend/trtllm.py
Added debug logs; prepare() signature updated to accept splitBatchOverlap; conditional assignment of attention metadata to half1/half2; minor forward adjustments for FP4 result wrapping and logging.
DeepSeek-V3 Model Split Path
tensorrt_llm/_torch/models/modeling_deepseekv3.py
Optional split-batch path: splits batch into two halves, runs on auxiliary CUDA streams, shallow-copies and adjusts attention/KV metadata per half, synchronizes via event, concatenates outputs; adds stream management and extensive debug logs.
Attention Module Alternation
tensorrt_llm/_torch/modules/attention.py
Introduces global toggle to alternate between attention_metadata_half1 and attention_metadata_half2 in extract_extra_attrs; extensive debug prints across MLA forward paths; retains original behavior when halves absent.
Model Engine Metadata Halving
tensorrt_llm/_torch/pyexecutor/model_engine.py
Adds env-flagged creation of attn_metadata_half1/half2 with sliced fields; passes halves through inputs; stores weak refs in model extras; verbose debug logging around prepare/forward and input inspection.
Config: Split-Batch Flags
tensorrt_llm/llmapi/llm_args.py
New Pydantic class SplitBatchAttnMoeOverlapConfig with enable_split_batch_attn_moe_overlap and split_batch_size.
Distributed Ops Tolerance
tensorrt_llm/_torch/distributed/ops.py
allgather assertion relaxed under ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP to accept half-sized tensors; minor commented debug notes.
MoE Wide EP Debug
tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Added debug prints in forward_chunk and forward for shapes/dtypes.
Multi-Stream Utilities
tensorrt_llm/_torch/modules/multi_stream_utils.py
Added commented-out debug print in maybe_execute_in_parallel; no behavioral change.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Script as run_deepseek_batch_split.sh
  participant Engine as ModelEngine
  participant Attn as TrtllmAttention
  participant Model as DeepseekV3 Decoder
  participant MoE as MoE/MLP

  User->>Script: Invoke with split/overlap mode
  Script->>Engine: Set env vars, launch benchmark
  Engine->>Engine: Prepare attn_metadata
  alt Split-batch enabled
    Engine->>Engine: Create attn_metadata_half1/half2
    Engine->>Attn: prepare(splitBatchOverlap=1) [half1]
    Engine->>Attn: prepare(splitBatchOverlap=2) [half2]
    Engine->>Model: forward(inputs + half1/half2 refs)
    Note over Model: Initializes aux CUDA streams
    par Half1 on Attn stream
      Model->>Attn: forward with metadata_half1
      Attn-->>Model: hidden_states_half1
      Model->>MoE: process half1
    and Half2 on MoEChunkingOverlap stream
      Model->>Attn: forward with metadata_half2
      Attn-->>Model: hidden_states_half2
      Model->>MoE: process half2
    end
    Model->>Model: Concatenate halves
    Model-->>Engine: hidden_states
  else No split
    Engine->>Attn: prepare()
    Engine->>Model: forward(inputs)
    Model-->>Engine: hidden_states
  end
  Engine-->>Script: Metrics/logs
  Script-->>User: Results saved
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Poem

A bunny split the batch in two, hop-hop on parallel streams,
With metadata halves to chew, it nibbled token dreams.
It synced, it stitched, with careful might,
Then logged the shapes by moonlit night—
“Overlap achieved!” it thumped in delight. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (3 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The title “Draft: Split atten_metadata” includes a non‐descriptive “Draft:” prefix and does not follow the repository’s ticket/type template, making it too vague and not clearly summarizing the primary change implemented. Update the title to follow the repository naming convention (e.g., “[None][feat] Enable split batch attention metadata overlap”) and remove the “Draft:” prefix, ensuring it concisely reflects the core feature addition.
Description Check ⚠️ Warning The PR description remains the unfilled template with no summary of the changes, explanation of the solution, or test coverage details, so it provides no substantive information about what was done or why. Please complete the description by filling in the template’s sections: provide a concise summary of the changes and their rationale, list the tests covering the new code paths, and confirm adherence to the PR checklist items.
Docstring Coverage ⚠️ Warning Docstring coverage is 27.59% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1171-1214: Remove debug prints for production readiness.

Similar to other methods, this forward method contains multiple debug print statements that should be removed for production use.

-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Input keys: {list(kwargs.keys())}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - input_ids shape: {input_ids.shape if input_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - inputs_embeds shape: {inputs_embeds.shape if inputs_embeds is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - return_context_logits: {return_context_logits}")
-        
         attn_metadata.num_generations_per_batch = self.model_nextn + 1
         hidden_states = self.model(
             input_ids=input_ids,
             attn_metadata=attn_metadata,
             position_ids=position_ids,
             inputs_embeds=inputs_embeds,
         )
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Model output hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")

         if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
             # get logits
             logits = self.logits_processor.forward(
                 hidden_states[spec_metadata.gather_ids],
                 self.lm_head,
                 attn_metadata,
                 True,
             )
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - MTP logits shape: {logits.shape}, dtype: {logits.dtype}")
             # get accepted tokens and next draft tokens
             return self.mtp_worker(
                 input_ids=input_ids,
                 position_ids=position_ids,
                 hidden_states=hidden_states,
                 logits=logits,
                 lm_head=self.lm_head,
                 embed_tokens=self.model.embed_tokens,
                 attn_metadata=attn_metadata,
                 spec_metadata=spec_metadata,
                 mtp_layers=self.model.layers[self.num_hidden_layers:])
         else:
             logits = self.logits_processor.forward(
                 hidden_states,
                 self.lm_head,
                 attn_metadata,
                 return_context_logits,
             )
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Final logits shape: {logits.shape}, dtype: {logits.dtype}")
             return logits
🧹 Nitpick comments (4)
run_deepseek_batch_split.sh (4)

13-13: Remove unused variable.

The variable disable_overlap_scheduler is declared but never used in the script.

-disable_overlap_scheduler="false"

3-7: Consider making hardcoded paths configurable.

The script contains hardcoded paths that may not be portable across different environments.

Consider making these configurable through environment variables or command line parameters:

model_card="${MODEL_CARD:-deepseek-ai/DeepSeek-R1}"
model_path="${MODEL_PATH:-/llm-models/DeepSeek-R1/DeepSeek-R1-FP4/}"
dataset_file="${DATASET_FILE:-/tmp/aa_prompt_50000.txt}"
nsys_on="${NSYS_ON:-true}"

194-194: Fix potential issue with variable declaration and assignment.

Shellcheck warns about masking return values when declaring and assigning in the same line.

-        local actual_lines=$(wc -l < "${dataset_file}")
+        local actual_lines
+        actual_lines=$(wc -l < "${dataset_file}")

260-268: Remove unused variable.

The nsys_prefix variable is assigned but never used in the script.

Either use the variable in the benchmark command or remove it:

-    nsys_prefix=""
-    # check NSYS_MODE is not empty
     if [ "${nsys_on}" == "true" ]; then
         nsys_file=${sub_dir}/nsys_worker_proc_${SLURM_PROCID}
-        nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
+        # Note: nsys profiling is handled by environment variables, not command prefix
         export TLLM_PROFILE_START_STOP=700-750
         export TLLM_PROFILE_RECORD_GC=1
         export TLLM_NVTX_DEBUG=1
     fi
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dcbfa7e and 615f600.

📒 Files selected for processing (5)
  • run_deepseek_batch_split.sh (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (3 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (6 hunks)
  • tensorrt_llm/_torch/modules/attention.py (10 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
**/*.{cpp,h,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧠 Learnings (3)
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • run_deepseek_batch_split.sh
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm, examples directory can have different dependency versions than the root requirement...
Learnt from: yibinl-nvidia
PR: NVIDIA/TensorRT-LLM#6506
File: examples/models/core/mixtral/requirements.txt:3-3
Timestamp: 2025-08-01T15:14:45.673Z
Learning: In TensorRT-LLM, examples directory can have different dependency versions than the root requirements.txt file. Version conflicts between root and examples dependencies are acceptable because examples are designed to be standalone and self-contained.

Applied to files:

  • run_deepseek_batch_split.sh
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/modules/attention.py

734-734: Line too long (123 > 120)

(E501)


753-753: Line too long (134 > 120)

(E501)


766-766: Line too long (131 > 120)

(E501)


791-791: Line too long (131 > 120)

(E501)


815-815: Line too long (134 > 120)

(E501)

tensorrt_llm/_torch/attention_backend/trtllm.py

1098-1098: Line too long (142 > 120)

(E501)


1099-1099: Line too long (142 > 120)

(E501)


1101-1101: Line too long (130 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/model_engine.py

2123-2123: Line too long (139 > 120)

(E501)


2136-2136: Line too long (140 > 120)

(E501)


2146-2146: Line too long (124 > 120)

(E501)

tensorrt_llm/_torch/models/modeling_deepseekv3.py

739-739: Line too long (137 > 120)

(E501)


740-740: Line too long (140 > 120)

(E501)


742-742: Line too long (175 > 120)

(E501)


743-743: Line too long (157 > 120)

(E501)


744-744: Line too long (147 > 120)

(E501)


745-745: Line too long (141 > 120)

(E501)


746-746: Line too long (158 > 120)

(E501)


747-747: Line too long (168 > 120)

(E501)


1078-1078: Line too long (121 > 120)

(E501)


1079-1079: Line too long (130 > 120)

(E501)


1080-1080: Line too long (133 > 120)

(E501)


1089-1089: Line too long (137 > 120)

(E501)


1093-1093: Line too long (132 > 120)

(E501)


1096-1096: Line too long (122 > 120)

(E501)


1103-1103: Line too long (123 > 120)

(E501)


1105-1105: Line too long (130 > 120)

(E501)


1172-1172: Line too long (127 > 120)

(E501)


1173-1173: Line too long (136 > 120)

(E501)


1174-1174: Line too long (139 > 120)

(E501)


1184-1184: Line too long (143 > 120)

(E501)

🪛 Shellcheck (0.10.0)
run_deepseek_batch_split.sh

[warning] 13-13: disable_overlap_scheduler appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 22-22: NABLE_ATTENTION_METADATA_SPLITTING appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 194-194: Declare and assign separately to avoid masking return values.

(SC2155)


[warning] 264-264: nsys_prefix appears unused. Verify use (or export if used externally).

(SC2034)

🔇 Additional comments (6)
tensorrt_llm/_torch/attention_backend/trtllm.py (2)

1097-1104: Replace print with logger.debug and gate behind a runtime flag

print(…) in a hot path spams stdout, breaks multi-process training logs and materially degrades performance once sequence length × layer count grows.
Please:

  1. Use the project logger (tensorrt_llm.logger.logger) with the DEBUG level so verbosity can be enabled/disabled centrally.
  2. Guard the branch with an env-var or if logger.isEnabledFor(logging.DEBUG): to avoid string-formatting overhead when debug logs are off.
-import torch
+import logging
 import torch
 …
 if logger.isEnabledFor(logging.DEBUG):
-    print(f"[DEBUG] TrtllmAttention.forward - q shape: {q.shape}, dtype: {q.dtype}")
+    logger.debug("TrtllmAttention.forward ‑ q %s dtype=%s", tuple(q.shape), q.dtype)

This pattern should be applied to every print added in this hunk.
[ suggest_essential_refactor ]

Also applies to: 1185-1186, 1203-1204


1198-1202: Return-type annotation now lies when NVFP4 path is hit

forward() is annotated to return torch.Tensor, yet in the NVFP4 branch it now returns an Fp4QuantizedTensor. Down-stream callers relying on the declared type (e.g., static analysis, IDEs, or mypy) will mis-behave.

Consider either:

-    def forward(... ) -> torch.Tensor:
+    from typing import Union
+    def forward(... ) -> Union[torch.Tensor, Fp4QuantizedTensor]:

or keep the public contract unchanged by returning the underlying torch.Tensor and exposing the scale separately.
[ request_verification ]

tensorrt_llm/_torch/modules/attention.py (1)

733-735: Debug print calls will flood logs and hurt performance; switch to logger.debug

print() executed per forward pass inside MLA.forward_impl / MLA.forward can easily emit millions of lines in real workloads, clogging I/O and slowing kernels.
Replace them with:

import logging
LOG = logging.getLogger(__name__)
...
if LOG.isEnabledFor(logging.DEBUG):
    LOG.debug("MLA.forward_impl – hidden %s dtype=%s", hidden_states.shape, hidden_states.dtype)

Also, several of these lines exceed the 120-char limit flagged by Ruff (E501). Converting to logger.debug with positional formatting naturally shortens them.

Apply the same change to every print added in this hunk.
[ suggest_essential_refactor ]

Also applies to: 753-767, 791-792, 815-816, 854-855, 1278-1299

tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

71-71: LGTM! Import is correctly placed.

The copy module import is properly positioned and necessary for the shallow copying of attention metadata in the batch splitting logic.

run_deepseek_batch_split.sh (2)

270-286: LGTM! Well-structured benchmark execution.

The trtllm-bench command is properly constructed with appropriate parameters and logging. The command structure is clear and follows good practices for parameter passing and output redirection.


315-334: Excellent debugging and validation approach.

The Python import test is a great way to validate that the batch split modules are properly available before running the benchmark. This proactive approach helps catch configuration issues early.

SYNC_AFTER_ATTENTION="true"
ENABLE_LAYER_COMPLETION_EVENTS="false"
ENABLE_INTRA_LAYER_BATCH_SPLITTING="true"
NABLE_ATTENTION_METADATA_SPLITTING="true"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix typo in variable name.

There's a typo in the variable name - missing 'E' at the beginning of NABLE_ATTENTION_METADATA_SPLITTING.

-NABLE_ATTENTION_METADATA_SPLITTING="true"
+ENABLE_ATTENTION_METADATA_SPLITTING="true"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
NABLE_ATTENTION_METADATA_SPLITTING="true"
ENABLE_ATTENTION_METADATA_SPLITTING="true"
🧰 Tools
🪛 Shellcheck (0.10.0)

[warning] 22-22: NABLE_ATTENTION_METADATA_SPLITTING appears unused. Verify use (or export if used externally).

(SC2034)

🤖 Prompt for AI Agents
In run_deepseek_batch_split.sh at line 22, fix the typo in the variable name by
adding the missing 'E' at the beginning, changing
`NABLE_ATTENTION_METADATA_SPLITTING` to `ENABLE_ATTENTION_METADATA_SPLITTING`.

eplb_num_slots=${7:-0}
updates_per_iter=${8:-1}
log_dir=${9:-"/scratch/TensorRT-LLM/results_batch_split_$(date +%Y%m%d)"}
mode=${10:-2} # Default to mode 2 (only inter-layer overlap)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Inconsistent default mode value.

Line 123 sets the default mode to 2, but the usage documentation at line 40-44 and the script comments suggest mode 4 should be the default for "full overlap".

Consider updating the default to match the documentation:

-mode=${10:-2}  # Default to mode 2 (only inter-layer overlap)
+mode=${10:-4}  # Default to mode 4 (full overlap)

Or update the documentation to clarify why mode 2 is the default.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mode=${10:-2} # Default to mode 2 (only inter-layer overlap)
mode=${10:-4} # Default to mode 4 (full overlap)
🤖 Prompt for AI Agents
In run_deepseek_batch_split.sh at line 123, the default mode is set to 2, which
conflicts with the usage documentation and comments around lines 40-44 that
indicate mode 4 should be the default for "full overlap." To fix this, update
the default value in line 123 from 2 to 4 to align with the documentation, or
alternatively, revise the documentation and comments to explain why mode 2 is
the intended default.

Comment on lines 737 to 790

# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states shape: {hidden_states.shape if hidden_states is not None else None}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata: {attn_metadata}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_pool_pointers: {attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.tokens_per_block: {attn_metadata.kv_cache_manager.tokens_per_block}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.max_seq_len: {attn_metadata.kv_cache_manager.max_seq_len}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.head_dim: {attn_metadata.kv_cache_manager.head_dim}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_block_offsets: {attn_metadata.kv_cache_block_offsets.shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.host_kv_cache_block_offsets: {attn_metadata.host_kv_cache_block_offsets.shape}")
num_requests = position_ids.shape[1]
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests: {num_requests}")

if num_requests == 64 :
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests == 64")
hidden_states_half1, hidden_states_half2 = hidden_states.chunk(2, dim=0)
position_ids_half1, position_ids_half2 = position_ids.chunk(2, dim=1)

attn_metadata_half1 = copy.copy(attn_metadata)
attn_metadata_half1.max_num_requests = 32

attn_metadata_half2 = copy.copy(attn_metadata)
attn_metadata_half2.max_num_requests = 32

hidden_states_half1 = self.self_attn(
position_ids=position_ids_half1,
hidden_states=hidden_states_half1,
attn_metadata=attn_metadata_half1,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)

hidden_states_half2 = self.self_attn(
position_ids=position_ids_half2,
hidden_states=hidden_states_half2,
attn_metadata=attn_metadata_half2,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)

hidden_states = torch.cat([hidden_states_half1, hidden_states_half2], dim=0)

else:
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Remove debug prints and fix potential issues in batch splitting logic.

This code segment introduces extensive debug prints and conditional batch splitting logic. Several concerns:

  1. Debug prints should be removed - These verbose debug statements will impact performance and clutter logs in production
  2. Hard-coded batch splitting condition - The num_requests == 64 condition is too specific and inflexible
  3. Shallow copy concerns - Using copy.copy() on attn_metadata may not properly duplicate all necessary internal state
  4. Missing validation - No validation that chunking produces expected tensor dimensions

Apply this diff to address the issues:

-        # Self Attention
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states shape: {hidden_states.shape if hidden_states is not None else None}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata: {attn_metadata}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_pool_pointers: {attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.tokens_per_block: {attn_metadata.kv_cache_manager.tokens_per_block}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.max_seq_len: {attn_metadata.kv_cache_manager.max_seq_len}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.head_dim: {attn_metadata.kv_cache_manager.head_dim}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_block_offsets: {attn_metadata.kv_cache_block_offsets.shape}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.host_kv_cache_block_offsets: {attn_metadata.host_kv_cache_block_offsets.shape}")
-        num_requests = position_ids.shape[1]
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests: {num_requests}")
-        
-        if num_requests == 64 :
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests == 64")
-            hidden_states_half1, hidden_states_half2 = hidden_states.chunk(2, dim=0)
-            position_ids_half1, position_ids_half2 = position_ids.chunk(2, dim=1)
-
-            attn_metadata_half1 = copy.copy(attn_metadata)
-            attn_metadata_half1.max_num_requests = 32
-            
-            attn_metadata_half2 = copy.copy(attn_metadata)
-            attn_metadata_half2.max_num_requests = 32
-         
-            hidden_states_half1 = self.self_attn(
-                position_ids=position_ids_half1,
-                hidden_states=hidden_states_half1,
-                attn_metadata=attn_metadata_half1,
-                all_reduce_params=AllReduceParams(
-                    enable_allreduce=not (self.disable_attn_allreduce)),
-                **kwargs,
-            )
-
-            hidden_states_half2 = self.self_attn(
-                position_ids=position_ids_half2,
-                hidden_states=hidden_states_half2,
-                attn_metadata=attn_metadata_half2,
-                all_reduce_params=AllReduceParams(
-                    enable_allreduce=not (self.disable_attn_allreduce)),
-                **kwargs,
-            )
-
-            hidden_states = torch.cat([hidden_states_half1, hidden_states_half2], dim=0)
-            
-        else:
-            hidden_states = self.self_attn(
-                position_ids=position_ids,
-                hidden_states=hidden_states,
-                attn_metadata=attn_metadata,
-                all_reduce_params=AllReduceParams(
-                    enable_allreduce=not (self.disable_attn_allreduce)),
-                **kwargs,
-            )
+        # Self Attention
+        hidden_states = self.self_attn(
+            position_ids=position_ids,
+            hidden_states=hidden_states,
+            attn_metadata=attn_metadata,
+            all_reduce_params=AllReduceParams(
+                enable_allreduce=not (self.disable_attn_allreduce)),
+            **kwargs,
+        )

If batch splitting is required, consider making it configurable via environment variables or model config rather than hard-coding specific request counts.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states shape: {hidden_states.shape if hidden_states is not None else None}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata: {attn_metadata}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_pool_pointers: {attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.tokens_per_block: {attn_metadata.kv_cache_manager.tokens_per_block}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.max_seq_len: {attn_metadata.kv_cache_manager.max_seq_len}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.head_dim: {attn_metadata.kv_cache_manager.head_dim}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_block_offsets: {attn_metadata.kv_cache_block_offsets.shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.host_kv_cache_block_offsets: {attn_metadata.host_kv_cache_block_offsets.shape}")
num_requests = position_ids.shape[1]
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests: {num_requests}")
if num_requests == 64 :
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests == 64")
hidden_states_half1, hidden_states_half2 = hidden_states.chunk(2, dim=0)
position_ids_half1, position_ids_half2 = position_ids.chunk(2, dim=1)
attn_metadata_half1 = copy.copy(attn_metadata)
attn_metadata_half1.max_num_requests = 32
attn_metadata_half2 = copy.copy(attn_metadata)
attn_metadata_half2.max_num_requests = 32
hidden_states_half1 = self.self_attn(
position_ids=position_ids_half1,
hidden_states=hidden_states_half1,
attn_metadata=attn_metadata_half1,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
hidden_states_half2 = self.self_attn(
position_ids=position_ids_half2,
hidden_states=hidden_states_half2,
attn_metadata=attn_metadata_half2,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
hidden_states = torch.cat([hidden_states_half1, hidden_states_half2], dim=0)
else:
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)
🧰 Tools
🪛 Ruff (0.12.2)

739-739: Line too long (137 > 120)

(E501)


740-740: Line too long (140 > 120)

(E501)


742-742: Line too long (175 > 120)

(E501)


743-743: Line too long (157 > 120)

(E501)


744-744: Line too long (147 > 120)

(E501)


745-745: Line too long (141 > 120)

(E501)


746-746: Line too long (158 > 120)

(E501)


747-747: Line too long (168 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py around lines 737 to 790,
remove all debug print statements to avoid performance impact and log clutter.
Replace the hard-coded batch splitting condition (num_requests == 64) with a
configurable parameter from environment variables or model configuration. Use a
deep copy method instead of copy.copy() for attn_metadata to ensure all internal
state is duplicated correctly. Add validation checks after chunking tensors to
confirm the resulting dimensions are as expected before proceeding.

Comment on lines 2106 to 2117
print(f"[DEBUG] _forward_step - Input keys: {list(inputs.keys())}")
for key, value in inputs.items():
if hasattr(value, 'shape'):
print(f"[DEBUG] _forward_step - {key} shape: {value.shape}, dtype: {value.dtype}")
elif isinstance(value, (list, tuple)):
print(f"[DEBUG] _forward_step - {key}: {type(value)} with {len(value)} elements")
else:
print(f"[DEBUG] _forward_step - {key}: {type(value)}")

if gather_ids is not None:
print(f"[DEBUG] _forward_step - gather_ids shape: {gather_ids.shape}, dtype: {gather_ids.dtype}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace debug prints with conditional logging.

These unconditional debug print statements will execute in production and may impact performance. Consider using the logger with appropriate log levels instead.

Apply this diff to use proper logging:

-        print(f"[DEBUG] _forward_step - Input keys: {list(inputs.keys())}")
-        for key, value in inputs.items():
-            if hasattr(value, 'shape'):
-                print(f"[DEBUG] _forward_step - {key} shape: {value.shape}, dtype: {value.dtype}")
-            elif isinstance(value, (list, tuple)):
-                print(f"[DEBUG] _forward_step - {key}: {type(value)} with {len(value)} elements")
-            else:
-                print(f"[DEBUG] _forward_step - {key}: {type(value)}")
-        
-        if gather_ids is not None:
-            print(f"[DEBUG] _forward_step - gather_ids shape: {gather_ids.shape}, dtype: {gather_ids.dtype}")
+        if logger.level <= logging.DEBUG:
+            logger.debug(f"_forward_step - Input keys: {list(inputs.keys())}")
+            for key, value in inputs.items():
+                if hasattr(value, 'shape'):
+                    logger.debug(f"_forward_step - {key} shape: {value.shape}, dtype: {value.dtype}")
+                elif isinstance(value, (list, tuple)):
+                    logger.debug(f"_forward_step - {key}: {type(value)} with {len(value)} elements")
+                else:
+                    logger.debug(f"_forward_step - {key}: {type(value)}")
+            
+            if gather_ids is not None:
+                logger.debug(f"_forward_step - gather_ids shape: {gather_ids.shape}, dtype: {gather_ids.dtype}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] _forward_step - Input keys: {list(inputs.keys())}")
for key, value in inputs.items():
if hasattr(value, 'shape'):
print(f"[DEBUG] _forward_step - {key} shape: {value.shape}, dtype: {value.dtype}")
elif isinstance(value, (list, tuple)):
print(f"[DEBUG] _forward_step - {key}: {type(value)} with {len(value)} elements")
else:
print(f"[DEBUG] _forward_step - {key}: {type(value)}")
if gather_ids is not None:
print(f"[DEBUG] _forward_step - gather_ids shape: {gather_ids.shape}, dtype: {gather_ids.dtype}")
if logger.level <= logging.DEBUG:
logger.debug(f"_forward_step - Input keys: {list(inputs.keys())}")
for key, value in inputs.items():
if hasattr(value, 'shape'):
logger.debug(f"_forward_step - {key} shape: {value.shape}, dtype: {value.dtype}")
elif isinstance(value, (list, tuple)):
logger.debug(f"_forward_step - {key}: {type(value)} with {len(value)} elements")
else:
logger.debug(f"_forward_step - {key}: {type(value)}")
if gather_ids is not None:
logger.debug(f"_forward_step - gather_ids shape: {gather_ids.shape}, dtype: {gather_ids.dtype}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 2106 to 2117,
replace all unconditional print debug statements with calls to a logger using an
appropriate debug log level. Use the existing logger instance if available, or
create one, and change each print statement to logger.debug to ensure debug
messages are only output when debug logging is enabled, preventing performance
impact in production.

Comment on lines 2119 to 2124
print(f"[DEBUG] _forward_step - After _preprocess_inputs, keys: {list(inputs.keys())}")

if inputs.get('spec_metadata', None):
gather_ids = inputs['spec_metadata'].gather_ids
print(f"[DEBUG] _forward_step - Updated gather_ids from spec_metadata: {gather_ids.shape if gather_ids is not None else None}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix logging and line length violation.

Same logging concerns as above, plus Line 2123 exceeds the 120 character limit.

Apply this diff:

-        print(f"[DEBUG] _forward_step - After _preprocess_inputs, keys: {list(inputs.keys())}")
-        
+        if logger.level <= logging.DEBUG:
+            logger.debug(f"_forward_step - After _preprocess_inputs, keys: {list(inputs.keys())}")
+        
-            print(f"[DEBUG] _forward_step - Updated gather_ids from spec_metadata: {gather_ids.shape if gather_ids is not None else None}")
+            if logger.level <= logging.DEBUG:
+                gather_shape = gather_ids.shape if gather_ids is not None else None
+                logger.debug(f"_forward_step - Updated gather_ids from spec_metadata: {gather_shape}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] _forward_step - After _preprocess_inputs, keys: {list(inputs.keys())}")
if inputs.get('spec_metadata', None):
gather_ids = inputs['spec_metadata'].gather_ids
print(f"[DEBUG] _forward_step - Updated gather_ids from spec_metadata: {gather_ids.shape if gather_ids is not None else None}")
if logger.level <= logging.DEBUG:
logger.debug(f"_forward_step - After _preprocess_inputs, keys: {list(inputs.keys())}")
if inputs.get('spec_metadata', None):
gather_ids = inputs['spec_metadata'].gather_ids
if logger.level <= logging.DEBUG:
gather_shape = gather_ids.shape if gather_ids is not None else None
logger.debug(f"_forward_step - Updated gather_ids from spec_metadata: {gather_shape}")
🧰 Tools
🪛 Ruff (0.12.2)

2123-2123: Line too long (139 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 2119 to 2124, the
debug print statements should be replaced with proper logging calls to follow
project standards, and the line at 2123 exceeds the 120 character limit. Replace
print statements with logger.debug calls and break long lines into shorter ones
to comply with the line length limit.

Comment on lines 2125 to 2132
if self.without_logits:
print(f"[DEBUG] _forward_step - Calling model_forward without logits")
outputs = self.model_forward(**inputs)
print(f"[DEBUG] _forward_step - Output keys: {list(outputs.keys())}")
for key, value in outputs.items():
if hasattr(value, 'shape'):
print(f"[DEBUG] _forward_step - output.{key} shape: {value.shape}, dtype: {value.dtype}")
return outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use conditional logging instead of print statements.

Convert these debug prints to proper logging as well.

Apply this diff:

-            print(f"[DEBUG] _forward_step - Calling model_forward without logits")
+            if logger.level <= logging.DEBUG:
+                logger.debug("_forward_step - Calling model_forward without logits")
-            print(f"[DEBUG] _forward_step - Output keys: {list(outputs.keys())}")
-            for key, value in outputs.items():
-                if hasattr(value, 'shape'):
-                    print(f"[DEBUG] _forward_step - output.{key} shape: {value.shape}, dtype: {value.dtype}")
+                logger.debug(f"_forward_step - Output keys: {list(outputs.keys())}")
+                for key, value in outputs.items():
+                    if hasattr(value, 'shape'):
+                        logger.debug(f"_forward_step - output.{key} shape: {value.shape}, dtype: {value.dtype}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.without_logits:
print(f"[DEBUG] _forward_step - Calling model_forward without logits")
outputs = self.model_forward(**inputs)
print(f"[DEBUG] _forward_step - Output keys: {list(outputs.keys())}")
for key, value in outputs.items():
if hasattr(value, 'shape'):
print(f"[DEBUG] _forward_step - output.{key} shape: {value.shape}, dtype: {value.dtype}")
return outputs
if self.without_logits:
if logger.level <= logging.DEBUG:
logger.debug("_forward_step - Calling model_forward without logits")
logger.debug(f"_forward_step - Output keys: {list(outputs.keys())}")
for key, value in outputs.items():
if hasattr(value, 'shape'):
logger.debug(f"_forward_step - output.{key} shape: {value.shape}, dtype: {value.dtype}")
outputs = self.model_forward(**inputs)
return outputs
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 2125 to 2132,
replace the print debug statements with conditional logging calls using the
appropriate logger instance. Use logger.debug() for these messages instead of
print(), ensuring that debug output respects the logging configuration and can
be enabled or disabled as needed.

Comment on lines 2136 to 2147
print(f"[DEBUG] _forward_step - Calling model_forward with return_context_logits={gather_ids is not None or gather_context_logits}")
logits = self.model_forward(
**inputs,
return_context_logits=gather_ids is not None
or gather_context_logits,
)
print(f"[DEBUG] _forward_step - Raw logits shape: {logits.shape}, dtype: {logits.dtype}")

if gather_ids is not None:
return {'logits': logits[gather_ids]}
gathered_logits = logits[gather_ids]
print(f"[DEBUG] _forward_step - Gathered logits shape: {gathered_logits.shape}, dtype: {gathered_logits.dtype}")
return {'logits': gathered_logits}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix logging and line length violations.

Multiple line length violations (Lines 2136, 2146) and unconditional debug prints need to be addressed.

Apply this diff:

-        print(f"[DEBUG] _forward_step - Calling model_forward with return_context_logits={gather_ids is not None or gather_context_logits}")
+        if logger.level <= logging.DEBUG:
+            return_ctx_logits = gather_ids is not None or gather_context_logits
+            logger.debug(f"_forward_step - Calling model_forward with return_context_logits={return_ctx_logits}")
-        print(f"[DEBUG] _forward_step - Raw logits shape: {logits.shape}, dtype: {logits.dtype}")
-        
+            logger.debug(f"_forward_step - Raw logits shape: {logits.shape}, dtype: {logits.dtype}")
+        
-            gathered_logits = logits[gather_ids]
-            print(f"[DEBUG] _forward_step - Gathered logits shape: {gathered_logits.shape}, dtype: {gathered_logits.dtype}")
-            return {'logits': gathered_logits}
+            if logger.level <= logging.DEBUG:
+                gathered_logits = logits[gather_ids]
+                logger.debug(f"_forward_step - Gathered logits shape: {gathered_logits.shape}, dtype: {gathered_logits.dtype}")
+                return {'logits': gathered_logits}
+            else:
+                return {'logits': logits[gather_ids]}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] _forward_step - Calling model_forward with return_context_logits={gather_ids is not None or gather_context_logits}")
logits = self.model_forward(
**inputs,
return_context_logits=gather_ids is not None
or gather_context_logits,
)
print(f"[DEBUG] _forward_step - Raw logits shape: {logits.shape}, dtype: {logits.dtype}")
if gather_ids is not None:
return {'logits': logits[gather_ids]}
gathered_logits = logits[gather_ids]
print(f"[DEBUG] _forward_step - Gathered logits shape: {gathered_logits.shape}, dtype: {gathered_logits.dtype}")
return {'logits': gathered_logits}
if logger.level <= logging.DEBUG:
return_ctx_logits = gather_ids is not None or gather_context_logits
logger.debug(f"_forward_step - Calling model_forward with return_context_logits={return_ctx_logits}")
logger.debug(f"_forward_step - Raw logits shape: {logits.shape}, dtype: {logits.dtype}")
logits = self.model_forward(
**inputs,
return_context_logits=gather_ids is not None
or gather_context_logits,
)
if gather_ids is not None:
if logger.level <= logging.DEBUG:
gathered_logits = logits[gather_ids]
logger.debug(f"_forward_step - Gathered logits shape: {gathered_logits.shape}, dtype: {gathered_logits.dtype}")
return {'logits': gathered_logits}
else:
return {'logits': logits[gather_ids]}
🧰 Tools
🪛 Ruff (0.12.2)

2136-2136: Line too long (140 > 120)

(E501)


2146-2146: Line too long (124 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 2136 to 2147, the
debug print statements cause line length violations and are unconditional.
Replace print statements with proper logging calls at the debug level, and break
long lines to respect line length limits. Ensure debug logs are conditional on
the logger's debug level to avoid unnecessary output.

Comment on lines 2149 to 2150
print(f"[DEBUG] _forward_step - Returning original logits")
return {'logits': logits}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use conditional logging.

Complete the logging conversion for consistency.

Apply this diff:

-            print(f"[DEBUG] _forward_step - Returning original logits")
+            if logger.level <= logging.DEBUG:
+                logger.debug("_forward_step - Returning original logits")

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 2149 to 2150,
replace the print statement used for debugging with a conditional logging call
to maintain consistent logging practices. Use the appropriate logger instance to
log the debug message instead of print, ensuring the message is only logged when
debug logging is enabled.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

🔭 Outside diff range comments (4)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1265-1308: Remove debug prints to improve performance.

Multiple debug print statements throughout this method will create performance overhead and excessive logging in production.

Apply this diff to remove the debug prints:

-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Input keys: {list(kwargs.keys())}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - input_ids shape: {input_ids.shape if input_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - inputs_embeds shape: {inputs_embeds.shape if inputs_embeds is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - return_context_logits: {return_context_logits}")
-        
         attn_metadata.num_generations_per_batch = self.model_nextn + 1
         hidden_states = self.model(
             input_ids=input_ids,
             attn_metadata=attn_metadata,
             position_ids=position_ids,
             inputs_embeds=inputs_embeds,
         )
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Model output hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")

         if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
             # get logits
             logits = self.logits_processor.forward(
                 hidden_states[spec_metadata.gather_ids],
                 self.lm_head,
                 attn_metadata,
                 True,
             )
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - MTP logits shape: {logits.shape}, dtype: {logits.dtype}")
             # get accepted tokens and next draft tokens
             return self.mtp_worker(
                 input_ids=input_ids,
                 position_ids=position_ids,
                 hidden_states=hidden_states,
                 logits=logits,
                 lm_head=self.lm_head,
                 embed_tokens=self.model.embed_tokens,
                 attn_metadata=attn_metadata,
                 spec_metadata=spec_metadata,
                 mtp_layers=self.model.layers[self.num_hidden_layers:])
         else:
             logits = self.logits_processor.forward(
                 hidden_states,
                 self.lm_head,
                 attn_metadata,
                 return_context_logits,
             )
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Final logits shape: {logits.shape}, dtype: {logits.dtype}")
             return logits
tensorrt_llm/_torch/attention_backend/trtllm.py (3)

735-795: Consolidate debug logging and address line length violations.

The debug prints provide valuable insights into the preparation process, but several improvements are needed:

  1. Line length violations: Lines 742 and 782-783 exceed 120 characters
  2. Inconsistent logging: Mix of print statements and scattered debug information
  3. Performance considerations: Debug prints will execute unconditionally

Consider consolidating debug logging:

-        print(f"[DEBUG] TrtllmAttention.prepare - num_seqs: {self.num_seqs}")
-        print(f"[DEBUG] TrtllmAttention.prepare - self.kv_cache_params.use_cache: {self.kv_cache_params.use_cache}")
-        print(f"[DEBUG] TrtllmAttention.prepare - self.kv_cache_params.num_cached_tokens_per_seq: {self.kv_cache_params.num_cached_tokens_per_seq}")
-        print(f"[DEBUG] TrtllmAttention.prepare - cached_token_lens: {cached_token_lens}")
-        print(f"[DEBUG] TrtllmAttention.prepare - self.seq_lens_kv: {self.seq_lens_kv}")
+        if logger.level <= logger.DEBUG:
+            logger.debug(
+                "TrtllmAttentionMetadata.prepare - num_seqs=%d, use_cache=%s, "
+                "cached_tokens=%s, seq_lens_kv=%s",
+                self.num_seqs, self.kv_cache_params.use_cache,
+                self.kv_cache_params.num_cached_tokens_per_seq, self.seq_lens_kv
+            )

1125-1175: Address severe line length violations and optimize debug logging.

The debug logging in the forward method has significant issues:

  1. Severe line length violations: Lines 1164-1175 exceed 120 characters by up to 170 characters
  2. Performance impact: Extensive string formatting will execute unconditionally
  3. Readability: The multi-line debug statement is difficult to read and maintain

The line length violations are severe and must be addressed:

-        print(f"[DEBUG] TrtllmAttention.forward wrapper.plan layer_idx: {self.get_local_layer_idx(metadata)}\
-              tokens_per_block: {metadata.tokens_per_block}, max_num_requests: {metadata.max_num_requests},\
-              max_seq_len: {metadata.max_seq_len}, max_num_tokens: {metadata.max_num_tokens}\
-              attention_window_size: {attention_window_size}, sink_token_length: {0}, beam_width: {metadata.beam_width}\
-              sequence_length: {metadata.kv_lens_cuda_runtime.shape if metadata.kv_lens_cuda_runtime is not None else None}, host_past_key_value_lengths: {metadata.kv_lens_runtime.shape if metadata.kv_lens_runtime is not None else None}\
-              context_lengths: {metadata.prompt_lens_cuda_runtime.shape if metadata.prompt_lens_cuda_runtime is not None else None}, host_context_lengths: {metadata.prompt_lens_cpu_runtime.shape if metadata.prompt_lens_cpu_runtime is not None else None}\
-              host_request_types: {metadata.host_request_types_runtime.shape if metadata.host_request_types_runtime is not None else None}, kv_cache_block_offsets: {metadata.kv_cache_block_offsets.shape if metadata.kv_cache_block_offsets is not None else None}\
-              host_kv_cache_block_offsets: {metadata.host_kv_cache_block_offsets.shape if metadata.host_kv_cache_block_offsets is not None else None}, host_kv_cache_pool_pointers: {metadata.host_kv_cache_pool_pointers.shape if metadata.host_kv_cache_pool_pointers is not None else None}\
-              host_kv_cache_pool_mapping: {metadata.host_kv_cache_pool_mapping.shape if metadata.host_kv_cache_pool_mapping is not None else None}, block_ids_per_seq: {metadata.block_ids_per_seq.shape if metadata.block_ids_per_seq is not None else None}\
-              workspace: {metadata.workspace.shape if metadata.workspace is not None else None}, cache_indirection: {metadata.cache_indirection.shape if metadata.cache_indirection is not None else None}, kv_scale_orig_quant: {self.kv_scale_orig_quant}\
-              kv_scale_quant_orig: {self.kv_scale_quant_orig}, out_scale: {out_scale.shape if out_scale is not None else None}, out_scale_sf: {out_scale_sf.shape if out_scale_sf is not None else None}\
-              latent_cache: {latent_cache.shape if latent_cache is not None else None}, q_pe: {q_pe.shape if q_pe is not None else None}, mrope_config: {mrope_config}, mla_context_paged_kv: {mla_context_paged_kv.shape if mla_context_paged_kv is not None else None}\
-              mla_context_kv_cache_block_offsets: {mla_context_kv_cache_block_offsets.shape if mla_context_kv_cache_block_offsets is not None else None}, softmax_stats_tensor: {softmax_stats_tensor.shape if softmax_stats_tensor is not None else None}\
-              is_spec_decoding_enabled: {metadata.is_spec_decoding_enabled}, use_spec_decoding: {metadata.use_spec_decoding}\
-              spec_decoding_position_offsets: {metadata.spec_decoding_position_offsets.shape if metadata.spec_decoding_position_offsets is not None else None}, spec_decoding_packed_mask: {metadata.spec_decoding_packed_mask.shape if metadata.spec_decoding_packed_mask is not None else None}\
-              spec_decoding_generation_lengths: {metadata.spec_decoding_generation_lengths}")
+        if logger.level <= logger.DEBUG:
+            logger.debug(
+                "TrtllmAttention.forward wrapper.plan - layer_idx=%d, "
+                "tokens_per_block=%s, max_requests=%d, max_seq_len=%d",
+                self.get_local_layer_idx(metadata), metadata.tokens_per_block,
+                metadata.max_num_requests, metadata.max_seq_len
+            )

1-11: Missing required NVIDIA copyright header.

According to the coding guidelines, "All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted."

Add the required NVIDIA copyright header at the top of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import math
 import os
 import weakref
♻️ Duplicate comments (11)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

739-884: Remove debug prints and fix batch splitting implementation.

This code segment has the same issues that were previously identified in past reviews:

  1. Debug prints impact performance - Extensive debug statements will degrade performance in production
  2. Hard-coded condition - The num_requests == 64 check is inflexible and environment-dependent
  3. Shallow copy concerns - Manual field modifications on copied attn_metadata may not preserve all internal state correctly
  4. Complex stream management - The CUDA stream orchestration adds significant complexity

The implementation needs the same fixes as previously suggested: remove debug prints, make batch splitting configurable, use proper deep copying, and add validation checks.


1172-1200: Remove debug prints for production readiness.

These debug print statements are identical to those flagged in previous reviews. They will negatively impact performance and create excessive log output in production environments.

As previously recommended, all debug print statements should be removed to improve performance and avoid excessive logging.

tensorrt_llm/_torch/pyexecutor/model_engine.py (9)

1439-1439: Replace debug print with logging.


1448-1448: Replace debug print with logging.


1465-1466: Document the splitBatchOverlap parameter.

The prepare() method is called with splitBatchOverlap=1 and splitBatchOverlap=2 for the two halves, but this parameter lacks documentation. Please add comments explaining what these values represent and how they affect the attention computation.


2143-2143: Replace debug print with logging.


2161-2172: Replace debug prints with logging.


2174-2179: Replace debug prints with logging.


2181-2187: Replace debug prints with logging.


2191-2202: Replace debug prints with logging.


2204-2205: Replace debug print with logging.

🧹 Nitpick comments (5)
tensorrt_llm/llmapi/llm_args.py (1)

62-71: Tidy docstring and enforce basic validation.

Ruff flags (D205/D415/W291) plus trailing whitespace stem from this block.
Consider:

-class SplitBatchAttnMoeOverlapConfig(BaseModel):
-    """
-    Configuration for split batch attention and MoE overlap.
-        Baseline batch_size = 2N |Attention|MOE||Attention|MOE| 
-        Split batch_size = N |Attention|MOE|           |Attention|MOE|
-                         = N           |Attention|MOE|           |Attention|MOE|
-    """
-    enable_split_batch_attn_moe_overlap: bool = Field(default=False, description="Enable split batch attention.")
-    split_batch_size: int = Field(default=0, description="Split batch size.")
+class SplitBatchAttnMoeOverlapConfig(BaseModel):
+    """Configuration for split-batch attention & MoE overlap.
+
+    Baseline (`batch_size = 2 N`)
+        |Attention|MOE||Attention|MOE|
+
+    Split (`batch_size = N`)
+        |Attention|MOE|           |Attention|MOE|
+    """
+
+    enable_split_batch_attn_moe_overlap: bool = Field(
+        default=False,
+        description="Toggle split-batch Attention/MoE overlap.",
+    )
+    split_batch_size: int = Field(
+        default=0,
+        description="Effective per-split batch size (N). Must be > 0 when the feature is enabled.",
+    )
+
+    @field_validator("split_batch_size")
+    @classmethod
+    def _check_positive(cls, v, info):
+        if info.data.get("enable_split_batch_attn_moe_overlap") and v <= 0:
+            raise ValueError("split_batch_size must be positive when overlap is enabled.")
+        return v

Fixes style issues, removes trailing whitespace, and adds a sanity check.

tensorrt_llm/_torch/attention_backend/trtllm.py (2)

698-713: LGTM! Good implementation of split batch overlap support.

The addition of the optional splitBatchOverlap parameter is well-implemented:

  • Maintains backward compatibility by making the parameter optional
  • Uses appropriate conditional logic to handle different overlap scenarios
  • Correctly uses weakref to avoid circular references when storing global attributes
  • Follows proper naming conventions

Consider using the logger instead of print statements for consistency:

-        print(f"[DEBUG] TrtllmAttention.prepare")
+        logger.debug("TrtllmAttentionMetadata.prepare called")

1244-1248: Final debug prints follow the same pattern - consider logging optimization.

These final debug statements are consistent with the overall debugging approach but could benefit from the same optimizations suggested for other debug prints (using logger with appropriate levels).

tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

1453-1453: Remove commented-out code.

This commented line should be removed or properly explained if it's needed for future implementation.

-            #attn_metadata_half1.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[:32]

2053-2054: Incomplete TODO comment.

The TODO comment appears to be incomplete or incorrectly formatted. Please complete the comment or remove it if not needed.

-        #kgoyal TODO: This is where the 
+        # TODO(kgoyal): Complete this comment or remove if not needed
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 615f600 and 08e4f0a.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/attention_backend/interface.py (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (9 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (6 hunks)
  • tensorrt_llm/_torch/modules/attention.py (15 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (6 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/llmapi/llm_args.py
🧠 Learnings (4)
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm, test files (files under tests/ directories) do not require nvidia copyright headers...
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: in tensorrt_llm/executor/worker.py, the lora adapter cache optimization logic that checks `is_adapte...
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks `is_adapter_in_cpu_cache()` and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py

739-739: Line too long (140 > 120)

(E501)


741-741: Line too long (175 > 120)

(E501)


742-742: Line too long (157 > 120)

(E501)


743-743: Line too long (147 > 120)

(E501)


744-744: Line too long (141 > 120)

(E501)


745-745: Line too long (158 > 120)

(E501)


746-746: Line too long (168 > 120)

(E501)


750-750: Line too long (129 > 120)

(E501)


769-769: Line too long (133 > 120)

(E501)


777-777: Line too long (136 > 120)

(E501)


786-786: Line too long (133 > 120)

(E501)


795-795: Line too long (136 > 120)

(E501)


867-867: Line too long (162 > 120)

(E501)


868-868: Line too long (162 > 120)

(E501)


869-869: Line too long (162 > 120)

(E501)


870-870: Line too long (162 > 120)

(E501)


1180-1180: Line too long (137 > 120)

(E501)


1184-1184: Line too long (132 > 120)

(E501)


1187-1187: Line too long (122 > 120)

(E501)


1194-1194: Line too long (123 > 120)

(E501)


1196-1196: Line too long (130 > 120)

(E501)


1265-1265: Line too long (139 > 120)

(E501)


1275-1275: Line too long (143 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/model_engine.py

1439-1439: Line too long (174 > 120)

(E501)


1447-1447: Line too long (147 > 120)

(E501)


1453-1453: Line too long (133 > 120)

(E501)


1463-1463: Line too long (147 > 120)

(E501)


1465-1465: Line too long (190 > 120)

(E501)


1472-1472: Line too long (133 > 120)

(E501)


1481-1481: Line too long (147 > 120)

(E501)


1483-1483: Line too long (190 > 120)

(E501)


2178-2178: Line too long (139 > 120)

(E501)


2191-2191: Line too long (140 > 120)

(E501)


2201-2201: Line too long (124 > 120)

(E501)

tensorrt_llm/_torch/attention_backend/interface.py

260-260: Line too long (123 > 120)

(E501)

tensorrt_llm/_torch/attention_backend/trtllm.py

415-415: Line too long (133 > 120)

(E501)


416-416: Line too long (240 > 120)

(E501)


417-417: Line too long (171 > 120)

(E501)


712-712: Line too long (129 > 120)

(E501)


742-742: Line too long (148 > 120)

(E501)


782-782: Line too long (136 > 120)

(E501)


783-783: Line too long (146 > 120)

(E501)


1126-1126: Line too long (142 > 120)

(E501)


1127-1127: Line too long (142 > 120)

(E501)


1129-1129: Line too long (130 > 120)

(E501)


1164-1164: Line too long (237 > 120)

(E501)


1165-1165: Line too long (254 > 120)

(E501)


1166-1166: Line too long (261 > 120)

(E501)


1167-1167: Line too long (287 > 120)

(E501)


1168-1168: Line too long (254 > 120)

(E501)


1169-1169: Line too long (252 > 120)

(E501)


1170-1170: Line too long (201 > 120)

(E501)


1171-1171: Line too long (265 > 120)

(E501)


1172-1172: Line too long (251 > 120)

(E501)


1173-1173: Line too long (125 > 120)

(E501)


1174-1174: Line too long (290 > 120)

(E501)

tensorrt_llm/_torch/modules/attention.py

747-747: Line too long (123 > 120)

(E501)


766-766: Line too long (134 > 120)

(E501)


779-779: Line too long (131 > 120)

(E501)


797-797: Line too long (123 > 120)

(E501)


804-804: Line too long (131 > 120)

(E501)


828-828: Line too long (134 > 120)

(E501)

tensorrt_llm/llmapi/llm_args.py

63-67: 1 blank line required between summary line and description

(D205)


63-67: First line should end with a period, question mark, or exclamation point

Add closing punctuation

(D415)


64-64: Trailing whitespace

Remove trailing whitespace

(W291)

🔇 Additional comments (4)
tensorrt_llm/_torch/modules/attention.py (1)

320-335: Metadata-splitting logic is correct and requires no changes.

The code in attention.py uses a global ping (initialized to 0 at line 318) to alternate between attention_metadata_half1 and attention_metadata_half2 only when both are set by model_engine.py. Otherwise it falls back to the default attention_metadata. get_model_extra_attrs() simply returns the current thread-local attrs dict, which is populated in all cases (default and split) in model_engine.py. The observed behavior matches the intended split-batch attention pattern.

tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

72-73: LGTM - imports support batch splitting functionality.

The copy and weakref imports are appropriately added to support the batch splitting logic introduced later in the file.

tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

14-15: LGTM!

The copy module import is correctly placed and necessary for the attention metadata splitting functionality.


2145-2148: LGTM! Good use of weak references.

Storing weak references to the metadata halves is appropriate to avoid circular references and potential memory leaks.

Comment on lines 260 to 261
print(f"[DEBUG] AttentionMetadata.is_cross - self.seq_lens: {self.seq_lens}, self.seq_lens_kv: {self.seq_lens_kv}")
return self.seq_lens is not self.seq_lens_kv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Replace print with proper logger & avoid hot-path I/O.

is_cross is called in every forward pass; an unconditional print will spam stdout, hurt performance, and break silent service scenarios.
Use the project logger at DEBUG level (or drop the log entirely) and keep the line ≤ 120 chars.

-        print(f"[DEBUG] AttentionMetadata.is_cross - self.seq_lens: {self.seq_lens}, self.seq_lens_kv: {self.seq_lens_kv}")
+        logger = logging.getLogger(__name__)
+        logger.debug(
+            "AttentionMetadata.is_cross – seq_lens=%s, seq_lens_kv=%s",
+            self.seq_lens,
+            self.seq_lens_kv,
+        )

Add once near the other imports:

import logging
🧰 Tools
🪛 Ruff (0.12.2)

260-260: Line too long (123 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/attention_backend/interface.py around lines 260 to 261,
replace the unconditional print statement with a call to the project logger at
DEBUG level to avoid spamming stdout and hurting performance in the hot path.
Import the logging module near other imports if not already done, then use
logging.getLogger() to get the logger and log the debug message instead of
printing. Ensure the log message is concise and does not exceed 120 characters.

Comment on lines 408 to 417

print(f"[DEBUG] TrtllmAttention.forward - q shape: {q.shape}, \
sequence_length shape: {self.sequence_length.shape}, \
host_past_key_value_lengths shape: {self.host_past_key_value_lengths.shape}, \
context_lengths shape: {self.context_lengths.shape}, \
host_context_lengths shape: {self.host_context_lengths.shape}, \
host_request_types shape: {self.host_request_types.shape}, \
kv_cache_block_offsets shape: {self.kv_cache_block_offsets.shape if self.kv_cache_block_offsets is not None else None}, \
k shape: {k.shape if k is not None else None}, v shape: {v.shape if v is not None else None}, output shape: {output.shape if output is not None else None}, output_sf shape: {output_sf.shape if output_sf is not None else None}, \
out_dtype: {out_dtype if out_dtype is not None else None}, is_fused_qkv: {is_fused_qkv}, update_kv_cache: {update_kv_cache}, attention_mask: {attention_mask}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve debug logging approach and fix line length violations.

The debug print statements are useful for development but have several issues:

  1. Line length violations: Multiple lines exceed 120 characters (flagged by static analysis)
  2. Performance impact: These prints will execute in production unless controlled by flags
  3. Verbose output: The extensive logging may overwhelm logs during normal operation

Consider using the existing logger with appropriate debug levels:

-        
-        print(f"[DEBUG] TrtllmAttention.forward - q shape: {q.shape}, \
-            sequence_length shape: {self.sequence_length.shape}, \
-            host_past_key_value_lengths shape: {self.host_past_key_value_lengths.shape}, \
-            context_lengths shape: {self.context_lengths.shape}, \
-            host_context_lengths shape: {self.host_context_lengths.shape}, \
-            host_request_types shape: {self.host_request_types.shape}, \
-            kv_cache_block_offsets shape: {self.kv_cache_block_offsets.shape if self.kv_cache_block_offsets is not None else None}, \
-            k shape: {k.shape if k is not None else None}, v shape: {v.shape if v is not None else None}, output shape: {output.shape if output is not None else None}, output_sf shape: {output_sf.shape if output_sf is not None else None}, \
-            out_dtype: {out_dtype if out_dtype is not None else None}, is_fused_qkv: {is_fused_qkv}, update_kv_cache: {update_kv_cache}, attention_mask: {attention_mask}")
+        if logger.level <= logger.DEBUG:
+            logger.debug(
+                "TrtllmAttention.run - q_shape=%s, seq_len_shape=%s, "
+                "kv_cache_offsets_shape=%s, fused_qkv=%s",
+                q.shape, self.sequence_length.shape,
+                self.kv_cache_block_offsets.shape if self.kv_cache_block_offsets is not None else None,
+                is_fused_qkv
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] TrtllmAttention.forward - q shape: {q.shape}, \
sequence_length shape: {self.sequence_length.shape}, \
host_past_key_value_lengths shape: {self.host_past_key_value_lengths.shape}, \
context_lengths shape: {self.context_lengths.shape}, \
host_context_lengths shape: {self.host_context_lengths.shape}, \
host_request_types shape: {self.host_request_types.shape}, \
kv_cache_block_offsets shape: {self.kv_cache_block_offsets.shape if self.kv_cache_block_offsets is not None else None}, \
k shape: {k.shape if k is not None else None}, v shape: {v.shape if v is not None else None}, output shape: {output.shape if output is not None else None}, output_sf shape: {output_sf.shape if output_sf is not None else None}, \
out_dtype: {out_dtype if out_dtype is not None else None}, is_fused_qkv: {is_fused_qkv}, update_kv_cache: {update_kv_cache}, attention_mask: {attention_mask}")
if logger.level <= logger.DEBUG:
logger.debug(
"TrtllmAttention.run - q_shape=%s, seq_len_shape=%s, "
"kv_cache_offsets_shape=%s, fused_qkv=%s",
q.shape,
self.sequence_length.shape,
self.kv_cache_block_offsets.shape if self.kv_cache_block_offsets is not None else None,
is_fused_qkv,
)
🧰 Tools
🪛 Ruff (0.12.2)

415-415: Line too long (133 > 120)

(E501)


416-416: Line too long (240 > 120)

(E501)


417-417: Line too long (171 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/attention_backend/trtllm.py around lines 408 to 417,
replace the long debug print statement with a logger.debug call that respects
the logging level to avoid performance impact in production. Break the log
message into multiple shorter lines to comply with line length limits, and
ensure the logging is conditional on the debug level being enabled. Use the
existing logger instance instead of print, and format the message clearly but
concisely to reduce verbosity.

return q, k, v


ping = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Thread safety concern with global variable.

The global ping variable could cause race conditions in multi-threaded environments, leading to unpredictable behavior when alternating between metadata halves.

Consider using thread-local storage or passing the ping state through function parameters:

-ping = 0
+import threading
+_local = threading.local()
-    global ping
+    if not hasattr(_local, 'ping'):
+        _local.ping = 0
-        if ping % 2 == 0:
+        if _local.ping % 2 == 0:
             metadata = metadata_ref_half1()

         else:
             metadata = metadata_ref_half2()
-        ping += 1
+        _local.ping += 1

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py at line 318, the global variable
'ping' is used, which can cause race conditions in multi-threaded environments.
To fix this, remove the global 'ping' variable and instead manage the ping state
by passing it explicitly as a function parameter or by using thread-local
storage to ensure thread safety and avoid shared mutable state.

Comment on lines 346 to 347
print(f"[DEBUG] extract_extra_attrs - mla_layer: {mla_layer}")
print(f"[DEBUG] extract_extra_attrs - mla_layer type: {type(mla_layer)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove debug print statements for production.

Debug print statements should not be present in production code as they can impact performance and clutter logs.

Consider using proper logging with appropriate levels:

-    print(f"[DEBUG] extract_extra_attrs - mla_layer: {mla_layer}")
-    print(f"[DEBUG] extract_extra_attrs - mla_layer type: {type(mla_layer)}")
+    logger.debug(f"extract_extra_attrs - mla_layer: {mla_layer}")
+    logger.debug(f"extract_extra_attrs - mla_layer type: {type(mla_layer)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] extract_extra_attrs - mla_layer: {mla_layer}")
print(f"[DEBUG] extract_extra_attrs - mla_layer type: {type(mla_layer)}")
logger.debug(f"extract_extra_attrs - mla_layer: {mla_layer}")
logger.debug(f"extract_extra_attrs - mla_layer type: {type(mla_layer)}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py around lines 346 to 347, remove the
debug print statements as they should not be present in production code to avoid
performance impact and log clutter. Replace them with proper logging calls using
a logging framework at an appropriate debug level if needed for future
troubleshooting.

output: torch.Tensor,
) -> None:
metadata, mla_layer = extract_extra_attrs(layer_idx)
print(f"[DEBUG] MLA.forward_impl.mla_custom_op_inplace - metadata: {metadata}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove debug print statement.

Debug print statements should be replaced with proper logging for production readiness.

-    print(f"[DEBUG] MLA.forward_impl.mla_custom_op_inplace - metadata: {metadata}")
+    logger.debug(f"MLA custom op inplace - metadata: {metadata}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] MLA.forward_impl.mla_custom_op_inplace - metadata: {metadata}")
logger.debug(f"MLA custom op inplace - metadata: {metadata}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py at line 364, replace the debug print
statement with a proper logging call. Remove the print statement and use the
logging module to log the metadata at an appropriate log level, such as debug,
to ensure production readiness and better log management.

output[attn_output_context.shape[0]:, :] = attn_output_gen
attn_output_context = None
attn_output_gen = None
print(f"[DEBUG] MLA.forward_impl - Final combined output: {output.shape}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove debug print statement.

-        print(f"[DEBUG] MLA.forward_impl - Final combined output: {output.shape}")
+        logger.debug(f"MLA.forward_impl - Final combined output: {output.shape}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] MLA.forward_impl - Final combined output: {output.shape}")
logger.debug(f"MLA.forward_impl - Final combined output: {output.shape}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py at line 867, there is a debug print
statement outputting the shape of the final combined output. Remove this print
statement to clean up the code and avoid unnecessary console output in
production.

# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16

print(f"[DEBUG] MLA.forward_genertion {attn_metadata.is_cross} {attn_metadata}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove debug print statement.

-        print(f"[DEBUG] MLA.forward_genertion {attn_metadata.is_cross} {attn_metadata}")
+        logger.debug(f"MLA.forward_generation {attn_metadata.is_cross} {attn_metadata}")

Note: Also fix the typo "genertion" -> "generation" in the log message.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] MLA.forward_genertion {attn_metadata.is_cross} {attn_metadata}")
logger.debug(f"MLA.forward_generation {attn_metadata.is_cross} {attn_metadata}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py at line 1240, remove the debug print
statement entirely and correct the typo in the log message from "genertion" to
"generation" if you decide to keep any logging. Since the comment requests
removal, simply delete the print line to clean up the code.

raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")

print(f"[DEBUG] MLA.forward_generation - output shape: {output.shape}, dtype: {output.dtype}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove debug print statement.

-        print(f"[DEBUG] MLA.forward_generation - output shape: {output.shape}, dtype: {output.dtype}")
+        logger.debug(f"MLA.forward_generation - output shape: {output.shape}, dtype: {output.dtype}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] MLA.forward_generation - output shape: {output.shape}, dtype: {output.dtype}")
logger.debug(f"MLA.forward_generation - output shape: {output.shape}, dtype: {output.dtype}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py at line 1281, remove the debug print
statement that outputs the shape and dtype of the output in
MLA.forward_generation to clean up the code and avoid unnecessary console
output.

Comment on lines 1291 to 1314
print(f"[DEBUG] MLA.forward - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
print(f"[DEBUG] MLA.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
print(f"[DEBUG] MLA.forward - attn_metadata: {attn_metadata}")

attn_output = self.create_output(hidden_states)
print(f"[DEBUG] MLA.forward - Created output tensor shape: {attn_output.shape}, dtype: {attn_output.dtype}")

if self.register_to_config:
print(f"[DEBUG] MLA.forward - Using custom op for layer {self.layer_idx_str}")
torch.ops.trtllm.mla_custom_op_inplace(hidden_states, position_ids,
self.layer_idx_str,
attn_output)
else:
print(f"[DEBUG] MLA.forward - Using forward_impl")
self.forward_impl(position_ids,
hidden_states,
attn_metadata,
output=attn_output)

print(f"[DEBUG] MLA.forward - Before o_proj, attn_output shape: {attn_output.shape}")
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
print(f"[DEBUG] MLA.forward - After o_proj, attn_output shape: {attn_output.shape}, dtype: {attn_output.dtype}")
return attn_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove multiple debug print statements.

The forward method contains numerous debug print statements that should be replaced with proper logging.

-        print(f"[DEBUG] MLA.forward - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
-        print(f"[DEBUG] MLA.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] MLA.forward - attn_metadata: {attn_metadata}")
-        print(f"[DEBUG] MLA.forward - ")
+        logger.debug(f"MLA.forward - hidden_states: {hidden_states.shape}, dtype: {hidden_states.dtype}")
+        logger.debug(f"MLA.forward - position_ids: {position_ids.shape if position_ids is not None else None}")
+        logger.debug(f"MLA.forward - attn_metadata: {attn_metadata}")

-        print(f"[DEBUG] MLA.forward - Created output tensor shape: {attn_output.shape}, dtype: {attn_output.dtype}")
-        print(f"[DEBUG] MLA.forward - ")
+        logger.debug(f"MLA.forward - Created output tensor: {attn_output.shape}, dtype: {attn_output.dtype}")

-            print(f"[DEBUG] MLA.forward - Using custom op for layer {self.layer_idx_str}")
+            logger.debug(f"MLA.forward - Using custom op for layer {self.layer_idx_str}")

-            print(f"[DEBUG] MLA.forward - Using forward_impl")
+            logger.debug(f"MLA.forward - Using forward_impl")

-        print(f"[DEBUG] MLA.forward - ")
-        print(f"[DEBUG] MLA.forward - Before o_proj, attn_output shape: {attn_output.shape}")
+        logger.debug(f"MLA.forward - Before o_proj, attn_output shape: {attn_output.shape}")

-        print(f"[DEBUG] MLA.forward - After o_proj, attn_output shape: {attn_output.shape}, dtype: {attn_output.dtype}")
+        logger.debug(f"MLA.forward - After o_proj, attn_output: {attn_output.shape}, dtype: {attn_output.dtype}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"[DEBUG] MLA.forward - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
print(f"[DEBUG] MLA.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
print(f"[DEBUG] MLA.forward - attn_metadata: {attn_metadata}")
attn_output = self.create_output(hidden_states)
print(f"[DEBUG] MLA.forward - Created output tensor shape: {attn_output.shape}, dtype: {attn_output.dtype}")
if self.register_to_config:
print(f"[DEBUG] MLA.forward - Using custom op for layer {self.layer_idx_str}")
torch.ops.trtllm.mla_custom_op_inplace(hidden_states, position_ids,
self.layer_idx_str,
attn_output)
else:
print(f"[DEBUG] MLA.forward - Using forward_impl")
self.forward_impl(position_ids,
hidden_states,
attn_metadata,
output=attn_output)
print(f"[DEBUG] MLA.forward - Before o_proj, attn_output shape: {attn_output.shape}")
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
print(f"[DEBUG] MLA.forward - After o_proj, attn_output shape: {attn_output.shape}, dtype: {attn_output.dtype}")
return attn_output
logger.debug(f"MLA.forward - hidden_states: {hidden_states.shape}, dtype: {hidden_states.dtype}")
logger.debug(f"MLA.forward - position_ids: {position_ids.shape if position_ids is not None else None}")
logger.debug(f"MLA.forward - attn_metadata: {attn_metadata}")
attn_output = self.create_output(hidden_states)
logger.debug(f"MLA.forward - Created output tensor: {attn_output.shape}, dtype: {attn_output.dtype}")
if self.register_to_config:
logger.debug(f"MLA.forward - Using custom op for layer {self.layer_idx_str}")
torch.ops.trtllm.mla_custom_op_inplace(hidden_states, position_ids,
self.layer_idx_str,
attn_output)
else:
logger.debug(f"MLA.forward - Using forward_impl")
self.forward_impl(position_ids,
hidden_states,
attn_metadata,
output=attn_output)
logger.debug(f"MLA.forward - Before o_proj, attn_output shape: {attn_output.shape}")
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
logger.debug(f"MLA.forward - After o_proj, attn_output: {attn_output.shape}, dtype: {attn_output.dtype}")
return attn_output
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/attention.py around lines 1291 to 1314, replace
all debug print statements with appropriate logging calls using a logger
instance. Remove the print statements and use logger.debug() to log the same
messages, ensuring the debug information is properly managed and can be enabled
or disabled via logging configuration.

Comment on lines 1447 to 1484
if attn_metadata.num_contexts == 0 and len(attn_metadata.request_ids) == 64 and os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1":
print(f"[DEBUG] TrtllmAttention.prepare - {attn_metadata}")
attn_metadata_half1 = copy.copy(attn_metadata)
attn_metadata_half1.max_num_requests = 32
attn_metadata_half1.max_num_sequences = 32
attn_metadata_half1.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
#attn_metadata_half1.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[:32]
attn_metadata_half1.kv_cache_manager.tokens_per_block = 32
attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,:32,:,:]
attn_metadata_half1.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,:32,:,:]
attn_metadata_half1.seq_lens = attn_metadata.seq_lens[:32]
# This is done as is_cross check is seq_lens_kv and seq_lens are the same tensor
attn_metadata_half1.seq_lens_kv = attn_metadata_half1.seq_lens
attn_metadata_half1.prompt_lens = attn_metadata.prompt_lens[:32]
attn_metadata_half1.request_ids = attn_metadata.request_ids[:32]
attn_metadata_half1.kv_cache_params = copy.copy(attn_metadata.kv_cache_params)
attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq = copy.copy(attn_metadata.kv_cache_params.num_cached_tokens_per_seq[:32])
attn_metadata_half1.on_update()
print(f"[DEBUG] TrtllmAttention.prepare (half1) - attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq: {attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq}")
attn_metadata_half1.prepare(splitBatchOverlap=1)

attn_metadata_half2 = copy.copy(attn_metadata)
attn_metadata_half2.max_num_requests = 32
attn_metadata_half2.max_num_sequences = 32
attn_metadata_half2.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
#attn_metadata_half2.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[32:]
attn_metadata_half2.kv_cache_manager.tokens_per_block = 32
attn_metadata_half2.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,32:,:,:]
attn_metadata_half2.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,32:,:,:]
attn_metadata_half2.seq_lens = attn_metadata.seq_lens[32:]
attn_metadata_half2.seq_lens_kv = attn_metadata_half2.seq_lens
attn_metadata_half2.prompt_lens = attn_metadata.prompt_lens[32:]
attn_metadata_half2.request_ids = attn_metadata.request_ids[32:]
attn_metadata_half2.kv_cache_params = copy.copy(attn_metadata.kv_cache_params)
attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq = copy.copy(attn_metadata.kv_cache_params.num_cached_tokens_per_seq[32:])
attn_metadata_half2.on_update()
print(f"[DEBUG] TrtllmAttention.prepare (half2) - attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq: {attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq}")
attn_metadata_half2.prepare(splitBatchOverlap=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Critical: Incorrect modification of tokens_per_block and use of magic numbers.

Multiple issues in this batch splitting logic:

  1. Incorrect tokens_per_block modification: Lines 1454 and 1473 set tokens_per_block = 32, but this is a fundamental KV cache configuration that shouldn't change per batch half. This could cause memory corruption or incorrect KV cache indexing.

  2. Magic numbers: The values 64, 32 should be defined as constants for maintainability.

  3. Line length violations: Multiple lines exceed 120 characters.

Apply this diff to fix the critical issues:

+        # Constants for batch splitting configuration
+        SPLIT_BATCH_SIZE = 64
+        SPLIT_BATCH_HALF_SIZE = 32
+        
         attn_metadata_half1 = None
         attn_metadata_half2 = None
-        if attn_metadata.num_contexts == 0 and len(attn_metadata.request_ids) == 64 and os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1":
+        if (attn_metadata.num_contexts == 0 and 
+            len(attn_metadata.request_ids) == SPLIT_BATCH_SIZE and 
+            os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1"):
             print(f"[DEBUG] TrtllmAttention.prepare - {attn_metadata}")
             attn_metadata_half1 = copy.copy(attn_metadata)
-            attn_metadata_half1.max_num_requests = 32
-            attn_metadata_half1.max_num_sequences = 32
+            attn_metadata_half1.max_num_requests = SPLIT_BATCH_HALF_SIZE
+            attn_metadata_half1.max_num_sequences = SPLIT_BATCH_HALF_SIZE
             attn_metadata_half1.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
-            #attn_metadata_half1.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[:32]
-            attn_metadata_half1.kv_cache_manager.tokens_per_block = 32  
-            attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,:32,:,:]
-            attn_metadata_half1.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,:32,:,:]
-            attn_metadata_half1.seq_lens = attn_metadata.seq_lens[:32]
+            # Remove incorrect tokens_per_block modification - this is a cache configuration
+            attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[
+                :, :SPLIT_BATCH_HALF_SIZE, :, :]
+            attn_metadata_half1.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[
+                :, :SPLIT_BATCH_HALF_SIZE, :, :]
+            attn_metadata_half1.seq_lens = attn_metadata.seq_lens[:SPLIT_BATCH_HALF_SIZE]
             # This is done as is_cross check is seq_lens_kv and seq_lens are the same tensor
             attn_metadata_half1.seq_lens_kv = attn_metadata_half1.seq_lens
-            attn_metadata_half1.prompt_lens = attn_metadata.prompt_lens[:32]
-            attn_metadata_half1.request_ids = attn_metadata.request_ids[:32]
+            attn_metadata_half1.prompt_lens = attn_metadata.prompt_lens[:SPLIT_BATCH_HALF_SIZE]
+            attn_metadata_half1.request_ids = attn_metadata.request_ids[:SPLIT_BATCH_HALF_SIZE]
             attn_metadata_half1.kv_cache_params = copy.copy(attn_metadata.kv_cache_params)
-            attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq = copy.copy(attn_metadata.kv_cache_params.num_cached_tokens_per_seq[:32])
+            attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq = copy.copy(
+                attn_metadata.kv_cache_params.num_cached_tokens_per_seq[:SPLIT_BATCH_HALF_SIZE])

Similar changes should be applied to attn_metadata_half2 initialization.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.12.2)

1447-1447: Line too long (147 > 120)

(E501)


1453-1453: Line too long (133 > 120)

(E501)


1463-1463: Line too long (147 > 120)

(E501)


1465-1465: Line too long (190 > 120)

(E501)


1472-1472: Line too long (133 > 120)

(E501)


1481-1481: Line too long (147 > 120)

(E501)


1483-1483: Line too long (190 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 1447 to 1484, the
batch splitting logic incorrectly modifies tokens_per_block to 32 for each half,
which should remain unchanged as it is a fundamental KV cache parameter; revert
tokens_per_block to its original value from attn_metadata for both halves.
Replace all magic numbers 64 and 32 with named constants defined at the top of
the file for clarity and maintainability. Also, refactor any lines exceeding 120
characters by breaking them into shorter lines to comply with style guidelines.
Apply these fixes symmetrically to both attn_metadata_half1 and
attn_metadata_half2 initializations.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

1268-1310: Remove debug prints throughout the method.

The debug print statements on lines 1268-1273, 1281, 1291, and 1310 should be removed as they will negatively impact performance and create log clutter in production environments.

Apply this diff to remove the debug prints:

-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Input keys: {list(kwargs.keys())}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - input_ids shape: {input_ids.shape if input_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - inputs_embeds shape: {inputs_embeds.shape if inputs_embeds is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - return_context_logits: {return_context_logits}")
-        
         attn_metadata.num_generations_per_batch = self.model_nextn + 1
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Model output hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - MTP logits shape: {logits.shape}, dtype: {logits.dtype}")
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Final logits shape: {logits.shape}, dtype: {logits.dtype}")
♻️ Duplicate comments (2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

741-878: Remove debug prints and address batch splitting issues.

The extensive debug logging and batch splitting logic still contain the same issues identified in previous reviews:

  1. Debug prints impact performance - Lines 741-751 and throughout the batch splitting logic contain verbose debug statements
  2. Hard-coded batch splitting condition - The num_requests == 64 condition remains inflexible
  3. Shallow copy concerns - Using copy.copy() on attn_metadata may not properly duplicate internal state
  4. Code complexity - The batch splitting logic spans over 120 lines and makes the method difficult to maintain
  5. Line length violations - Multiple lines exceed 120 characters per static analysis

Please address the issues outlined in the previous review comment, particularly removing debug prints and making batch splitting configurable rather than hard-coded.


1175-1202: Remove debug prints for production readiness.

The debug print statements added throughout this method will negatively impact performance and create excessive log output in production environments. These should be removed as previously identified.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 08e4f0a and 3e8c876.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (6 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧠 Learnings (2)
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm, test files (files under tests/ directories) do not require nvidia copyright headers...
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py

739-739: Line too long (140 > 120)

(E501)


741-741: Line too long (175 > 120)

(E501)


742-742: Line too long (157 > 120)

(E501)


743-743: Line too long (147 > 120)

(E501)


744-744: Line too long (141 > 120)

(E501)


745-745: Line too long (158 > 120)

(E501)


746-746: Line too long (168 > 120)

(E501)


750-750: Line too long (129 > 120)

(E501)


760-760: Line too long (139 > 120)

(E501)


770-770: Line too long (133 > 120)

(E501)


778-778: Line too long (136 > 120)

(E501)


787-787: Line too long (133 > 120)

(E501)


796-796: Line too long (136 > 120)

(E501)


838-838: Line too long (134 > 120)

(E501)


870-870: Line too long (145 > 120)

(E501)


871-871: Line too long (145 > 120)

(E501)


1183-1183: Line too long (137 > 120)

(E501)


1187-1187: Line too long (132 > 120)

(E501)


1190-1190: Line too long (122 > 120)

(E501)


1197-1197: Line too long (123 > 120)

(E501)


1199-1199: Line too long (130 > 120)

(E501)


1268-1268: Line too long (139 > 120)

(E501)


1278-1278: Line too long (143 > 120)

(E501)

🔇 Additional comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

72-73: LGTM - Imports support batch splitting functionality.

The copy and weakref imports are necessary for the batch splitting logic implemented later in this file.

Comment on lines 739 to 878

# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs,
)

if isinstance(self.mlp, Deepseekv3MoE):
return self.forward_MoE(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states shape: {hidden_states.shape if hidden_states is not None else None}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata: {attn_metadata}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_pool_pointers: {attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.tokens_per_block: {attn_metadata.kv_cache_manager.tokens_per_block}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.max_seq_len: {attn_metadata.kv_cache_manager.max_seq_len}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.head_dim: {attn_metadata.kv_cache_manager.head_dim}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_block_offsets: {attn_metadata.kv_cache_block_offsets.shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.host_kv_cache_block_offsets: {attn_metadata.host_kv_cache_block_offsets.shape}")
num_requests = position_ids.shape[1]
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests: {num_requests}")

if num_requests == 64 and os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1" and attn_metadata.num_contexts == 0:

stream_half1 = torch.cuda.Stream()
stream_half2 = torch.cuda.Stream()

# Create CUDA event to synchronize between streams
event_half1_complete = torch.cuda.Event()

print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests == 64")
hidden_states_half1, hidden_states_half2 = hidden_states.chunk(2, dim=0)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1: {hidden_states_half1.shape} {hidden_states_half2.shape}")
position_ids_half1, position_ids_half2 = position_ids.chunk(2, dim=1)
residual_half1, residual_half2 = residual.chunk(2, dim=0)

attn_metadata_half1 = copy.copy(attn_metadata)
attn_metadata_half1.max_num_requests = 32
attn_metadata_half1.max_num_sequences = 32
#attn_metadata_half1.all_rank_num_tokens = [32, 32]
#attn_metadata_half1.all_rank_max_num_tokens = 32
attn_metadata_half1.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
#attn_metadata_half1.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[:32]
attn_metadata_half1.kv_cache_manager.tokens_per_block = 32
attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,:32,:,:]
attn_metadata_half1.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,:32,:,:]
attn_metadata_half1.seq_lens = attn_metadata.seq_lens[:32]
attn_metadata_half1.seq_lens_kv = attn_metadata.seq_lens_kv[:32]
attn_metadata_half1.prompt_lens = attn_metadata.prompt_lens[:32]
attn_metadata_half1.request_ids = attn_metadata.request_ids[:32]
attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq[:32]
attn_metadata_half1.on_update()

attn_metadata_half2 = copy.copy(attn_metadata)
attn_metadata_half2.max_num_requests = 32
attn_metadata_half2.max_num_sequences = 32
#attn_metadata_half2.all_rank_num_tokens = [32, 32]
#attn_metadata_half2.all_rank_max_num_tokens = 32
attn_metadata_half2.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
#attn_metadata_half2.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[32:]
attn_metadata_half2.kv_cache_manager.tokens_per_block = 32
attn_metadata_half2.kv_cache_manager.max_seq_len = 32
attn_metadata_half2.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,32:,:,:]
attn_metadata_half2.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,32:,:,:]
attn_metadata_half2.seq_lens = attn_metadata.seq_lens[32:]
attn_metadata_half2.seq_lens_kv = attn_metadata.seq_lens_kv[32:]
attn_metadata_half2.prompt_lens = attn_metadata.prompt_lens[32:]
attn_metadata_half2.request_ids = attn_metadata.request_ids[32:]
attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq[32:]
attn_metadata_half2.on_update()

print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata_half1: {attn_metadata_half1}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata_half2: {attn_metadata_half2}")

print(f"[DEBUG] DeepseekV3DecoderLayer.forward - kwargs: {kwargs}")

kwargs_half1 = kwargs.copy()
if 'attn_metadata' in kwargs_half1:
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - kwargs_half1: {kwargs_half1}")
del kwargs_half1['attn_metadata']

with torch.cuda.stream(stream_half1):
hidden_states_half1 = self.self_attn(
position_ids=position_ids_half1,
hidden_states=hidden_states_half1,
attn_metadata=attn_metadata_half1,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs_half1,
)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1: {hidden_states_half1.shape}")
# Record event when half1 is complete
event_half1_complete.record()

with torch.cuda.stream(stream_half1):
event_half1_complete.wait()
if isinstance(self.mlp, Deepseekv3MoE):
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - self.mlp is Deepseekv3MoE")
hidden_states_half1, residual_half1 = self.forward_MoE(
hidden_states=hidden_states_half1,
attn_metadata=attn_metadata_half1,
residual=residual_half1,
)
else:
assert isinstance(self.mlp, GatedMLP)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - self.mlp is GatedMLP")
hidden_states_half1, residual_half1 = self.forward_mlp(
hidden_states=hidden_states_half1,
residual=residual_half1,
)
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1 after MOE/GatedMLP: {hidden_states_half1.shape}")
kwargs_half2 = kwargs.copy()
if 'attn_metadata' in kwargs_half2:
del kwargs_half2['attn_metadata']

with torch.cuda.stream(stream_half2):
# Wait for half1 to complete before starting half2
event_half1_complete.wait()
hidden_states_half2 = self.self_attn(
position_ids=position_ids_half2,
hidden_states=hidden_states_half2,
attn_metadata=attn_metadata_half2,
all_reduce_params=AllReduceParams(
enable_allreduce=not (self.disable_attn_allreduce)),
**kwargs_half2,
)
if isinstance(self.mlp, Deepseekv3MoE):
hidden_states_half2, residual_half2 = self.forward_MoE(
hidden_states=hidden_states_half2,
attn_metadata=attn_metadata_half2,
residual=residual_half2,
)
else:
assert isinstance(self.mlp, GatedMLP)
hidden_states_half2, residual_half2 = self.forward_mlp(
hidden_states=hidden_states_half2,
residual=residual_half2,
)
# Wait for both streams to complete before concatenating
torch.cuda.current_stream().wait_stream(stream_half1)
torch.cuda.current_stream().wait_stream(stream_half2)

print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1: {hidden_states_half1[0].shape} {hidden_states_half1[1].shape}")
print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half2: {hidden_states_half2[0].shape} {hidden_states_half2[1].shape}")

hidden_states = torch.cat([hidden_states_half1, hidden_states_half2], dim=0)
residual = torch.cat([residual_half1, residual_half2], dim=0)
return hidden_states, residual
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Address line length violations and production readiness.

Multiple lines exceed the 120-character limit per static analysis. These violations are primarily in the debug print statements and batch splitting logic that should be refactored or removed for production readiness.

The current implementation contains experimental debugging code that needs to be either:

  1. Removed entirely if this is meant for production
  2. Made configurable through proper configuration classes rather than hard-coded conditions
  3. Moved to development/testing utilities if needed for benchmarking

Consider implementing batch splitting through the existing configuration system rather than environment variables and hard-coded request counts.

🧰 Tools
🪛 Ruff (0.12.2)

739-739: Line too long (140 > 120)

(E501)


741-741: Line too long (175 > 120)

(E501)


742-742: Line too long (157 > 120)

(E501)


743-743: Line too long (147 > 120)

(E501)


744-744: Line too long (141 > 120)

(E501)


745-745: Line too long (158 > 120)

(E501)


746-746: Line too long (168 > 120)

(E501)


750-750: Line too long (129 > 120)

(E501)


760-760: Line too long (139 > 120)

(E501)


770-770: Line too long (133 > 120)

(E501)


778-778: Line too long (136 > 120)

(E501)


787-787: Line too long (133 > 120)

(E501)


796-796: Line too long (136 > 120)

(E501)


838-838: Line too long (134 > 120)

(E501)


870-870: Line too long (145 > 120)

(E501)


871-871: Line too long (145 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py between lines 739 and 878,
the code contains many debug print statements and hard-coded batch splitting
logic that exceed 120 characters per line and are not suitable for production.
To fix this, remove or disable all debug print statements, replace hard-coded
batch size checks and environment variable conditions with configurable
parameters accessed via the existing configuration system, and move any
experimental or benchmarking code into separate development or testing utility
modules. Ensure all lines comply with the 120-character limit and that
production code does not rely on debug prints or hard-coded environment checks.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (7)
tensorrt_llm/_torch/distributed/ops.py (1)

179-193: Refactor environment-dependent assertions and remove debug prints.

This code segment has several production readiness concerns:

  1. Core distributed operations should not depend on environment variables for correctness
  2. Extensive debug prints will impact performance
  3. The conditional assertion logic appears experimental

Apply this diff to address the issues:

-            for val in input:
-                if val is not None:
-                    print(f"[DEBUG] allgather - val shape: {val.shape}, dtype: {val.dtype}")
-                    print(f"[DEBUG] allgather - val shape: {val.shape[dim]}, dtype: {val.dtype}")
-                    print(f"[DEBUG] allgather - sizes[mapping.tp_rank]: {sizes[mapping.tp_rank]}")
-            if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
-                assert all([
-                    val.shape[dim] == sizes[mapping.tp_rank] / 2 for val in input
-                    if val is not None
-                ])
-            else:
-                assert all([
-                    val.shape[dim] == sizes[mapping.tp_rank] for val in input
-                    if val is not None
-                ])
+            assert all([
+                val.shape[dim] == sizes[mapping.tp_rank] for val in input
+                if val is not None
+            ])

Consider making batch splitting configurable through proper configuration classes rather than environment variables.

tensorrt_llm/_torch/models/modeling_deepseekv3.py (6)

519-519: Remove debug print for production readiness.

This debug print statement will negatively impact performance and create excessive log output in production environments.

Apply this diff:

-                print(f"[DEBUG] Deepseekv3MoE.compute_routed_output - all_rank_num_tokens: {all_rank_num_tokens}")

552-552: Remove debug prints for production readiness.

These debug print statements will negatively impact performance and create excessive log output in production environments.

Apply this diff:

-            print(f"[DEBUG] Deepseekv3MoE.forward - _compute_shared_output - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
-            print(f"[DEBUG] Deepseekv3MoE.forward - _compute_routed_output - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}, all_rank_num_tokens: {all_rank_num_tokens}")

Also applies to: 560-560


913-913: Remove debug prints for production readiness.

These debug print statements will negatively impact performance and create excessive log output in production environments.

Apply this diff to remove all debug prints:

-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward_MoE - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward_MoE - self.fusion_config.PRE_MOE_FUSION: {self.fusion_config.PRE_MOE_FUSION}")
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward_MoE - No fusion")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward_MoE - do_finalize: {do_finalize}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward_MoE - hidden_states after MoE: {hidden_states.shape}, dtype: {hidden_states.dtype}")

Also applies to: 928-928, 940-940, 950-951, 954-954


1181-1208: Remove debug prints for production readiness.

Multiple debug print statements have been added that will negatively impact performance and create excessive log output in production environments.

Apply this diff to remove the debug prints:

-        print(f"[DEBUG] DeepSeekV3Model.forward - input_ids shape: {input_ids.shape if input_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3Model.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3Model.forward - inputs_embeds shape: {inputs_embeds.shape if inputs_embeds is not None else None}")
-        
-        
-            print(f"[DEBUG] DeepSeekV3Model.forward - Computed inputs_embeds shape: {inputs_embeds.shape}, dtype: {inputs_embeds.dtype}")
-        print(f"[DEBUG] DeepSeekV3Model.forward - Initial hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
-        for layer_idx, decoder_layer in enumerate(self.layers[:self.num_hidden_layers]):
-            print(f"[DEBUG] DeepSeekV3Model.forward - Layer {layer_idx} input hidden_states shape: {hidden_states.shape}")
+        for decoder_layer in self.layers[:self.num_hidden_layers]:
-            print(f"[DEBUG] DeepSeekV3Model.forward - Layer {layer_idx} output hidden_states shape: {hidden_states.shape}")
-        print(f"[DEBUG] DeepSeekV3Model.forward - Final hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")

1274-1279: Remove debug prints for production readiness.

These debug print statements will negatively impact performance and create excessive log output in production environments.

Apply this diff to remove the debug prints:

-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Input keys: {list(kwargs.keys())}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - input_ids shape: {input_ids.shape if input_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - inputs_embeds shape: {inputs_embeds.shape if inputs_embeds is not None else None}")
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - return_context_logits: {return_context_logits}")
-        
-        print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Model output hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - MTP logits shape: {logits.shape}, dtype: {logits.dtype}")
-            print(f"[DEBUG] DeepSeekV3ForCausalLM.forward - Final logits shape: {logits.shape}, dtype: {logits.dtype}")

Also applies to: 1287-1287, 1297-1297, 1316-1316


742-881: Critical: Remove experimental batch splitting code and debug prints.

This code segment contains experimental batch splitting logic with multiple critical issues:

  1. Hard-coded conditions - num_requests == 64 is too specific and inflexible
  2. Extensive debug prints - Will severely impact performance and create log clutter
  3. Environment variable dependency - Core model logic should not depend on environment variables
  4. Shallow copying concerns - Using copy.copy() may not properly duplicate all necessary state
  5. Line length violations - Multiple lines exceed 120 characters
  6. Complex conditional logic - Makes code hard to maintain and debug

This experimental code should be either:

  1. Removed entirely if not ready for production
  2. Moved to a separate experimental module with proper configuration
  3. Refactored to use proper configuration classes instead of environment variables

Apply this diff to remove the experimental code:

-        
-        # Self Attention
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - position_ids shape: {position_ids.shape if position_ids is not None else None}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states shape: {hidden_states.shape if hidden_states is not None else None}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata: {attn_metadata}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_pool_pointers: {attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.tokens_per_block: {attn_metadata.kv_cache_manager.tokens_per_block}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.max_seq_len: {attn_metadata.kv_cache_manager.max_seq_len}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.head_dim: {attn_metadata.kv_cache_manager.head_dim}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_block_offsets: {attn_metadata.kv_cache_block_offsets.shape}")
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.host_kv_cache_block_offsets: {attn_metadata.host_kv_cache_block_offsets.shape}")
-        num_requests = position_ids.shape[1]
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests: {num_requests}")
-        
-        if num_requests == 64 and os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1" and attn_metadata.num_contexts == 0:
-            
-            stream_half1 = torch.cuda.Stream()
-            stream_half2 = torch.cuda.Stream()
-
-            # Create CUDA event to synchronize between streams
-            event_half1_complete = torch.cuda.Event()
-
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - num_requests == 64")
-            hidden_states_half1, hidden_states_half2 = hidden_states.chunk(2, dim=0)
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1: {hidden_states_half1.shape} {hidden_states_half2.shape}")
-            position_ids_half1, position_ids_half2 = position_ids.chunk(2, dim=1)
-            residual_half1, residual_half2 = residual.chunk(2, dim=0)
-            
-            attn_metadata_half1 = copy.copy(attn_metadata)
-            attn_metadata_half1.max_num_requests = 32
-            attn_metadata_half1.max_num_sequences = 32
-            #attn_metadata_half1.all_rank_num_tokens = [32, 32]
-            #attn_metadata_half1.all_rank_max_num_tokens = 32
-            attn_metadata_half1.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
-            #attn_metadata_half1.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[:32]
-            attn_metadata_half1.kv_cache_manager.tokens_per_block = 32  
-            attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,:32,:,:]
-            attn_metadata_half1.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,:32,:,:]
-            attn_metadata_half1.seq_lens = attn_metadata.seq_lens[:32]
-            attn_metadata_half1.seq_lens_kv = attn_metadata.seq_lens_kv[:32]
-            attn_metadata_half1.prompt_lens = attn_metadata.prompt_lens[:32]
-            attn_metadata_half1.request_ids = attn_metadata.request_ids[:32]
-            attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq[:32]
-            attn_metadata_half1.on_update()
-
-            attn_metadata_half2 = copy.copy(attn_metadata)
-            attn_metadata_half2.max_num_requests = 32
-            attn_metadata_half2.max_num_sequences = 32
-            #attn_metadata_half2.all_rank_num_tokens = [32, 32]
-            #attn_metadata_half2.all_rank_max_num_tokens = 32
-            attn_metadata_half2.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
-            #attn_metadata_half2.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[32:]
-            attn_metadata_half2.kv_cache_manager.tokens_per_block = 32  
-            attn_metadata_half2.kv_cache_manager.max_seq_len = 32
-            attn_metadata_half2.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,32:,:,:]
-            attn_metadata_half2.host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets[:,32:,:,:]
-            attn_metadata_half2.seq_lens = attn_metadata.seq_lens[32:]
-            attn_metadata_half2.seq_lens_kv = attn_metadata.seq_lens_kv[32:]
-            attn_metadata_half2.prompt_lens = attn_metadata.prompt_lens[32:]
-            attn_metadata_half2.request_ids = attn_metadata.request_ids[32:]
-            attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq[32:]
-            attn_metadata_half2.on_update()
-
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata_half1: {attn_metadata_half1}")
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata_half2: {attn_metadata_half2}")
-
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - kwargs: {kwargs}")
-
-            kwargs_half1 = kwargs.copy()
-            if 'attn_metadata' in kwargs_half1:
-                print(f"[DEBUG] DeepseekV3DecoderLayer.forward - kwargs_half1: {kwargs_half1}")
-                del kwargs_half1['attn_metadata']
-           
-            with torch.cuda.stream(stream_half1):
-                hidden_states_half1 = self.self_attn(
-                    position_ids=position_ids_half1,
-                    hidden_states=hidden_states_half1,
-                    attn_metadata=attn_metadata_half1,
-                    all_reduce_params=AllReduceParams(
-                        enable_allreduce=not (self.disable_attn_allreduce)),
-                    **kwargs_half1,
-                )
-                print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1: {hidden_states_half1.shape}")
-            # Record event when half1 is complete
-            event_half1_complete.record()
-
-            with torch.cuda.stream(stream_half1):
-                event_half1_complete.wait()
-                if isinstance(self.mlp, Deepseekv3MoE):
-                    print(f"[DEBUG] DeepseekV3DecoderLayer.forward - self.mlp is Deepseekv3MoE")
-                    hidden_states_half1, residual_half1 = self.forward_MoE(
-                        hidden_states=hidden_states_half1,
-                        attn_metadata=attn_metadata_half1,
-                        residual=residual_half1,
-                    )
-                else:
-                    assert isinstance(self.mlp, GatedMLP)
-                    print(f"[DEBUG] DeepseekV3DecoderLayer.forward - self.mlp is GatedMLP")
-                    hidden_states_half1, residual_half1 = self.forward_mlp(
-                        hidden_states=hidden_states_half1,
-                        residual=residual_half1,
-                    )
-                print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1 after MOE/GatedMLP: {hidden_states_half1.shape}")
-            kwargs_half2 = kwargs.copy()
-            if 'attn_metadata' in kwargs_half2:
-                del kwargs_half2['attn_metadata']
-
-            with torch.cuda.stream(stream_half2):
-                # Wait for half1 to complete before starting half2
-                event_half1_complete.wait()
-                hidden_states_half2 = self.self_attn(
-                    position_ids=position_ids_half2,
-                    hidden_states=hidden_states_half2,
-                    attn_metadata=attn_metadata_half2,
-                    all_reduce_params=AllReduceParams(
-                    enable_allreduce=not (self.disable_attn_allreduce)),
-                    **kwargs_half2,
-                )
-                if isinstance(self.mlp, Deepseekv3MoE):
-                    hidden_states_half2, residual_half2 = self.forward_MoE(
-                        hidden_states=hidden_states_half2,
-                        attn_metadata=attn_metadata_half2,
-                        residual=residual_half2,
-                    )
-                else:
-                    assert isinstance(self.mlp, GatedMLP)
-                    hidden_states_half2, residual_half2 = self.forward_mlp(
-                        hidden_states=hidden_states_half2,
-                        residual=residual_half2,
-                    )
-            # Wait for both streams to complete before concatenating
-            torch.cuda.current_stream().wait_stream(stream_half1)
-            torch.cuda.current_stream().wait_stream(stream_half2)
-            
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half1: {hidden_states_half1[0].shape} {hidden_states_half1[1].shape}")
-            print(f"[DEBUG] DeepseekV3DecoderLayer.forward - hidden_states_half2: {hidden_states_half2[0].shape} {hidden_states_half2[1].shape}")
-                 
-            hidden_states = torch.cat([hidden_states_half1, hidden_states_half2], dim=0)
-            residual = torch.cat([residual_half1, residual_half2], dim=0)
-            return hidden_states, residual
-        else:
-            hidden_states = self.self_attn(
-                position_ids=position_ids,
+        # Self Attention
+        hidden_states = self.self_attn(
+            position_ids=position_ids,
+            hidden_states=hidden_states,
+            attn_metadata=attn_metadata,
+            all_reduce_params=AllReduceParams(
+                enable_allreduce=not (self.disable_attn_allreduce)),
+            **kwargs,
+        )
+
+        if isinstance(self.mlp, Deepseekv3MoE):
+            return self.forward_MoE(
                 hidden_states=hidden_states,
-                attn_metadata=attn_metadata,
-                all_reduce_params=AllReduceParams(
-                    enable_allreduce=not (self.disable_attn_allreduce)),
-                **kwargs,
+                attn_metadata=attn_metadata,
+                residual=residual,
             )
-
-            if isinstance(self.mlp, Deepseekv3MoE):
-                return self.forward_MoE(
-                    hidden_states=hidden_states,
-                    attn_metadata=attn_metadata,
-                    residual=residual,
-                )
-            else:
-                assert isinstance(self.mlp, GatedMLP)
-                return self.forward_mlp(
-                    hidden_states=hidden_states,
-                    residual=residual,
-                )
+        else:
+            assert isinstance(self.mlp, GatedMLP)
+            return self.forward_mlp(
+                hidden_states=hidden_states,
+                residual=residual,
+            )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3e8c876 and c1730af.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/distributed/ops.py (3 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (12 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (2 hunks)
  • tensorrt_llm/_torch/modules/multi_stream_utils.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
  • tensorrt_llm/_torch/modules/multi_stream_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧠 Learnings (5)
📚 Learning: in tensorrt-llm, test files (files under tests/ directories) do not require nvidia copyright headers...
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm, examples directory can have different dependency versions than the root requirement...
Learnt from: yibinl-nvidia
PR: NVIDIA/TensorRT-LLM#6506
File: examples/models/core/mixtral/requirements.txt:3-3
Timestamp: 2025-08-01T15:14:45.673Z
Learning: In TensorRT-LLM, examples directory can have different dependency versions than the root requirements.txt file. Version conflicts between root and examples dependencies are acceptable because examples are designed to be standalone and self-contained.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: applies to **/*.py : the code developed for tensorrt-llm should conform to python 3.8+....
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T08:45:40.701Z
Learning: Applies to **/*.py : The code developed for TensorRT-LLM should conform to Python 3.8+.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py

557-557: Line too long (195 > 120)

(E501)


742-742: Line too long (140 > 120)

(E501)


744-744: Line too long (175 > 120)

(E501)


745-745: Line too long (157 > 120)

(E501)


746-746: Line too long (147 > 120)

(E501)


747-747: Line too long (141 > 120)

(E501)


748-748: Line too long (158 > 120)

(E501)


749-749: Line too long (168 > 120)

(E501)


753-753: Line too long (129 > 120)

(E501)


763-763: Line too long (139 > 120)

(E501)


773-773: Line too long (133 > 120)

(E501)


781-781: Line too long (136 > 120)

(E501)


790-790: Line too long (133 > 120)

(E501)


799-799: Line too long (136 > 120)

(E501)


841-841: Line too long (134 > 120)

(E501)


873-873: Line too long (145 > 120)

(E501)


874-874: Line too long (145 > 120)

(E501)


951-951: Line too long (139 > 120)

(E501)


1189-1189: Line too long (137 > 120)

(E501)


1193-1193: Line too long (132 > 120)

(E501)


1196-1196: Line too long (122 > 120)

(E501)


1203-1203: Line too long (123 > 120)

(E501)


1205-1205: Line too long (130 > 120)

(E501)


1274-1274: Line too long (139 > 120)

(E501)


1284-1284: Line too long (143 > 120)

(E501)

🔇 Additional comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

72-73: LGTM - Necessary imports for batch splitting functionality.

The addition of copy and weakref imports supports the attention metadata copying functionality in the batch splitting logic.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

743-1318: Fix line length violations (>120 characters).

Multiple lines exceed the 120-character limit specified in the coding guidelines. Most violations are in debug print statements and the experimental batch splitting logic.

Addressing the previously flagged issues (removing debug prints and refactoring batch splitting logic) will resolve most of these line length violations. Any remaining long lines should be broken appropriately:

# Example fix for long lines:
-        print(f"[DEBUG] DeepseekV3DecoderLayer.forward - attn_metadata.kv_cache_manager.kv_cache_pool_pointers: {attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape}")
+        # Remove debug prints entirely, or if needed:
+        logger.debug(
+            "kv_cache_pool_pointers shape: %s", 
+            attn_metadata.kv_cache_manager.kv_cache_pool_pointers.shape
+        )
♻️ Duplicate comments (7)
tensorrt_llm/_torch/distributed/ops.py (4)

172-173: Remove debug prints for production readiness.

These debug print statements will negatively impact performance and create excessive log output in production environments.


208-211: Remove debug prints for production readiness.

These debug print statements will negatively impact performance and create excessive log output.


565-565: Remove debug print for production readiness.

This debug print statement will negatively impact performance and create excessive log output in production environments.


179-183: Remove debug prints for production readiness.

These debug print statements will negatively impact performance and create excessive log output in production environments.

Apply this diff to remove the debug prints:

-            for val in input:
-                if val is not None:
-                    print(f"[DEBUG] allgather - val shape: {val.shape}, dtype: {val.dtype}")
-                    print(f"[DEBUG] allgather - val shape: {val.shape[dim]}, dtype: {val.dtype}")
-                    print(f"[DEBUG] allgather - sizes[mapping.tp_rank]: {sizes[mapping.tp_rank]}")
tensorrt_llm/_torch/models/modeling_deepseekv3.py (3)

519-578: Remove debug prints for production readiness.

Multiple debug print statements have been added throughout the MoE forward methods that will negatively impact performance and create excessive log output in production environments.

Apply this diff to remove the debug prints:

-        print(f"[DEBUG] Deepseekv3MoE.compute_routed_output - all_rank_num_tokens: {all_rank_num_tokens}")
         hidden_states = allgather(hidden_states,
                                   self.mapping,
                                   dim=0,
                                   sizes=all_rank_num_tokens)
-        print(f"[DEBUG] Deepseekv3MoE.compute_routed_output.gather - hidden_states shape: {hidden_states.shape}, dtype: {hidden_states.dtype}")
         router_logits = self.gate(hidden_states)
-        print(f"[DEBUG] Deepseekv3MoE.compute_routed_output.gate - router_logits shape: {router_logits.shape}, dtype: {router_logits.dtype}")

Similar removals should be applied to all other debug print statements in the file.


1182-1318: Remove debug prints as previously requested.

Multiple debug print statements logging tensor shapes and metadata have been added to the model forward methods. These were flagged in previous reviews for removal due to performance impact and excessive logging.

All debug prints should be removed from production model code. If debugging capabilities are needed, implement them through a proper logging framework with configurable levels rather than hardcoded print statements.


743-904: Address experimental batch splitting code as previously flagged.

This extensive batch splitting logic contains the same issues that were previously identified in past reviews:

  1. Hard-coded conditions - num_requests == 64 is inflexible
  2. Environment variable dependencies - Production model code shouldn't rely on environment flags
  3. Debug prints throughout - Performance impact and log clutter
  4. Line length violations - Multiple lines exceed 120 characters
  5. Complex stream management - Potential for race conditions and resource leaks

The batch splitting functionality should be:

  • Made configurable through proper model configuration classes
  • Moved to a separate utility module for experimental features
  • Thoroughly tested for correctness and performance
  • Stripped of all debug prints

Consider implementing this as a configurable execution strategy rather than embedding it directly in the model forward pass.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (1)

72-73: Consider removing these imports if batch splitting code is refactored.

The copy and weakref imports are only used for the experimental batch splitting logic. If this code is moved to a separate utility or made properly configurable, these imports may no longer be needed in this core model file.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c1730af and 750135a.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/distributed/ops.py (3 hunks)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (13 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
🧠 Learnings (5)
📚 Learning: in tensorrt-llm, test files (files under tests/ directories) do not require nvidia copyright headers...
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()...
Learnt from: yechank-nvidia
PR: NVIDIA/TensorRT-LLM#6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: in tensorrt-llm testing, it's common to have both cli flow tests (test_cli_flow.py) and pytorch api ...
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: applies to **/*.py : the code developed for tensorrt-llm should conform to python 3.8+....
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T08:45:40.701Z
Learning: Applies to **/*.py : The code developed for TensorRT-LLM should conform to Python 3.8+.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : use a maximum of 120 characters per line in c++ code....
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-06T08:45:40.701Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use a maximum of 120 characters per line in C++ code.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/distributed/ops.py

186-186: Line too long (125 > 120)

(E501)

tensorrt_llm/_torch/models/modeling_deepseekv3.py

521-521: Line too long (143 > 120)

(E501)


524-524: Line too long (141 > 120)

(E501)


558-558: Line too long (195 > 120)

(E501)


743-743: Line too long (140 > 120)

(E501)


745-745: Line too long (175 > 120)

(E501)


746-746: Line too long (157 > 120)

(E501)


747-747: Line too long (147 > 120)

(E501)


748-748: Line too long (141 > 120)

(E501)


749-749: Line too long (158 > 120)

(E501)


750-750: Line too long (168 > 120)

(E501)


754-754: Line too long (129 > 120)

(E501)


764-764: Line too long (139 > 120)

(E501)


774-774: Line too long (133 > 120)

(E501)


782-782: Line too long (136 > 120)

(E501)


791-791: Line too long (133 > 120)

(E501)


800-800: Line too long (136 > 120)

(E501)


842-842: Line too long (134 > 120)

(E501)


874-874: Line too long (145 > 120)

(E501)


875-875: Line too long (145 > 120)

(E501)


952-952: Line too long (139 > 120)

(E501)


1190-1190: Line too long (137 > 120)

(E501)


1194-1194: Line too long (132 > 120)

(E501)


1197-1197: Line too long (122 > 120)

(E501)


1204-1204: Line too long (123 > 120)

(E501)


1206-1206: Line too long (130 > 120)

(E501)


1275-1275: Line too long (139 > 120)

(E501)


1285-1285: Line too long (143 > 120)

(E501)

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

713-713: Line too long (141 > 120)

(E501)

Comment on lines +184 to +193
if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
assert all([
val.shape[dim] == sizes[mapping.tp_rank] / 2 or val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix line length violation and improve readability.

The conditional assertion logic implements valid functionality for split batch overlap, but the line exceeds the 120-character limit and could be more readable.

Apply this diff to fix the line length and improve readability:

             if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
-                assert all([
-                    val.shape[dim] == sizes[mapping.tp_rank] / 2 or val.shape[dim] == sizes[mapping.tp_rank] for val in input
-                    if val is not None
-                ])
+                expected_size = sizes[mapping.tp_rank]
+                assert all([
+                    val.shape[dim] == expected_size / 2 or val.shape[dim] == expected_size
+                    for val in input if val is not None
+                ])
             else:
                 assert all([
                     val.shape[dim] == sizes[mapping.tp_rank] for val in input
                     if val is not None
                 ])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
assert all([
val.shape[dim] == sizes[mapping.tp_rank] / 2 or val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
if os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP", "0") == "1":
expected_size = sizes[mapping.tp_rank]
assert all([
val.shape[dim] == expected_size / 2 or val.shape[dim] == expected_size
for val in input if val is not None
])
else:
assert all([
val.shape[dim] == sizes[mapping.tp_rank] for val in input
if val is not None
])
🧰 Tools
🪛 Ruff (0.12.2)

186-186: Line too long (125 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/distributed/ops.py around lines 184 to 193, the assertion
lines exceed the 120-character limit and reduce readability. Refactor the
assertions by breaking down the conditions into intermediate variables or using
multiple lines for the list comprehension, ensuring each line stays within the
character limit and the logic remains clear and easy to follow.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (7)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (7)

519-527: Remove debug print statements.

These debug print statements should be removed as they will impact performance and create excessive log output in production environments.


553-561: Remove debug print statements.

These debug print statements should be removed as they will impact performance and create excessive log output in production environments.


577-578: Remove debug print statements and fix line length.

This debug print statement should be removed and the line exceeds the 120-character limit.


924-965: Remove debug print statements from MoE forward methods.

Multiple debug print statements should be removed for production readiness.


1192-1219: Remove debug print statements from model forward.

These debug print statements should be eliminated to improve performance and avoid excessive logging.


1285-1327: Remove debug print statements from causal LM forward.

All debug print statements in this method should be removed for production deployment.


752-892: Critical issues in split-batch implementation.

The split-batch logic has several critical problems:

  1. All debug prints must be removed for production
  2. Shallow copy issues: Using copy.copy() on complex objects like attn_metadata may not properly duplicate internal state
  3. Hard-coded magic numbers: The logic relies on specific batch sizes without proper validation
  4. Potential memory corruption: Tensor slicing and concatenation without validation could cause shape mismatches
  5. Complex stream synchronization: The event-based synchronization adds complexity without clear benefits over simpler approaches

This implementation needs significant refactoring:

  1. Remove all debug print statements
  2. Replace environment variable checks with proper configuration system
  3. Use deep copy or proper metadata splitting utilities for attn_metadata
  4. Add tensor shape validation after chunking operations
  5. Consider if the complexity of dual-stream processing provides measurable benefits

Based on the PR objectives indicating this is still a "Draft", recommend completing the design review before proceeding with this complex parallel processing logic.

🧹 Nitpick comments (2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

893-915: Refactor conditional logic for better readability.

The nested if-else logic can be simplified and the debug prints removed.

-            if isinstance(self.mlp, Deepseekv3MoE):
-                return self.forward_MoE(
-                    hidden_states=hidden_states,
-                    attn_metadata=attn_metadata,
-                    residual=residual,
-                )
-            else:
-                assert isinstance(self.mlp, GatedMLP)
-                return self.forward_mlp(
-                    hidden_states=hidden_states,
-                    residual=residual,
-                )
+            # Forward through MLP or MoE
+            if isinstance(self.mlp, Deepseekv3MoE):
+                return self.forward_MoE(hidden_states, attn_metadata, residual)
+            else:
+                assert isinstance(self.mlp, GatedMLP)
+                return self.forward_mlp(hidden_states, residual)

765-892: Add explicit validations around split-batch tensor and metadata operations

To prevent silent runtime errors when splitting batches and slicing metadata, introduce assertions and more robust copying before performing the chunking and metadata manipulation.

Key areas to address:

  • Validate that hidden_states, position_ids, and residual have dimensions divisible by two and match the configured batch sizes (enable_split_batch_overlap_local_bs, enable_split_batch_overlap_split_bs).
  • Ensure attn_metadata fields (e.g. seq_lens, kv_cache_block_offsets, kv_cache_params.num_cached_tokens_per_seq) are long enough for the intended slices, and raise descriptive errors if not.
  • Switch from shallow copy.copy to copy.deepcopy when duplicating nested metadata to avoid unintended shared-state mutations.

Suggested diff around lines 770–776:

@@ tensorrt_llm/_torch/models/modeling_deepseekv3.py:765
-        if self.enable_split_batch_overlap and \
-            num_requests == self.enable_split_batch_overlap_local_bs and attn_metadata.num_contexts == 0:
+        if self.enable_split_batch_overlap and \
+           num_requests == self.enable_split_batch_overlap_local_bs and \
+           attn_metadata.num_contexts == 0:
+            # --- Validate shapes before splitting ---
+            split_bs = self.enable_split_batch_overlap_split_bs
+            batch_size, reqs = hidden_states.size(0), position_ids.size(1)
+            assert batch_size == 2 * split_bs, (
+                f"hidden_states.batch_size ({batch_size}) must equal 2 * split_bs ({split_bs})"
+            )
+            assert reqs == 2 * split_bs, (
+                f"position_ids.num_requests ({reqs}) must equal 2 * split_bs ({split_bs})"
+            )
+            assert residual.size(0) == batch_size, (
+                f"residual.batch_size ({residual.size(0)}) must match hidden_states"
+            )
+            # Validate metadata lengths
+            if len(attn_metadata.seq_lens) < 2 * split_bs:
+                raise ValueError(
+                    f"attn_metadata.seq_lens has length {len(attn_metadata.seq_lens)}, "
+                    f"expected at least {2*split_bs}"
+                )
+            # Deep-copy metadata to avoid shared-state issues
+            attn_metadata_half1 = copy.deepcopy(attn_metadata)
+            attn_metadata_half2 = copy.deepcopy(attn_metadata)
+            # … proceed with chunk() and slicing …
             hidden_states_half1, hidden_states_half2 = hidden_states.chunk(2, dim=0)
             position_ids_half1, position_ids_half2 = position_ids.chunk(2, dim=1)
             residual_half1, residual_half2 = residual.chunk(2, dim=0)

• File: tensorrt_llm/_torch/models/modeling_deepseekv3.py
• Lines: ~770–776

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 750135a and 962b614.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py (15 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA Apache-2.0 copyright header with current year to all source files

Files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧠 Learnings (1)
📚 Learning: 2025-08-14T06:36:40.701Z
Learnt from: timlee0212
PR: NVIDIA/TensorRT-LLM#6886
File: tensorrt_llm/_torch/models/modeling_deepseekv3.py:0-0
Timestamp: 2025-08-14T06:36:40.701Z
Learning: In DeepSeek V3 model (tensorrt_llm/_torch/models/modeling_deepseekv3.py), the disagreement between AllReduce.__init__ guard and _compute_mlp_tp_size logic for MNNVL usage is expected by design. The AllReduce component and MLP TP-size computation intentionally use different criteria for MNNVL availability decisions.

Applied to files:

  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (4)
tensorrt_llm/_torch/utils.py (4)
  • get_model_extra_attrs (47-48)
  • set_torch_compiling (27-29)
  • with_model_extra_attrs (61-71)
  • shape (98-99)
tensorrt_llm/_torch/distributed/ops.py (1)
  • allgather (138-238)
tensorrt_llm/_torch/attention_backend/interface.py (9)
  • all_rank_num_tokens (161-162)
  • all_rank_num_tokens (165-168)
  • num_contexts (199-200)
  • num_contexts (203-206)
  • seq_lens (171-172)
  • seq_lens (175-196)
  • seq_lens_kv (223-224)
  • seq_lens_kv (227-234)
  • on_update (148-158)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
  • tokens_per_block (582-586)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py

521-521: Line too long (145 > 120)

(E501)


524-524: Line too long (143 > 120)

(E501)


558-558: Line too long (197 > 120)

(E501)


620-620: Line too long (217 > 120)

(E501)


751-751: Line too long (142 > 120)

(E501)


753-753: Line too long (177 > 120)

(E501)


754-754: Line too long (159 > 120)

(E501)


755-755: Line too long (149 > 120)

(E501)


756-756: Line too long (143 > 120)

(E501)


757-757: Line too long (160 > 120)

(E501)


758-758: Line too long (170 > 120)

(E501)


773-773: Line too long (141 > 120)

(E501)


780-780: Line too long (138 > 120)

(E501)


783-783: Line too long (133 > 120)

(E501)


785-785: Line too long (142 > 120)

(E501)


786-786: Line too long (152 > 120)

(E501)


791-791: Line too long (174 > 120)

(E501)


797-797: Line too long (138 > 120)

(E501)


800-800: Line too long (133 > 120)

(E501)


803-803: Line too long (142 > 120)

(E501)


804-804: Line too long (152 > 120)

(E501)


809-809: Line too long (174 > 120)

(E501)


854-854: Line too long (136 > 120)

(E501)


882-882: Line too long (147 > 120)

(E501)


883-883: Line too long (147 > 120)

(E501)


962-962: Line too long (141 > 120)

(E501)


1200-1200: Line too long (139 > 120)

(E501)


1204-1204: Line too long (134 > 120)

(E501)


1207-1207: Line too long (124 > 120)

(E501)


1214-1214: Line too long (125 > 120)

(E501)


1216-1216: Line too long (132 > 120)

(E501)


1285-1285: Line too long (141 > 120)

(E501)


1295-1295: Line too long (145 > 120)

(E501)

🔇 Additional comments (2)
tensorrt_llm/_torch/models/modeling_deepseekv3.py (2)

69-73: LGTM - Proper import additions for new functionality.

The imports for get_model_extra_attrs, set_torch_compiling, with_model_extra_attrs, copy, and weakref are appropriate for the split-batch metadata handling functionality being introduced.


689-690: LGTM - Proper auxiliary stream management.

The addition of aux_stream_dict instance variable enables proper CUDA stream management for the split-batch functionality.

Comment on lines +69 to +73
from ..utils import (get_model_extra_attrs, set_torch_compiling,
with_model_extra_attrs)

import copy
import weakref
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add copyright header to comply with coding guidelines.

The file is missing the required NVIDIA Apache-2.0 copyright header that should be present in all source files according to the coding guidelines.

Add the copyright header at the top of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 # --------------------------------------------------
 # Portions of this code were derived from DeepSeek‑V3:
 #   https://github.com/deepseek-ai/DeepSeek-V3
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from ..utils import (get_model_extra_attrs, set_torch_compiling,
with_model_extra_attrs)
import copy
import weakref
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# --------------------------------------------------
# Portions of this code were derived from DeepSeek-V3:
# https://github.com/deepseek-ai/DeepSeek-V3
from ..utils import (get_model_extra_attrs, set_torch_compiling,
with_model_extra_attrs)
import copy
import weakref
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_deepseekv3.py around lines 69 to 73, the
file is missing the required NVIDIA Apache-2.0 copyright header; add the
standard NVIDIA Apache-2.0 copyright header block at the very top of the file
(before any imports), including the copyright year(s), NVIDIA Corporation as the
copyright holder, the SPDX-License-Identifier: Apache-2.0 and a link or short
note referencing the full Apache-2.0 license, matching the project's standard
header format.

Comment on lines +619 to +624
self.enable_split_batch_overlap = os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1"
self.enable_split_batch_overlap_local_bs = int(os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_LOCAL_BS",0))
self.enable_split_batch_overlap_split_bs = int(os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_SPLIT_BS",0))
if self.enable_split_batch_overlap:
assert self.enable_split_batch_overlap_local_bs > 0 and self.enable_split_batch_overlap_split_bs > 0, "ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_LOCAL_BS and ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_SPLIT_BS must be set"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Environment variable configuration pattern needs improvement.

The environment variable-based configuration approach has several issues:

  1. Hard-coded environment variable names scattered throughout the code
  2. No validation of configuration values
  3. Type conversion from strings without error handling
  4. Line length violations

Consider introducing a configuration class to centralize these settings:

-        self.enable_split_batch_overlap = os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP") == "1"
-        self.enable_split_batch_overlap_local_bs = int(os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_LOCAL_BS",0))
-        self.enable_split_batch_overlap_split_bs = int(os.environ.get("ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_SPLIT_BS",0))
-        if self.enable_split_batch_overlap:
-            assert self.enable_split_batch_overlap_local_bs > 0 and self.enable_split_batch_overlap_split_bs > 0, "ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_LOCAL_BS and ENABLE_TRTLLM_SPLIT_BATCH_OVERLAP_SPLIT_BS must be set"
+        # Use configuration from model_config instead of environment variables
+        split_batch_config = getattr(model_config, 'split_batch_config', None)
+        self.enable_split_batch_overlap = split_batch_config is not None
+        if self.enable_split_batch_overlap:
+            self.enable_split_batch_overlap_local_bs = split_batch_config.local_batch_size
+            self.enable_split_batch_overlap_split_bs = split_batch_config.split_batch_size

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.12.2)

620-620: Line too long (217 > 120)

(E501)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
run_deepseek_batch_split.sh (2)

23-23: Fix typo in attention metadata flag export.

ENABLE_ATTENTION_METADATA_SPLITTING never gets its default because the variable name here drops the leading E, so downstream exports see an empty value. Restore the correct name so the script behaves as intended.

Apply this diff:

-ENABLE_INTRA_LAYER_BATCH_SPLITTING="true"
-NABLE_ATTENTION_METADATA_SPLITTING="true"
+ENABLE_INTRA_LAYER_BATCH_SPLITTING="true"
+ENABLE_ATTENTION_METADATA_SPLITTING="true"

124-124: Align default mode with documented behavior.

print_usage advertises mode 4 as the default (“full overlap”), yet the positional default here is 2. That discrepancy will silently run the wrong configuration when callers omit the mode. Please either set the default to 4 or update the usage text to match.

Apply this diff:

-mode=${10:-2}  # Default to mode 2 (only inter-layer overlap)
+mode=${10:-4}  # Default to mode 4 (full overlap)
tensorrt_llm/_torch/modules/attention.py (1)

318-335: Make metadata half selection thread-safe and remove stray prints.

The new global ping counter is shared across requests and threads, so concurrent inference can interleave updates and route callers to the wrong metadata half. The print calls here also spam stdout on every call. Please move the counter into thread-local storage (or another per-request structure) and drop the prints.

Apply this diff:

-import math
-import weakref
+import math
+import threading
+import weakref
@@
-ping = 0
+_metadata_ping = threading.local()
@@
-    global ping
+    counter = getattr(_metadata_ping, "counter", 0)
@@
-        if ping % 2 == 0:
-            print(f"[DEBUG] extract_extra_attrs - {ping} - metadata_ref_half1 {metadata_ref_half1}")
-            metadata = metadata_ref_half1()
-
-        else:
-            print(f"[DEBUG] extract_extra_attrs - {ping} - metadata_ref_half2 {metadata_ref_half2}")
-            metadata = metadata_ref_half2()
-        ping += 1
+        if counter % 2 == 0:
+            metadata = metadata_ref_half1()
+        else:
+            metadata = metadata_ref_half2()
+        _metadata_ping.counter = counter + 1
tensorrt_llm/_torch/attention_backend/trtllm.py (1)

698-713: Remove debug prints from prepare().

These print calls fire on every prepare and flood stdout, hurting performance in production. Use the existing logger (at DEBUG level) if you need diagnostics, otherwise drop them.

Apply this diff:

-    def prepare(self, splitBatchOverlap: Optional[int] = None) -> None:
-        print(f"[DEBUG] TrtllmAttention.prepare {splitBatchOverlap}")
+    def prepare(self, splitBatchOverlap: Optional[int] = None) -> None:
+        logger.debug("TrtllmAttention.prepare split_batch_overlap=%s",
+                     splitBatchOverlap)
@@
-            if splitBatchOverlap is not None:
-                #print(f"[DEBUG] TrtllmAttention.prepare - splitBatchOverlap is not None")
-                if splitBatchOverlap == 1:
-                    print(f"[DEBUG] TrtllmAttention.prepare - splitBatchOverlap is 1")
-                    get_global_attrs().attention_metadata_half1 = weakref.ref(self)
-                else:
-                    print(f"[DEBUG] TrtllmAttention.prepare - splitBatchOverlap is not 1")
-                    get_global_attrs().attention_metadata_half2 = weakref.ref(self)
-            else:
-                print(f"[DEBUG] TrtllmAttention.prepare - extra_attrs is None Setting self refrence to global attention_metadata")
-                get_global_attrs().attention_metadata = weakref.ref(self)
+            if splitBatchOverlap is not None:
+                if splitBatchOverlap == 1:
+                    get_global_attrs().attention_metadata_half1 = weakref.ref(self)
+                else:
+                    get_global_attrs().attention_metadata_half2 = weakref.ref(self)
+            else:
+                get_global_attrs().attention_metadata = weakref.ref(self)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

1456-1477: Keep tokens_per_block immutable across splits

KVCacheManager.tokens_per_block encodes the physical block size of the cache; overriding it to the split batch size makes the metadata disagree with the actual allocation, which will corrupt KV indexing as soon as kernels touch the cache. Reuse the existing manager and drop these assignments instead of mutating the block size per half.

-            attn_metadata_half1.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
-            #attn_metadata_half1.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[:32]
-            attn_metadata_half1.kv_cache_manager.tokens_per_block = self.enable_split_batch_overlap_split_bs #32
+            attn_metadata_half1.kv_cache_manager = attn_metadata.kv_cache_manager
...
-            attn_metadata_half2.kv_cache_manager = copy.copy(attn_metadata.kv_cache_manager)
-            #attn_metadata_half2.kv_cache_manager.kv_cache_pool_pointers = attn_metadata.kv_cache_manager.kv_cache_pool_pointers[32:]
-            attn_metadata_half2.kv_cache_manager.tokens_per_block = self.enable_split_batch_overlap_split_bs #32
+            attn_metadata_half2.kv_cache_manager = attn_metadata.kv_cache_manager

1458-1478: Slice KV cache offsets correctly for half1

Half1 currently assigns kv_cache_block_offsets[:, split:, ...], i.e. it drops the first split entries even though the host offsets keep them, so GPU/host views diverge and half1 never sees its own requests. Use the leading slice for half1 (keep the trailing slice for half2) so each half owns the correct slots.

-            attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:,self.enable_split_batch_overlap_split_bs:,:,:]
+            attn_metadata_half1.kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets[:, :self.enable_split_batch_overlap_split_bs, :, :]
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

1444-1487: Replace debug prints with logger calls

These prints fire unconditionally in the hot path and bypass the module logging controls. Please switch them to logger.debug (or guard with isEnabledFor) so we can dial them on/off without spewing to stdout.

-        print(f"[DEBUG] TrtllmAttention.prepare - attn_metadata")
+        logger.debug("TrtllmAttention.prepare - attn_metadata")
...
-            print(f"[DEBUG] TrtllmAttention.prepare (half1) - attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq: {attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq}")
+            logger.debug(
+                "TrtllmAttention.prepare (half1) - attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq: %s",
+                attn_metadata_half1.kv_cache_params.num_cached_tokens_per_seq,
+            )
...
-            print(f"[DEBUG] TrtllmAttention.prepare (half2) - attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq: {attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq}")
+            logger.debug(
+                "TrtllmAttention.prepare (half2) - attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq: %s",
+                attn_metadata_half2.kv_cache_params.num_cached_tokens_per_seq,
+            )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 962b614 and 2b42417.

📒 Files selected for processing (7)
  • run_deepseek_batch_split.sh (1 hunks)
  • tensorrt_llm/_torch/attention_backend/trtllm.py (9 hunks)
  • tensorrt_llm/_torch/distributed/ops.py (3 hunks)
  • tensorrt_llm/_torch/modules/attention.py (14 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (3 hunks)
  • tensorrt_llm/_torch/modules/multi_stream_utils.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (7 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
  • tensorrt_llm/_torch/modules/multi_stream_utils.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/distributed/ops.py
  • tensorrt_llm/_torch/attention_backend/trtllm.py
🧠 Learnings (1)
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/modules/attention.py (3)
tensorrt_llm/_torch/utils.py (1)
  • get_model_extra_attrs (47-48)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
  • TrtllmAttentionMetadata (527-1012)
tensorrt_llm/_torch/attention_backend/interface.py (4)
  • num_tokens (271-272)
  • num_contexts (199-200)
  • num_contexts (203-206)
  • num_ctx_tokens (267-268)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
  • prepare (698-792)
tensorrt_llm/_torch/attention_backend/interface.py (8)
  • prepare (274-277)
  • num_contexts (199-200)
  • num_contexts (203-206)
  • seq_lens (171-172)
  • seq_lens (175-196)
  • seq_lens_kv (223-224)
  • seq_lens_kv (227-234)
  • on_update (148-158)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
  • get_resource_manager (1117-1118)
tensorrt_llm/_torch/distributed/ops.py (3)
tensorrt_llm/mapping.py (1)
  • tp_group (337-338)
cpp/tensorrt_llm/thop/allgatherOp.cpp (2)
  • input (108-111)
  • input (108-108)
tensorrt_llm/_torch/distributed/communicator.py (1)
  • tp_rank (54-55)
tensorrt_llm/_torch/attention_backend/trtllm.py (3)
tensorrt_llm/_torch/attention_backend/interface.py (2)
  • prepare (274-277)
  • num_seqs (249-253)
tensorrt_llm/_torch/attention_backend/flashinfer.py (2)
  • prepare (190-306)
  • cached_token_lens (113-115)
tensorrt_llm/_torch/utils.py (3)
  • get_model_extra_attrs (47-48)
  • get_global_attrs (40-41)
  • Fp4QuantizedTensor (92-99)
🪛 Ruff (0.13.3)
tensorrt_llm/_torch/pyexecutor/model_engine.py

1444-1444: f-string without any placeholders

Remove extraneous f prefix

(F541)


2146-2146: f-string without any placeholders

Remove extraneous f prefix

(F541)


2149-2149: f-string without any placeholders

Remove extraneous f prefix

(F541)


2152-2152: f-string without any placeholders

Remove extraneous f prefix

(F541)

tensorrt_llm/_torch/attention_backend/trtllm.py

706-706: f-string without any placeholders

Remove extraneous f prefix

(F541)


709-709: f-string without any placeholders

Remove extraneous f prefix

(F541)


712-712: f-string without any placeholders

Remove extraneous f prefix

(F541)

🪛 Shellcheck (0.11.0)
run_deepseek_batch_split.sh

[warning] 13-13: disable_overlap_scheduler appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 23-23: NABLE_ATTENTION_METADATA_SPLITTING appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 195-195: Declare and assign separately to avoid masking return values.

(SC2155)


[warning] 266-266: Quotes/backslashes will be treated literally. Use an array.

(SC2089)


[warning] 277-277: Quotes/backslashes in this variable will not be respected.

(SC2090)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant