Skip to content

Commit fa34cb7

Browse files
authored
[refactor] Clean up drafter/resource manager creation logic (#5805)
Signed-off-by: Mike Iovine <[email protected]>
1 parent e0836f9 commit fa34cb7

File tree

5 files changed

+18
-27
lines changed

5 files changed

+18
-27
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,18 +360,19 @@ def create_py_executor(
360360
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
361361
kv_cache_creator.build_managers(resources)
362362

363-
# Drafter for speculative decoding
364-
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
365-
drafter = get_spec_drafter(model_engine)
366-
367363
# Resource managers for speculative decoding
364+
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
365+
# to provide a resource manager if required.
368366
spec_resource_manager = get_spec_resource_manager(model_engine,
369-
draft_model_engine,
370-
drafter)
367+
draft_model_engine)
371368
if spec_resource_manager is not None:
372369
resources[
373370
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
374371

372+
# Drafter for speculative decoding
373+
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
374+
drafter = get_spec_drafter(model_engine, spec_resource_manager)
375+
375376
with mem_monitor.observe_creation_stage(
376377
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
377378
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
32

4-
from ..pyexecutor.resource_manager import BaseResourceManager
53
from ..pyexecutor.scheduler import ScheduledRequests
64

75

86
class Drafter(ABC):
97

10-
def __init__(
11-
self,
12-
spec_resource_manager: Optional[BaseResourceManager] = None,
13-
):
14-
self.spec_resource_manager = spec_resource_manager
15-
168
@abstractmethod
179
def prepare_draft_tokens(
1810
self,

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def __init__(
167167
ngram_pool_manager: NGramPoolManager = None,
168168
):
169169
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
170-
super().__init__(spec_resource_manager=ngram_pool_manager)
171170
self.max_draft_len = spec_config.max_draft_len
171+
self.spec_resource_manager = ngram_pool_manager
172172

173173
def prepare_draft_tokens(
174174
self,

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def get_spec_metadata(spec_config,
5555
return None
5656

5757

58-
def get_spec_resource_manager(model_engine,
59-
draft_model_engine=None,
60-
drafter=None):
58+
def get_spec_resource_manager(model_engine, draft_model_engine=None):
6159
spec_config = model_engine.spec_config
6260
if spec_config is None:
6361
return None
@@ -93,9 +91,10 @@ def get_spec_resource_manager(model_engine,
9391
max_seq_len,
9492
max_num_tokens,
9593
)
96-
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
97-
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
98-
return drafter.spec_resource_manager
94+
if spec_dec_mode.is_ngram():
95+
return NGramPoolManager(spec_config, max_num_requests)
96+
if spec_dec_mode.is_user_provided():
97+
return spec_config.resource_manager
9998
return None
10099

101100

@@ -113,14 +112,12 @@ def get_spec_decoder(sampler_args: TorchSampler.Args,
113112
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
114113

115114

116-
def get_spec_drafter(model_engine):
115+
def get_spec_drafter(model_engine, spec_resource_manager):
117116
spec_config = model_engine.spec_config
118-
max_num_requests = model_engine.batch_size
119117
if spec_config is None:
120118
return None
121119
if spec_config.spec_dec_mode.is_ngram():
122-
return NGramDrafter(spec_config,
123-
NGramPoolManager(spec_config, max_num_requests))
120+
return NGramDrafter(spec_config, spec_resource_manager)
124121
if spec_config.spec_dec_mode.is_user_provided():
125122
return spec_config.drafter
126123
return None

tensorrt_llm/llmapi/llm_args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,9 @@ def get_draft_model_prompt(self,
354354

355355

356356
class UserProvidedDecodingConfig(DecodingBaseConfig):
357-
# Type should be Drafter, but it leads to circular import
358-
drafter: object
357+
# Cannot use real type annotations due to circular imports
358+
drafter: object # Type is Drafter
359+
resource_manager: object = None # Type is Optional[ResourceManager]
359360

360361
@classmethod
361362
def from_dict(cls, data: dict):

0 commit comments

Comments
 (0)