diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index c402480b7d9..6826cda6114 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -11,7 +11,7 @@ import weakref from collections import deque, namedtuple from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import torch @@ -308,7 +308,7 @@ def __init__(self, if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) - if self.draft_model_engine is not None: + if self.drafter is not None: if self.event_loop.__name__ != self._executor_loop.__name__: raise NotImplementedError( "Drafting is not supported for selected executor loop. " @@ -905,10 +905,6 @@ def _executor_loop_pp(self): def _executor_loop(self): torch.cuda.set_device(self.device_id) - is_ngram = hasattr( - self.model_engine, "spec_config" - ) and self.model_engine.spec_config is not None and self.model_engine.spec_config.spec_dec_mode.is_ngram( - ) with self._profiler() as profile_step: sample_state = None iter_start_time = time.time() @@ -931,7 +927,7 @@ def _executor_loop(self): self._pad_attention_dp_dummy_request() - if self.draft_model_engine is not None or is_ngram or self.drafter is not None: + if self.drafter is not None: self._prepare_draft_requests(self.active_requests) scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( @@ -971,11 +967,9 @@ def _executor_loop(self): scheduled_batch) self.resource_manager.prepare_resources(scheduled_batch) - if self.draft_model_engine is not None: - self._prepare_draft_tokens(scheduled_batch) - if self.drafter is not None: - self.drafter.prepare_draft_tokens(scheduled_batch) + self.drafter.prepare_draft_tokens( + scheduled_batch, self.resource_manager) if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer @@ -1798,188 +1792,6 @@ def _update_requests(self, sample_state: SampleState): logger.error(f"Encountered an error in sampling: {error_msg}") self._handle_errors(error_msg) - @nvtx_range("_prepare_draft_batch") - def _prepare_draft_batch( - self, scheduled_requests: ScheduledRequests - ) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]: - """ - Prepares a batch for the draft model engine. Draft tokens are only produced - for generation requests. - - The requests are prepared as follows: - 1. The first time the draft engine sees a request, it's a context request. - 2. Otherwise, if draft tokens were accepted on the last target model decoding - step, it's a chunked context request (we process all the accepted tokens together). - 3. Otherwise, it's a generation request. - """ - try: - draft_batch = ScheduledRequests() - - for request in scheduled_requests.generation_requests: - if request.py_draft_pages_allocated == 0: - # No space for draft tokens. - continue - - # Stop drafting when we hit the max seqlen. We still need dummy draft - # tokens attached to the requests to make sure everything works properly - # with CUDA graph. These dummy tokens are already added by - # _prepare_draft_requests to make the KV cache/scheduler aware of the fact - # that we want to do spec decoding, so no need to do anything else here. - # This makes the perf for this case suboptimal, but that's OK - this is - # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. - if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: - continue - - num_draft_tokens = len( - request.py_last_draft_tokens - ) if request.py_last_draft_tokens is not None else 0 - request.py_draft_tokens = [] - - num_accepted_tokens = request.py_num_accepted_draft_tokens - num_rejected_tokens = num_draft_tokens - num_accepted_tokens - assert num_rejected_tokens >= 0 - - spec_config = self.model_engine.spec_config - beam_idx = 0 - input_tokens = spec_config.get_draft_model_prompt( - request.get_tokens()[beam_idx]) - - def create_new_request(input_tokens): - return LlmRequest( - request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - input_tokens=input_tokens, - sampling_config=request.sampling_config, - return_perf_metrics=request.return_perf_metrics, - is_streaming=False, - is_draft=True) - - if request.max_beam_num_tokens - 1 == request.py_prompt_len: - # This is the first time the draft model is seeing this request. - # Prepare a context request. We discard the first token and take - # the newly decoded one - this is the convention for EAGLE 2 and 3. - new_request = create_new_request(input_tokens) - draft_batch.context_requests.append(new_request) - elif num_accepted_tokens == 0: - new_request = create_new_request(input_tokens[:-1]) - # Explicitly add the last token so get_last_tokens() returns - # the right value - new_request.add_new_token(input_tokens[-1], beam_idx) - new_request.state = LlmRequestState.GENERATION_IN_PROGRESS - draft_batch.generation_requests.append(new_request) - else: - new_request = create_new_request(input_tokens) - new_request.context_chunk_size = num_accepted_tokens + 1 - new_request.context_current_position = len( - input_tokens) - num_accepted_tokens - 1 - new_request.context_chunk_size = num_accepted_tokens + 1 - new_request.context_current_position = len( - input_tokens) - num_accepted_tokens - 1 - - draft_batch.context_requests.append(new_request) - - new_request.py_stop_words_list = request.py_stop_words_list - - return draft_batch - - except Exception as e: - traceback.print_exc() - error_msg = str(e) - logger.error(f"Encountered an error in decode: {error_msg}") - self._handle_errors(error_msg) - - @nvtx_range("_prepare_draft_tokens") - def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): - if not self.draft_model_engine: - raise ValueError("Draft model engine is not set") - - try: - draft_batch = self._prepare_draft_batch(scheduled_requests) - - if draft_batch.batch_size == 0: - return - self.draft_seq_slot_manager.prepare_resources(draft_batch) - - req_id_to_old_request = { - req.py_request_id: req - for req in scheduled_requests.all_requests() - } - - # Disable cuda graph for the 1st draft model forward - if self.model_engine.spec_config.spec_dec_mode.needs_kv_cache_recompute( - ): - with self.draft_model_engine.no_cuda_graph(): - outputs = self.draft_model_engine.forward( - draft_batch, self.resource_manager) - else: - outputs = self.draft_model_engine.forward( - draft_batch, self.resource_manager) - if hasattr(self.draft_model_engine.model.model, 'd2t'): - outputs['d2t'] = self.draft_model_engine.model.model.d2t.data - - sample_state = self._sample_async(draft_batch, outputs) - previous_batch = sample_state - - self._update_request_states(draft_batch) - - def _process_decoded_tokens(draft_batch): - new_requests = [] - for req in draft_batch.all_requests(): - target_model_req = req_id_to_old_request[req.py_request_id] - target_model_req.py_draft_tokens.append( - req.get_last_tokens(0)) - if req.state != LlmRequestState.GENERATION_COMPLETE and len( - target_model_req.py_draft_tokens - ) < target_model_req.py_draft_pages_allocated: - new_requests.append(req) - else: - self.draft_seq_slot_manager.free_resources(req) - - return new_requests - - # The TRTLLM attention kernels cannot handle generation requests with - # different seqlens. No issues with flashinfer, should we look into removing - # this? Just needs proper kernel support. - def _pad_to_max_draft_tokens(): - for req in scheduled_requests.generation_requests: - max_draft_len = self.max_draft_len - num_draft_tokens = len(req.py_draft_tokens) - req.py_draft_tokens.extend( - 0 for _ in range(max_draft_len - num_draft_tokens)) - - draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests - draft_batch.context_requests = [] - - for i in range(self.max_draft_len - 1): - if len(draft_batch.generation_requests) == 0: - break - - outputs = self.draft_model_engine.forward( - draft_batch, - self.resource_manager, - new_tensors_device=previous_batch.device) - - if hasattr(self.draft_model_engine.model.model, 'd2t'): - outputs[ - 'd2t'] = self.draft_model_engine.model.model.d2t.data - sample_state = self._sample_async(draft_batch, outputs) - self._update_request_states(draft_batch) - self._update_requests(previous_batch) - new_requests = _process_decoded_tokens( - previous_batch.scheduled_requests) - draft_batch.generation_requests = new_requests - previous_batch = sample_state - self._update_requests(previous_batch) - new_requests = _process_decoded_tokens( - previous_batch.scheduled_requests) - _pad_to_max_draft_tokens() - - except Exception as e: - traceback.print_exc() - error_msg = str(e) - logger.error(f"Encountered an error in decode: {error_msg}") - self._handle_errors(error_msg) - def _handle_errors(self, error_msg: Optional[str] = None): error_responses = {} error_msg = error_msg or "error" diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index b9eccc90601..446b647618d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -382,7 +382,8 @@ def create_py_executor( # Drafter for speculative decoding with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER): - drafter = get_spec_drafter(model_engine, spec_resource_manager) + drafter = get_spec_drafter(model_engine, draft_model_engine, sampler, + spec_resource_manager) with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_EXTRA_RESOURCES diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index d99c5dd92d8..e08044cbb4f 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,16 +1,23 @@ from abc import ABC, abstractmethod +from typing import Optional +from ..pyexecutor.resource_manager import ResourceManager from ..pyexecutor.scheduler import ScheduledRequests class Drafter(ABC): + """Abstract base class for all drafter implementations.""" @abstractmethod def prepare_draft_tokens( self, scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, ) -> None: """ Prepare the drafter tokens for the forward computation this step. + + Args: + scheduled_requests: The scheduled requests for this iteration """ raise NotImplementedError diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py new file mode 100644 index 00000000000..ac195ccf515 --- /dev/null +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import traceback +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.logger import logger + +from ..pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig +from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager +from ..pyexecutor.sampler import Sampler, SampleState +from ..pyexecutor.scheduler import ScheduledRequests +from ..pyexecutor.seq_slot_manager import SeqSlotManager +from .drafter import Drafter + +if TYPE_CHECKING: + from ..pyexecutor.model_engine import ModelEngine + + +class ModelDrafter(Drafter): + """Model-based drafter that uses a draft model to generate draft tokens.""" + + def __init__( + self, + spec_config: "DecodingBaseConfig", + draft_model_engine: "ModelEngine", + max_draft_tokens: int, + draft_seq_slot_manager: SeqSlotManager, + sampler: Sampler, + spec_resource_manager: Optional[BaseResourceManager] = None, + ): + # Validate required parameters + if draft_model_engine is None: + raise ValueError("draft_model_engine cannot be None") + if max_draft_tokens < 0: + raise ValueError(f"max_draft_tokens must be >= 0") + + # Model and resource management + self.draft_model_engine = draft_model_engine + self.draft_seq_slot_manager = draft_seq_slot_manager + self.spec_resource_manager = spec_resource_manager + + # Configuration + self.spec_config = spec_config + self.max_draft_tokens = max_draft_tokens + + # Sampling + self.sampler = sampler + + def _create_draft_request(self, request_id: int, max_new_tokens: int, + input_tokens: Optional[List], + sampling_config: SamplingConfig, + return_perf_metrics: bool) -> LlmRequest: + """Create a draft request with common parameters.""" + return LlmRequest(request_id=request_id, + max_new_tokens=max_new_tokens, + input_tokens=input_tokens, + sampling_config=sampling_config, + return_perf_metrics=return_perf_metrics, + is_streaming=False, + is_draft=True) + + def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: + """Initialize draft token tracking for a request.""" + num_draft_tokens = len( + request.py_last_draft_tokens + ) if request.py_last_draft_tokens is not None else 0 + request.py_draft_tokens = [] + + num_accepted_tokens = request.py_num_accepted_draft_tokens + num_rejected_tokens = num_draft_tokens - num_accepted_tokens + assert num_rejected_tokens >= 0 + + return num_draft_tokens, num_accepted_tokens + + def _create_context_request(self, request: LlmRequest, + input_tokens: Any) -> LlmRequest: + """Create a context request for first-time drafting.""" + return self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens, request.sampling_config, + request.return_perf_metrics) + + def _create_generation_request(self, request: LlmRequest, + input_tokens: Any) -> LlmRequest: + """Create a generation request when no tokens were accepted.""" + new_request = self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens[:-1], + request.sampling_config, + request.return_perf_metrics) + # Explicitly add the last token so get_last_tokens() returns the right value + new_request.add_new_token(input_tokens[-1], 0) + new_request.state = LlmRequestState.GENERATION_IN_PROGRESS + return new_request + + def _create_chunked_context_request(self, request: LlmRequest, + input_tokens: Any, + num_accepted_tokens: int) -> LlmRequest: + """Create a chunked context request when some tokens were accepted.""" + new_request = self._create_draft_request(request.py_request_id, + request.py_max_new_tokens, + input_tokens, + request.sampling_config, + request.return_perf_metrics) + new_request.context_chunk_size = num_accepted_tokens + 1 + new_request.context_current_position = len( + input_tokens) - num_accepted_tokens - 1 + return new_request + + def _create_draft_request_for_request( + self, request: LlmRequest) -> Optional[LlmRequest]: + """Create a draft request based on the original request state.""" + num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens( + request) + input_tokens = self.spec_config.get_draft_model_prompt( + request.get_tokens()[0]) + + # First time seeing this request - context request + if request.max_beam_num_tokens - 1 == request.py_prompt_len: + # This is the first time the draft model is seeing this request. + # Prepare a context request. We discard the first token and take + # the newly decoded one - this is the convention for EAGLE 2 and 3. + assert num_draft_tokens == 0 + return self._create_context_request(request, input_tokens) + + # No tokens accepted - generation request + elif num_accepted_tokens == 0: + return self._create_generation_request(request, input_tokens) + + # Tokens accepted - chunked context request + else: + return self._create_chunked_context_request(request, input_tokens, + num_accepted_tokens) + + def _add_to_draft_batch(self, draft_batch: ScheduledRequests, + draft_request: LlmRequest, + original_request: LlmRequest) -> None: + """Add the draft request to the appropriate batch list.""" + # Copy additional properties + draft_request.py_stop_words_list = original_request.py_stop_words_list + + # Add to appropriate batch based on request type + if draft_request.state == LlmRequestState.GENERATION_IN_PROGRESS: + draft_batch.generation_requests.append(draft_request) + else: + draft_batch.context_requests.append(draft_request) + + @nvtx_range("_prepare_draft_batch") + def _prepare_draft_batch( + self, scheduled_requests: ScheduledRequests) -> ScheduledRequests: + """ + Prepares a batch for the draft model engine. Draft tokens are only produced + for generation requests. + + The requests are prepared as follows: + 1. The first time the draft engine sees a request, it's a context request. + 2. Otherwise, if draft tokens were accepted on the last target model decoding + step, it's a chunked context request (we process all the accepted tokens together). + 3. Otherwise, it's a generation request. + + Args: + scheduled_requests: The scheduled requests to prepare draft batch for + + Returns: + ScheduledRequests: The prepared draft batch + """ + try: + draft_batch = ScheduledRequests() + + for request in scheduled_requests.generation_requests: + if request.py_draft_pages_allocated == 0: + # No space for draft tokens + continue + + # Stop drafting when we hit the max seqlen. We still need dummy draft + # tokens attached to the requests to make sure everything works properly + # with CUDA graph. These dummy tokens are already added by + # _prepare_draft_requests to make the KV cache/scheduler aware of the fact + # that we want to do spec decoding, so no need to do anything else here. + # This makes the perf for this case suboptimal, but that's OK - this is + # a corner case for weird models like the llama 3.1 8b EAGLE3 implementation. + if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: + continue + + draft_request = self._create_draft_request_for_request(request) + if draft_request is not None: + self._add_to_draft_batch(draft_batch, draft_request, + request) + + return draft_batch + + except Exception as e: + logger.error(f"Error in _prepare_draft_batch: {str(e)}") + traceback.print_exc() + raise e + + def _should_disable_cuda_graph( + self, previous_batch: Optional[SampleState]) -> bool: + """Check if CUDA graph should be disabled for the current forward pass.""" + if previous_batch is not None: + return False + return self.spec_config.spec_dec_mode.needs_kv_cache_recompute() + + def _forward_draft_model( + self, + draft_batch: ScheduledRequests, + resource_manager: ResourceManager, + previous_batch: Optional[SampleState] = None) -> Dict[str, Any]: + """Forward pass through the draft model.""" + if self._should_disable_cuda_graph(previous_batch): + with self.draft_model_engine.no_cuda_graph(): + outputs = self.draft_model_engine.forward( + draft_batch, resource_manager) + else: + new_tensors_device = previous_batch.device if previous_batch else None + outputs = self.draft_model_engine.forward( + draft_batch, + resource_manager, + new_tensors_device=new_tensors_device) + + # Handle d2t data if available + if hasattr(self.draft_model_engine.model.model, 'd2t'): + outputs['d2t'] = self.draft_model_engine.model.model.d2t.data + + return outputs + + def _sample_async(self, draft_batch: ScheduledRequests, + outputs: Dict[str, Any]) -> Optional[SampleState]: + """Sample tokens from draft model outputs.""" + try: + if self.sampler is not None: + return self.sampler.sample_async(draft_batch, outputs) + return None + except Exception as e: + logger.error(f"Error in sampling: {str(e)}") + return None + + def _update_request_states(self, + scheduled_requests: ScheduledRequests) -> None: + """Update request states after processing.""" + for request in scheduled_requests.context_requests: + if request.state != LlmRequestState.GENERATION_COMPLETE: + request.move_to_next_context_chunk() + if request.context_remaining_length == 0: + request.state = LlmRequestState.GENERATION_IN_PROGRESS + + def _update_requests(self, sample_state: SampleState) -> None: + """Update requests with sample state.""" + if self.sampler is not None: + self.sampler.update_requests(sample_state) + + def _process_decoded_tokens( + self, draft_batch: ScheduledRequests, + req_id_to_old_request: Dict[int, LlmRequest]) -> List[LlmRequest]: + """Process decoded tokens and determine which requests to continue processing.""" + new_requests = [] + for req in draft_batch.all_requests(): + target_model_req = req_id_to_old_request[req.py_request_id] + target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) + if req.state != LlmRequestState.GENERATION_COMPLETE and len( + target_model_req.py_draft_tokens + ) < target_model_req.py_draft_pages_allocated: + new_requests.append(req) + else: + self.draft_seq_slot_manager.free_resources(req) + + return new_requests + + def _pad_to_max_draft_tokens(self, + scheduled_requests: ScheduledRequests) -> None: + """Pad draft tokens to maximum length for all generation requests.""" + for req in scheduled_requests.generation_requests: + max_draft_tokens = self.max_draft_tokens + num_draft_tokens = len(req.py_draft_tokens) + req.py_draft_tokens.extend( + 0 for _ in range(max_draft_tokens - num_draft_tokens)) + + @nvtx_range("prepare_draft_tokens") + def prepare_draft_tokens( + self, + scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, + ) -> None: + """ + Prepare draft tokens for the scheduled requests. + + Args: + scheduled_requests: The scheduled requests for this iteration + resource_manager: The resource manager for this iteration + """ + if not self.draft_model_engine: + raise ValueError("Draft model engine is not set") + + if resource_manager is None: + raise ValueError("Resource manager is required") + + try: + draft_batch = self._prepare_draft_batch(scheduled_requests) + + if draft_batch.batch_size == 0: + return + + self.draft_seq_slot_manager.prepare_resources(draft_batch) + + req_id_to_old_request = { + req.py_request_id: req + for req in scheduled_requests.all_requests() + } + + # Initial forward pass + outputs = self._forward_draft_model(draft_batch, resource_manager) + sample_state = self._sample_async(draft_batch, outputs) + previous_batch = sample_state + + self._update_request_states(draft_batch) + + # Convert context requests to generation requests + draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests + draft_batch.context_requests = [] + + # Generate remaining draft tokens iteratively + for i in range(self.max_draft_tokens - 1): + if len(draft_batch.generation_requests) == 0: + break + + outputs = self._forward_draft_model(draft_batch, + resource_manager, + previous_batch) + sample_state = self._sample_async(draft_batch, outputs) + self._update_request_states(draft_batch) + if previous_batch is not None: + self._update_requests(previous_batch) + new_requests = self._process_decoded_tokens( + previous_batch.scheduled_requests, + req_id_to_old_request) + else: + new_requests = [] + draft_batch.generation_requests = new_requests + previous_batch = sample_state + + # Final cleanup + if previous_batch is not None: + self._update_requests(previous_batch) + self._process_decoded_tokens(previous_batch.scheduled_requests, + req_id_to_old_request) + self._pad_to_max_draft_tokens(scheduled_requests) + + except Exception as e: + traceback.print_exc() + error_msg = str(e) + logger.error(f"Encountered an error in decode: {error_msg}") + raise e diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 57f3045e664..9113900ef94 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -5,7 +5,7 @@ from tensorrt_llm.logger import logger from ..pyexecutor.llm_request import * -from ..pyexecutor.resource_manager import BaseResourceManager +from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager from ..pyexecutor.scheduler import ScheduledRequests from .drafter import Drafter @@ -59,10 +59,10 @@ def __init__(self, spec_config: "NGramDecodingConfig", self.start_index = {} def get_max_resource_count(self) -> int: - raise self.max_num_requests + return self.max_num_requests def get_needed_resource_to_completion(self, request: LlmRequest) -> int: - raise 0 + return 0 def prepare_resources(self, scheduled_batch: ScheduledRequests): pass @@ -173,6 +173,7 @@ def __init__( def prepare_draft_tokens( self, scheduled_requests: ScheduledRequests, + resource_manager: Optional[ResourceManager] = None, ) -> None: # Sort by request_id when py_batch_idx is None as a fallback. # This happens in the disagg case: for a set of new requests, we draft diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 667d1a14b0e..2519584274f 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -1,9 +1,11 @@ from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler from tensorrt_llm._torch.speculative.interface import SpecMetadata +from ..pyexecutor.seq_slot_manager import SeqSlotManager from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata, Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3SpecMetadata) +from .model_drafter import ModelDrafter from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata, MTPWorker) from .ngram import NGramDrafter, NGramPoolManager @@ -112,14 +114,26 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") -def get_spec_drafter(model_engine, spec_resource_manager): +def get_spec_drafter(model_engine, draft_model_engine, sampler, + spec_resource_manager): spec_config = model_engine.spec_config if spec_config is None: return None - if spec_config.spec_dec_mode.is_ngram(): - return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_user_provided(): return spec_config.drafter + + max_num_requests = model_engine.batch_size + if spec_config.spec_dec_mode.is_draft_target( + ) or spec_config.spec_dec_mode.is_eagle3(): + return ModelDrafter(spec_config, draft_model_engine, + spec_config.max_draft_len, + SeqSlotManager(max_num_requests), sampler, + spec_resource_manager) + + if spec_config.spec_dec_mode.is_ngram(): + return NGramDrafter(spec_config, spec_resource_manager) + return None