diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 9b37f8c7b29..a215b85675a 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -138,6 +138,7 @@ def add_llm_args(parser): default=False, action='store_true') parser.add_argument('--dynamic_tree_max_topK', type=int, default=None) + parser.add_argument('--max_total_draft_tokens', type=int, default=None) # Relaxed acceptance parser.add_argument('--use_relaxed_acceptance_for_thinking', @@ -205,7 +206,8 @@ def setup_llm(args, **kwargs): eagle3_one_model=args.use_one_model, eagle_choices=args.eagle_choices, use_dynamic_tree=args.use_dynamic_tree, - dynamic_tree_max_topK=args.dynamic_tree_max_topK) + dynamic_tree_max_topK=args.dynamic_tree_max_topK, + max_total_draft_tokens=args.max_total_draft_tokens) elif spec_decode_algo == "DRAFT_TARGET": spec_config = DraftTargetDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 43244fc1bc1..fa50aff1e8b 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -10,8 +10,6 @@ from typing_extensions import Self if TYPE_CHECKING: - from ..speculative.utils import SpecDecodingTensor - from ..speculative.interface import SpecMetadata from ..speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm.functional import (PositionEmbeddingType, RopeEmbeddingUtils, @@ -346,10 +344,9 @@ def update_spec_dec_param( is_spec_dec_dynamic_tree, max_draft_len, max_total_draft_tokens, + is_target_model: bool = True, model_is_wrapped: bool = False, - spec_metadata: Optional['SpecMetadata'] = None, - spec_tree_manager: Optional['SpecTreeManager'] = None, - spec_decoding_tensor: Optional['SpecDecodingTensor'] = None): + spec_tree_manager: Optional['SpecTreeManager'] = None): """ Hook to be called when using TRTLLM attention backend in spec-dec mode. """ diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index ccdd1e25fde..14530ebc970 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -496,24 +496,21 @@ def create_expanded_buffers(self, capture_graph=False): # This function is only used to create the expanded buffers when the max_draft_tokens is changed. # TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1. def update_spec_dec_param( - self, - batch_size, - is_spec_decoding_enabled, - is_spec_dec_tree, - is_spec_dec_dynamic_tree, - max_draft_len, - max_total_draft_tokens, - model_is_wrapped: bool = False, - spec_metadata: Optional['SpecMetadata'] = None, - spec_tree_manager: Optional['SpecTreeManager'] = None, - spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, - ): + self, + batch_size, + is_spec_decoding_enabled, + is_spec_dec_tree, + is_spec_dec_dynamic_tree, + max_draft_len, + max_total_draft_tokens, + is_target_model: bool = True, + model_is_wrapped: bool = False, + spec_tree_manager: Optional['SpecTreeManager'] = None): super().update_spec_dec_param(batch_size, is_spec_decoding_enabled, is_spec_dec_tree, is_spec_dec_dynamic_tree, max_draft_len, - max_total_draft_tokens, model_is_wrapped, - spec_metadata, spec_tree_manager, - spec_decoding_tensor) + max_total_draft_tokens, is_target_model, + model_is_wrapped, spec_tree_manager) self.max_draft_tokens = max_draft_len init_shape = self.kv_lens_expanded_host.shape[0] if self.max_num_sequences * (1 + self.max_draft_tokens) != init_shape: diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index d754eb701a8..5f664fb883b 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -7,8 +7,6 @@ import torch if TYPE_CHECKING: - from ..speculative.utils import SpecDecodingTensor - from ..speculative.interface import SpecMetadata from ..speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm._utils import get_sm_version @@ -1189,10 +1187,9 @@ def update_spec_dec_param( is_spec_dec_dynamic_tree, max_draft_len, max_total_draft_tokens, + is_target_model: bool = True, model_is_wrapped: bool = False, - spec_metadata: Optional['SpecMetadata'] = None, spec_tree_manager: Optional['SpecTreeManager'] = None, - spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, ) -> None: ''' Update the spec-dec parameters for the TRTLLM attention layer. @@ -1203,34 +1200,17 @@ def update_spec_dec_param( is_spec_dec_dynamic_tree: bool, whether using dynamic tree. max_draft_len: int, the number of the draft layers. max_total_draft_tokens: int, the number of all nodes in the tree (except the root). + is_target_model: bool = True, whether the model is the target model. model_is_wrapped: Optional[bool] = False, whether the drafter model is wrapped (i.e, CDL). - spec_metadata: Optional['SpecMetadata'] = None, the metadata of the spec-dec. spec_tree_manager: Optional['SpecTreeManager'] = None, the spec_tree_manager for draft token tree. - spec_decoding_tensor: Optional['SpecDecodingTensor'] = None, the spec_decoding_tensor for draft token tree. ''' - if spec_decoding_tensor is not None: - spec_decoding_position_offsets = spec_decoding_tensor.position_offsets - spec_decoding_packed_mask = spec_decoding_tensor.packed_mask - spec_decoding_generation_lengths = spec_decoding_tensor.generation_lengths - else: - spec_decoding_position_offsets = None - spec_decoding_packed_mask = None - spec_decoding_generation_lengths = None # spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree. self.is_spec_decoding_enabled = is_spec_decoding_enabled and ( get_sm_version() < 100 or get_sm_version() == 120) - self.is_spec_dec_tree = spec_tree_manager is not None - self.is_spec_dec_dynamic_tree = spec_tree_manager is not None and spec_tree_manager.use_dynamic_tree - - if get_sm_version() >= 100 and get_sm_version() != 120: - if self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree: - assert not self.is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree." - # use_spec_decoding is default to true by default, change in runtime by layers / requests self.use_spec_decoding = self.is_spec_decoding_enabled - self.is_spec_dec_tree = is_spec_dec_tree self.is_spec_dec_dynamic_tree = is_spec_dec_dynamic_tree @@ -1267,43 +1247,42 @@ def update_spec_dec_param( self.spec_decoding_bl_tree_mask = None self.spec_bl_tree_first_sparse_mask_offset_kv = None - # Case 1: dynamic tree - if self.is_spec_dec_dynamic_tree: - assert spec_decoding_position_offsets is not None, "spec_decoding_position_offsets is required for dynamic tree" - assert spec_decoding_packed_mask is not None, "spec_decoding_packed_mask is required for dynamic tree" - self.spec_decoding_position_offsets.copy_( - spec_decoding_position_offsets, non_blocking=True) - self.spec_decoding_packed_mask.copy_(spec_decoding_packed_mask, - non_blocking=True) - if spec_decoding_generation_lengths is not None: - self.spec_decoding_generation_lengths.copy_( - spec_decoding_generation_lengths, non_blocking=True) - else: - self.generate_spec_decoding_generation_length( - max_draft_len=max_total_draft_tokens) - - # Case 2/3: static tree - elif self.is_spec_dec_tree and not self.is_spec_dec_dynamic_tree and spec_metadata is not None: - assert spec_metadata.spec_dec_mode.is_eagle3( - ), "Tree decoding is only supported for Eagle3 now" - - is_target_model = not getattr(spec_metadata, 'is_draft_model', - False) - - # Case 2: static tree and target model + # Case 1: draft token tree + if self.is_spec_dec_tree: + assert spec_tree_manager is not None, "spec_tree_manager is required for tree" + # Case 1.1: target model if is_target_model: - # For the target model, we update the spec-dec parameters with the spec_tree_manager, which is prepared in advance. - self.spec_decoding_position_offsets[:batch_size, :].copy_( - spec_tree_manager.spec_dec_position_offsets[0, :], - non_blocking=True) - self.spec_decoding_packed_mask[:batch_size, :, :].copy_( - spec_tree_manager.spec_dec_packed_mask[0, :, :], - non_blocking=True) - self.spec_decoding_generation_lengths[:batch_size].fill_( - spec_tree_manager.max_total_draft_tokens + 1) - - # Case 3: static tree and the first drafter layer + # Case 1.1.1: dynamic tree + if self.is_spec_dec_dynamic_tree: + # For the dynamic tree, we just copy batch_size's spec_tree_manager.spec_dec_position_offsets, spec_dec_packed_mask. + # - For the context requests, we do not need to prepare these spec-dec parameters. + # - For the generation requests, their relative spec-dec parameters are update in 'model_drafter.py::reconstruct_dynamic_tree()' + # And the XQA kernel will only handle the generation requests. + self.spec_decoding_position_offsets[:batch_size, :].copy_( + spec_tree_manager. + spec_dec_position_offsets[:batch_size, :], + non_blocking=True) + self.spec_decoding_packed_mask[:batch_size, :, :].copy_( + spec_tree_manager. + spec_dec_packed_mask[:batch_size, :, :], + non_blocking=True) + self.spec_decoding_generation_lengths[:batch_size].fill_( + spec_tree_manager.max_total_draft_tokens + 1) + # Case 1.1.2: static tree + else: + # For the target model, we update the spec-dec parameters with the spec_tree_manager, which is prepared in advance. + self.spec_decoding_position_offsets[:batch_size, :].copy_( + spec_tree_manager.spec_dec_position_offsets[0, :], + non_blocking=True) + self.spec_decoding_packed_mask[:batch_size, :, :].copy_( + spec_tree_manager.spec_dec_packed_mask[0, :, :], + non_blocking=True) + self.spec_decoding_generation_lengths[:batch_size].fill_( + spec_tree_manager.max_total_draft_tokens + 1) + + # Case 1.2: the first drafter layer else: + # Dynamic tree and static tree can use the same code path. assert model_is_wrapped == True, "The drafter model should be wrapped" # The first drafter layer will take the padded tokens as input (padding to the max_draft_len + 1) # But the spec-dec parameters are still in the shape of max_total_draft_tokens + 1. @@ -1332,7 +1311,7 @@ def update_spec_dec_param( self.generate_spec_decoding_generation_length( max_draft_len=max_draft_len) - # Case 4: linear tree + # Case 2: linear tree else: assert max_draft_len == max_total_draft_tokens, "max_draft_len should be equal to max_total_draft_tokens for linear tree" # Prepare for the linear-tree. diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 0e104185bc2..d76d1f00071 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -50,7 +50,6 @@ from ..speculative.drafting_loops import BaseDraftingLoopWrapper from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata from ..speculative.mtp import SampleStateTensorsMTP -from ..speculative.utils import SpecDecodingTensor from ..utils import (get_model_extra_attrs, set_per_request_piecewise_cuda_graph_flag, set_torch_compiling, with_model_extra_attrs) @@ -2610,7 +2609,6 @@ def forward(self, new_tensors_device: Optional[SampleStateTensors] = None, gather_context_logits: bool = False, cache_indirection_buffer: Optional[torch.Tensor] = None, - spec_decoding_tensor: Optional[SpecDecodingTensor] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None, req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None): kv_cache_manager = resource_manager.get_resource_manager( @@ -2637,10 +2635,9 @@ def forward(self, is_spec_dec_dynamic_tree=spec_metadata.is_spec_dec_dynamic_tree, max_draft_len=self.original_max_draft_len, max_total_draft_tokens=self.original_max_total_draft_tokens, + is_target_model=not self.is_draft_model, model_is_wrapped=self.model_is_wrapped, - spec_metadata=spec_metadata, - spec_tree_manager=spec_tree_manager, - spec_decoding_tensor=spec_decoding_tensor) + spec_tree_manager=spec_tree_manager) else: spec_resource_manager = None spec_metadata = None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 3fc0027d638..6da7619d1ef 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -363,18 +363,33 @@ def allocation_scope(current_stage: ExecutorMemoryType, def drafting_loop_wrapper(model): from tensorrt_llm._torch.speculative.drafting_loops import ( - LinearDraftingLoopWrapper, TreeDraftingLoopWrapper) + DynamicTreeDraftingLoopWrapper, + LinearDraftingLoopWrapper, + StaticTreeDraftingLoopWrapper) from tensorrt_llm.llmapi import EagleDecodingConfig use_tree_drafter = isinstance( draft_spec_config, EagleDecodingConfig ) and not draft_spec_config.is_linear_tree - if use_tree_drafter: - return TreeDraftingLoopWrapper( + static_tree_drafter = isinstance( + draft_spec_config, EagleDecodingConfig + ) and draft_spec_config.eagle_choices is not None + + dynamic_tree_drafter = isinstance( + draft_spec_config, EagleDecodingConfig + ) and draft_spec_config.use_dynamic_tree + + if static_tree_drafter: + return StaticTreeDraftingLoopWrapper( spec_config.max_draft_len, spec_config.max_total_draft_tokens, max_batch_size, model) + elif dynamic_tree_drafter: + return DynamicTreeDraftingLoopWrapper( + spec_config.max_draft_len, + spec_config.max_total_draft_tokens, max_batch_size, + draft_spec_config.dynamic_tree_max_topK, model) else: return LinearDraftingLoopWrapper( spec_config.max_draft_len, diff --git a/tensorrt_llm/_torch/speculative/drafting_loops.py b/tensorrt_llm/_torch/speculative/drafting_loops.py index f044fdd1063..d615e5e812d 100644 --- a/tensorrt_llm/_torch/speculative/drafting_loops.py +++ b/tensorrt_llm/_torch/speculative/drafting_loops.py @@ -13,6 +13,7 @@ from typing import Optional, final import torch +import torch.nn as nn from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata from tensorrt_llm._torch.speculative.eagle3 import Eagle3SpecMetadata @@ -197,7 +198,7 @@ def prepare_for_generation(self, attn_metadata: AttentionMetadata, return new_position_ids -class TreeDraftingLoopWrapper(BaseDraftingLoopWrapper): +class StaticTreeDraftingLoopWrapper(BaseDraftingLoopWrapper): def __init__(self, max_draft_len: int, max_total_draft_tokens: int, max_batch_size: int, draft_model: torch.nn.Module): @@ -221,11 +222,8 @@ def __init__(self, max_draft_len: int, max_total_draft_tokens: int, def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata, **kwargs) -> dict[str, torch.Tensor]: - spec_tree_manager = None - if isinstance(spec_metadata, Eagle3SpecMetadata): - spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager - - assert spec_tree_manager is not None + assert isinstance(spec_metadata, Eagle3SpecMetadata) + spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager logits = self.draft_model.forward(input_ids=input_ids, position_ids=position_ids, @@ -509,3 +507,533 @@ def prepare_for_generation(self, attn_metadata: AttentionMetadata, spec_metadata.is_first_draft = False return + + +class DynamicTreeDraftingLoopWrapper(BaseDraftingLoopWrapper): + + def __init__(self, max_draft_len: int, max_total_draft_tokens: int, + max_batch_size: int, dynamic_tree_max_topK, + draft_model: torch.nn.Module): + super().__init__() + self.draft_model = draft_model + self.config = self.draft_model.config + self.model_config = self.draft_model.model_config + self.max_draft_len = max_draft_len + self.max_total_draft_tokens = max_total_draft_tokens + self.max_batch_size = max_batch_size + self.dynamic_tree_max_topK = dynamic_tree_max_topK + self.logsoftmax = nn.LogSoftmax(dim=-1) + + self.draft_tokens_buffer = torch.zeros( + (max_batch_size, max_total_draft_tokens + 1), + dtype=torch.int64, + device='cuda') + self.position_ids_buffer = torch.zeros( + (max_batch_size, max_total_draft_tokens + 1), + dtype=torch.int64, + device='cuda') + self.history_draft_tokens_buffer = torch.zeros( + (max_batch_size, dynamic_tree_max_topK + + dynamic_tree_max_topK * dynamic_tree_max_topK * + (max_draft_len - 1)), + dtype=torch.int64, + device='cuda') + self.history_score_buffer = torch.zeros( + (max_batch_size, dynamic_tree_max_topK + + dynamic_tree_max_topK * dynamic_tree_max_topK * + (max_draft_len - 1)), + dtype=torch.float32, + device='cuda') + self.history_draft_tokens_parent_buffer = torch.ones( + (max_batch_size, dynamic_tree_max_topK + + dynamic_tree_max_topK * dynamic_tree_max_topK * + (max_draft_len - 1)), + dtype=torch.int64, + device='cuda') * -1 + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata, + **kwargs) -> dict[str, torch.Tensor]: + + assert isinstance(spec_metadata, Eagle3SpecMetadata) + spec_tree_manager = spec_metadata.eagle3_resource_manager.spec_tree_manager + + logits = self.draft_model.forward(input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + return_context_logits=True) + batch_size = attn_metadata.num_seqs + vocab_size = logits.shape[-1] + logits = logits[spec_metadata.gather_ids] # [batch_size, vocab_size] + + # new_draft_tokens: [batch_size * dynamic_tree_max_topK] + # new_draft_scores: [batch_size * dynamic_tree_max_topK] + new_draft_tokens, new_draft_scores = self.sample( + logits=logits, max_top_k=self.dynamic_tree_max_topK) + + cur_scores = self.update_draft_tokens_and_scores( + cur_draft_idx=0, + batch_size=batch_size, + new_draft_tokens=new_draft_tokens, + new_draft_scores=new_draft_scores, + previous_draft_scores=None, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata) + + return_draft_logits = None + with save_metadata_state(attn_metadata, spec_metadata): + batch_size = attn_metadata.num_seqs + + self.prepare_for_generation(attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + spec_tree_manager=spec_tree_manager, + position_ids=position_ids) + + for layer_idx in range(1, self.max_draft_len): + # input_ids: [batch_size * (max_total_draft_tokens + 1)] + # position_ids: [batch_size * (max_total_draft_tokens + 1)] + # logits: [batch_size * (max_total_draft_tokens + 1), vocab_size] + logits = self.draft_model.forward( + input_ids=self.draft_tokens_buffer[:batch_size, :].reshape( + -1), + position_ids=self.position_ids_buffer[:batch_size, :]. + reshape(-1), + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + return_context_logits=True) + + # new_draft_tokens: [batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK] + # new_draft_scores: [batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK] + new_draft_tokens, new_draft_scores = self.sample( + logits=logits, + max_top_k=spec_tree_manager.dynamic_tree_max_topK) + # Keep updating + cur_scores = self.update_draft_tokens_and_scores( + cur_draft_idx=layer_idx, + batch_size=batch_size, + new_draft_tokens=new_draft_tokens, + new_draft_scores=new_draft_scores, + previous_draft_scores=cur_scores, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata) + + if layer_idx == self.max_draft_len - 1: + # FIXME: Actually the logits is incorrect; we don't have compatibility with that yet. + return_draft_logits = logits + + # Resampling the final draft tokens + # real_draft_tokens: [batch_size, self.max_total_draft_tokens] + # topk_score_indices: [batch_size, self.max_total_draft_tokens] + real_draft_tokens, topk_score_indices = self.resampling_final_draft_tokens( + batch_size=batch_size) + + # return_new_draft_tokens: [max_total_draft_tokens, batch_size] + return_new_draft_tokens = torch.transpose(real_draft_tokens, 0, 1) + + # return_draft_logits: [batch_size, max_total_draft_tokens + 1, vocab_size] -> [max_total_draft_tokens, batch_size, vocab_size] + return_draft_logits = return_draft_logits.reshape( + batch_size, self.max_total_draft_tokens + 1, vocab_size) + return_draft_logits = torch.transpose(return_draft_logits[:, :-1, :], 0, + 1) + + assert return_new_draft_tokens.shape == (self.max_total_draft_tokens, + batch_size) + assert return_draft_logits.shape == (self.max_total_draft_tokens, + batch_size, vocab_size) + + return { + "new_draft_tokens": return_new_draft_tokens, + "draft_logits": return_draft_logits, + "dynamic_tree_buffers": { + "topk_score_indices": + topk_score_indices, + "history_draft_tokens_parent_buffer": + self.history_draft_tokens_parent_buffer[:batch_size, :] + } + } + + def sample(self, logits: torch.Tensor, max_top_k: int) -> torch.Tensor: + # TODO: inject the sampler here so we can support non-greedy + + # for draft_layer_idx == 0, logits is of shape [batch_size, vocab_size] + # for draft_layer_idx > 0, logits is of shape [batch_size * (max_total_draft_tokens + 1), vocab_size] + last_p = self.logsoftmax(logits) + topk_values, topk_indices = torch.topk( + last_p, k=max_top_k, dim=-1 + ) # [batch_size, max_top_k] or [batch_size * max_total_draft_tokens, max_top_k] + + tokens = topk_indices.reshape(-1) + scores = topk_values.reshape(-1) + + if hasattr(self.draft_model.model, "d2t"): + d2t = self.draft_model.model.d2t.data + tokens = tokens + d2t[tokens] + + return tokens, scores + + def update_draft_tokens_and_scores( + self, cur_draft_idx: int, batch_size: int, + new_draft_tokens: torch.Tensor, new_draft_scores: torch.Tensor, + previous_draft_scores: torch.Tensor, + attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata): + ''' + Args: + cur_draft_idx: int, already finished forward. + batch_size: int + new_draft_tokens: + when cur_draft_idx == 0: [batch_size * dynamic_tree_max_topK] + when cur_draft_idx > 0: [batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK] + previous_draft_scores: + when cur_draft_idx == 0: None + when cur_draft_idx > 0: [batch_size, dynamic_tree_max_topK] + ''' + ''' + What this function does: + 1) Update the scores (exclude the first drafter layer) + 2) Extract the real draft tokens this layer + 3) Save the draft tokens and scores to self.history_draft_tokens_buffer and self.history_score_buffer, respectively. + 4) Update the attn_metadata.spec_decoding_packed_mask for the subsequent drafter layer. + 5) Update the spec_metadata.hidden_states_read_indices for the subsequent drafter layer. + 6) Update the parent nodes of the next layer's new nodes in advance. + ''' + # After the first drafter layer, new_draft_tokens: [batch_size * dynamic_tree_max_topK] + # For other drafter layers, new_draft_tokens: [batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK] + if cur_draft_idx == 0: + assert new_draft_tokens.shape[0] == (batch_size * + self.dynamic_tree_max_topK) + assert new_draft_scores.shape[0] == (batch_size * + self.dynamic_tree_max_topK) + else: + assert new_draft_tokens.shape[0] == ( + batch_size * (self.max_total_draft_tokens + 1) * + self.dynamic_tree_max_topK) + assert new_draft_scores.shape[0] == ( + batch_size * (self.max_total_draft_tokens + 1) * + self.dynamic_tree_max_topK) + + if cur_draft_idx == 0: + # new_draft_tokens: [batch_size, self.dynamic_tree_max_topK] + # new_draft_scores: [batch_size, self.dynamic_tree_max_topK] + new_draft_tokens = new_draft_tokens.reshape( + batch_size, self.dynamic_tree_max_topK) + new_draft_scores = new_draft_scores.reshape( + batch_size, self.dynamic_tree_max_topK) + + # 2) & 3) Update draft tokens and scores buffer. + self.draft_tokens_buffer[:batch_size, :self. + dynamic_tree_max_topK] = new_draft_tokens[:, :] + self.history_draft_tokens_buffer[:batch_size, :self. + dynamic_tree_max_topK] = new_draft_tokens[:, :] + self.history_score_buffer[:batch_size, :self. + dynamic_tree_max_topK] = new_draft_scores[:, :] + + # 4) Update the attn_metadata.spec_decoding_packed_mask + attn_metadata.spec_decoding_packed_mask[:batch_size, :, :].fill_(0) + dummy_idx = torch.arange(self.dynamic_tree_max_topK, + dtype=torch.int32, + device='cuda') + packed_mask = torch.pow(2, + dummy_idx) # [self.dynamic_tree_max_topK] + attn_metadata.spec_decoding_packed_mask[:batch_size, :self. + dynamic_tree_max_topK, :] = packed_mask.unsqueeze( + 1) + + # 5) Update the attn_metadata.hidden_states_read_indices + ## Will be updated in the prepare_for_generation function. Because it will need the information of the old_write_indices and so on. + + # 6) Process the parent buffer. + self.history_draft_tokens_parent_buffer[:batch_size, :self. + dynamic_tree_max_topK] = -1 # Use -1 to represent the root node + # These selected nodes will expand into new nodes at the next layer. + # We update the parent nodes of these new nodes in advance. + parents_indices_for_next_layer_draft_tokens = torch.repeat_interleave( + torch.arange(0, + self.dynamic_tree_max_topK, + dtype=torch.int32, + device='cuda'), + self.dynamic_tree_max_topK, + dim=0 + ) # [self.dynamic_tree_max_topK * self.dynamic_tree_max_topK] + self.history_draft_tokens_parent_buffer[:batch_size, + self.dynamic_tree_max_topK: + self.dynamic_tree_max_topK + + self.dynamic_tree_max_topK * + self. + dynamic_tree_max_topK] = parents_indices_for_next_layer_draft_tokens + + return new_draft_scores # [batch_size, self.dynamic_tree_max_topK] + else: + # new_draft_tokens: [batch_size * (self.max_total_draft_tokens + 1) * self.dynamic_tree_max_topK] + # new_draft_scores: [batch_size * (self.max_total_draft_tokens + 1) * self.dynamic_tree_max_topK] + + new_draft_tokens = new_draft_tokens.reshape( + batch_size, (self.max_total_draft_tokens + 1), + self.dynamic_tree_max_topK) + new_draft_scores = new_draft_scores.reshape( + batch_size, (self.max_total_draft_tokens + 1), + self.dynamic_tree_max_topK) + + # We process 'self.max_total_draft_tokens + 1' draft tokens, but we only need specific draft tokens for each layer. + gather_draft_tokens_start_offset = (cur_draft_idx - + 1) * self.dynamic_tree_max_topK + gather_draft_tokens_end_offset = gather_draft_tokens_start_offset + self.dynamic_tree_max_topK + gather_new_draft_tokens = new_draft_tokens[:, + gather_draft_tokens_start_offset: + gather_draft_tokens_end_offset, :].reshape( + batch_size, self. + dynamic_tree_max_topK + * self. + dynamic_tree_max_topK + ) # [batch_size, self.dynamic_tree_max_topK * self.dynamic_tree_max_topK] + gather_new_draft_scores = new_draft_scores[:, + gather_draft_tokens_start_offset: + gather_draft_tokens_end_offset, :] # [batch_size, self.dynamic_tree_max_topK, self.dynamic_tree_max_topK] + + # 1) Update the scores with the previous layer's scores + assert previous_draft_scores.shape == (batch_size, + self.dynamic_tree_max_topK) + gather_new_draft_scores = gather_new_draft_scores + previous_draft_scores.unsqueeze( + 2 + ) # [batch_size, self.dynamic_tree_max_topK, self.dynamic_tree_max_topK] + gather_new_draft_scores = gather_new_draft_scores.reshape( + batch_size, + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK + ) # [batch_size, self.dynamic_tree_max_topK * self.dynamic_tree_max_topK] + + # 2) Extract the real draft tokens this layer, topk again. + # topk_values: [batch_size, self.dynamic_tree_max_topK], the output scores of this layer + # topk_indices: [batch_size, self.dynamic_tree_max_topK] + topk_values, topk_indices = torch.topk(gather_new_draft_scores, + k=self.dynamic_tree_max_topK, + dim=-1) + real_draft_tokens = torch.gather( + gather_new_draft_tokens, dim=1, + index=topk_indices) # [batch_size, self.dynamic_tree_max_topK] + write_back_real_draft_tokens_start_offset = cur_draft_idx * self.dynamic_tree_max_topK + write_back_real_draft_tokens_end_offset = write_back_real_draft_tokens_start_offset + self.dynamic_tree_max_topK + self.draft_tokens_buffer[:batch_size, + write_back_real_draft_tokens_start_offset: + write_back_real_draft_tokens_end_offset] = real_draft_tokens[:, :] + + # 3) Save the draft tokens and scores to self.history_draft_tokens_buffer and self.history_score_buffer. + write_history_start_offset = self.dynamic_tree_max_topK + ( + cur_draft_idx - + 1) * self.dynamic_tree_max_topK * self.dynamic_tree_max_topK + write_history_end_offset = write_history_start_offset + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK + self.history_draft_tokens_buffer[:batch_size, + write_history_start_offset: + write_history_end_offset] = gather_new_draft_tokens[:, :] + self.history_score_buffer[:batch_size, write_history_start_offset: + write_history_end_offset] = gather_new_draft_scores[:, :] + + # 4) Update the attn_metadata.spec_decoding_packed_mask, shape: [max_num_requests, max_total_draft_tokens + 1, math.ceil(max_total_draft_tokens + 1 / 32)] + selected_parents = topk_indices // self.dynamic_tree_max_topK # [batch_size, self.dynamic_tree_max_topK] + # For simplicity, we will only consider the case where math.ceil(max_total_draft_tokens + 1 / 32) == 1. + parents_packed_mask = torch.gather( + attn_metadata. + spec_decoding_packed_mask[:batch_size, + gather_draft_tokens_start_offset: + gather_draft_tokens_end_offset, :]. + squeeze(-1), + dim=1, + index=selected_parents + ) # [batch_size, self.dynamic_tree_max_topK] + child_packed_mask = torch.pow( + 2, + torch.arange(cur_draft_idx * self.dynamic_tree_max_topK, + cur_draft_idx * self.dynamic_tree_max_topK + + self.dynamic_tree_max_topK, + dtype=torch.int32, + device='cuda')) # [self.dynamic_tree_max_topK] + child_packed_mask = child_packed_mask + parents_packed_mask # [batch_size, self.dynamic_tree_max_topK] + attn_metadata.spec_decoding_packed_mask[:batch_size, + write_back_real_draft_tokens_start_offset: + write_back_real_draft_tokens_end_offset, :] = child_packed_mask.unsqueeze( + -1) + + # 5) Update the spec_metadata.hidden_states_read_indices, shape: [max_num_tokens], but we save as [:batch_size * (max_total_draft_tokens + 1)] + selected_parents_write_indices = selected_parents + gather_draft_tokens_start_offset # [batch_size, self.dynamic_tree_max_topK] + hidden_states_write_indices_view = spec_metadata.hidden_states_write_indices[:batch_size * ( + self.max_total_draft_tokens + + 1)] # [batch_size, self.max_total_draft_tokens + 1] + hidden_states_write_indices_view = hidden_states_write_indices_view.view( + batch_size, self.max_total_draft_tokens + + 1) # [batch_size, self.max_total_draft_tokens + 1] + child_hidden_states_read_indices = torch.gather( + hidden_states_write_indices_view, + dim=1, + index=selected_parents_write_indices + ) # [batch_size, self.dynamic_tree_max_topK] + + hidden_states_read_indices_view = spec_metadata.hidden_states_read_indices[:batch_size * ( + self.max_total_draft_tokens + + 1)] # [batch_size * (max_total_draft_tokens + 1)] + hidden_states_read_indices_view = hidden_states_read_indices_view.view( + batch_size, self.max_total_draft_tokens + + 1) # [batch_size, self.max_total_draft_tokens + 1] + hidden_states_read_indices_view[:, + write_back_real_draft_tokens_start_offset: + write_back_real_draft_tokens_end_offset] = child_hidden_states_read_indices[:, :] + + if cur_draft_idx < self.max_draft_len - 1: + # 6) Update the parent nodes of the next layer's new nodes in advance. + # We need to know next layer's draft tokens are expanded from which parents. + # i.e. calculate the index of the selected draft tokens in the entire tree (including all historical nodes, for subsequent reconstruction of the entire tree). + parents_indices = topk_indices + ( + self.dynamic_tree_max_topK + (cur_draft_idx - 1) * + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK + ) # [batch_size, self.dynamic_tree_max_topK] + parents_indices = torch.repeat_interleave( + parents_indices, self.dynamic_tree_max_topK, dim=1 + ) # [batch_size, self.dynamic_tree_max_topK * self.dynamic_tree_max_topK] + next_layer_draft_tokens_start_offset = self.dynamic_tree_max_topK + cur_draft_idx * self.dynamic_tree_max_topK * self.dynamic_tree_max_topK + next_layer_draft_tokens_end_offset = next_layer_draft_tokens_start_offset + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK + self.history_draft_tokens_parent_buffer[:batch_size, + next_layer_draft_tokens_start_offset: + next_layer_draft_tokens_end_offset] = parents_indices[:, :] + + return topk_values # [batch_size, self.dynamic_tree_max_topK] + + def resampling_final_draft_tokens(self, batch_size: int): + ''' + Restruct the tree based on the self.history_draft_tokens_buffer, self.history_draft_tokens_parent_buffer and self.history_score_buffer. + ''' + # self.history_score_buffer[:batch_size, :] shape: [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)] + topk_score_indices = torch.topk( + self.history_score_buffer[:batch_size, :], + k=self.max_total_draft_tokens, + dim=-1).indices + topk_score_indices = torch.sort( + topk_score_indices + ).values # [batch_size, self.max_total_draft_tokens] + + # The final output draft tokens + real_draft_tokens = torch.gather( + self.history_draft_tokens_buffer[:batch_size, :], + dim=1, + index=topk_score_indices + ) # [batch_size, self.max_total_draft_tokens] + + # self.history_draft_tokens_parent_buffer[:batch_size, :] shape: [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)] + # real_draft_tokens_parents = torch.gather(self.history_draft_tokens_parent_buffer[:batch_size, :], dim=1, index=topk_score_indices) # [batch_size, self.max_total_draft_tokens] + + # return real_draft_tokens, topk_score_indices, real_draft_tokens_parents + return real_draft_tokens, topk_score_indices + + def prepare_for_generation(self, attn_metadata: AttentionMetadata, + spec_metadata: SpecMetadata, + spec_tree_manager: SpecTreeManager, + position_ids: torch.Tensor): + ''' + Setup the attn_metadata and spec_metadata for the subsequent drafter layer. Therefore, only call once after the first drafter layer. + To the subsequent drafter layer, we take 'max_total_drafter_tokens + 1' draft tokens as input. + Only the first part of the draft tokens is meaningful, and the later tokens can be regarded as padding + until we continuously write the correct value. + + This introduces additional redundant computation, but it makes it compatible with cuda graphs. + + What we need to prepare are: + 1) position_ids + 2) attn_metadata + 2.1) kv_lens_cuda + 2.2) _seq_lens, _seq_lens_cuda + 2.3) host_request_types + 2.4) num_contexts + 2.5) use_spec_decoding + 2.6) spec_decoding_position_offsets + 2.7) spec_decoding_packed_mask + 2.8) spec_decoding_generation_lengths + 3) spec_metadata + 3.1) num_tokens + 3.2) hidden_states_read_indices, hidden_states_write_indices + 3.3) is_first_draft + ''' + batch_size = attn_metadata.num_seqs + + # 1) Prepare the position_ids + num_accepted_draft_tokens = spec_metadata.num_accepted_draft_tokens[: + batch_size] + seq_lens = attn_metadata.seq_lens_cuda[:batch_size] + # Calculate last accepted token indices + last_tokens_idx = torch.cumsum( + seq_lens, dim=0, + dtype=torch.long) - seq_lens + num_accepted_draft_tokens + position_start_idx = position_ids[0, + last_tokens_idx] + 1 # [batch_size] + self.position_ids_buffer[:batch_size, :] = position_start_idx.unsqueeze( + 1) + spec_tree_manager.spec_dec_position_offsets_for_drafter_model[ + 0, :].unsqueeze(0) # [batch_size, max_total_draft_tokens + 1] + + # 2) Prepare the attn_metadata + ## 2.1) kv_lens_cuda + attn_metadata.kv_lens_cuda[: + batch_size] -= seq_lens - num_accepted_draft_tokens - 1 + attn_metadata.kv_lens_cuda[:batch_size] += ( + self.max_total_draft_tokens + 1) + + ## 2.2) _seq_lens, _seq_lens_cuda + attn_metadata._seq_lens[:batch_size].fill_(self.max_total_draft_tokens + + 1) + attn_metadata._seq_lens_cuda[:batch_size].fill_( + self.max_total_draft_tokens + 1) + attn_metadata.on_update() + + ## 2.3) host_request_types + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + + ## 2.4) num_contexts + attn_metadata.num_contexts = 0 + + ## 2.5) use_spec_decoding + attn_metadata.use_spec_decoding = True + + ## 2.6) spec_decoding_position_offsets + ### attn_metadata.spec_decoding_position_offsets: [max_num_requests, max_total_draft_tokens + 1] + attn_metadata.spec_decoding_position_offsets[:batch_size, :] = spec_tree_manager.spec_dec_position_offsets_for_drafter_model[ + 0, :].unsqueeze(0) # [batch_size, max_total_draft_tokens + 1] + + ## 2.7) spec_decoding_packed_mask + ### NOTE: spec_decoding_packed_mask will be updated for each drafter layer in 'update_draft_tokens_and_scores' + # attn_metadata.spec_decoding_packed_mask[:batch_size, :, :].fill_(0) + # dummy_idx = torch.arange(self.dynamic_tree_max_topK, dtype=torch.int32, device='cuda') + # packed_mask = torch.pow(2, dummy_idx + 1) - 1 + # attn_metadata.spec_decoding_packed_mask[:batch_size, :self.dynamic_tree_max_topK, :] = packed_mask.unsqueeze(1) + + ## 2.8) spec_decoding_generation_lengths + ### attn_metadata.spec_decoding_generation_lengths: [max_num_requests] + attn_metadata.spec_decoding_generation_lengths[: + batch_size] = self.max_total_draft_tokens + 1 + + # 3) Update spec_metadata + ## 3.1) num_tokens + spec_metadata.num_tokens = batch_size * (self.max_total_draft_tokens + + 1) + ## 3.2) hidden_states_read_indices, hidden_states_write_indices + old_write_indices = spec_metadata.hidden_states_write_indices + start_idx = old_write_indices[ + last_tokens_idx] # [batch_size], already take the accepted tokens into account. + + ### spec_metadata.hidden_states_read_indices: [max_num_tokens], but we save as [:batch_size * (max_total_draft_tokens + 1)] + ### NOTE: spec_metadata.hidden_states_read_indices needs to be updated for each drafter layer + hidden_states_read_offset = start_idx.unsqueeze(1).repeat( + 1, self.max_total_draft_tokens + + 1) # [batch_size, max_total_draft_tokens + 1] + spec_metadata.hidden_states_read_indices[:batch_size * ( + self.max_total_draft_tokens + + 1)] = hidden_states_read_offset.reshape(-1) + + ### spec_metadata.hidden_states_write_indices: [max_num_tokens], but we save as [:batch_size * (max_total_draft_tokens + 1)] + hidden_states_write_offset = torch.arange( + 1, self.max_total_draft_tokens + 1 + 1, + device=position_ids.device).unsqueeze(0).repeat( + batch_size, 1) + start_idx.unsqueeze(1) + spec_metadata.hidden_states_write_indices[:batch_size * ( + self.max_total_draft_tokens + + 1)] = hidden_states_write_offset.reshape(-1) + + ## 3.3) is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False + + return diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 44fed8fddad..5504a515476 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -65,7 +65,8 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype, self.spec_tree_manager = None if isinstance(config, - EagleDecodingConfig) and config.eagle_choices is not None: + EagleDecodingConfig) and (config.eagle_choices is not None + or config.use_dynamic_tree): self.spec_tree_manager = SpecTreeManager( max_num_requests=self.max_num_requests, use_dynamic_tree=config.use_dynamic_tree, diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 8727ded6d68..346855cd43d 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -8,6 +8,7 @@ from tensorrt_llm._utils import nvtx_range from tensorrt_llm.logger import logger +from ...llmapi.llm_args import EagleDecodingConfig from ..attention_backend.trtllm import TrtllmAttention from ..pyexecutor.guided_decoder import GuidedDecoder from ..pyexecutor.handle_logits import HandleLogits @@ -617,6 +618,22 @@ def process_static_draft_outputs(self, outputs: dict[str, torch.Tensor] draft_tokens_host = outputs[1].host.new_tokens outputs[1].sampler_event.synchronize() + # Get the dynamic tree buffers + topk_score_indices, history_draft_tokens_parent_buffer = None, None + spec_tree_manager = None + if isinstance( + self.spec_config, + EagleDecodingConfig) and self.spec_config.use_dynamic_tree: + dynamic_tree_buffers = outputs["dynamic_tree_buffers"] + topk_score_indices = dynamic_tree_buffers["topk_score_indices"].cpu( + ) # [batch_size, self.max_total_draft_tokens] + history_draft_tokens_parent_buffer = dynamic_tree_buffers[ + "history_draft_tokens_parent_buffer"].cpu( + ) # [batch_size, dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)] + + spec_tree_manager = self.spec_resource_manager.spec_tree_manager + assert spec_tree_manager is not None + for req_idx, req in enumerate(draft_batch.all_requests()): target_model_req = self.req_id_to_old_request[req.py_request_id] if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: @@ -633,6 +650,90 @@ def process_static_draft_outputs(self, outputs: dict[str, torch.Tensor] if self.disable_overlap_scheduler: target_model_req.py_draft_logits = torch.stack(py_draft_logits) + if topk_score_indices is not None and history_draft_tokens_parent_buffer is not None: + self.reconstruct_dynamic_tree( + req_idx, topk_score_indices[req_idx], + history_draft_tokens_parent_buffer[req_idx], + spec_tree_manager) + + def reconstruct_dynamic_tree( + self, req_idx: int, cur_topk_score_indices: torch.Tensor, + cur_history_draft_tokens_parent_buffer: torch.Tensor, + spec_tree_manager: "SpecTreeManager") -> None: + ''' + Reconstruct the dynamic tree based on the current topk score indices and draft tokens parents. + Update the spec_tree_manager buffers: + - spec_tree_manager.eagle_paths + - spec_tree_manager.spec_dec_mask_matrix + - spec_tree_manager.spec_dec_packed_mask + - spec_tree_manager.spec_dec_position_offsets + All these buffers will be used to update the attn_metadata for the target model's tree attention. + And the verification after forward the target model. + Args: + req_idx: The index of the request. + cur_topk_score_indices: The topk score indices for the final output draft tokens. + The indices will take all history draft tokens into account. + shape: [self.max_total_draft_tokens] + cur_history_draft_tokens_parent_buffer: The draft tokens' parent indices. + The indices will also take all history draft tokens into account. + shape: [dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1)] + spec_tree_manager: The spec tree manager. + ''' + + topk_score_indices_list = cur_topk_score_indices.tolist() + cur_history_draft_tokens_parent_buffer = cur_history_draft_tokens_parent_buffer.tolist( + ) + + # 1) Generate the mapping + # Because these indices will take all history draft tokens into account, + # We need to generate a mapping between the current node index and the index of the final output tree. + idx_mapping = {} + idx_mapping[-1] = 0 # root node + for i in range(len(cur_topk_score_indices)): + idx_mapping[topk_score_indices_list[ + i]] = i + 1 # shift by 1 for the root node + + # 2) Update the eagle_paths + spec_tree_manager.eagle_paths[req_idx].fill_(-1) + # root node + spec_tree_manager.eagle_paths[req_idx][0][0] = 0 + for path_idx, cur_node_idx in enumerate(topk_score_indices_list): + tmp_path = [cur_node_idx] + parent_idx = cur_history_draft_tokens_parent_buffer[cur_node_idx] + tmp_path = [parent_idx] + tmp_path + while parent_idx != -1: + parent_idx = cur_history_draft_tokens_parent_buffer[parent_idx] + tmp_path = [parent_idx] + tmp_path + + assert len(tmp_path) >= 2 and len( + tmp_path) <= self.max_draft_len + 1 + # Map the indices + tmp_map_path = [idx_mapping[idx] for idx in tmp_path] + tmp_map_path += [-1] * (self.max_draft_len + 1 - len(tmp_map_path) + ) # pad with -1 + spec_tree_manager.eagle_paths[req_idx, + path_idx + 1, :] = torch.tensor( + tmp_map_path, dtype=torch.int32) + + # 3) Update the spec_dec_mask_matrix + spec_tree_manager.compute_spec_dec_mask_matrix(req_idx) + + # 4)Update the spec_dec_packed_mask + spec_tree_manager.compute_spec_dec_packed_mask( + spec_tree_manager.spec_dec_mask_matrix, + spec_tree_manager.spec_dec_packed_mask) + + # 5) Update the spec_dec_position_offsets + spec_tree_manager.spec_dec_position_offsets[req_idx, :] = 0 + start_idx = 0 + for i in range(self.max_draft_len + 1): + tmp_set = set(spec_tree_manager.eagle_paths[req_idx, :, i].tolist()) + tmp_set.discard(-1) + num_nodes_this_layer = len(tmp_set) + spec_tree_manager.spec_dec_position_offsets[ + req_idx, start_idx:start_idx + num_nodes_this_layer] = i + start_idx += num_nodes_this_layer + def process_dynamic_draft_outputs( self, outputs: Any, @@ -897,11 +998,8 @@ def generate_draft_tokens_with_overlap( self.previous_scheduled_batch = scheduled_batch @nvtx_range("prepare_draft_tokens") - def prepare_draft_tokens( - self, - scheduled_requests: ScheduledRequests, - resource_manager: Optional[ResourceManager] = None, - ) -> None: + def prepare_draft_tokens(self, scheduled_requests: ScheduledRequests, + resource_manager: ResourceManager) -> None: """ Prepare draft tokens for the scheduled requests. diff --git a/tensorrt_llm/_torch/speculative/spec_tree_manager.py b/tensorrt_llm/_torch/speculative/spec_tree_manager.py index 5b7a42a24c3..63bcdba0163 100644 --- a/tensorrt_llm/_torch/speculative/spec_tree_manager.py +++ b/tensorrt_llm/_torch/speculative/spec_tree_manager.py @@ -35,39 +35,44 @@ class SpecTreeManager: # shape: [num_trees, max_total_draft_tokens + 1], device tensor. spec_dec_position_offsets: torch.Tensor = None - # TODO: Optimized together with the subsequent dynamic tree. - # Auxiliary buffers for the static tree. + ############################ Auxiliary buffers for the static tree. ############################ # Considering that the static tree does not modify the tree structure during inference, we can calculate some buffers in advance. # NOTE: Most of these buffers are introduced due to limitations of XQA: # With tree attention, XQA cannot simply take the tokens to be processed in the next round as input. Instead, it needs to take ALL of their parent nodes as input. # This incurs additional computation, but it is unavoidable. # NOTE: The reason why most of these auxiliary buffers are with `len == max_draft_len - 1` is that: we do not need to prepare specific input data for the first draft layer. - # The top k value for each draft layer. Device tensor. top_k_list_cuda: list[torch.Tensor] = None # The max top k value for all draft layers. Which is used for torch.topk and cuda graph. max_top_k = -1 - # Gather the required draft tokens from all currently generated draft tokens as the input of the next draft layer. + # Gather the required draft tokens among the 'max_total_draft_tokens + 1' tokens. # Only the nodes has child(s) this layer and all their parents nodes will be gathered. - # Device tensor. len(tokens_gather_idx) == max_draft_len - 1. Each element is a tensor with shape [num_tokens_for_next_layer]. tokens_gather_idx_for_drafter_model: list[torch.Tensor] = None # Gather the required logits from all currently generated logits. - # Device tensor. len(tokens_gather_idx) == max_draft_len - 1. logits_gather_idx: list[torch.Tensor] = None # The packed mask for the drafter model's attention (i.e., xqa). + # shape: [1, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32)], device tensor. spec_dec_packed_mask_for_drafter_model: torch.Tensor = None # The read indices offset for the drafter model. + # shape: [max_total_draft_tokens + 1], device tensor. hidden_states_read_indices_offset_for_drafter_model: torch.Tensor = None # The write back start indices for the drafter tokens between different draft layers. + # shape: [max_draft_len + 1], device tensor. draft_tokens_indices_cumsum: torch.Tensor = None + ############################ Auxiliary buffers for the dynamic tree. ############################ + # For the dynamic tree, before reconstructing the tree, + # the draft token offsets of each draft layer are fixed, which we can set in advance. + # shape: [1, max_total_draft_tokens + 1], device tensor. + spec_dec_position_offsets_for_drafter_model: torch.Tensor = None + def __init__(self, max_num_requests: int, use_dynamic_tree: bool, max_total_draft_tokens: int, max_draft_len: int, eagle_choices: [List[List[int]]], dynamic_tree_max_topK: int): @@ -118,12 +123,21 @@ def __init__(self, max_num_requests: int, use_dynamic_tree: bool, def init_tree_info_for_dynamic_tree(self): # For the dynamic tree # To the internal layer, the number of nodes is the same as the dynamic_tree_max_topK. - self.top_k_list = [ - torch.ones(self.dynamic_tree_max_topK, - dtype=torch.int32, - device='cpu', - pin_memory=True) * self.dynamic_tree_max_topK - ] + self.spec_dec_position_offsets_for_drafter_model = torch.zeros( + (1, self.max_total_draft_tokens + 1), + dtype=torch.int32, + device='cuda', + ) + tmp_position_offsets = [] + for i in range(self.max_draft_len): + tmp_position_offsets.extend([i] * self.dynamic_tree_max_topK) + self.spec_dec_position_offsets_for_drafter_model[ + 0, :len(tmp_position_offsets)].copy_( + torch.tensor(tmp_position_offsets, + dtype=torch.int32, + device='cuda'), + non_blocking=True, + ) # For the static tree def init_tree_info_for_static_tree(self): @@ -174,9 +188,7 @@ def init_tree_info_for_static_tree(self): pin_memory=True)) # 6) Compute the spec decoding according to the eagle_paths for the target model - for i, path in enumerate(self.eagle_paths[0]): - indices = path[path > -1] - self.spec_dec_mask_matrix[0][i, indices] = 1 + self.compute_spec_dec_mask_matrix(0) self.compute_spec_dec_packed_mask(self.spec_dec_mask_matrix, self.spec_dec_packed_mask) @@ -255,22 +267,19 @@ def init_tree_info_for_static_tree(self): # Get the eagle_paths def get_eagle_paths(self, tree_idx=0): - if self.use_dynamic_tree: - self.eagle_paths[tree_idx].fill_(-1) - # If dynamic tree, return the eagle_paths according to the mask. - for i in range(self.max_total_draft_tokens + 1): - self.eagle_paths[tree_idx][:, i, :] = self.spec_dec_mask_matrix[ - tree_idx][i, :].nonzero() - return self.eagle_paths[tree_idx] - else: - # If static tree, return the prepared eagle_paths. These paths are immutable. - return self.eagle_paths[0] + return self.eagle_paths[tree_idx] # Get the topK list for the specific draft layer def get_top_k_list(self, draft_layer_id): assert draft_layer_id >= 0 return self.top_k_list[draft_layer_id] + # Compute the spec decoding mask matrix according to the eagle_paths + def compute_spec_dec_mask_matrix(self, tree_idx=0): + for i, path in enumerate(self.eagle_paths[tree_idx]): + indices = path[path > -1] + self.spec_dec_mask_matrix[tree_idx][i, indices] = 1 + # Compute the packed mask according to the mask matrix def compute_spec_dec_packed_mask(self, mask_matrix, packed_mask): # mask_matrix: shape: [num_trees, max_total_draft_tokens + 1, max_total_draft_tokens + 1] @@ -315,9 +324,14 @@ def compute_spec_dec_packed_mask(self, mask_matrix, packed_mask): # Print the tree info def dump_tree_info(self): print(f"TopK list: {self.top_k_list}") - if not self.use_dynamic_tree: + if self.use_dynamic_tree: + print(f"Dynamic max top k: {self.dynamic_tree_max_topK}") + print( + f"Spec dec position offsets for drafter model: {self.spec_dec_position_offsets_for_drafter_model}" + ) + else: print(f"Max top k list cuda: {self.max_top_k}") - print(f"Static tree: {self.eagle_paths}") + print(f"Eagle paths: {self.eagle_paths}") print(f"Index mapping set: {self.index_mapping_set}") print(f"Nodes list per layer: {self.nodes_list_per_layer}") print( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 616f20488c1..f9079a1319e 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -802,11 +802,25 @@ def __init__(self, **kwargs): self.max_total_draft_tokens = len(self.eagle_choices) # Dynamic tree logic - if self.use_dynamic_tree: + if self.use_dynamic_tree or self.dynamic_tree_max_topK is not None: + self.use_dynamic_tree = True + assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK is required for dynamic tree" assert self.eagle_choices is None, "If use_dynamic_tree is True, eagle_choices should be None" - assert self.max_draft_len is not None and self.max_draft_len > 0, "max_draft_len should be provided, which indicates the number of drafter layers" - assert self.dynamic_tree_max_topK is not None and self.dynamic_tree_max_topK > 0, "dynamic_tree_max_topK should be provided, which indicates the number of nodes to expand each time" - assert self.max_total_draft_tokens is not None and self.max_total_draft_tokens > 0, "max_total_draft_tokens should be provided, which indicates the total nodes of the final draft tree. (exclude the root node)" + total_history_draft_tokens = self.dynamic_tree_max_topK + self.dynamic_tree_max_topK * self.dynamic_tree_max_topK * ( + self.max_draft_len - 1) + default_max_total_draft_tokens = self.dynamic_tree_max_topK * self.max_draft_len + + if self.max_total_draft_tokens is None: + self.max_total_draft_tokens = default_max_total_draft_tokens + logger.warning( + f"max_total_draft_tokens is not provided, use the default value {default_max_total_draft_tokens} (default_max_total_draft_tokens = dynamic_tree_max_topK * max_draft_len)" + ) + else: + assert self.max_total_draft_tokens <= total_history_draft_tokens and self.max_total_draft_tokens >= default_max_total_draft_tokens, f"max_total_draft_tokens should be between {default_max_total_draft_tokens} and {total_history_draft_tokens}" + + # Linear tree + if self.max_total_draft_tokens is None: + self.max_total_draft_tokens = self.max_draft_len @classmethod def from_dict(cls, data: dict): diff --git a/tests/integration/defs/accuracy/test_llm_api.py b/tests/integration/defs/accuracy/test_llm_api.py index e019572ada0..4b1410b4b33 100644 --- a/tests/integration/defs/accuracy/test_llm_api.py +++ b/tests/integration/defs/accuracy/test_llm_api.py @@ -496,7 +496,8 @@ class TestEagle2Vicuna_7B_v1_3(LlmapiAccuracyTestHarness): MODEL_PATH = f"{llm_models_root()}/vicuna-7b-v1.3" speculative_config = EagleDecodingConfig( - max_draft_len=63, + max_draft_len=4, + max_total_draft_tokens=63, speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3", num_eagle_layers=4, max_non_leaves_per_layer=10, diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index c1fe7c90895..426e845a72b 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -2073,9 +2073,8 @@ def test_ptp_quickstart_advanced_eagle3(llm_root, llm_venv, model_name, ("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct", "EAGLE3-LLaMA3.1-Instruct-8B"), ]) -def test_draft_token_tree_quickstart_advanced_eagle3(llm_root, llm_venv, - model_name, model_path, - eagle_model_path): +def test_static_draft_token_tree_quickstart_advanced_eagle3( + llm_root, llm_venv, model_name, model_path, eagle_model_path): print(f"Testing {model_name}.") example_root = Path(os.path.join(llm_root, "examples", "llm-api")) with tempfile.NamedTemporaryFile(mode='w+t', @@ -2104,6 +2103,43 @@ def test_draft_token_tree_quickstart_advanced_eagle3(llm_root, llm_venv, _check_mem_usage(running_log, [27, 0, 0, 0]) +@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [ + ("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct", + "EAGLE3-LLaMA3.1-Instruct-8B"), +]) +def test_dynamic_draft_token_tree_quickstart_advanced_eagle3( + llm_root, llm_venv, model_name, model_path, eagle_model_path): + print(f"Testing {model_name}.") + example_root = Path(os.path.join(llm_root, "examples", "llm-api")) + with tempfile.NamedTemporaryFile(mode='w+t', + suffix=f".{model_name}.log", + dir="./", + delete=True, + delete_on_close=True) as running_log: + llm_venv.run_cmd([ + str(example_root / "quickstart_advanced.py"), + "--prompt", + "You are a good assistant. Please tell me the capital of France is", + "--spec_decode_max_draft_len", + "3", + "--spec_decode_algo", + "eagle3", + "--model_dir", + f"{llm_models_root()}/{model_path}", + "--draft_model_dir", + f"{llm_models_root()}/{eagle_model_path}", + "--disable_kv_cache_reuse", + "--disable_overlap_scheduler", + "--use_dynamic_tree", + "--dynamic_tree_max_topK", + "3", + "--max_total_draft_tokens", + "12", + ], + stdout=running_log) + _check_mem_usage(running_log, [27, 0, 0, 0]) + + @pytest.mark.parametrize("model_name,model_path", [ ("Llama-3.1-8B-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct"), ]) diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index cad865087be..759c933c60c 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -19,7 +19,7 @@ from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests -from tensorrt_llm._torch.speculative.utils import SpecDecodingTensor +from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm._utils import get_sm_version from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -492,20 +492,19 @@ def run_forward(input_ids, position_ids, attn_metadata): ], dtype=torch.int, device=device) - spec_decoding_position_offsets = torch.tensor([ + spec_decoding_position_offsets = torch.tensor([[ 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3 - ], + ]], dtype=torch.int, device=device) - spec_decoding_packed_mask = torch.tensor( - [ - 1, 3, 5, 9, 17, 33, 65, 129, 257, 513, 1025, 2051, 4099, 8195, - 16387, 32771, 65541, 131077, 262153, 524297, 1048593, 2097169, - 4194321, 8388641, 16842757 - ], - dtype=torch.int, - device=device).unsqueeze(0).unsqueeze(2) + spec_decoding_packed_mask = torch.tensor([[ + 1, 3, 5, 9, 17, 33, 65, 129, 257, 513, 1025, 2051, 4099, 8195, + 16387, 32771, 65541, 131077, 262153, 524297, 1048593, 2097169, + 4194321, 8388641, 16842757 + ]], + dtype=torch.int, + device=device).unsqueeze(-1) num_cached_tokens_per_seq = [input_ids.size(-1)] is_spec_decoding_enabled = True @@ -532,9 +531,17 @@ def run_forward(input_ids, position_ids, attn_metadata): is_spec_dec_dynamic_tree=is_spec_dec_dynamic_tree, num_heads_per_kv=num_heads_per_kv, ) - spec_decoding_tensor = SpecDecodingTensor( - position_offsets=spec_decoding_position_offsets, - packed_mask=spec_decoding_packed_mask) + # Use the spec_tree_manager to save spec_decoding_tensor for testing + spec_tree_manager = SpecTreeManager( + max_num_requests=batch_size, + use_dynamic_tree=True, + max_draft_len=3, + max_total_draft_tokens=max_total_draft_tokens, + eagle_choices=None, + dynamic_tree_max_topK=3, + ) + spec_tree_manager.spec_dec_position_offsets = spec_decoding_position_offsets + spec_tree_manager.spec_dec_packed_mask = spec_decoding_packed_mask attn_metadata_gen_phase_0.prepare() attn_metadata_gen_phase_0.update_spec_dec_param( @@ -544,8 +551,9 @@ def run_forward(input_ids, position_ids, attn_metadata): is_spec_dec_tree=is_spec_dec_tree, max_draft_len=max_total_draft_tokens, max_total_draft_tokens=max_total_draft_tokens, + is_target_model=True, model_is_wrapped=False, - spec_decoding_tensor=spec_decoding_tensor, + spec_tree_manager=spec_tree_manager, ) gen_position_ids_0 = [ @@ -587,10 +595,19 @@ def run_forward(input_ids, position_ids, attn_metadata): [gen_input_ids_1.size(-1)], dtype=torch.int) attn_metadata_gen_phase_0.kv_cache_params.num_cached_tokens_per_seq = num_cached_tokens_per_seq_1 + spec_tree_manager.spec_dec_position_offsets = torch.tensor( + [[0, 1]], dtype=torch.int, device=device) + spec_tree_manager.spec_dec_packed_mask = torch.tensor([[[1], [3]]], + dtype=torch.int, + device=device) + spec_tree_manager.spec_dec_generation_lengths = torch.tensor( + [2], dtype=torch.int, device=device) + attn_metadata_gen_phase_0.spec_decoding_position_offsets = None attn_metadata_gen_phase_0.spec_decoding_packed_mask = None attn_metadata_gen_phase_0.spec_decoding_generation_lengths = None attn_metadata_gen_phase_0.prepare() + attn_metadata_gen_phase_0.update_spec_dec_param( batch_size=batch_size, is_spec_decoding_enabled=is_spec_decoding_enabled, @@ -599,7 +616,11 @@ def run_forward(input_ids, position_ids, attn_metadata): is_spec_dec_dynamic_tree=False, max_draft_len=gen_input_ids_1.size(-1) - 1, max_total_draft_tokens=gen_input_ids_1.size(-1) - 1, - model_is_wrapped=False) + model_is_wrapped=False, + is_target_model=True, + spec_tree_manager=spec_tree_manager, + ) + attn_metadata_gen_phase_0.spec_decoding_generation_lengths = spec_tree_manager.spec_dec_generation_lengths gen_position_ids_1 = [ torch.full( @@ -646,6 +667,14 @@ def run_forward(input_ids, position_ids, attn_metadata): attn_metadata_ref.spec_decoding_packed_mask = None attn_metadata_ref.spec_decoding_generation_lengths = None attn_metadata_ref.prepare() + + spec_tree_manager.spec_dec_position_offsets = torch.tensor( + [[0, 1, 2, 3]], dtype=torch.int, device=device) + spec_tree_manager.spec_dec_packed_mask = torch.tensor( + [[[1], [3], [7], [15]]], dtype=torch.int, device=device) + spec_tree_manager.spec_dec_generation_lengths = torch.tensor( + [4], dtype=torch.int, device=device) + attn_metadata_ref.update_spec_dec_param( batch_size=batch_size, is_spec_decoding_enabled=is_spec_decoding_enabled, @@ -654,7 +683,11 @@ def run_forward(input_ids, position_ids, attn_metadata): is_spec_dec_dynamic_tree=False, max_draft_len=gen_input_ids_ref.size(-1) - 1, max_total_draft_tokens=gen_input_ids_ref.size(-1) - 1, - model_is_wrapped=False) + model_is_wrapped=False, + is_target_model=True, + spec_tree_manager=spec_tree_manager, + ) + attn_metadata_ref.spec_decoding_generation_lengths = spec_tree_manager.spec_dec_generation_lengths gen_position_ids_ref = [ torch.full((gen_input_ids_ref.size(-1), ), diff --git a/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py b/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py index 4a75e1b6f4a..7110535b32d 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py +++ b/tests/unittest/_torch/speculative/test_draft_token_prepare_for_generation.py @@ -7,8 +7,12 @@ from utils.llm_data import llm_models_root from tensorrt_llm._torch.attention_backend.trtllm import TrtllmAttentionMetadata -from tensorrt_llm._torch.speculative.drafting_loops import TreeDraftingLoopWrapper +from tensorrt_llm._torch.speculative.drafting_loops import ( + DynamicTreeDraftingLoopWrapper, + StaticTreeDraftingLoopWrapper, +) from tensorrt_llm._torch.speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata +from tensorrt_llm._torch.speculative.model_drafter import ModelDrafter from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm.llmapi import EagleDecodingConfig @@ -21,6 +25,7 @@ def __init__(self): self.model_config = None self.config = None self.model = {} + self.model_is_wrapped = True def forward(self, *args, **kwargs) -> torch.Tensor: pass @@ -138,8 +143,8 @@ def run_test( input_hidden_states_read_indices # set from input ) - # 3) Create TreeDraftingLoopWrapper - tree_drafting_loop_wrapper = TreeDraftingLoopWrapper( + # 3) Create StaticTreeDraftingLoopWrapper + static_tree_drafting_loop_wrapper = StaticTreeDraftingLoopWrapper( max_batch_size=max_batch_size, max_draft_len=max_draft_len, max_total_draft_tokens=max_total_draft_tokens, @@ -147,7 +152,7 @@ def run_test( ) # 3) Run the function - tree_drafting_loop_wrapper.prepare_for_generation( + static_tree_drafting_loop_wrapper.prepare_for_generation( attn_metadata=attn_metadata, spec_metadata=spec_metadata, spec_tree_manager=spec_tree_manager, @@ -156,7 +161,8 @@ def run_test( # Compare input_ids and position_ids print( - f"tree_drafting_loop_wrapper.position_ids_buffer: {tree_drafting_loop_wrapper.position_ids_buffer}, \ + f"static_tree_drafting_loop_wrapper.position_ids_buffer: \ + {static_tree_drafting_loop_wrapper.position_ids_buffer}, \ ref_output_position_ids: {ref_position_ids}" ) @@ -208,7 +214,7 @@ def run_test( ref_spec_metadata.hidden_states_write_indices: {ref_spec_metadata['hidden_states_write_indices']}" ) - assert torch.all(tree_drafting_loop_wrapper.position_ids_buffer == ref_position_ids) + assert torch.all(static_tree_drafting_loop_wrapper.position_ids_buffer == ref_position_ids) assert torch.all(attn_metadata.kv_lens_cuda == ref_attn_metadata["kv_lens_cuda"]) assert torch.all(attn_metadata._seq_lens == ref_attn_metadata["_seq_lens"]) assert torch.all(attn_metadata._seq_lens_cuda == ref_attn_metadata["_seq_lens_cuda"]) @@ -454,5 +460,895 @@ def run_test( ) +def test_dynamic_tree_update_draft_tokens_and_scores(): + # Fix parameters + models_path = llm_models_root() + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. + use_dynamic_tree = True + max_new_tokens = 128 + kv_cache_manager = None + + def run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_draft_idx, + new_draft_tokens, + new_draft_scores, + previous_draft_scores, + input_spec_decoding_packed_mask, + input_hidden_states_write_indices, + input_hidden_states_read_indices, + ref_draft_tokens_buffer, + ref_history_draft_tokens_buffer, + ref_history_draft_tokens_parent_buffer, + ref_history_score_buffer, + ref_spec_decoding_packed_mask, + ref_hidden_states_read_indices, + ): + # 1) Create attention metadata + attn_metadata = TrtllmAttentionMetadata( + max_num_requests=max_batch_size, + max_num_tokens=max_new_tokens, + kv_cache_manager=kv_cache_manager, + ) + # Set initial values + attn_metadata.spec_decoding_packed_mask = input_spec_decoding_packed_mask + attn_metadata.spec_decoding_generation_lengths = torch.zeros( + [max_batch_size], + dtype=torch.int, + device="cuda", + ) + + # 3) Create spec metadata + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=False, + eagle_choices=None, + use_dynamic_tree=use_dynamic_tree, + dynamic_tree_max_topK=dynamic_tree_max_topK, + ) + eagle3_resource_manager = Eagle3ResourceManager( + config=spec_config, + dtype=torch.bfloat16, + hidden_size=1024, + max_num_requests=max_batch_size, + max_seq_len=max_new_tokens, + max_num_tokens=max_new_tokens, + ) + spec_metadata = Eagle3SpecMetadata( + max_draft_len=spec_config.max_draft_len, + spec_dec_mode=spec_config.spec_dec_mode, + max_num_requests=max_batch_size, + num_layers=32, + hidden_size=1024, + max_num_tokens=max_new_tokens, + dtype=torch.bfloat16, + is_draft_model=True, + eagle3_resource_manager=eagle3_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + is_spec_dec_tree=spec_config.eagle_choices is not None or spec_config.use_dynamic_tree, + is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree, + ) + spec_metadata.hidden_states_write_indices = input_hidden_states_write_indices + spec_metadata.hidden_states_read_indices = input_hidden_states_read_indices + + # 4) Create DynamicTreeDraftingLoopWrapper + dynamic_tree_drafting_loop_wrapper = DynamicTreeDraftingLoopWrapper( + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + max_batch_size=max_batch_size, + dynamic_tree_max_topK=dynamic_tree_max_topK, + draft_model=DummyModel(), + ) + + # 5) Run the function + _ = dynamic_tree_drafting_loop_wrapper.update_draft_tokens_and_scores( + cur_draft_idx=cur_draft_idx, + batch_size=max_batch_size, + new_draft_tokens=new_draft_tokens, + new_draft_scores=new_draft_scores, + previous_draft_scores=previous_draft_scores, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + ) + + # 5) Check the results + print("==================================") + print(f"ref_draft_tokens_buffer: {ref_draft_tokens_buffer}") + print( + f"dynamic_tree_drafting_loop_wrapper.draft_tokens_buffer: \ + {dynamic_tree_drafting_loop_wrapper.draft_tokens_buffer}" + ) + + print(f"ref_history_draft_tokens_buffer: {ref_history_draft_tokens_buffer}") + print( + f"dynamic_tree_drafting_loop_wrapper.history_draft_tokens_buffer: \ + {dynamic_tree_drafting_loop_wrapper.history_draft_tokens_buffer}" + ) + + print( + f"ref_history_draft_tokens_parent_buffer: \ + {ref_history_draft_tokens_parent_buffer}" + ) + print( + f"dynamic_tree_drafting_loop_wrapper.history_draft_tokens_parent_buffer: \ + {dynamic_tree_drafting_loop_wrapper.history_draft_tokens_parent_buffer}" + ) + + print(f"ref_history_score_buffer: {ref_history_score_buffer}") + print( + f"dynamic_tree_drafting_loop_wrapper.history_score_buffer: \ + {dynamic_tree_drafting_loop_wrapper.history_score_buffer}" + ) + + print( + f"ref_spec_decoding_packed_mask: \ + {ref_spec_decoding_packed_mask}" + ) + print( + f"attn_metadata.spec_decoding_packed_mask: \ + {attn_metadata.spec_decoding_packed_mask}" + ) + + print(f"ref_hidden_states_read_indices: {ref_hidden_states_read_indices}") + print( + f"spec_metadata.hidden_states_read_indices: \ + {spec_metadata.hidden_states_read_indices}" + ) + + assert torch.all( + dynamic_tree_drafting_loop_wrapper.draft_tokens_buffer == ref_draft_tokens_buffer + ) + assert torch.all( + dynamic_tree_drafting_loop_wrapper.history_draft_tokens_buffer + == ref_history_draft_tokens_buffer + ) + if ref_history_draft_tokens_parent_buffer is not None: + assert torch.all( + dynamic_tree_drafting_loop_wrapper.history_draft_tokens_parent_buffer + == ref_history_draft_tokens_parent_buffer + ) + assert torch.allclose( + dynamic_tree_drafting_loop_wrapper.history_score_buffer, + ref_history_score_buffer, + atol=1e-3, + ) + if ref_spec_decoding_packed_mask is not None: + assert torch.all( + attn_metadata.spec_decoding_packed_mask == ref_spec_decoding_packed_mask + ) + if ref_hidden_states_read_indices is not None: + assert torch.all( + spec_metadata.hidden_states_read_indices == ref_hidden_states_read_indices + ) + + ##### CASE 1 dynamic tree, batch size = 1, cur_draft_idx = 0 ############# + max_batch_size = 1 + max_draft_len = 3 + max_total_draft_tokens = 15 + dynamic_tree_max_topK = 3 + cur_draft_idx = 0 + + new_draft_tokens = torch.tensor([2, 6, 3], dtype=torch.int32, device="cuda") + new_draft_scores = torch.tensor([0.3, 0.2, 0.1], dtype=torch.float32, device="cuda") + previous_draft_scores = None + + input_spec_decoding_packed_mask = torch.zeros( + max_batch_size, + max_total_draft_tokens + 1, + math.ceil((max_total_draft_tokens + 1) / 32), + dtype=torch.int32, + device="cuda", + ) + input_hidden_states_write_indices = torch.arange( + 1, max_batch_size * (max_total_draft_tokens + 1) + 1, dtype=torch.int32, device="cuda" + ) + input_hidden_states_read_indices = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ) + + ref_draft_tokens_buffer = torch.zeros( + max_batch_size, max_total_draft_tokens + 1, dtype=torch.int32, device="cuda" + ) + ref_draft_tokens_buffer[:, :3] = new_draft_tokens + + ref_history_draft_tokens_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + ref_history_draft_tokens_buffer[:, :3] = new_draft_tokens + + ref_history_draft_tokens_parent_buffer = ( + torch.ones( + max_batch_size, + dynamic_tree_max_topK + + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + * -1 + ) + ref_history_draft_tokens_parent_buffer[:, 0:12] = torch.tensor( + [-1, -1, -1, 0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=torch.int32, device="cuda" + ) + + ref_history_score_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.float32, + device="cuda", + ) + ref_history_score_buffer[:, :3] = new_draft_scores + + ref_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + ref_hidden_states_read_indices = None + + run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_draft_idx, + new_draft_tokens, + new_draft_scores, + previous_draft_scores, + input_spec_decoding_packed_mask, + input_hidden_states_write_indices, + input_hidden_states_read_indices, + ref_draft_tokens_buffer, + ref_history_draft_tokens_buffer, + ref_history_draft_tokens_parent_buffer, + ref_history_score_buffer, + ref_spec_decoding_packed_mask, + ref_hidden_states_read_indices, + ) + + ##### CASE 2 dynamic tree, batch size = 1, cur_draft_idx = 1 ############# + max_batch_size = 1 + max_draft_len = 3 + max_total_draft_tokens = 15 + dynamic_tree_max_topK = 3 + cur_draft_idx = 1 + + # new_draft_tokens: [[48, 47, 46], [45, 44, 43], ..., [6, 5, 4], [3, 2, 1]] + # new_draft_scores: [[0.48, 0.47, 0.46], [0.45, 0.44, 0.43], ..., [0.06, 0.05, 0.04], [0.03, 0.02, 0.01]] + # But the valuable draft tokens are new_draft_tokens[3:] + new_draft_tokens = torch.arange( + max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK, + 0, + -1, + dtype=torch.int32, + device="cuda", + ).reshape(max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK) + new_draft_scores = ( + torch.arange( + max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK, + 0, + -1, + dtype=torch.float32, + device="cuda", + ).reshape(max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK) + * 0.01 + ) + previous_draft_scores = torch.tensor([[1.5, 0.7, 0.4]], dtype=torch.float32, device="cuda") + + input_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + input_hidden_states_write_indices = torch.arange( + 1, max_batch_size * (max_total_draft_tokens + 1) + 1, dtype=torch.int32, device="cuda" + ) + input_hidden_states_read_indices = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ) + + ref_draft_tokens_buffer = torch.zeros( + max_batch_size, max_total_draft_tokens + 1, dtype=torch.int32, device="cuda" + ) + ref_draft_tokens_buffer[:, 3:6] = torch.tensor([48, 47, 46], dtype=torch.int32, device="cuda") + + ref_history_draft_tokens_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + ref_history_draft_tokens_buffer[:, 3:12] = torch.tensor( + [48, 47, 46, 45, 44, 43, 42, 41, 40], dtype=torch.int32, device="cuda" + ) + + ref_history_draft_tokens_parent_buffer = ( + torch.ones( + max_batch_size, + dynamic_tree_max_topK + + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + * -1 + ) + ref_history_draft_tokens_parent_buffer[:, 12 : 12 + 9] = torch.tensor( + [3, 3, 3, 4, 4, 4, 5, 5, 5], dtype=torch.int32, device="cuda" + ) + + ref_history_score_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.float32, + device="cuda", + ) + ref_history_score_buffer[:, 3:12] = torch.tensor( + [1.98, 1.97, 1.96, 1.15, 1.14, 1.13, 0.82, 0.81, 0.80], dtype=torch.float32, device="cuda" + ) + + ref_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 9, 17, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + ref_hidden_states_read_indices = torch.tensor( + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ) + + run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_draft_idx, + new_draft_tokens, + new_draft_scores, + previous_draft_scores, + input_spec_decoding_packed_mask, + input_hidden_states_write_indices, + input_hidden_states_read_indices, + ref_draft_tokens_buffer, + ref_history_draft_tokens_buffer, + ref_history_draft_tokens_parent_buffer, + ref_history_score_buffer, + ref_spec_decoding_packed_mask, + ref_hidden_states_read_indices, + ) + + ##### CASE 2 dynamic tree, batch size = 1, cur_draft_idx = 1 ############# + max_batch_size = 1 + max_draft_len = 3 + max_total_draft_tokens = 15 + dynamic_tree_max_topK = 3 + cur_draft_idx = 2 + + # new_draft_tokens: [[48, 47, 46], [45, 44, 43], ..., [6, 5, 4], [3, 2, 1]] + # new_draft_scores: [[0.48, 0.47, 0.46], [0.45, 0.44, 0.43], ..., [0.06, 0.05, 0.04], [0.03, 0.02, 0.01]] + # But the valuable draft tokens are new_draft_tokens[3:] + new_draft_tokens = torch.arange( + max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK, + 0, + -1, + dtype=torch.int32, + device="cuda", + ).reshape(max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK) + new_draft_scores = ( + torch.arange( + max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK, + 0, + -1, + dtype=torch.float32, + device="cuda", + ).reshape(max_batch_size * (max_total_draft_tokens + 1) * dynamic_tree_max_topK) + * 0.01 + ) + previous_draft_scores = torch.tensor([[1.5, 0.7, 0.4]], dtype=torch.float32, device="cuda") + + input_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 9, 17, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + input_hidden_states_write_indices = torch.arange( + 1, max_batch_size * (max_total_draft_tokens + 1) + 1, dtype=torch.int32, device="cuda" + ) + input_hidden_states_read_indices = torch.tensor( + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ) + + ref_draft_tokens_buffer = torch.zeros( + max_batch_size, max_total_draft_tokens + 1, dtype=torch.int32, device="cuda" + ) + ref_draft_tokens_buffer[:, 6:9] = torch.tensor([39, 38, 37], dtype=torch.int32, device="cuda") + + ref_history_draft_tokens_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + ref_history_draft_tokens_buffer[:, 12:21] = torch.tensor( + [39, 38, 37, 36, 35, 34, 33, 32, 31], dtype=torch.int32, device="cuda" + ) + + ref_history_draft_tokens_parent_buffer = None # will not be updated for this layer + + ref_history_score_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.float32, + device="cuda", + ) + ref_history_score_buffer[:, 12:21] = torch.tensor( + [1.89, 1.88, 1.87, 1.06, 1.05, 1.04, 0.73, 0.72, 0.71], dtype=torch.float32, device="cuda" + ) + + ref_spec_decoding_packed_mask = torch.tensor( + [1, 2, 4, 9, 17, 33, 73, 137, 265, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + ref_hidden_states_read_indices = torch.tensor( + [0, 0, 0, 1, 1, 1, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32, device="cuda" + ) + + run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_draft_idx, + new_draft_tokens, + new_draft_scores, + previous_draft_scores, + input_spec_decoding_packed_mask, + input_hidden_states_write_indices, + input_hidden_states_read_indices, + ref_draft_tokens_buffer, + ref_history_draft_tokens_buffer, + ref_history_draft_tokens_parent_buffer, + ref_history_score_buffer, + ref_spec_decoding_packed_mask, + ref_hidden_states_read_indices, + ) + + ##### CASE 4 dynamic tree, batch size = 2, cur_draft_idx = 1 ############# + max_batch_size = 2 + max_draft_len = 3 + max_total_draft_tokens = 15 + dynamic_tree_max_topK = 3 + cur_draft_idx = 1 + + # new_draft_tokens: [[48, 47, 46], [45, 44, 43], ..., [6, 5, 4], [3, 2, 1]] + # new_draft_scores: [[0.48, 0.47, 0.46], [0.45, 0.44, 0.43], ..., [0.06, 0.05, 0.04], [0.03, 0.02, 0.01]] + # But the valuable draft tokens are new_draft_tokens[3:] + new_draft_tokens = ( + torch.arange( + (max_total_draft_tokens + 1) * dynamic_tree_max_topK, + 0, + -1, + dtype=torch.int32, + device="cuda", + ) + .reshape((max_total_draft_tokens + 1) * dynamic_tree_max_topK) + .repeat(max_batch_size) + ) + new_draft_scores = ( + torch.arange( + (max_total_draft_tokens + 1) * dynamic_tree_max_topK, + 0, + -1, + dtype=torch.float32, + device="cuda", + ) + .reshape((max_total_draft_tokens + 1) * dynamic_tree_max_topK) + .repeat(max_batch_size) + * 0.01 + ) + previous_draft_scores = torch.tensor( + [[1.5, 0.7, 0.4], [2.5, 1.7, 1.4]], dtype=torch.float32, device="cuda" + ) + + input_spec_decoding_packed_mask = torch.tensor( + [ + [1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=torch.int32, + device="cuda", + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + input_hidden_states_write_indices = torch.arange( + 1, max_batch_size * (max_total_draft_tokens + 1) + 1, dtype=torch.int32, device="cuda" + ) + input_hidden_states_read_indices = torch.tensor( + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # req1 + 16, + 16, + 16, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # req2 + ], + dtype=torch.int32, + device="cuda", + ) + + ref_draft_tokens_buffer = torch.zeros( + max_batch_size, max_total_draft_tokens + 1, dtype=torch.int32, device="cuda" + ) + ref_draft_tokens_buffer[:, 3:6] = torch.tensor([48, 47, 46], dtype=torch.int32, device="cuda") + + ref_history_draft_tokens_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + ref_history_draft_tokens_buffer[:, 3:12] = torch.tensor( + [48, 47, 46, 45, 44, 43, 42, 41, 40], dtype=torch.int32, device="cuda" + ) + + ref_history_draft_tokens_parent_buffer = ( + torch.ones( + max_batch_size, + dynamic_tree_max_topK + + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.int32, + device="cuda", + ) + * -1 + ) + ref_history_draft_tokens_parent_buffer[:, 12 : 12 + 9] = torch.tensor( + [3, 3, 3, 4, 4, 4, 5, 5, 5], dtype=torch.int32, device="cuda" + ) + + ref_history_score_buffer = torch.zeros( + max_batch_size, + dynamic_tree_max_topK + dynamic_tree_max_topK * dynamic_tree_max_topK * (max_draft_len - 1), + dtype=torch.float32, + device="cuda", + ) + ref_history_score_buffer[0:, 3:12] = torch.tensor( + [1.98, 1.97, 1.96, 1.15, 1.14, 1.13, 0.82, 0.81, 0.80], dtype=torch.float32, device="cuda" + ) + ref_history_score_buffer[1:, 3:12] = torch.tensor( + [2.98, 2.97, 2.96, 2.15, 2.14, 2.13, 1.82, 1.81, 1.80], dtype=torch.float32, device="cuda" + ) + + ref_spec_decoding_packed_mask = torch.tensor( + [ + 1, + 2, + 4, + 9, + 17, + 33, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # req1 + 1, + 2, + 4, + 9, + 17, + 33, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # req2 + ], + dtype=torch.int32, + device="cuda", + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + ref_hidden_states_read_indices = torch.tensor( + [ + 0, + 0, + 0, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # req1 + 16, + 16, + 16, + 17, + 17, + 17, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, # req2 + ], + dtype=torch.int32, + device="cuda", + ) + + run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_draft_idx, + new_draft_tokens, + new_draft_scores, + previous_draft_scores, + input_spec_decoding_packed_mask, + input_hidden_states_write_indices, + input_hidden_states_read_indices, + ref_draft_tokens_buffer, + ref_history_draft_tokens_buffer, + ref_history_draft_tokens_parent_buffer, + ref_history_score_buffer, + ref_spec_decoding_packed_mask, + ref_hidden_states_read_indices, + ) + + +def test_dynamic_tree_restruct_tree(): + # Fix parameters + models_path = llm_models_root() + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" # It will not actually be used. + use_dynamic_tree = True + max_new_tokens = 128 + + def run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_topk_score_indices, + cur_history_draft_tokens_parent_buffer, + ref_eagle_paths, + ref_spec_dec_packed_mask, + ref_spec_dec_position_offsets, + ): + # 1) Create spec metadata + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=False, + eagle_choices=None, + use_dynamic_tree=use_dynamic_tree, + dynamic_tree_max_topK=dynamic_tree_max_topK, + ) + + eagle3_resource_manager = Eagle3ResourceManager( + config=spec_config, + dtype=torch.bfloat16, + hidden_size=1024, + max_num_requests=max_batch_size, + max_seq_len=max_new_tokens, + max_num_tokens=max_new_tokens, + ) + + spec_tree_manager = SpecTreeManager( + max_num_requests=max_batch_size, + use_dynamic_tree=spec_config.use_dynamic_tree, + max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, + eagle_choices=spec_config.eagle_choices, + dynamic_tree_max_topK=spec_config.dynamic_tree_max_topK, + ) + + # 2) Create model drafter + model_drafter = ModelDrafter( + spec_config=spec_config, + draft_model_engine=DummyModel(), + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + draft_seq_slot_manager=None, + sampler=None, + spec_resource_manager=eagle3_resource_manager, + guided_decoder=None, + ) + + # 3) Reconstruct the dynamic tree + model_drafter.reconstruct_dynamic_tree( + 0, cur_topk_score_indices, cur_history_draft_tokens_parent_buffer, spec_tree_manager + ) + + print("==================================") + print(f"ref_eagle_paths: {ref_eagle_paths}") + print(f"spec_tree_manager.eagle_paths: {spec_tree_manager.eagle_paths}") + + print(f"ref_spec_dec_packed_mask: {ref_spec_dec_packed_mask}") + print(f"spec_tree_manager.spec_dec_packed_mask: {spec_tree_manager.spec_dec_packed_mask}") + + print(f"ref_spec_dec_position_offsets: {ref_spec_dec_position_offsets}") + print( + f"spec_tree_manager.spec_dec_position_offsets: {spec_tree_manager.spec_dec_position_offsets}" + ) + + assert torch.all(ref_eagle_paths == spec_tree_manager.eagle_paths) + assert torch.all(ref_spec_dec_packed_mask == spec_tree_manager.spec_dec_packed_mask) + assert torch.all( + ref_spec_dec_position_offsets == spec_tree_manager.spec_dec_position_offsets + ) + + ##### CASE 1 dynamic tree, batch size = 1 ############# + max_batch_size = 1 + max_draft_len = 3 + max_total_draft_tokens = 10 + dynamic_tree_max_topK = 3 + + cur_topk_score_indices = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 9, 12], dtype=torch.int32, device="cpu" + ) + cur_history_draft_tokens_parent_buffer = torch.tensor( + [-1, -1, -1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5], + dtype=torch.int32, + device="cpu", + ) + + ref_eagle_paths = torch.tensor( + [ + [ + [0, -1, -1, -1], + [0, 1, -1, -1], + [0, 2, -1, -1], + [0, 3, -1, -1], + [0, 1, 4, -1], + [0, 1, 5, -1], + [0, 1, 6, -1], + [0, 2, 7, -1], + [0, 2, 8, -1], + [0, 3, 9, -1], + [0, 1, 4, 10], + ] + ], + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) + + ref_spec_dec_packed_mask = torch.tensor( + [1, 3, 5, 9, 19, 35, 67, 133, 261, 521, 1043], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + + ref_spec_dec_position_offsets = torch.tensor( + [[0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3]], dtype=torch.int32, device="cuda" + ) + + run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_topk_score_indices, + cur_history_draft_tokens_parent_buffer, + ref_eagle_paths, + ref_spec_dec_packed_mask, + ref_spec_dec_position_offsets, + ) + + ##### CASE 2 dynamic tree, batch size = 1 ############# + max_batch_size = 1 + max_draft_len = 3 + max_total_draft_tokens = 9 + dynamic_tree_max_topK = 3 + + cur_topk_score_indices = torch.tensor( + [0, 1, 2, 3, 4, 6, 12, 13, 18], dtype=torch.int32, device="cpu" + ) + cur_history_draft_tokens_parent_buffer = torch.tensor( + [-1, -1, -1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 6, 6, 6], + dtype=torch.int32, + device="cpu", + ) + + ref_eagle_paths = torch.tensor( + [ + [ + [0, -1, -1, -1], + [0, 1, -1, -1], + [0, 2, -1, -1], + [0, 3, -1, -1], + [0, 1, 4, -1], + [0, 1, 5, -1], + [0, 2, 6, -1], + [0, 1, 4, 7], + [0, 1, 4, 8], + [0, 2, 6, 9], + ] + ], + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) + + ref_spec_dec_packed_mask = torch.tensor( + [1, 3, 5, 9, 19, 35, 69, 147, 275, 581], dtype=torch.int32, device="cuda" + ).reshape( + max_batch_size, max_total_draft_tokens + 1, math.ceil((max_total_draft_tokens + 1) / 32) + ) + + ref_spec_dec_position_offsets = torch.tensor( + [[0, 1, 1, 1, 2, 2, 2, 3, 3, 3]], dtype=torch.int32, device="cuda" + ) + + run_test( + max_batch_size, + max_draft_len, + max_total_draft_tokens, + dynamic_tree_max_topK, + cur_topk_score_indices, + cur_history_draft_tokens_parent_buffer, + ref_eagle_paths, + ref_spec_dec_packed_mask, + ref_spec_dec_position_offsets, + ) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py b/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py index 6002d9d6856..08160e77094 100644 --- a/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py +++ b/tests/unittest/_torch/speculative/test_draft_token_tree_sampling.py @@ -6,7 +6,7 @@ from utils.llm_data import llm_models_root from tensorrt_llm._torch.speculative.drafting_loops import \ - TreeDraftingLoopWrapper + StaticTreeDraftingLoopWrapper from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager from tensorrt_llm.llmapi import EagleDecodingConfig @@ -53,7 +53,7 @@ def run_test(max_batch_size, draft_layer_id, max_total_draft_tokens, ) # Create the chain drafter - tree_drafter = TreeDraftingLoopWrapper( + tree_drafter = StaticTreeDraftingLoopWrapper( max_batch_size=max_batch_size, max_draft_len=spec_config.max_draft_len, max_total_draft_tokens=spec_config.max_total_draft_tokens,