Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def add_llm_args(parser):
default=False,
action='store_true')
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
parser.add_argument('--max_total_draft_tokens', type=int, default=None)

# Relaxed acceptance
parser.add_argument('--use_relaxed_acceptance_for_thinking',
Expand Down Expand Up @@ -205,7 +206,8 @@ def setup_llm(args, **kwargs):
eagle3_one_model=args.use_one_model,
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
dynamic_tree_max_topK=args.dynamic_tree_max_topK)
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
max_total_draft_tokens=args.max_total_draft_tokens)
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,
Expand Down
7 changes: 2 additions & 5 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from typing_extensions import Self

if TYPE_CHECKING:
from ..speculative.utils import SpecDecodingTensor
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager

from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils,
Expand Down Expand Up @@ -346,10 +344,9 @@ def update_spec_dec_param(
is_spec_dec_dynamic_tree,
max_draft_len,
max_total_draft_tokens,
is_target_model: bool = True,
model_is_wrapped: bool = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None):
spec_tree_manager: Optional['SpecTreeManager'] = None):
"""
Hook to be called when using TRTLLM attention backend in spec-dec mode.
"""
Expand Down
27 changes: 12 additions & 15 deletions tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,24 +496,21 @@ def create_expanded_buffers(self, capture_graph=False):
# This function is only used to create the expanded buffers when the max_draft_tokens is changed.
# TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
def update_spec_dec_param(
self,
batch_size,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_len,
max_total_draft_tokens,
model_is_wrapped: bool = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
):
self,
batch_size,
is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree,
max_draft_len,
max_total_draft_tokens,
is_target_model: bool = True,
model_is_wrapped: bool = False,
spec_tree_manager: Optional['SpecTreeManager'] = None):
super().update_spec_dec_param(batch_size, is_spec_decoding_enabled,
is_spec_dec_tree,
is_spec_dec_dynamic_tree, max_draft_len,
max_total_draft_tokens, model_is_wrapped,
spec_metadata, spec_tree_manager,
spec_decoding_tensor)
max_total_draft_tokens, is_target_model,
model_is_wrapped, spec_tree_manager)
self.max_draft_tokens = max_draft_len
init_shape = self.kv_lens_expanded_host.shape[0]
if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape:
Expand Down
95 changes: 37 additions & 58 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import torch

if TYPE_CHECKING:
from ..speculative.utils import SpecDecodingTensor
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager

from tensorrt_llm._utils import get_sm_version
Expand Down Expand Up @@ -1189,10 +1187,9 @@ def update_spec_dec_param(
is_spec_dec_dynamic_tree,
max_draft_len,
max_total_draft_tokens,
is_target_model: bool = True,
model_is_wrapped: bool = False,
spec_metadata: Optional['SpecMetadata'] = None,
spec_tree_manager: Optional['SpecTreeManager'] = None,
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None,
) -> None:
'''
Update the spec-dec parameters for the TRTLLM attention layer.
Expand All @@ -1203,34 +1200,17 @@ def update_spec_dec_param(
is_spec_dec_dynamic_tree: bool, whether using dynamic tree.
max_draft_len: int, the number of the draft layers.
max_total_draft_tokens: int, the number of all nodes in the tree (except the root).
is_target_model: bool = True, whether the model is the target model.
model_is_wrapped: Optional[bool] = False, whether the drafter model is wrapped (i.e, CDL).
spec_metadata: Optional['SpecMetadata'] = None, the metadata of the spec-dec.
spec_tree_manager: Optional['SpecTreeManager'] = None, the spec_tree_manager for draft token tree.
spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, the spec_decoding_tensor for draft token tree.
'''
if spec_decoding_tensor is not None:
spec_decoding_position_offsets = spec_decoding_tensor.position_offsets
spec_decoding_packed_mask = spec_decoding_tensor.packed_mask
spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths
else:
spec_decoding_position_offsets = None
spec_decoding_packed_mask = None
spec_decoding_generation_lengths = None

# spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree.
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
get_sm_version() < 100 or get_sm_version() == 120)

self.is_spec_dec_tree = spec_tree_manager is not None
self.is_spec_dec_dynamic_tree = spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree

if get_sm_version() >= 100 and get_sm_version() != 120:
if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree:
assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."

# use_spec_decoding is default to true by default, change in runtime by layers / requests
self.use_spec_decoding = self.is_spec_decoding_enabled

self.is_spec_dec_tree = is_spec_dec_tree
self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree

Expand Down Expand Up @@ -1267,43 +1247,42 @@ def update_spec_dec_param(
self.spec_decoding_bl_tree_mask = None
self.spec_bl_tree_first_sparse_mask_offset_kv = None

# Case 1: dynamic tree
if self.is_spec_dec_dynamic_tree:
assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree"
assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree"
self.spec_decoding_position_offsets.copy_(
spec_decoding_position_offsets, non_blocking=True)
self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask,
non_blocking=True)
if spec_decoding_generation_lengths is not None:
self.spec_decoding_generation_lengths.copy_(
spec_decoding_generation_lengths, non_blocking=True)
else:
self.generate_spec_decoding_generation_length(
max_draft_len=max_total_draft_tokens)

# Case 2/3: static tree
elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None:
assert spec_metadata.spec_dec_mode.is_eagle3(
), "Tree decoding is only supported for Eagle3 now"

is_target_model = not getattr(spec_metadata, 'is_draft_model',
False)

