-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[#9147][feat] AutoDeploy: Draft Target Speculative Decoding #9275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…runs as ADEngine, draft model runs as PyTorchModelEngine. Only two-model spec dec is supported. Signed-off-by: Govind Ramnarayan <[email protected]>
📝 WalkthroughWalkthroughThis change introduces speculative decoding support to the AutoDeploy framework by adding draft model engine creation, configuration wiring, KV cache management, and speculative resource integration. New helper functions enable constructing draft configurations and orchestrating speculative decoding alongside the main model engine. Changes
Sequence DiagramsequenceDiagram
participant Main as Autodeploy Executor
participant ADEngine as AD Engine Build
participant DraftEngine as Draft Engine Create
participant SpecMgr as Spec Resource Manager
participant Drafter as Spec Drafter
Main->>ADEngine: build_from_config(ad_config)
ADEngine->>ADEngine: Initialize with ad_config<br/>Store model_config & run_with_spec_decode
rect rgb(220, 240, 255)
Note over Main,Drafter: If speculative decoding enabled
Main->>DraftEngine: create_draft_model_engine_maybe()
DraftEngine-->>Main: draft_model_engine (or None)
Main->>SpecMgr: create_spec_resource_manager()
SpecMgr->>SpecMgr: Extract config from ADEngine<br/>Use draft_model_engine if present
SpecMgr-->>Main: spec_resource_manager
Main->>Drafter: get_spec_drafter()
Drafter->>Drafter: Build from spec_config,<br/>draft_model_engine, sampler
Drafter-->>Main: drafter instance
Main->>Main: Wire draft KV cache<br/>& spec resources to ResourceManager
end
Main->>ADEngine: Forward with spec_metadata
ADEngine-->>Main: Speculative decoded output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (8)
tests/integration/defs/examples/test_ad_speculative_decoding.py (2)
33-47: Consider adding path validation for clearer error messages.The function constructs model paths but doesn't verify they exist. While the test will fail if paths are missing, adding validation would provide clearer error messages.
You could add a quick check like:
def get_model_paths(): """Get model paths using llm_models_root().""" models_root = llm_models_root() base_model = os.path.join( models_root, "llama-models-v2/Llama-3.1-8B-Instruct", ) speculative_model = os.path.join( models_root, "llama-models-v2/TinyLlama-1.1B-Chat-v1.0", ) # Validate paths exist if not os.path.exists(base_model): raise FileNotFoundError(f"Base model not found: {base_model}") if not os.path.exists(speculative_model): raise FileNotFoundError(f"Speculative model not found: {speculative_model}") print(f"Base model path: {base_model}") print(f"Speculative model path: {speculative_model}") return base_model, speculative_model
62-62: Consider adding a boundary check for batch_size.The slice operation
prompts[:batch_size]won't fail but could silently return fewer prompts than requested ifbatch_size > len(prompts). While the current test only uses batch_size values of 1 and 4 (both within bounds), adding a check would make the function more robust for future modifications.# Select prompts based on batch size +if batch_size > len(prompts): + raise ValueError(f"batch_size ({batch_size}) exceeds available prompts ({len(prompts)})") selected_prompts = prompts[:batch_size]tensorrt_llm/_torch/pyexecutor/_util.py (1)
498-672: Shared_create_kv_cache_managerhelper looks sound; consider documenting behavior differences across modes.The new helper correctly:
- Enforces
model_engine.model.model_config.is_generation.- Derives
head_dimrobustly and respects FP8/FP4 KV cache quantization.- Handles MLA, Nemotron hybrid, Qwen3 Next, and the general/VSWA case, including connector and beam‑search constraints and passing
is_draft/max_num_tokensinto the generic KVCacheManager.Given the complexity of the branching, a short docstring summarizing when each branch is taken (especially VSWA vs non‑VSWA and when
kv_connector_manageris honored or ignored) would make future maintenance easier, but functionally this looks correct.tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (5)
103-173: Draft TorchLlmArgs construction is reasonable; consider slightly more general load_format handling.The way you reuse
LlmArgsto buildTorchLlmArgsfor the draft model is solid:
- Core sizing and KV‑cache fields are mirrored.
- Optional fields are copied only when present and non‑None.
- A separate checkpoint loader for the draft (
draft_checkpoint_loader) is supported.One minor limitation:
draft_spec_config.load_formatonly handles the"dummy"string. If in future other load formats are allowed on the draft spec config (e.g."VISION_ONLY"), you might want to pass them through generically instead of special‑casing just"dummy".
388-539: Overlap-scheduler + draft-token handling in_prepare_inputslooks coherent; no obvious logic bugs.The new helpers:
_compute_num_tokens_seencorrectly distinguish extend vs normal generation and adjustinput_posdepending on whether overlap scheduling (new_tokens) andpy_batch_idxare active._build_input_idsbuilds either true token sequences (no overlap / dummy) or dummy-token sequences plus gather indices for the overlap scheduler, including the extend case where draft tokens are concatenated.Because
py_batch_idxis only set after_build_input_idson the first iteration and persists on the request, new sequences will naturally take the “no overlap” path initially, and only reusenew_tokensonce they’ve been assigned a batch index. That matches typical overlap behavior.Given the complexity, adding a small comment that clarifies the assumed shape/layout of
new_tokens(rows vs columns) for the extend path would make future reasoning aboutgather_idxsafer, but functionally this looks correct.
554-585:spec_metadatais currently unused and type hint fornew_tensors_deviceis stale.In
forward:
- You build a
spec_metadataobject whenrun_with_spec_decodeis true and pass it into_prepare_inputs, but_prepare_inputsnever uses it. That means we pay the cost of constructingSpecMetadataeach step without any effect. Either wiring this through to wherever it’s intended to be consumed, or dropping it for now (with a TODO) would reduce confusion and dead code.- The
new_tensors_deviceparameter is annotated asOptional[torch.Tensor], but you now treat it asSampleStateTensorsand immediately access.new_tokens. Updating the type hint toOptional[SampleStateTensors]will make this clearer to readers and tooling.Neither affects runtime today (beyond minor overhead), but tightening them up would improve maintainability.
604-704: Draft model engine creation is well-structured; ensure the DraftTarget-only assert matches intended mode support.
create_draft_model_engine_maybe:
- Correctly gates on
spec_config is not Noneandhas_draft_model().- Builds a
draft_spec_configcopy, optionally wraps withChainDrafter, and configuresAttentionRuntimeFeaturesso the draft model can participate in chunked prefill and cache reuse.- Constructs a
PyTorchModelEnginewithis_draft_model=True, sets the draft KV cache manager key, and disables chunked prefill for MLA targets viais_mla(engine.model_config).The explicit:
assert ad_config.speculative_config.spec_dec_mode.is_draft_target()makes it clear this helper is currently only meant for DraftTarget speculative decoding, which matches the PR description. Just keep in mind that the earlier MTP‑Eagle special case (
is_mtp_eagle) will be dead code as long as this assert remains; if you later extend AutoDeploy to other two‑model modes, you’ll need to revisit this assertion and the DeepseekV3 MTP tweak.
835-852: Guided decoding and speculative drafter are still independent; consider explicit TODO about combining them.You correctly construct an optional
GuidedDecoderfor the last PP rank and a speculative drafter viaget_spec_drafter, and pass both intoPyExecutor. However,get_spec_drafteris always called withguided_decoder=None, which means guided decoding and speculative decoding cannot yet be combined in the AD flow.You already have a TODO in
_prepare_inputsabout managing guided + speculative decoding together. It might be worth adding a brief comment here noting that the combination isn’t supported yet and thatguided_decoderintentionally isn’t wired into the drafter, to avoid confusion for future readers.Also applies to: 853-876
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
tensorrt_llm/_torch/auto_deploy/llm_args.py(1 hunks)tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py(17 hunks)tensorrt_llm/_torch/pyexecutor/_util.py(2 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py(1 hunks)tensorrt_llm/_torch/speculative/__init__.py(2 hunks)tensorrt_llm/_torch/speculative/utils.py(6 hunks)tensorrt_llm/llmapi/llm_args.py(2 hunks)tests/integration/defs/examples/test_ad_speculative_decoding.py(1 hunks)tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py(1 hunks)
🧰 Additional context used
🧠 Learnings (13)
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.
Applied to files:
tensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
Applied to files:
tensorrt_llm/llmapi/llm_args.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tensorrt_llm/llmapi/llm_args.py
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/py_executor.py
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-15T06:46:53.813Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:53.813Z
Learning: In the TensorRT-LLM KV cache manager, SWA (Sliding Window Attention) combined with beam search is currently in a broken/non-functional state and is planned for future rework. During preparatory refactoring phases, code related to SWA+beam search may intentionally remain in a non-working state until the broader rework is completed.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-21T09:41:49.347Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:2010-2045
Timestamp: 2025-08-21T09:41:49.347Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, updateSequenceCacheBlockOffsets is specifically for updating bookkeeping when blocks are added during the context phase, not for refreshing offsets after detach operations. During detach operations, GenerationRequest::removeFrontBlock handles the necessary cache block bookkeeping internally.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
📚 Learning: 2025-08-26T06:07:02.166Z
Learnt from: shaharmor98
Repo: NVIDIA/TensorRT-LLM PR: 7231
File: tensorrt_llm/_torch/pyexecutor/_util.py:504-509
Timestamp: 2025-08-26T06:07:02.166Z
Learning: In tensorrt_llm/_torch/pyexecutor/_util.py, when calling model_engine.set_lora_model_config(), pass model_binding_config.mlp_hidden_size directly without multiplying by mapping.tp_size, as the mlp_hidden_size from get_bindings_model_config() is already the per-TP rank value needed for LoRA weight packaging.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor.py
📚 Learning: 2025-08-18T08:42:02.640Z
Learnt from: samuellees
Repo: NVIDIA/TensorRT-LLM PR: 6974
File: tensorrt_llm/serve/scripts/benchmark_dataset.py:558-566
Timestamp: 2025-08-18T08:42:02.640Z
Learning: In TensorRT-LLM's RandomDataset (tensorrt_llm/serve/scripts/benchmark_dataset.py), when using --random-token-ids option, sequence length accuracy is prioritized over semantic correctness for benchmarking purposes. The encode/decode operations should use skip_special_tokens=True and add_special_tokens=False to ensure exact target token lengths.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor.py
📚 Learning: 2025-08-14T15:43:23.107Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: tensorrt_llm/_torch/attention_backend/trtllm.py:259-262
Timestamp: 2025-08-14T15:43:23.107Z
Learning: In TensorRT-LLM's attention backend, tensor parameters in the plan() method are assigned directly without validation (dtype, device, contiguity checks). This maintains consistency across all tensor inputs and follows the pattern of trusting callers to provide correctly formatted tensors.
Applied to files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
📚 Learning: 2025-09-16T09:30:09.716Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7763
File: cpp/tensorrt_llm/CMakeLists.txt:297-301
Timestamp: 2025-09-16T09:30:09.716Z
Learning: In the TensorRT-LLM project, NCCL libraries are loaded earlier by PyTorch libraries or the bindings library, so the main shared library doesn't need NCCL paths in its RPATH - the libraries will already be available in the process address space when needed.
Applied to files:
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
🧬 Code graph analysis (7)
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)
get_small_model_config(508-547)examples/auto_deploy/build_and_run_ad.py (1)
ExperimentConfig(126-239)tensorrt_llm/llmapi/llm_args.py (2)
DraftTargetDecodingConfig(958-971)KvCacheConfig(1426-1570)
tests/integration/defs/examples/test_ad_speculative_decoding.py (2)
examples/auto_deploy/build_and_run_ad.py (1)
ExperimentConfig(126-239)tensorrt_llm/llmapi/llm_args.py (2)
DraftTargetDecodingConfig(958-971)KvCacheConfig(1426-1570)
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
tensorrt_llm/llmapi/llm_args.py (1)
Field(63-90)
tensorrt_llm/_torch/pyexecutor/_util.py (4)
tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py (1)
KvCacheConnectorManager(364-572)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
KVCacheManager(151-1196)tensorrt_llm/_utils.py (4)
str_dtype_to_binding(221-224)torch_dtype_to_str(230-231)dtype(985-986)dtype(993-1003)tensorrt_llm/_torch/pyexecutor/config_utils.py (3)
is_mla(12-16)is_nemotron_hybrid(4-9)is_qwen3_next(19-23)
tensorrt_llm/_torch/speculative/utils.py (2)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (2)
model_config(301-302)get_spec_drafter(243-281)tensorrt_llm/_torch/speculative/eagle3.py (1)
Eagle3ResourceManager(23-107)
tensorrt_llm/_torch/speculative/__init__.py (1)
tensorrt_llm/_torch/speculative/utils.py (2)
_get_spec_drafter(225-268)_get_spec_resource_manager(111-175)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (7)
tensorrt_llm/_torch/attention_backend/interface.py (1)
AttentionRuntimeFeatures(27-32)tensorrt_llm/_torch/pyexecutor/_util.py (3)
_create_kv_cache_manager(424-466)_create_kv_cache_manager(498-672)get_kv_cache_manager_cls(47-55)tensorrt_llm/_torch/pyexecutor/config_utils.py (1)
is_mla(12-16)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
get_draft_token_length(807-818)tensorrt_llm/_torch/speculative/interface.py (3)
SpecMetadata(148-236)is_mtp_eagle(35-36)is_draft_target(56-57)tensorrt_llm/_torch/speculative/utils.py (3)
_get_spec_drafter(225-268)_get_spec_resource_manager(111-175)get_spec_drafter(271-289)tensorrt_llm/llmapi/llm_args.py (8)
DecodingBaseConfig(547-701)LoadFormat(2507-2512)TorchLlmArgs(2576-3009)spec_dec_mode(694-701)spec_dec_mode(832-837)spec_dec_mode(886-889)spec_dec_mode(1018-1025)speculative_model_dir(1931-1932)
🪛 Ruff (0.14.5)
tests/integration/defs/examples/test_ad_speculative_decoding.py
151-151: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tensorrt_llm/_torch/pyexecutor/_util.py
557-558: Avoid specifying long messages outside the exception class
(TRY003)
598-599: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
394-394: Unused method argument: spec_metadata
(ARG002)
492-492: Consider [dummy_token, *dummy_draft_tokens] instead of concatenation
Replace with [dummy_token, *dummy_draft_tokens]
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (16)
tests/integration/defs/examples/test_ad_speculative_decoding.py (4)
1-22: LGTM!The license header and imports are properly structured. All imported modules are used in the test implementation.
24-30: LGTM!Test prompts are well-defined and provide a good variety of queries for testing speculative decoding functionality.
100-104: Excellent use of deterministic sampling!Setting
temperature=0.0ensures deterministic outputs, which is essential for comparing speculative decoding results with baseline outputs. This is the correct approach for this correctness test.
114-165: Well-structured correctness test!The test approach is sound:
- Tests both batch sizes (1 and 4) to cover single and multi-batch scenarios
- Runs identical configurations with and without speculative decoding
- Uses deterministic sampling (temperature=0.0) to ensure reproducible outputs
- Thoroughly compares both prompts and outputs element-by-element
- Provides helpful diagnostic output for debugging
This validates that speculative decoding produces bit-identical results to the baseline, which is the expected behavior for Draft-Target decoding.
tensorrt_llm/_torch/auto_deploy/llm_args.py (1)
188-192: Newdraft_checkpoint_loaderfield is consistent but could be documented more explicitlyThe optional
draft_checkpoint_loaderhook fits well intoAutoDeployConfigand is backward-compatible as a pure additive field. Consider tightening the description to explicitly mention that this is expected to be aBaseCheckpointLoader-compatible object for the draft model (mirroringTorchLlmArgs.checkpoint_loadersemantics) so users don’t confuse it with the main-model loader.tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1045-1067: Spec-decode gating now correctly keyed offllm_args.max_num_tokensand spec configPassing
model_engine.llm_args.max_num_tokensandmodel_engine.spec_config.max_total_draft_tokensintodrafter.should_use_spec_decode(...)aligns the gating decision with the active runtime config and the actual speculative setup rather than any cached executor-side values. This looks correct, assuming the existing invariant thatmodel_engine.spec_configis set wheneverdrafteris non-Nonecontinues to hold.tensorrt_llm/llmapi/llm_args.py (1)
968-972: Extending DraftTarget speculative config to_autodeploybackend is consistentAllowing
DraftTargetDecodingConfig.supports_backendto returnTruefor both"pytorch"and"_autodeploy", and asserting the same set invalidate_speculative_config, cleanly enables two‑model DraftTarget decoding under AutoDeploy while preserving the existing behavior for other backends. The wiring toSpeculativeDecodingMode.DRAFT_TOKENS_EXTERNALandbuild_config.max_draft_lenremains unchanged and looks coherent with the rest of the speculative config handling.Also applies to: 2249-2255
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py (1)
22-74: Smoke test wiring for AutoDeploy DraftTarget spec decode looks reasonableThe test exercises a realistic AutoDeploy path: it builds base and draft configs via
get_small_model_config, configuresDraftTargetDecodingConfigandKvCacheConfig, togglesruntime="trtllm"andworld_size=1, and then drivesbuild_and_run_ad.main()with deterministic sampling before asserting on a single prompt/output pair. This is a good minimal regression guard for “DraftTarget + AutoDeploy” wiring; if you later want stronger coverage, you could also assert that speculative decoding was actually enabled (e.g., via stats or flags) rather than just that some text was produced.tensorrt_llm/_torch/speculative/__init__.py (1)
8-12: Re-exporting internal spec helpers is fine but broadens the public surfaceExposing
_get_spec_drafterand_get_spec_resource_managervia__all__is a straightforward way to make these helpers available to other modules/tests, and it doesn’t affect existing callers. Just be aware that, despite the leading underscore, this effectively promotes them to public API from a tooling perspective, so future refactors should treat them as semi‑stable entry points.Also applies to: 23-25
tensorrt_llm/_torch/speculative/utils.py (2)
111-176: Centralizing spec resource manager construction looks correct and keeps behavior consistent.The new
_get_spec_resource_managerplusget_spec_resource_managercleanly factor out logic and preserve existing semantics:
spec_config is Noneshort-circuits early.- Eagle3 / MTP Eagle assert a non-None
draft_model_config, which will catch misconfigured two‑model flows early.- Public wrapper correctly derives
spec_config,model_config,batch_size,max_seq_len, andmax_num_tokensfrommodel_engine.This is a nice internal API surface for both regular PyTorch engines and the new AutoDeploy helpers.
Also applies to: 178-207
225-267: Spec drafter helpers are consistent with existing mode handling.
_get_spec_drafterand itsget_spec_drafterwrapper mirror prior logic:
spec_config is Noneandis_user_provided()paths are preserved.- Draft‑target, Eagle3, and MTP Eagle all route through
ModelDrafterwith a per‑requestSeqSlotManager(max_num_requests).- N‑gram and save‑hidden‑states modes still use their specialized drafters.
The ADEngine-specific wrapper in AutoDeploy can safely reuse this helper.
Also applies to: 271-289
tensorrt_llm/_torch/pyexecutor/_util.py (1)
424-447: Delegating KvCacheCreator._create_kv_cache_manager to the shared helper is a good cleanup.The method now simply forwards the creator’s fields into
_create_kv_cache_manager, which keeps all branching logic centralized while preserving existing behavior. Draft vs non‑draft and estimation vs non‑estimation flows are still controlled viais_draftandestimating_kv_cacheflags.tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (4)
175-180:get_max_num_tokenshelper keeps max-token logic consistent across call sites.Using
num_tokens_limitwhen set and falling back tomax_seq_len * batch_sizematches the typical pattern used elsewhere and avoids duplicating this computation in the AD path.
182-209: Draft KV cache manager creation correctly reuses the shared helper.
create_draft_kv_cache_manager_maybe:
- Safely returns
Nonewhen no draft model or a non‑generation draft is present.- Uses
get_kv_cache_manager_clson the draft’smodel.model_config.- Delegates to the shared
_create_kv_cache_managerwithis_draft=Trueandkv_connector_manager=None, which aligns with the “no connector for draft models” constraint you mention.This should give the draft model a proper KV cache manager separate from the AutoDeploy fake pool.
347-377: ADEngine now carries ad_config/spec_config; this looks correct and is needed for spec wiring.Storing
ad_configandspec_configon ADEngine and mirroringseq_info.max_seq_lenintoself.llm_args.max_seq_lenare straightforward but important for downstream helpers (spec resource creation,get_max_num_sequences, etc.). This aligns the AD engine with expectations of the generic PyExecutor path.
706-795: Draft KV cache, spec resource manager, and drafter wiring into AutoDeploy executor looks correct.In
create_autodeploy_executor:
- You create a target ADEngine and an optional draft
PyTorchModelEngine, plus:
- A fake KV cache manager for the AD target.
- A real draft KV cache manager via
create_draft_kv_cache_manager_maybe.- A speculative resource manager via
create_spec_resource_manager.- These are all registered on the
ResourceManagerunder the expectedResourceManagerTypekeys, withKV_CACHE_MANAGERkept last as required.- A
TorchSampleris instantiated withmax_draft_len/max_total_draft_tokens, andget_spec_drafteris used to obtain the appropriate drafter whenspec_configis present.This is a clean integration of the new speculative path into the existing AutoDeploy executor construction.
tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_speculative_decoding.py
Show resolved
Hide resolved
Signed-off-by: Govind Ramnarayan <[email protected]>
|
/bot run |
|
PR_Github #24949 [ run ] triggered by Bot. Commit: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will only run in CI when we also properly register it in the corresponding CI runs. Just fyi, we can address this before merging
| ) | ||
|
|
||
| # Construct TorchLlmArgs for the draft model | ||
| draft_llm_args = construct_draft_llm_args( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need this construct_draft_llm_args function? why not just re-use ad_config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type is slightly different - it's TorchLlmArgs for the PyTorchModelEngine, and LlmArgs for the ADEngine. Both of these inherit from BaseLlmArgs, but PyTorchModelEngine specifically requires TorchLlmArgs. So I made this change due to the type checker complaining. https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/model_engine.py#L133
I only made this conversion function because I was getting an error otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess an alternative would be to have LlmArgs extend/inherit TorchLlmArgs. Wdyt? My inclination is we can try it if its preferable, but it might be overkill if it takes time.
| mb_scheduler = BindMicroBatchScheduler( | ||
| max_batch_size=ad_config.max_batch_size, | ||
| max_num_tokens=engine.cache_seq_interface.info.max_num_tokens, | ||
| max_num_tokens=engine.llm_args.max_num_tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not stick with engine.cache_seq_interface.info.max_num_tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't matter, I thought since the purpose of llm_args is basically to have the arguments to the ADEngine stored within it, might as well read from there - presumably it plays some role in engine.cache_seq_interface.info, and accessing it from there to pass without the rest of engine.cache_seq_interface.info is peeking into this abstraction. But here in engine.llm_args.max_num_tokens its only purpose is to be accessed from the outside. So it was just a slight preference, I can undo it.
| # If the model is an MLA model, we need to set | ||
| # draft_model_engine.attn_runtime_features.chunked_prefill to False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like pytorch backend requirement. Not sure why something similar would apply for us
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, maybe no need for this check altogether? (No need for target_model_config at all here)?
| # draft_model_engine.attn_runtime_features.chunked_prefill to False | ||
| target_model_config = engine.model_config | ||
|
|
||
| if is_mla(target_model_config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the only reason why we need target_model_config? In this case I suggest we have a utility that can extract that information from the graph.
For example, one approach could be to read the transform history that is stored in the graph here:
TensorRT-LLM/tensorrt_llm/_torch/auto_deploy/transform/interface.py
Lines 394 to 396 in 34f845b
| history[t_name] = info | |
| autodeploy_meta[self._history_key] = history | |
| self._set_autodeploy_meta(mod, autodeploy_meta) |
and check if the insert_cached_mla_attention transform had any matches (which would indicate presence of MLA)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above - should I just delete this?
|
PR_Github #24949 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Tests
Description
Implementation for vanilla ("DraftTarget") speculative decoding for AutoDeploy.
Target model runs as ADEngine, draft model runs as PyTorchModelEngine. Only two-model spec dec is supported.
Tested with various Llama models.
Fixes: #9147
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.