Skip to content

Commit c5e68a6

Browse files
committed
fix autodeploy tokenizer
Signed-off-by: leslie-fang25 <[email protected]>
1 parent 9ac52c7 commit c5e68a6

File tree

5 files changed

+26
-18
lines changed

5 files changed

+26
-18
lines changed

tensorrt_llm/executor/executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..llmapi.llm_utils import KvCacheRetentionConfig
2626
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
2727
need_spawn_mpi_workers)
28+
from ..llmapi.tokenizer import TokenizerBase
2829
from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
2930
enable_worker_single_process_for_tp1, print_colored,
3031
print_colored_debug)
@@ -356,6 +357,7 @@ def create(
356357
is_llm_executor: Optional[bool] = None,
357358
lora_config: Optional[LoraConfig] = None,
358359
hf_model_dir: Optional[Path] = None,
360+
tokenizer: Optional[TokenizerBase] = None,
359361
llm_args: Optional[TorchLlmArgs] = None,
360362
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
361363
# local imports to avoid cyclic importing
@@ -384,6 +386,7 @@ def create(
384386
"executor_config": executor_config,
385387
"batched_logits_processor": batched_logits_processor,
386388
"hf_model_dir": hf_model_dir,
389+
"tokenizer": tokenizer,
387390
"llm_args": llm_args,
388391
}
389392

tensorrt_llm/executor/worker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..builder import ConfigEncoder, Engine, EngineConfig
2121
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
2222
from ..llmapi.mpi_session import set_mpi_session_cpp
23+
from ..llmapi.tokenizer import TokenizerBase
2324
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
2425
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
2526
clear_sched_affinity, print_colored_debug,
@@ -61,6 +62,7 @@ def __init__(
6162
is_llm_executor: Optional[bool] = None,
6263
lora_config: Optional[LoraConfig] = None,
6364
hf_model_dir: Optional[Path] = None,
65+
tokenizer: Optional[TokenizerBase] = None,
6466
llm_args: Optional[TorchLlmArgs] = None,
6567
) -> None:
6668
postproc_config = postproc_worker_config or PostprocWorkerConfig()
@@ -102,7 +104,8 @@ def _get_comm_ranks_device_id():
102104

103105
def _create_py_executor(executor_config):
104106
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
105-
executor_config = llm_args.get_executor_config(hf_model_dir)
107+
executor_config = llm_args.get_executor_config(
108+
hf_model_dir, tokenizer)
106109
# Persist so downstream code (e.g., default max_tokens deduction) has access
107110
self._executor_config = executor_config
108111
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
@@ -662,6 +665,7 @@ def worker_main(
662665
bool] = True, # whether it's the main executor instance
663666
lora_config: Optional[LoraConfig] = None,
664667
hf_model_dir: Optional[Path] = None,
668+
tokenizer: Optional[TokenizerBase] = None,
665669
llm_args: Optional[TorchLlmArgs] = None,
666670
) -> None:
667671
mpi_comm().barrier()
@@ -790,6 +794,7 @@ def notify_proxy_threads_to_quit():
790794
is_llm_executor=is_llm_executor,
791795
lora_config=lora_config,
792796
hf_model_dir=hf_model_dir,
797+
tokenizer=tokenizer,
793798
llm_args=llm_args)
794799
except Exception as e:
795800
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,6 @@ def _build_model(self):
966966
self.tokenizer)
967967
self._tokenizer = self.input_processor.tokenizer
968968

969-
# Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config
970-
self.args.set_tokenizer(self.tokenizer)
971-
972969
# TODO: revisit gather_context_logits
973970
return_logits = self.args.gather_generation_logits
974971
self._executor = self._executor_cls.create(
@@ -987,6 +984,7 @@ def _build_model(self):
987984
is_llm_executor=True,
988985
lora_config=self.args.lora_config,
989986
hf_model_dir=self._hf_model_dir,
987+
tokenizer=self.tokenizer,
990988
llm_args=self.args)
991989

992990
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,12 +1837,11 @@ def _load_config_from_ckpt(self, ckpt_dir: Path):
18371837
moe_tp_size=moe_tp_size,
18381838
moe_ep_size=moe_ep_size)
18391839

1840-
def set_tokenizer(self, tokenizer):
1841-
self.tokenizer = tokenizer
1842-
1843-
def get_executor_config(self,
1844-
_hf_model_dir: Optional[Path] = None
1845-
) -> _ExecutorConfig:
1840+
def get_executor_config(
1841+
self,
1842+
_hf_model_dir: Optional[Path] = None,
1843+
tokenizer: Optional[TokenizerBase] = None,
1844+
) -> _ExecutorConfig:
18461845
executor_config = _ExecutorConfig(
18471846
max_beam_width=self.max_beam_width,
18481847
scheduler_config=PybindMirror.maybe_to_pybind(
@@ -1867,13 +1866,15 @@ def get_executor_config(self,
18671866
if self.decoding_config is not None:
18681867
executor_config.decoding_config = self.decoding_config
18691868
if self.guided_decoding_backend == 'xgrammar':
1869+
assert tokenizer is not None
18701870
executor_config.guided_decoding_config = _GuidedDecodingConfig(
18711871
backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
1872-
**_xgrammar_tokenizer_info(self.tokenizer))
1872+
**_xgrammar_tokenizer_info(tokenizer))
18731873
elif self.guided_decoding_backend == 'llguidance':
1874+
assert tokenizer is not None
18741875
executor_config.guided_decoding_config = _GuidedDecodingConfig(
18751876
backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE,
1876-
**_llguidance_tokenizer_info(self.tokenizer))
1877+
**_llguidance_tokenizer_info(tokenizer))
18771878
elif self.guided_decoding_backend is not None:
18781879
raise ValueError(
18791880
f"Unsupported guided decoding backend {self.guided_decoding_backend}"
@@ -2460,10 +2461,12 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
24602461
def set_mm_encoder_only(self, mm_encoder_only):
24612462
self.mm_encoder_only = mm_encoder_only
24622463

2463-
def get_executor_config(self,
2464-
_hf_model_dir: Optional[Path] = None
2465-
) -> _ExecutorConfig:
2466-
executor_config = super().get_executor_config(_hf_model_dir)
2464+
def get_executor_config(
2465+
self,
2466+
_hf_model_dir: Optional[Path] = None,
2467+
tokenizer: Optional[TokenizerBase] = None,
2468+
) -> _ExecutorConfig:
2469+
executor_config = super().get_executor_config(_hf_model_dir, tokenizer)
24672470
executor_config.mm_encoder_only = self.mm_encoder_only
24682471
return executor_config
24692472

tensorrt_llm/llmapi/mm_encoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ def _build_model(self):
5656
self._tokenizer = self.input_processor.tokenizer
5757

5858
assert isinstance(self.args, TorchLlmArgs)
59-
# Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config
60-
self.args.set_tokenizer(self.tokenizer)
6159
self.args.set_mm_encoder_only(True)
6260

6361
self._executor = self._executor_cls.create(
@@ -69,6 +67,7 @@ def _build_model(self):
6967
self.args.parallel_config.world_size),
7068
is_llm_executor=True, # TODO: check if this is correct or needed
7169
hf_model_dir=self._hf_model_dir,
70+
tokenizer=self.tokenizer,
7271
llm_args=self.args)
7372

7473
def _validate_mm_args_for_torch_backend(self, kwargs: dict) -> None:

0 commit comments

Comments
 (0)