# Case 2: static tree and target model
# Case 1: draft token tree
if self.is_spec_dec_tree:
assert spec_tree_manager is not None, "spec_tree_manager is required for tree"
# Case 1.1: target model
if is_target_model:
# For the target model, we update the spec-dec parameters with the spec_tree_manager, which is prepared in advance.
self.spec_decoding_position_offsets[:batch_size, :].copy_(
spec_tree_manager.spec_dec_position_offsets[0, :],
non_blocking=True)
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
spec_tree_manager.spec_dec_packed_mask[0, :, :],
non_blocking=True)
self.spec_decoding_generation_lengths[:batch_size].fill_(
spec_tree_manager.max_total_draft_tokens + 1)

# Case 3: static tree and the first drafter layer
# Case 1.1.1: dynamic tree
if self.is_spec_dec_dynamic_tree:
# For the dynamic tree, we just copy batch_size's spec_tree_manager.spec_dec_position_offsets, spec_dec_packed_mask.
# - For the context requests, we do not need to prepare these spec-dec parameters.
# - For the generation requests, their relative spec-dec parameters are update in 'model_drafter.py::reconstruct_dynamic_tree()'
# And the XQA kernel will only handle the generation requests.
self.spec_decoding_position_offsets[:batch_size, :].copy_(
spec_tree_manager.
spec_dec_position_offsets[:batch_size, :],
non_blocking=True)
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
spec_tree_manager.
spec_dec_packed_mask[:batch_size, :, :],
non_blocking=True)
self.spec_decoding_generation_lengths[:batch_size].fill_(
spec_tree_manager.max_total_draft_tokens + 1)
# Case 1.1.2: static tree
else:
# For the target model, we update the spec-dec parameters with the spec_tree_manager, which is prepared in advance.
self.spec_decoding_position_offsets[:batch_size, :].copy_(
spec_tree_manager.spec_dec_position_offsets[0, :],
non_blocking=True)
self.spec_decoding_packed_mask[:batch_size, :, :].copy_(
spec_tree_manager.spec_dec_packed_mask[0, :, :],
non_blocking=True)
self.spec_decoding_generation_lengths[:batch_size].fill_(
spec_tree_manager.max_total_draft_tokens + 1)

# Case 1.2: the first drafter layer
else:
# Dynamic tree and static tree can use the same code path.
assert model_is_wrapped == True, "The drafter model should be wrapped"
# The first drafter layer will take the padded tokens as input (padding to the max_draft_len + 1)
# But the spec-dec parameters are still in the shape of max_total_draft_tokens + 1.
Expand Down Expand Up @@ -1332,7 +1311,7 @@ def update_spec_dec_param(
self.generate_spec_decoding_generation_length(
max_draft_len=max_draft_len)

# Case 4: linear tree
# Case 2: linear tree
else:
assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree"
# Prepare for the linear-tree.
Expand Down
7 changes: 2 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.utils import SpecDecodingTensor
from ..utils import (get_model_extra_attrs,
set_per_request_piecewise_cuda_graph_flag,
set_torch_compiling, with_model_extra_attrs)
Expand Down Expand Up @@ -2610,7 +2609,6 @@ def forward(self,
new_tensors_device: Optional[SampleStateTensors] = None,
gather_context_logits: bool = False,
cache_indirection_buffer: Optional[torch.Tensor] = None,
spec_decoding_tensor: Optional[SpecDecodingTensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None,
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
kv_cache_manager = resource_manager.get_resource_manager(
Expand All @@ -2637,10 +2635,9 @@ def forward(self,
is_spec_dec_dynamic_tree=spec_metadata.is_spec_dec_dynamic_tree,
max_draft_len=self.original_max_draft_len,
max_total_draft_tokens=self.original_max_total_draft_tokens,
is_target_model=not self.is_draft_model,
model_is_wrapped=self.model_is_wrapped,
spec_metadata=spec_metadata,
spec_tree_manager=spec_tree_manager,
spec_decoding_tensor=spec_decoding_tensor)
spec_tree_manager=spec_tree_manager)
else:
spec_resource_manager = None
spec_metadata = None
Expand Down
21 changes: 18 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,33 @@ def allocation_scope(current_stage: ExecutorMemoryType,

def drafting_loop_wrapper(model):
from tensorrt_llm._torch.speculative.drafting_loops import (
LinearDraftingLoopWrapper, TreeDraftingLoopWrapper)
DynamicTreeDraftingLoopWrapper,
LinearDraftingLoopWrapper,
StaticTreeDraftingLoopWrapper)
from tensorrt_llm.llmapi import EagleDecodingConfig

use_tree_drafter = isinstance(
draft_spec_config, EagleDecodingConfig
) and not draft_spec_config.is_linear_tree

if use_tree_drafter:
return TreeDraftingLoopWrapper(
static_tree_drafter = isinstance(
draft_spec_config, EagleDecodingConfig
) and draft_spec_config.eagle_choices is not None

dynamic_tree_drafter = isinstance(
draft_spec_config, EagleDecodingConfig
) and draft_spec_config.use_dynamic_tree

if static_tree_drafter:
return StaticTreeDraftingLoopWrapper(
spec_config.max_draft_len,
spec_config.max_total_draft_tokens, max_batch_size,
model)
elif dynamic_tree_drafter:
return DynamicTreeDraftingLoopWrapper(
spec_config.max_draft_len,
spec_config.max_total_draft_tokens, max_batch_size,
draft_spec_config.dynamic_tree_max_topK, model)
else:
return LinearDraftingLoopWrapper(
spec_config.max_draft_len,
Expand Down
Loading