Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,19 @@ def create_py_executor(
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
kv_cache_creator.build_managers(resources)

# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
drafter = get_spec_drafter(model_engine)

# Resource managers for speculative decoding
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
# to provide a resource manager if required.
spec_resource_manager = get_spec_resource_manager(model_engine,
draft_model_engine,
drafter)
draft_model_engine)
if spec_resource_manager is not None:
resources[
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager

# Drafter for speculative decoding
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
drafter = get_spec_drafter(model_engine, spec_resource_manager)

with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):
Expand Down
8 changes: 0 additions & 8 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from abc import ABC, abstractmethod
from typing import Optional

from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.scheduler import ScheduledRequests


class Drafter(ABC):

def __init__(
self,
spec_resource_manager: Optional[BaseResourceManager] = None,
):
self.spec_resource_manager = spec_resource_manager

@abstractmethod
def prepare_draft_tokens(
self,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def __init__(
ngram_pool_manager: NGramPoolManager = None,
):
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
super().__init__(spec_resource_manager=ngram_pool_manager)
self.max_draft_len = spec_config.max_draft_len
self.spec_resource_manager = ngram_pool_manager

def prepare_draft_tokens(
self,
Expand Down
17 changes: 7 additions & 10 deletions tensorrt_llm/_torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def get_spec_metadata(spec_config,
return None


def get_spec_resource_manager(model_engine,
draft_model_engine=None,
drafter=None):
def get_spec_resource_manager(model_engine, draft_model_engine=None):
spec_config = model_engine.spec_config
if spec_config is None:
return None
Expand Down Expand Up @@ -93,9 +91,10 @@ def get_spec_resource_manager(model_engine,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
return drafter.spec_resource_manager
if spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
if spec_dec_mode.is_user_provided():
return spec_config.resource_manager
return None


Expand All @@ -113,14 +112,12 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")


def get_spec_drafter(model_engine):
def get_spec_drafter(model_engine, spec_resource_manager):
spec_config = model_engine.spec_config
max_num_requests = model_engine.batch_size
if spec_config is None:
return None
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config,
NGramPoolManager(spec_config, max_num_requests))
return NGramDrafter(spec_config, spec_resource_manager)
if spec_config.spec_dec_mode.is_user_provided():
return spec_config.drafter
return None
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,9 @@ def get_draft_model_prompt(self,


class UserProvidedDecodingConfig(DecodingBaseConfig):
# Type should be Drafter, but it leads to circular import
drafter: object
# Cannot use real type annotations due to circular imports
drafter: object # Type is Drafter
resource_manager: object = None # Type is Optional[ResourceManager]

@classmethod
def from_dict(cls, data: dict):
Expand Down