-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-8164][feat] Add dynamic tree support on CDL #9469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
21c71ab to
07b2975
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #26137 [ run ] triggered by Bot. Commit: |
|
PR_Github #26137 [ run ] completed with state |
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
Signed-off-by: Yue Weng <[email protected]>
07b2975 to
04129d5
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #26373 [ run ] triggered by Bot. Commit: |
|
PR_Github #26373 [ run ] completed with state |
Signed-off-by: Yue Weng <[email protected]>
|
/bot run --disable-fail-fast |
|
PR_Github #26449 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces dynamic tree drafting mode for EAGLE3 speculative decoding alongside existing static tree support. Changes include refactored API signatures for speculative parameter updates, new drafting loop wrappers for static/dynamic modes, enhanced tree manager buffers, and updated configuration validation for max_total_draft_tokens calculation. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1535-1557: Undefined / stalenum_draft_tokensin overlap-scheduler tree pathIn the
extend_requestsbranch where a request does have a previous batch (next_draft_tokens_device is not None and not request.is_dummy and request.py_batch_idx is not None),num_draft_tokensis never assigned in that branch, but it’s used in:if not self.is_draft_model and not spec_config.is_linear_tree: assert spec_tree_manager is not None assert num_draft_tokens == spec_tree_manager.max_total_draft_tokensThis will either raise
UnboundLocalErroron the first such request or, worse, reuse a stalenum_draft_tokensvalue from a different code path/iteration.Given overlap scheduler assumes a fixed number of draft tokens (
self.runtime_draft_len) in this branch, the assertion should compare against that instead of an uninitializednum_draft_tokens.Suggested fix:
- if not self.is_draft_model and not spec_config.is_linear_tree: - assert spec_tree_manager is not None - assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens - position_ids.extend( - past_seen_token_num + - spec_tree_manager.spec_dec_position_offsets[ - 0] # [max_total_draft_tokens + 1] - ) + if not self.is_draft_model and not spec_config.is_linear_tree: + assert spec_tree_manager is not None + assert self.runtime_draft_len == spec_tree_manager.max_total_draft_tokens + position_ids.extend( + past_seen_token_num + + spec_tree_manager.spec_dec_position_offsets[ + 0] # [max_total_draft_tokens + 1] + )tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
496-519: Aligns with baseupdate_spec_dec_param, but missingSpecTreeManagerimport (F821)The override correctly mirrors the updated base signature and forwards
is_target_model/spec_tree_managertosuper().update_spec_dec_param, so behavior is consistent with the TRTLLM backend.However, the type annotation
Optional['SpecTreeManager']currently has no corresponding symbol in this module, and Ruff flags it asF821(undefined name). Even though it’s a string annotation at runtime, this will fail lint/CI.Recommend adding a TYPE_CHECKING-only import alongside the existing
DecodingBaseConfigimport:if TYPE_CHECKING: - from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig + from tensorrt_llm.llmapi.llm_args import DecodingBaseConfig + from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManagerThis keeps runtime dependencies unchanged while satisfying static analysis.
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
190-201: Reset per-tree mask before recomputing incompute_spec_dec_mask_matrix
compute_spec_dec_mask_matrixcurrently only sets bits to 1 for indices along each path and never clears previous values. For dynamic trees, whereeagle_paths[tree_idx]is reconstructed each step, this can leave stale 1s from earlier trees inspec_dec_mask_matrix[tree_idx], producing an over-permissive attention mask.You can safely zero out the per-tree slice before recomputing; for both static and dynamic trees the path for each node includes the node itself, so diagonal entries remain correct.
Suggested diff:
- # Compute the spec decoding mask matrix according to the eagle_paths - def compute_spec_dec_mask_matrix(self, tree_idx=0): - for i, path in enumerate(self.eagle_paths[tree_idx]): - indices = path[path > -1] - self.spec_dec_mask_matrix[tree_idx][i, indices] = 1 + # Compute the spec decoding mask matrix according to the eagle_paths + def compute_spec_dec_mask_matrix(self, tree_idx: int = 0) -> None: + # Clear previous mask for this tree to avoid stale entries when paths change. + self.spec_dec_mask_matrix[tree_idx].zero_() + for i, path in enumerate(self.eagle_paths[tree_idx]): + indices = path[path > -1] + if indices.numel() > 0: + self.spec_dec_mask_matrix[tree_idx, i, indices] = 1Also fix the fullwidth parenthesis in the comment near the packed-mask section:
- # 4)Compute the spec_dec_packed_mask for the drafter model + # 4) Compute the spec_dec_packed_mask for the drafter modelAlso applies to: 277-282
🧹 Nitpick comments (6)
examples/llm-api/quickstart_advanced.py (1)
141-141: Newmax_total_draft_tokensCLI flag is wired correctly; consider clarifying help textPlumbing
--max_total_draft_tokensthrough toEagleDecodingConfig(max_total_draft_tokens=...)is correct and matches the new config semantics. To make it usable from the CLI, consider adding a briefhelp=description (e.g., that it bounds total draft tokens for EAGLE3 dynamic/static trees and defaults according todynamic_tree_max_topK * max_draft_lenwhen omitted).Also applies to: 203-211
tests/integration/defs/test_e2e.py (1)
1-1: Update copyright year to include 2025.Per coding guidelines, TensorRT-LLM source files should include the current year in the copyright header.
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (2)
146-259: Static-treeprepare_for_generationtest wiring is consistent; minor nit on unused argThe refactoring to use
StaticTreeDraftingLoopWrapper.prepare_for_generationwithSpecTreeManagerplus the detailed reference tensors forposition_ids, attn metadata, and spec metadata provides good coverage of the static-tree path. The assertions look correct and shape-safe.Only minor nit: the
prepare_for_layer_idxparameter inrun_testis currently unused, which may confuse future readers.If you don’t plan to use it, consider dropping it from
run_test’s signature and call sites; otherwise, a brief comment on its future role would help.
463-1153: Dynamic-tree update path tests are thorough and match expected buffer semanticsThe new
test_dynamic_tree_update_draft_tokens_and_scorescases (variouscur_draft_idx, batch sizes, and reference buffers) exercise:
DynamicTreeDraftingLoopWrapper.update_draft_tokens_and_scores’ handling of:
- Root vs deeper draft layers (
cur_draft_idx0/1/2).- Single- and multi-request batches.
- History buffers (
history_draft_tokens_buffer, parent buffer, score buffer).spec_decoding_packed_maskandhidden_states_read_indicesupdates.Shapes and dtypes for all constructed inputs/refs are consistent, and the use of
torch.all/torch.allclosewith a reasonableatolmakes the checks robust.The extensive
tensorrt_llm/_torch/speculative/drafting_loops.py (2)
512-655: Dynamic-tree drafting loop: overall structure looks sound, but logits FIXME is a known gapThe new
DynamicTreeDraftingLoopWrappersets up per-batch draft/history buffers and returns:
new_draft_tokensshaped[max_total_draft_tokens, batch_size],draft_logitsshaped[max_total_draft_tokens, batch_size, vocab_size](currently taken from the last drafter layer and marked as FIXME),dynamic_tree_bufferswithtopk_score_indicesandhistory_draft_tokens_parent_buffer.The high-level flow (initial forward → per-layer expansion with
update_draft_tokens_and_scores→ final resampling) is coherent. The explicit FIXME onreturn_draft_logitsis a good reminder that compatibility for dynamic tree logits is pending and can be addressed in a follow-up without blocking this PR.
554-566: Forward should propagate**kwargstodraft_model.forward
DynamicTreeDraftingLoopWrapper.forwardaccepts**kwargsbut currently ignores them when callingself.draft_model.forward. Other wrappers in this module propagate extra kwargs to the underlying model (e.g., for additional features or debug flags).Consider forwarding
**kwargsto both top-level and per-layerdraft_model.forwardcalls to keep behavior consistent:- logits = self.draft_model.forward(input_ids=input_ids, - position_ids=position_ids, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata, - return_context_logits=True) + logits = self.draft_model.forward( + input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + return_context_logits=True, + **kwargs, + ) @@ - logits = self.draft_model.forward( + logits = self.draft_model.forward( input_ids=self.draft_tokens_buffer[:batch_size, :].reshape( -1), position_ids=self.position_ids_buffer[:batch_size, :]. reshape(-1), attn_metadata=attn_metadata, spec_metadata=spec_metadata, - return_context_logits=True) + return_context_logits=True, + **kwargs, + )This also addresses the Ruff
ARG002warning about unusedkwargs.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/attention_backend/interface.py(1 hunks)tensorrt_llm/_torch/attention_backend/sparse/dsa.py(1 hunks)tensorrt_llm/_torch/attention_backend/trtllm.py(4 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(1 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py(1 hunks)tensorrt_llm/_torch/speculative/drafting_loops.py(4 hunks)tensorrt_llm/_torch/speculative/eagle3.py(1 hunks)tensorrt_llm/_torch/speculative/model_drafter.py(4 hunks)tensorrt_llm/_torch/speculative/spec_tree_manager.py(5 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tests/integration/defs/accuracy/test_llm_api.py(1 hunks)tests/integration/defs/test_e2e.py(2 hunks)tests/unittest/_torch/modeling/test_modeling_llama.py(8 hunks)tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py(6 hunks)tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., usefrom package.subpackage import fooand thenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g.,some_file.py)
Python class names should use PascalCase (e.g.,class SomeClass)
Python function and method names should use snake_case (e.g.,def my_awesome_function():)
Python local variable names should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile = ...)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g.,MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g.,self.x = 5followed by"""<type>: Description of 'x'""")
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic
Files:
tensorrt_llm/_torch/attention_backend/interface.pyexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/model_drafter.pytests/unittest/_torch/modeling/test_modeling_llama.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/integration/defs/accuracy/test_llm_api.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/speculative/spec_tree_manager.pytensorrt_llm/_torch/speculative/drafting_loops.pytests/integration/defs/test_e2e.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top
Files:
tensorrt_llm/_torch/attention_backend/interface.pyexamples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/speculative/eagle3.pytensorrt_llm/_torch/speculative/model_drafter.pytests/unittest/_torch/modeling/test_modeling_llama.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/integration/defs/accuracy/test_llm_api.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/speculative/spec_tree_manager.pytensorrt_llm/_torch/speculative/drafting_loops.pytests/integration/defs/test_e2e.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/llmapi/llm_args.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.pytensorrt_llm/_torch/pyexecutor/model_engine.py
🧠 Learnings (12)
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.
Applied to files:
tensorrt_llm/_torch/attention_backend/interface.pytests/unittest/_torch/modeling/test_modeling_llama.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/pyexecutor/model_engine.py
📚 Learning: 2025-08-14T15:43:23.107Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: tensorrt_llm/_torch/attention_backend/trtllm.py:259-262
Timestamp: 2025-08-14T15:43:23.107Z
Learning: In TensorRT-LLM's attention backend, tensor parameters in the plan() method are assigned directly without validation (dtype, device, contiguity checks). This maintains consistency across all tensor inputs and follows the pattern of trusting callers to provide correctly formatted tensors.
Applied to files:
tensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: 2025-08-27T14:23:55.566Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/modules/rms_norm.py:17-17
Timestamp: 2025-08-27T14:23:55.566Z
Learning: The TensorRT-LLM project requires Python 3.10+ as evidenced by the use of TypeAlias from typing module, match/case statements, and union type | syntax throughout the codebase, despite some documentation still mentioning Python 3.8+.
Applied to files:
tensorrt_llm/_torch/attention_backend/interface.pytests/unittest/_torch/modeling/test_modeling_llama.pytensorrt_llm/_torch/attention_backend/trtllm.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/pyexecutor/model_engine.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tensorrt_llm/_torch/speculative/model_drafter.py
📚 Learning: 2025-08-01T15:14:45.673Z
Learnt from: yibinl-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 6506
File: examples/models/core/mixtral/requirements.txt:3-3
Timestamp: 2025-08-01T15:14:45.673Z
Learning: In TensorRT-LLM, examples directory can have different dependency versions than the root requirements.txt file. Version conflicts between root and examples dependencies are acceptable because examples are designed to be standalone and self-contained.
Applied to files:
tests/unittest/_torch/modeling/test_modeling_llama.pytests/unittest/_torch/speculative/test_draft_token_tree_sampling.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
📚 Learning: 2025-08-21T00:16:56.457Z
Learnt from: farshadghodsian
Repo: NVIDIA/TensorRT-LLM PR: 7101
File: docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md:36-36
Timestamp: 2025-08-21T00:16:56.457Z
Learning: TensorRT-LLM container release tags in documentation should only reference published NGC container images. The README badge version may be ahead of the actual published container versions.
Applied to files:
tensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: 2025-08-15T06:46:53.813Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:53.813Z
Learning: In the TensorRT-LLM KV cache manager, SWA (Sliding Window Attention) combined with beam search is currently in a broken/non-functional state and is planned for future rework. During preparatory refactoring phases, code related to SWA+beam search may intentionally remain in a non-working state until the broader rework is completed.
Applied to files:
tensorrt_llm/_torch/attention_backend/trtllm.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/integration/defs/test_e2e.pytests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.
Applied to files:
tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py
📚 Learning: 2025-08-26T06:07:02.166Z
Learnt from: shaharmor98
Repo: NVIDIA/TensorRT-LLM PR: 7231
File: tensorrt_llm/_torch/pyexecutor/_util.py:504-509
Timestamp: 2025-08-26T06:07:02.166Z
Learning: In tensorrt_llm/_torch/pyexecutor/_util.py, when calling model_engine.set_lora_model_config(), pass model_binding_config.mlp_hidden_size directly without multiplying by mapping.tp_size, as the mlp_hidden_size from get_bindings_model_config() is already the per-TP rank value needed for LoRA weight packaging.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (7)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
tensorrt_llm/_torch/speculative/drafting_loops.py (3)
DynamicTreeDraftingLoopWrapper(512-1039)LinearDraftingLoopWrapper(98-198)StaticTreeDraftingLoopWrapper(201-509)tensorrt_llm/llmapi/llm_args.py (3)
EagleDecodingConfig(746-877)is_linear_tree(709-710)is_linear_tree(874-877)
tensorrt_llm/_torch/speculative/eagle3.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
EagleDecodingConfig(746-877)
tensorrt_llm/_torch/speculative/model_drafter.py (4)
tensorrt_llm/llmapi/llm_args.py (1)
EagleDecodingConfig(746-877)tensorrt_llm/_torch/speculative/spec_tree_manager.py (2)
compute_spec_dec_mask_matrix(278-281)compute_spec_dec_packed_mask(284-322)tensorrt_llm/_torch/speculative/ngram.py (1)
prepare_draft_tokens(182-205)tensorrt_llm/_torch/speculative/drafter.py (1)
prepare_draft_tokens(27-38)
tests/unittest/_torch/modeling/test_modeling_llama.py (1)
tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
SpecTreeManager(7-355)
tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
StaticTreeDraftingLoopWrapper(201-509)
tensorrt_llm/_torch/speculative/drafting_loops.py (3)
tensorrt_llm/_torch/speculative/eagle3.py (2)
Eagle3SpecMetadata(114-276)forward(373-495)tensorrt_llm/_torch/speculative/interface.py (1)
SpecMetadata(187-275)tensorrt_llm/_torch/speculative/spec_tree_manager.py (1)
SpecTreeManager(7-355)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (2)
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
update_spec_dec_param(1182-1324)tensorrt_llm/_torch/attention_backend/interface.py (1)
update_spec_dec_param(339-352)
🪛 Ruff (0.14.6)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
371-371: Local variable use_tree_drafter is assigned to but never used
Remove assignment to unused variable use_tree_drafter
(F841)
tensorrt_llm/_torch/speculative/model_drafter.py
662-662: Undefined name SpecTreeManager
(F821)
703-703: Consider [parent_idx, *tmp_path] instead of concatenation
Replace with [parent_idx, *tmp_path]
(RUF005)
706-706: Consider [parent_idx, *tmp_path] instead of concatenation
Replace with [parent_idx, *tmp_path]
(RUF005)
721-721: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
tensorrt_llm/_torch/speculative/drafting_loops.py
556-556: Unused method argument: kwargs
(ARG002)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
508-508: Undefined name SpecTreeManager
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (24)
tensorrt_llm/_torch/speculative/eagle3.py (1)
65-77: SpecTreeManager init condition correctly extended for dynamic treeInitializing
spec_tree_managerwhenconfig.use_dynamic_treeis true (in addition toeagle_choicesbeing set) matches the new dynamic‑tree workflow and keeps the linear-tree path unchanged (spec_tree_managerstaysNonethere). The constructor arguments are consistent withEagleDecodingConfig’s new fields.tensorrt_llm/llmapi/llm_args.py (1)
804-823: Dynamic-tree bounds and defaults formax_total_draft_tokenslook consistentThe new dynamic-tree block correctly:
- Treats
dynamic_tree_max_topKas the switch (with or withoutuse_dynamic_tree).- Enforces a strict, computed range for
max_total_draft_tokens.- Falls back to
dynamic_tree_max_topK * max_draft_lenwhen unset, and tomax_draft_lenonly in the linear-tree case.This keeps static-tree behavior intact while making dynamic-tree configuration safer.
tests/integration/defs/accuracy/test_llm_api.py (1)
498-505: Dynamic-tree Eagle2 test config matches new invariantsUsing
max_draft_len=4,max_total_draft_tokens=63, anddynamic_tree_max_topK=10is consistent with the newEagleDecodingConfigchecks (63 falls within the allowed dynamic-tree range). This test should give good coverage for the dynamic-tree path without changing the overall draft-token budget too drastically.tensorrt_llm/_torch/attention_backend/interface.py (1)
339-350:update_spec_dec_paramsignature now reflects SpecTreeManager‑based design; verify external backendsThe new signature (
is_target_modelflag, nospec_metadata/spec_decoding_tensor, optionalspec_tree_manager) matches the rest of the spec‑dec refactor and in‑tree call sites. However, this is a breaking change for any out‑of‑treeAttentionBackendimplementations that overrode or calledupdate_spec_dec_paramwith the old parameters, so it’s worth double‑checking whether such extensions exist and, if so, documenting the migration.tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
375-396: Tri-state drafting loop selection looks correct.The logic properly selects between
StaticTreeDraftingLoopWrapper,DynamicTreeDraftingLoopWrapper, andLinearDraftingLoopWrapperbased on theEagleDecodingConfigattributes. The mutual exclusivity betweeneagle_choices(static) anduse_dynamic_tree(dynamic) is enforced inEagleDecodingConfig.__init__per the relevant code snippets fromllm_args.py.tests/integration/defs/test_e2e.py (2)
2076-2104: Static tree test rename looks good.The function rename from
test_draft_token_tree_quickstart_advanced_eagle3totest_static_draft_token_tree_quickstart_advanced_eagle3clearly distinguishes it from the new dynamic tree test.
2106-2141: New dynamic tree test properly exercises the feature.The test correctly uses the new CLI options (
--use_dynamic_tree,--dynamic_tree_max_topK,--max_total_draft_tokens) to exercise dynamic tree drafting. The memory threshold of 27 GiB matches the static tree test, which is reasonable for the same model configuration.tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py (1)
8-9: Import and instantiation updated correctly for class rename.The test now correctly uses
StaticTreeDraftingLoopWrapperfollowing the class rename indrafting_loops.py. The test name (test_draft_token_static_tree_sampling) and functionality align with the static tree drafting path.Also applies to: 56-61
tests/unittest/_torch/modeling/test_modeling_llama.py (5)
22-22: Import updated to use SpecTreeManager.The import correctly switches from
SpecDecodingTensortoSpecTreeManagerto align with the refactored spec-decoding parameter handling.
495-507: Tensor shapes updated for SpecTreeManager compatibility.The
spec_decoding_position_offsetsnow uses a 2D shape[[...]]andspec_decoding_packed_maskincludes.unsqueeze(-1)to match the expected buffer shapes inSpecTreeManager(per the relevant code snippets showing shapes like[num_trees, max_total_draft_tokens + 1, ...]).
534-557: SpecTreeManager setup and parameter update look correct.The
SpecTreeManageris properly initialized with dynamic tree configuration (use_dynamic_tree=True,max_draft_len=3,max_total_draft_tokens=max_total_draft_tokens). Thespec_dec_position_offsetsandspec_dec_packed_maskare assigned as attributes before callingupdate_spec_dec_param. The addition ofis_target_model=Truecorrectly indicates this is testing the target model path.
598-623: Second generation phase correctly updates SpecTreeManager state.The test properly updates
spec_tree_manager.spec_dec_position_offsets,spec_dec_packed_mask, andspec_dec_generation_lengthsfor the second generation phase. The assignment ofspec_decoding_generation_lengthsto the metadata at line 623 ensures the value propagates correctly.
670-690: Reference generation phase updated consistently.The reference path mirrors the same pattern: updating
SpecTreeManagerattributes and passing them throughupdate_spec_dec_param. The test maintains consistency across all three generation phases.tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
2621-2641: Spec-dec param call wiring for dynamic/static tree looks consistentThe updated
update_spec_dec_paramcall matches the new interface (is_spec_dec_dynamic_tree,is_target_model,spec_tree_manager) and correctly uses the savedoriginal_max_*values so attention metadata is sized for the full tree even when the draft model’s runtime config has been zeroed. No issues from this change itself.tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py (2)
10-29: Wrapper imports and DummyModel flag look appropriateSwitching to
StaticTreeDraftingLoopWrapper/DynamicTreeDraftingLoopWrapperand importingModelDrafteraligns this test with the new speculative API. Addingself.model_is_wrapped = TrueonDummyModelis a simple way to satisfy wrapper expectations in tests without pulling in full engine machinery. No issues here.
1155-1350: Dynamic-tree restructuring test correctly exercisesModelDrafter.reconstruct_dynamic_tree
test_dynamic_tree_restruct_treesets upEagle3ResourceManager,SpecTreeManager, andModelDrafterand then validates thatreconstruct_dynamic_treeproduces the expected:
eagle_paths,spec_dec_packed_mask,spec_dec_position_offsets.The reference tensors encode non-trivial tree topologies and masking patterns, so this gives strong coverage of the reconstruction logic for dynamic trees. The constructions and assertions are shape-consistent and deterministic.
tensorrt_llm/_torch/speculative/spec_tree_manager.py (2)
123-141: Dynamic-tree drafter offsets look consistent with Eagle configThe new
spec_dec_position_offsets_for_drafter_modelbuffer and its initialization ininit_tree_info_for_dynamic_treecorrectly precompute fixed per-layer offsets based ondynamic_tree_max_topKandmax_draft_len. This matches howDynamicTreeDraftingLoopWrapper.prepare_for_generationflattens per-layer tokens later.
324-355: Dynamic vs static dump output separation is clearThe updated
dump_tree_infocleanly distinguishes dynamic-tree vs static-tree logging and surfaces the newspec_dec_position_offsets_for_drafter_modelfield for debugging, which should help when validating dynamic-tree behavior.tensorrt_llm/_torch/speculative/drafting_loops.py (4)
201-303: Static tree wrapper correctly requires Eagle3 metadata and SpecTreeManager
StaticTreeDraftingLoopWrapper.forwardnow assertsspec_metadataisEagle3SpecMetadataand pullsspec_tree_managerfrom the Eagle3 resource manager, which makes the static-tree dependency explicit. The reuse of the prior tree logic (sampling withmax_top_k, fillingdraft_tokens_buffer, and usingspec_dec_packed_mask_for_drafter_model) looks consistent.No functional issues spotted here.
656-674: Sampling via log-softmax is appropriate for path scoringUsing
LogSoftmaxfollowed bytorch.topkto obtain both tokens and log-prob scores is appropriate for additive path scoring inupdate_draft_tokens_and_scores. Thed2tadjustment is consistent with the static-tree and linear wrappers.No changes needed here.
898-923: Resampling final draft tokens from history buffers is consistent
resampling_final_draft_tokenscorrectly takes top-max_total_draft_tokenspaths based on accumulated scores inhistory_score_bufferand uses those indices to gather both tokens and indices for later tree reconstruction. Given thatupdate_draft_tokens_and_scorespopulates the entire history buffer over all layers, there are no uninitialized entries and this top-k is well-defined.This function looks correct as written.
924-1038: Dynamic-treeprepare_for_generationmirrors static-tree setup with drafter-specific offsetsThe dynamic-tree
prepare_for_generationcorrectly:
- derives
position_idsfrom the last accepted token per request and usesspec_dec_position_offsets_for_drafter_modelto tile them;- adjusts
kv_lens_cuda,_seq_lens,_seq_lens_cuda, andnum_contextssimilarly to the static-tree path;- enables
use_spec_decodingand setsspec_decoding_position_offsets/spec_decoding_generation_lengthsfor the drafter;- initializes
hidden_states_read_indices/hidden_states_write_indicesbased on per-requeststart_idx.This aligns with the expected Eagle3 semantics for the first drafter layer in dynamic-tree mode.
tensorrt_llm/_torch/attention_backend/trtllm.py (1)
1182-1325: Add assertion to verifymax_total_draft_tokensconsistency withSpecTreeManagerThe refactored
update_spec_dec_paramcleanly separates target-model and drafter-layer cases. However, there is an implicit assumption thatmax_total_draft_tokenspassed as a parameter matchesspec_tree_manager.max_total_draft_tokenswhen the tree path is taken.Add an assertion in the tree case to catch mismatches early:
if self.is_spec_dec_tree: assert spec_tree_manager is not None, "spec_tree_manager is required for tree" + assert ( + spec_tree_manager.max_total_draft_tokens == max_total_draft_tokens + ), "max_total_draft_tokens must match SpecTreeManager.max_total_draft_tokens"This prevents silent bugs where inconsistent
max_total_draft_tokensvalues could cause incorrect buffer layouts or off-by-one errors in offset/mask generation.tensorrt_llm/_torch/speculative/model_drafter.py (1)
1001-1059:prepare_draft_tokenssignature change may break existing call sites
ModelDrafter.prepare_draft_tokensnow requires a non-optionalResourceManagerparameter (no default value), while the baseDrafter.prepare_draft_tokensstill declaresresource_manager: Optional[ResourceManager] = None. Code callingprepare_draft_tokens(scheduled_requests)without explicitly passingresource_managerwill raise aTypeError. Verify that all call sites have been updated to passresource_managerexplicitly. If not, consider maintaining backward compatibility by keeping a default value and validating non-Noneonly where required.
|
PR_Github #26449 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #26690 [ run ] triggered by Bot. Commit: |
|
PR_Github #26690 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
--max_total_draft_tokensCLI parameter for granular control over draft token allocationChores
✏️ Tip: You can customize this high-level summary in your review settings.
Description
This PR adds runtime support for dynamic trees based on CDL. AR has not yet been verified. This will be done in future work.
The execution flow of a dynamic tree is as follows.
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.