88import time
99import traceback
1010import weakref
11- from concurrent .futures import Future
11+ from concurrent .futures import Future , ThreadPoolExecutor
1212from dataclasses import dataclass
1313from enum import Enum , auto
1414from functools import partial
@@ -53,10 +53,11 @@ def _init_executor(self) -> None:
5353
5454 self .world_size = self .parallel_config .world_size
5555 tensor_parallel_size = self .parallel_config .tensor_parallel_size
56- assert self .world_size == tensor_parallel_size , (
56+ pp_parallel_size = self .parallel_config .pipeline_parallel_size
57+ assert self .world_size == tensor_parallel_size * pp_parallel_size , (
5758 f"world_size ({ self .world_size } ) must be equal to the "
58- f"tensor_parallel_size ({ tensor_parallel_size } ). "
59- f"Pipeline parallelism is not yet implemented in v1 " )
59+ f"tensor_parallel_size ({ tensor_parallel_size } ) x pipeline "
60+ f"_parallel_size ( { pp_parallel_size } ). " )
6061
6162 # Set multiprocessing envs that are common to V0 and V1
6263 set_multiprocessing_worker_envs (self .parallel_config )
@@ -104,6 +105,17 @@ def _init_executor(self) -> None:
104105 self ._ensure_worker_termination (
105106 [w .proc for w in unready_workers ])
106107
108+ # For pipeline parallel, we use a thread pool for asynchronous
109+ # execute_model.
110+ self .io_thread_pool : Optional [ThreadPoolExecutor ] = None
111+ if self .max_concurrent_batches > 1 :
112+ # Note: must use only 1 IO thread to keep dequeue sequence
113+ # from the response queue
114+ self .io_thread_pool = ThreadPoolExecutor (
115+ max_workers = 1 , thread_name_prefix = "mp_exec_io" )
116+
117+ self .output_rank = self ._get_output_rank ()
118+
107119 def start_worker_monitor (self ):
108120 workers = self .workers
109121 self_ref = weakref .ref (self )
@@ -145,7 +157,9 @@ def execute_model(
145157 ) -> Union [ModelRunnerOutput , Future [ModelRunnerOutput ]]:
146158 (output , ) = self .collective_rpc ("execute_model" ,
147159 args = (scheduler_output , ),
148- rank0_reply_only = True ,
160+ unique_reply_rank = self .output_rank ,
161+ non_block = self .max_concurrent_batches
162+ > 1 ,
149163 timeout = EXECUTE_MODEL_TIMEOUT_S )
150164 return output
151165
@@ -154,7 +168,8 @@ def collective_rpc(self,
154168 timeout : Optional [float ] = None ,
155169 args : tuple = (),
156170 kwargs : Optional [dict ] = None ,
157- rank0_reply_only : bool = False ) -> list [Any ]:
171+ non_block : bool = False ,
172+ unique_reply_rank : Optional [int ] = None ) -> list [Any ]:
158173 if self .is_failed :
159174 raise RuntimeError ("Executor failed." )
160175
@@ -171,22 +186,35 @@ def collective_rpc(self,
171186 send_method = cloudpickle .dumps (
172187 method , protocol = pickle .HIGHEST_PROTOCOL )
173188 self .rpc_broadcast_mq .enqueue (
174- (send_method , args , kwargs , rank0_reply_only ))
189+ (send_method , args , kwargs , unique_reply_rank ))
175190
176- workers = (self .workers [0 ], ) if rank0_reply_only else self .workers
177- responses = [None ] * len (workers )
178- for w in workers :
179- dequeue_timeout = None if deadline is None else (
180- deadline - time .monotonic ())
191+ workers = (self .workers [unique_reply_rank ],
192+ ) if unique_reply_rank is not None else self .workers
193+ responses = []
194+
195+ def get_response (w : WorkerProcHandle ,
196+ dequeue_timeout : Optional [float ] = None ,
197+ cancel_event : Optional [threading .Event ] = None ):
181198 status , result = w .worker_response_mq .dequeue (
182- timeout = dequeue_timeout , cancel = self . shutdown_event )
199+ timeout = dequeue_timeout , cancel = cancel_event )
183200
184201 if status != WorkerProc .ResponseStatus .SUCCESS :
185202 raise RuntimeError (
186203 f"Worker failed with error '{ result } ', please check the"
187204 " stack trace above for the root cause" )
205+ return result
188206
189- responses [w .rank ] = result
207+ for w in workers :
208+ dequeue_timeout = None if deadline is None else (
209+ deadline - time .monotonic ())
210+
211+ if non_block :
212+ result = self .io_thread_pool .submit ( # type: ignore
213+ get_response , w , dequeue_timeout , self .shutdown_event )
214+ else :
215+ result = get_response (w , dequeue_timeout )
216+
217+ responses .append (result )
190218
191219 return responses
192220 except TimeoutError as e :
@@ -225,6 +253,11 @@ def shutdown(self):
225253 if not getattr (self , 'shutting_down' , False ):
226254 self .shutting_down = True
227255 self .shutdown_event .set ()
256+
257+ if self .io_thread_pool is not None :
258+ self .io_thread_pool .shutdown (wait = False , cancel_futures = True )
259+ self .io_thread_pool = None
260+
228261 for w in self .workers :
229262 w .worker_response_mq = None
230263 self ._ensure_worker_termination ([w .proc for w in self .workers ])
@@ -235,6 +268,22 @@ def check_health(self) -> None:
235268 self .collective_rpc ("check_health" , timeout = 10 )
236269 return
237270
271+ @property
272+ def max_concurrent_batches (self ) -> int :
273+ return self .parallel_config .pipeline_parallel_size
274+
275+ def _get_output_rank (self ) -> int :
276+ # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
277+ # (the first TP worker of the last PP stage).
278+ # Example:
279+ # Assuming TP=8, PP=4, then the world_size=32
280+ # 0-7, PP rank 0
281+ # 8-15, PP rank 1
282+ # 16-23, PP rank 2
283+ # 24-31, PP rank 3
284+ # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
285+ return self .world_size - self .parallel_config .tensor_parallel_size
286+
238287
239288@dataclass
240289class UnreadyWorkerProcHandle :
@@ -280,12 +329,14 @@ def __init__(
280329 all_kwargs : list [dict ] = [
281330 {} for _ in range (vllm_config .parallel_config .world_size )
282331 ]
332+ is_driver_worker = (
333+ rank % vllm_config .parallel_config .tensor_parallel_size == 0 )
283334 all_kwargs [rank ] = {
284335 "vllm_config" : vllm_config ,
285336 "local_rank" : local_rank ,
286337 "rank" : rank ,
287338 "distributed_init_method" : distributed_init_method ,
288- "is_driver_worker" : rank == 0 ,
339+ "is_driver_worker" : is_driver_worker ,
289340 }
290341 wrapper .init_worker (all_kwargs )
291342 self .worker = wrapper
@@ -455,7 +506,7 @@ class ResponseStatus(Enum):
455506 def worker_busy_loop (self ):
456507 """Main busy loop for Multiprocessing Workers"""
457508 while True :
458- method , args , kwargs , rank0_only = self .rpc_broadcast_mq .dequeue ()
509+ method , args , kwargs , output_rank = self .rpc_broadcast_mq .dequeue ()
459510
460511 try :
461512 if isinstance (method , str ):
@@ -470,11 +521,11 @@ def worker_busy_loop(self):
470521 logger .exception ("WorkerProc hit an exception." )
471522 # exception might not be serializable, so we convert it to
472523 # string, only for logging purpose.
473- if not rank0_only or self .rank == 0 :
524+ if output_rank is None or self .rank == output_rank :
474525 self .worker_response_mq .enqueue (
475526 (WorkerProc .ResponseStatus .FAILURE , str (e )))
476527 continue
477528
478- if not rank0_only or self .rank == 0 :
529+ if output_rank is None or self .rank == output_rank :
479530 self .worker_response_mq .enqueue (
480531 (WorkerProc .ResponseStatus .SUCCESS , output ))
0 commit comments