1818 mpi_comm , mpi_rank , nvtx_range_debug )
1919from ..bindings import executor as tllm
2020from ..builder import ConfigEncoder , Engine , EngineConfig
21- from ..llmapi .llm_args import PybindMirror
21+ from ..llmapi .llm_args import PybindMirror , TorchLlmArgs
2222from ..llmapi .mpi_session import set_mpi_session_cpp
23+ from ..llmapi .tokenizer import TokenizerBase
2324from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
2425from ..llmapi .utils import (AsyncQueue , ManagedThread , _SyncQueue ,
2526 clear_sched_affinity , print_colored_debug ,
@@ -60,7 +61,9 @@ def __init__(
6061 postproc_worker_config : Optional [PostprocWorkerConfig ] = None ,
6162 is_llm_executor : Optional [bool ] = None ,
6263 lora_config : Optional [LoraConfig ] = None ,
63- garbage_collection_gen0_threshold : Optional [int ] = None ,
64+ hf_model_dir : Optional [Path ] = None ,
65+ tokenizer : Optional [TokenizerBase ] = None ,
66+ llm_args : Optional [TorchLlmArgs ] = None ,
6467 ) -> None :
6568 postproc_config = postproc_worker_config or PostprocWorkerConfig ()
6669 super ().__init__ (
@@ -81,54 +84,49 @@ def __init__(
8184 self ._await_response_helper = AwaitResponseHelper (
8285 self ) # TODO: make it weakref
8386 self ._executor_config = executor_config
84- self ._is_pytorch_backend = getattr ( self . _executor_config , "backend" ,
85- None ) == "pytorch"
87+ self ._is_pytorch_backend = llm_args is not None and llm_args . backend == "pytorch"
88+ self . llm_args = llm_args
8689
8790 if global_mpi_size () > 1 :
8891 logger .set_rank (self .global_rank )
8992
9093 if isinstance (engine , list ):
9194 engine = engine [self .rank ]
9295
93- if executor_config is None :
94- executor_config = tllm .ExecutorConfig (1 )
95-
96- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
97- processor_batched = batched_logits_processor , replicate = False )
98-
99- def _create_engine ():
96+ def _get_comm_ranks_device_id ():
10097 device_id = self .global_rank % torch .cuda .device_count ()
10198 torch .cuda .set_device (device_id )
102-
10399 # Make sure C++ executor would use same devices/ranks as py_executor
104100 global_rank = global_mpi_rank ()
105101 comm_ranks = mpi_comm ().allgather (global_rank )
106102 device_ids = mpi_comm ().allgather (device_id )
103+ return comm_ranks , device_ids
104+
105+ def _create_py_executor (executor_config ):
106+ assert executor_config is None , "expect an empty executor_config is _create_py_executor"
107+ executor_config = llm_args .get_executor_config (
108+ hf_model_dir , tokenizer )
109+ # Persist so downstream code (e.g., default max_tokens deduction) has access
110+ self ._executor_config = executor_config
111+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
112+ processor_batched = batched_logits_processor , replicate = False )
113+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
107114 executor_config .parallel_config = tllm .ParallelConfig (
108115 participant_ids = comm_ranks , device_ids = device_ids )
109-
110- if isinstance (engine , Engine ):
111- return tllm .Executor (engine .engine ,
112- json .dumps (engine .config .to_dict (),
113- cls = ConfigEncoder ),
114- tllm .ModelType .DECODER_ONLY ,
115- executor_config = executor_config ,
116- managed_weights = engine .managed_weights )
117-
118- if not hasattr (executor_config , "backend" ):
119- return tllm .Executor (engine , tllm .ModelType .DECODER_ONLY ,
120- executor_config )
121116 args = {
122117 "executor_config" : executor_config ,
123118 "checkpoint_dir" : executor_config .hf_model_dir ,
124119 }
120+ assert hasattr (
121+ executor_config , "backend"
122+ ), "executor_config should be with backend in _create_py_executor"
125123 if executor_config .backend == "pytorch" :
126124 from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
127125 create_py_executor
128126 create_executor = create_py_executor
129127 args ["lora_config" ] = lora_config
130128 args [
131- "garbage_collection_gen0_threshold" ] = garbage_collection_gen0_threshold
129+ "garbage_collection_gen0_threshold" ] = llm_args . garbage_collection_gen0_threshold
132130 elif executor_config .backend == "_autodeploy" :
133131 from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
134132 create_autodeploy_executor
@@ -138,7 +136,30 @@ def _create_engine():
138136 f"Unsupported backend config: { executor_config .backend } " )
139137 return create_executor (** args )
140138
141- self .engine = _create_engine ()
139+ def _create_engine (executor_config ):
140+ if executor_config is None :
141+ executor_config = tllm .ExecutorConfig (1 )
142+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
143+ processor_batched = batched_logits_processor , replicate = False )
144+ comm_ranks , device_ids = _get_comm_ranks_device_id ()
145+ executor_config .parallel_config = tllm .ParallelConfig (
146+ participant_ids = comm_ranks , device_ids = device_ids )
147+
148+ if isinstance (engine , Engine ):
149+ return tllm .Executor (engine .engine ,
150+ json .dumps (engine .config .to_dict (),
151+ cls = ConfigEncoder ),
152+ tllm .ModelType .DECODER_ONLY ,
153+ executor_config = executor_config ,
154+ managed_weights = engine .managed_weights )
155+
156+ assert not hasattr (executor_config , "backend" )
157+ return tllm .Executor (engine , tllm .ModelType .DECODER_ONLY ,
158+ executor_config )
159+
160+ self .engine = _create_py_executor (
161+ executor_config ) if llm_args is not None else _create_engine (
162+ executor_config )
142163
143164 self ._lora_manager : Optional [LoraManager ] = None
144165 self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -161,7 +182,7 @@ def _create_engine():
161182 if engine_config .build_config .max_prompt_embedding_table_size > 0 :
162183 self ._prompt_adapter_manager = PromptAdapterManager ()
163184
164- if getattr (executor_config , "backend" ,
185+ if getattr (self . _executor_config , "backend" ,
165186 "" ) == "pytorch" and lora_config is not None :
166187 from tensorrt_llm ._torch .pyexecutor .resource_manager import \
167188 ResourceManagerType
@@ -430,14 +451,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430451 context_phase_params = request .disaggregated_params .get_context_phase_params (
431452 )
432453
433- is_overlap_enabled = self ._is_pytorch_backend and not self ._executor_config .pytorch_backend_config .disable_overlap_scheduler
434- if is_overlap_enabled :
435- is_disaggregated = self .engine .kv_cache_transceiver is not None
436- if is_disaggregated and (
437- request_type == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
438- raise ValueError (
439- "Context only requests are not supported in pytorch backend when overlap is enabled."
440- )
454+ if self ._is_pytorch_backend :
455+ assert isinstance (self .llm_args , TorchLlmArgs )
456+ if not self .llm_args .disable_overlap_scheduler :
457+ is_disaggregated = self .engine .kv_cache_transceiver is not None
458+ if is_disaggregated and (
459+ request_type
460+ == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
461+ raise ValueError (
462+ "Context only requests are not supported in pytorch backend when overlap is enabled."
463+ )
441464
442465 assert request .id is not None
443466
@@ -641,7 +664,9 @@ def worker_main(
641664 is_llm_executor : Optional [
642665 bool ] = True , # whether it's the main executor instance
643666 lora_config : Optional [LoraConfig ] = None ,
644- garbage_collection_gen0_threshold : Optional [int ] = None ,
667+ hf_model_dir : Optional [Path ] = None ,
668+ tokenizer : Optional [TokenizerBase ] = None ,
669+ llm_args : Optional [TorchLlmArgs ] = None ,
645670) -> None :
646671 mpi_comm ().barrier ()
647672 print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
@@ -768,7 +793,9 @@ def notify_proxy_threads_to_quit():
768793 postproc_worker_config = postproc_worker_config ,
769794 is_llm_executor = is_llm_executor ,
770795 lora_config = lora_config ,
771- garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
796+ hf_model_dir = hf_model_dir ,
797+ tokenizer = tokenizer ,
798+ llm_args = llm_args )
772799 except Exception as e :
773800 logger .error (f"Failed to initialize executor on rank { mpi_rank ()} : { e } " )
774801 logger .error (traceback .format_exc ())
0 commit comments