1818 mpi_comm , mpi_rank , nvtx_range_debug )
1919from ..bindings import executor as tllm
2020from ..builder import ConfigEncoder , Engine , EngineConfig
21- from ..llmapi .llm_args import PybindMirror
21+ from ..llmapi .llm_args import PybindMirror , TorchLlmArgs
2222from ..llmapi .mpi_session import set_mpi_session_cpp
2323from ..llmapi .tracer import VizTracer , global_tracer , set_global_tracer
2424from ..llmapi .utils import (AsyncQueue , ManagedThread , _SyncQueue ,
@@ -60,7 +60,8 @@ def __init__(
6060 postproc_worker_config : Optional [PostprocWorkerConfig ] = None ,
6161 is_llm_executor : Optional [bool ] = None ,
6262 lora_config : Optional [LoraConfig ] = None ,
63- garbage_collection_gen0_threshold : Optional [int ] = None ,
63+ hf_model_dir : Optional [Path ] = None ,
64+ llm_args : Optional [TorchLlmArgs ] = None ,
6465 ) -> None :
6566 postproc_config = postproc_worker_config or PostprocWorkerConfig ()
6667 super ().__init__ (
@@ -81,29 +82,49 @@ def __init__(
8182 self ._await_response_helper = AwaitResponseHelper (
8283 self ) # TODO: make it weakref
8384 self ._executor_config = executor_config
84- self ._is_pytorch_backend = getattr ( self . _executor_config , "backend" ,
85- None ) == "pytorch"
85+ self ._is_pytorch_backend = llm_args is not None and llm_args . backend == "pytorch"
86+ self . llm_args = llm_args
8687
8788 if global_mpi_size () > 1 :
8889 logger .set_rank (self .global_rank )
8990
9091 if isinstance (engine , list ):
9192 engine = engine [self .rank ]
9293
93- if executor_config is None :
94- executor_config = tllm .ExecutorConfig (1 )
94+ def _create_py_executor (comm_ranks , device_ids ):
9595
96- executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
97- processor_batched = batched_logits_processor , replicate = False )
96+ executor_config = llm_args .get_executor_config (hf_model_dir )
9897
99- def _create_engine ():
100- device_id = self .global_rank % torch .cuda .device_count ()
101- torch .cuda .set_device (device_id )
98+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
99+ processor_batched = batched_logits_processor , replicate = False )
100+ executor_config .parallel_config = tllm .ParallelConfig (
101+ participant_ids = comm_ranks , device_ids = device_ids )
102+ args = {
103+ "executor_config" : executor_config ,
104+ "checkpoint_dir" : executor_config .hf_model_dir ,
105+ }
106+ assert hasattr (
107+ executor_config , "backend"
108+ ), "executor_config should be with backend in _create_py_executor"
109+ if executor_config .backend == "pytorch" :
110+ from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
111+ create_py_executor
112+ create_executor = create_py_executor
113+ args ["lora_config" ] = lora_config
114+ args [
115+ "garbage_collection_gen0_threshold" ] = llm_args .garbage_collection_gen0_threshold
116+ else :
117+ raise ValueError (
118+ f"Unsupported backend config: { executor_config .backend } " )
119+ return create_executor (** args )
120+
121+ def _create_engine (comm_ranks , device_ids ):
122+ if executor_config is None :
123+ executor_config = tllm .ExecutorConfig (1 )
124+
125+ executor_config .logits_post_processor_config = tllm .LogitsPostProcessorConfig (
126+ processor_batched = batched_logits_processor , replicate = False )
102127
103- # Make sure C++ executor would use same devices/ranks as py_executor
104- global_rank = global_mpi_rank ()
105- comm_ranks = mpi_comm ().allgather (global_rank )
106- device_ids = mpi_comm ().allgather (device_id )
107128 executor_config .parallel_config = tllm .ParallelConfig (
108129 participant_ids = comm_ranks , device_ids = device_ids )
109130
@@ -122,14 +143,7 @@ def _create_engine():
122143 "executor_config" : executor_config ,
123144 "checkpoint_dir" : executor_config .hf_model_dir ,
124145 }
125- if executor_config .backend == "pytorch" :
126- from tensorrt_llm ._torch .pyexecutor .py_executor_creator import \
127- create_py_executor
128- create_executor = create_py_executor
129- args ["lora_config" ] = lora_config
130- args [
131- "garbage_collection_gen0_threshold" ] = garbage_collection_gen0_threshold
132- elif executor_config .backend == "_autodeploy" :
146+ if executor_config .backend == "_autodeploy" :
133147 from tensorrt_llm ._torch .auto_deploy .shim .ad_executor import \
134148 create_autodeploy_executor
135149 create_executor = create_autodeploy_executor
@@ -138,7 +152,17 @@ def _create_engine():
138152 f"Unsupported backend config: { executor_config .backend } " )
139153 return create_executor (** args )
140154
141- self .engine = _create_engine ()
155+ device_id = self .global_rank % torch .cuda .device_count ()
156+ torch .cuda .set_device (device_id )
157+
158+ # Make sure C++ executor would use same devices/ranks as py_executor
159+ global_rank = global_mpi_rank ()
160+ comm_ranks = mpi_comm ().allgather (global_rank )
161+ device_ids = mpi_comm ().allgather (device_id )
162+
163+ self .engine = _create_py_executor (
164+ comm_ranks , device_ids ) if llm_args is not None else _create_engine (
165+ comm_ranks , device_ids )
142166
143167 self ._lora_manager : Optional [LoraManager ] = None
144168 self ._prompt_adapter_manager : Optional [PromptAdapterManager ] = None
@@ -430,14 +454,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
430454 context_phase_params = request .disaggregated_params .get_context_phase_params (
431455 )
432456
433- is_overlap_enabled = self ._is_pytorch_backend and not self ._executor_config .pytorch_backend_config .disable_overlap_scheduler
434- if is_overlap_enabled :
435- is_disaggregated = self .engine .kv_cache_transceiver is not None
436- if is_disaggregated and (
437- request_type == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
438- raise ValueError (
439- "Context only requests are not supported in pytorch backend when overlap is enabled."
440- )
457+ if self ._is_pytorch_backend :
458+ assert isinstance (self .llm_args , TorchLlmArgs )
459+ if not self .llm_args .disable_overlap_scheduler :
460+ is_disaggregated = self .engine .kv_cache_transceiver is not None
461+ if is_disaggregated and (
462+ request_type
463+ == tllm .RequestType .REQUEST_TYPE_CONTEXT_ONLY ):
464+ raise ValueError (
465+ "Context only requests are not supported in pytorch backend when overlap is enabled."
466+ )
441467
442468 assert request .id is not None
443469
@@ -641,7 +667,8 @@ def worker_main(
641667 is_llm_executor : Optional [
642668 bool ] = True , # whether it's the main executor instance
643669 lora_config : Optional [LoraConfig ] = None ,
644- garbage_collection_gen0_threshold : Optional [int ] = None ,
670+ hf_model_dir : Optional [Path ] = None ,
671+ llm_args : Optional [TorchLlmArgs ] = None ,
645672) -> None :
646673 mpi_comm ().barrier ()
647674 print_colored_debug (f"Worker { mpi_rank ()} entering worker_main...\n " ,
@@ -761,14 +788,16 @@ def notify_proxy_threads_to_quit():
761788 "green" )
762789
763790 try :
791+ print ("---- worker_cls is: {}" .format (worker_cls ), flush = True )
764792 worker : GenerationExecutorWorker = worker_cls (
765793 engine ,
766794 executor_config ,
767795 batched_logits_processor ,
768796 postproc_worker_config = postproc_worker_config ,
769797 is_llm_executor = is_llm_executor ,
770798 lora_config = lora_config ,
771- garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
799+ hf_model_dir = hf_model_dir ,
800+ llm_args = llm_args )
772801 except Exception as e :
773802 logger .error (f"Failed to initialize executor on rank { mpi_rank ()} : { e } " )
774803 logger .error (traceback .format_exc ())
0 commit comments