1818 mpi_comm , mpi_rank , nvtx_range_debug )
1919from ..bindings import executor as tllm
2020from ..builder import ConfigEncoder , Engine , EngineConfig
21- from ..llmapi .llm_args import PybindMirror
21+ from ..llmapi .llm_args import PybindMirror , TorchLlmArgs
2222from ..llmapi .mpi_session import set_mpi_session_cpp
2323from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
2424from ..llmapi .utils import (AsyncQueue , ManagedThread , _SyncQueue ,
@@ -60,7 +60,8 @@ def __init__(
6060 postproc_worker_config : Optional [PostprocWorkerConfig ] = None ,
6161 is_llm_executor : Optional [bool ] = None ,
6262 lora_config : Optional [LoraConfig ] = None ,
63- garbage_collection_gen0_threshold : Optional [int ] = None ,
63+ hf_model_dir : Optional [Path ] = None ,
64+ llm_args : Optional [TorchLlmArgs ] = None ,
6465 ) -> None :
6566 postproc_config = postproc_worker_config or PostprocWorkerConfig ()
6667 super ().__init__ (
@@ -81,29 +82,51 @@ def __init__(
8182 self ._await_response_helper = AwaitResponseHelper (
8283 self ) # TODO: make it weakref
8384 self ._executor_config = executor_config
84- self ._is_pytorch_backend = getattr ( self . _executor_config , "backend" ,
85- None ) == "pytorch"
85+ self ._is_pytorch_backend = llm_args is not None and llm_args . backend == "pytorch"
86+ self . llm_args = llm_args
8687
8788 if global_mpi_size () > 1 :
8889 logger .set_rank (self .global_rank )
8990
9091 if isinstance (engine , list ):
9192 engine = engine [self .rank ]
9293
93- if executor_config is None :
94- executor_config = tllm .ExecutorConfig (1 )
94+ def _create_py_executor (comm_ranks , device_ids ):
9595
96- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
97- processor_batched = batched_logits_processor , replicate = False )
96+ executor_config = llm_args .get_executor_config (hf_model_dir )
97+ # Persist so downstream code (e.g., default max_tokens deduction) has access
98+ self ._executor_config = executor_config
9899
99- def _create_engine ():
100- device_id = self .global_rank % torch .cuda .device_count ()
101- torch .cuda .set_device (device_id )
100+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
101+ processor_batched = batched_logits_processor , replicate = False )
102+ executor_config .parallel_config = tllm .ParallelConfig (
103+ participant_ids = comm_ranks , device_ids = device_ids )
104+ args = {
105+ "executor_config" : executor_config ,
106+ "checkpoint_dir" : executor_config .hf_model_dir ,
107+ }
108+ assert hasattr (
109+ executor_config , "backend"
110+ ), "executor_config should be with backend in _create_py_executor"
111+ if executor_config .backend == "pytorch" :
112+ from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
113+ create_py_executor
114+ create_executor = create_py_executor
115+ args ["lora_config" ] = lora_config
116+ args [
117+ "garbage_collection_gen0_threshold" ] = llm_args .garbage_collection_gen0_threshold
118+ else :
119+ raise ValueError (
120+ f"Unsupported backend config: { executor_config .backend } " )
121+ return create_executor (** args )
122+
123+ def _create_engine (comm_ranks , device_ids ):
124+ if executor_config is None :
125+ executor_config = tllm .ExecutorConfig (1 )
126+
127+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
128+ processor_batched = batched_logits_processor , replicate = False )
102129
103- # Make sure C++ executor would use same devices/ranks as py_executor
104- global_rank = global_mpi_rank ()
105- comm_ranks = mpi_comm ().allgather (global_rank )
106- device_ids = mpi_comm ().allgather (device_id )
107130 executor_config .parallel_config = tllm .ParallelConfig (
108131 participant_ids = comm_ranks , device_ids = device_ids )
109132
@@ -122,14 +145,7 @@ def _create_engine():
122145 "executor_config" : executor_config ,
123146 "checkpoint_dir" : executor_config .hf_model_dir ,
124147 }
125- if executor_config .backend == "pytorch" :
126- from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
127- create_py_executor
128- create_executor = create_py_executor
129- args ["lora_config" ] = lora_config
130- args [
131- "garbage_collection_gen0_threshold" ] = garbage_collection_gen0_threshold
132- elif executor_config .backend == "_autodeploy" :
148+ if executor_config .backend == "_autodeploy" :
133149 from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
134150 create_autodeploy_executor
135151 create_executor = create_autodeploy_executor
@@ -138,7 +154,17 @@ def _create_engine():
138154 f"Unsupported backend config: { executor_config .backend } " )
139155 return create_executor (** args )
140156
141- self .engine = _create_engine ()
157+ device_id = self .global_rank % torch .cuda .device_count ()
158+ torch .cuda .set_device (device_id )
159+
160+ # Make sure C++ executor would use same devices/ranks as py_executor
161+ global_rank = global_mpi_rank ()
162+ comm_ranks = mpi_comm ().allgather (global_rank )
163+ device_ids = mpi_comm ().allgather (device_id )
164+
165+ self .engine = _create_py_executor (
166+ comm_ranks , device_ids ) if llm_args is not None else _create_engine (
167+ comm_ranks , device_ids )
142168
143169 self ._lora_manager : Optional [LoraManager ] = None
144170 self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -430,14 +456,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430456 context_phase_params = request .disaggregated_params .get_context_phase_params (
431457 )
432458
433- is_overlap_enabled = self ._is_pytorch_backend and not self ._executor_config .pytorch_backend_config .disable_overlap_scheduler
434- if is_overlap_enabled :
435- is_disaggregated = self .engine .kv_cache_transceiver is not None
436- if is_disaggregated and (
437- request_type == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
438- raise ValueError (
439- "Context only requests are not supported in pytorch backend when overlap is enabled."
440- )
459+ if self ._is_pytorch_backend :
460+ assert isinstance (self .llm_args , TorchLlmArgs )
461+ if not self .llm_args .disable_overlap_scheduler :
462+ is_disaggregated = self .engine .kv_cache_transceiver is not None
463+ if is_disaggregated and (
464+ request_type
465+ == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
466+ raise ValueError (
467+ "Context only requests are not supported in pytorch backend when overlap is enabled."
468+ )
441469
442470 assert request .id is not None
443471
@@ -641,7 +669,8 @@ def worker_main(
641669 is_llm_executor : Optional [
642670 bool ] = True , # whether it's the main executor instance
643671 lora_config : Optional [LoraConfig ] = None ,
644- garbage_collection_gen0_threshold : Optional [int ] = None ,
672+ hf_model_dir : Optional [Path ] = None ,
673+ llm_args : Optional [TorchLlmArgs ] = None ,
645674) -> None :
646675 mpi_comm ().barrier ()
647676 print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
@@ -768,7 +797,8 @@ def notify_proxy_threads_to_quit():
768797 postproc_worker_config = postproc_worker_config ,
769798 is_llm_executor = is_llm_executor ,
770799 lora_config = lora_config ,
771- garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
800+ hf_model_dir = hf_model_dir ,
801+ llm_args = llm_args )
772802 except Exception as e :
773803 logger .error (f"Failed to initialize executor on rank { mpi_rank ()} : { e } " )
774804 logger .error (traceback .format_exc ())
0 commit comments