Skip to content

Commit 25197f8

Browse files
committed
[refactor] Clean up drafter/resource manager creation logic
Signed-off-by: Mike Iovine <[email protected]>
1 parent 1191555 commit 25197f8

File tree

6 files changed

+30
-30
lines changed

6 files changed

+30
-30
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,18 +358,19 @@ def create_py_executor(
358358
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
359359
kv_cache_creator.build_managers(resources)
360360

361-
# Drafter for speculative decoding
362-
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
363-
drafter = get_spec_drafter(model_engine)
364-
365361
# Resource managers for speculative decoding
362+
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
363+
# to provide a resource manager if required.
366364
spec_resource_manager = get_spec_resource_manager(model_engine,
367-
draft_model_engine,
368-
drafter)
365+
draft_model_engine)
369366
if spec_resource_manager is not None:
370367
resources[
371368
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
372369

370+
# Drafter for speculative decoding
371+
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
372+
drafter = get_spec_drafter(model_engine, spec_resource_manager)
373+
373374
with mem_monitor.observe_creation_stage(
374375
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
375376
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
32

4-
from ..pyexecutor.resource_manager import BaseResourceManager
53
from ..pyexecutor.scheduler import ScheduledRequests
64

75

86
class Drafter(ABC):
97

10-
def __init__(
11-
self,
12-
spec_resource_manager: Optional[BaseResourceManager] = None,
13-
):
14-
self.spec_resource_manager = spec_resource_manager
15-
168
@abstractmethod
179
def prepare_draft_tokens(
1810
self,

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,9 @@ def __init__(
194194
spec_config: SpecConfig,
195195
ngram_pool_manager: NGramPoolManager = None,
196196
):
197+
super().__init__()
197198
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
198-
super().__init__(spec_resource_manager=ngram_pool_manager)
199+
self.spec_resource_manager = ngram_pool_manager
199200
self.max_num_draft_tokens = spec_config.max_draft_tokens
200201

201202
def prepare_draft_tokens(

tensorrt_llm/_torch/speculative/user_provided.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22
from typing import Optional
33

4+
from tensorrt_llm._torch.pyexecutor.resource_manager import BaseResourceManager
45
from tensorrt_llm._torch.speculative.drafter import Drafter
56

67
from .interface import SpecConfig, SpeculativeDecodingMode
@@ -17,10 +18,17 @@ class UserProvidedConfig(SpecConfig):
1718
num_extra_kv_tokens: int = 0
1819
max_draft_tokens: int = 0
1920
drafter: Optional[Drafter] = None
21+
# For convenience, this will default to drafter.spec_resource_manager if such
22+
# an attribute exists and resource_manager has not been explicitly specified.
23+
resource_manager: Optional[BaseResourceManager] = None
2024

2125
def __post_init__(self) -> None:
2226
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
2327
self.spec_dec_name)
2428

29+
if self.resource_manager is None and self.drafter is not None and hasattr(
30+
self.drafter, "spec_resource_manager"):
31+
self.resource_manager = self.drafter.spec_resource_manager
32+
2533
def update_from_model_config(self, model_config):
2634
pass

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ def get_spec_metadata(spec_config,
6060
return None
6161

6262

63-
def get_spec_resource_manager(model_engine,
64-
draft_model_engine=None,
65-
drafter=None):
63+
def get_spec_resource_manager(model_engine, draft_model_engine=None):
6664
spec_config = model_engine.spec_config
6765
if spec_config is None:
6866
return None
@@ -98,9 +96,10 @@ def get_spec_resource_manager(model_engine,
9896
max_seq_len,
9997
max_num_tokens,
10098
)
101-
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
102-
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
103-
return drafter.spec_resource_manager
99+
if spec_dec_mode.is_ngram():
100+
return NGramPoolManager(spec_config, max_num_requests)
101+
if spec_dec_mode.is_user_provided():
102+
return spec_config.resource_manager
104103
return None
105104

106105

@@ -117,16 +116,13 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
117116
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
118117

119118

120-
def get_spec_drafter(model_engine):
119+
def get_spec_drafter(model_engine, spec_resource_manager):
121120
spec_config = model_engine.spec_config
122-
max_num_requests = model_engine.batch_size
121+
model_engine.batch_size
123122
if spec_config is None:
124123
return None
125124
if spec_config.spec_dec_mode.is_ngram():
126-
return NGramDrafter(spec_config,
127-
NGramPoolManager(spec_config, max_num_requests))
128-
if spec_config.spec_dec_mode.is_user_provided():
129-
return spec_config.drafter
125+
return NGramDrafter(spec_config, spec_resource_manager)
130126
return None
131127

132128

tensorrt_llm/llmapi/llm_args.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,9 @@ def from_dict(cls, data: dict):
278278

279279

280280
class UserProvidedDecodingConfig(DecodingBaseConfig):
281-
# Type should be Drafter, but it leads to circular import
282-
drafter: object
281+
# Cannot use real type annotations due to circular imports
282+
drafter: object # Type is Drafter
283+
resource_manager: object = None # Type is Optional[ResourceManager]
283284

284285
@classmethod
285286
def from_dict(cls, data: dict):
@@ -1401,7 +1402,8 @@ def validate_speculative_config(self):
14011402
from tensorrt_llm._torch.speculative import UserProvidedConfig
14021403
self.speculative_config = UserProvidedConfig(
14031404
max_draft_tokens=self.speculative_config.max_draft_len,
1404-
drafter=self.speculative_config.drafter)
1405+
drafter=self.speculative_config.drafter,
1406+
resource_manager=self.speculative_config.resource_manager)
14051407
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED
14061408
self.build_config.max_draft_len = self.speculative_config.max_draft_tokens
14071409
else:

0 commit comments

Comments
 (0)