diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 0437c0c23c3..bd17db117bc 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -83,12 +83,49 @@ logger.setLevel(logging.INFO) +class _DeferredZmqFuture(Future): + """Future that defers ZMQ recv() until result() is called. + + This allows vLLM's EngineCore to overlap work (e.g., grammar bitmask computation) + with remote model execution when using non_block=True. + """ + + def __init__(self, sockets: list["zmq.Socket"], unique_reply_rank: int | None = None): + super().__init__() + self._sockets = sockets + self._unique_reply_rank = unique_reply_rank + + def result(self, timeout=None): + if not self.done(): + try: + outputs = [] + for socket in self._sockets: + outputs.append(pickle.loads(socket.recv())) + for output in outputs: + if isinstance(output, Exception): + raise output + if self._unique_reply_rank is not None: + self.set_result(outputs[self._unique_reply_rank]) + else: + self.set_result(outputs) + except Exception as e: + self.set_exception(e) + raise + return super().result(timeout) + + class ExternalZeroMQDistributedExecutor(Executor): """An executor that engines are launched by external ray actors.""" uses_ray: bool = False def _init_executor(self) -> None: + # _DeferredZmqFuture relies on ZMQ REQ/REP send-recv ordering (one request in-flight). + # This is only safe when max_concurrent_batches == 1 (which is the default in Executor). + assert self.max_concurrent_batches == 1, ( + f"ExternalZeroMQDistributedExecutor requires max_concurrent_batches=1, got {self.max_concurrent_batches}" + ) + dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local tp_size = self.vllm_config.parallel_config.tensor_parallel_size @@ -119,24 +156,22 @@ def _init_executor(self) -> None: def execute_model( self, scheduler_output: "SchedulerOutput", non_block: bool = False ) -> "ModelRunnerOutput | None | Future[ModelRunnerOutput | None]": - output = self.collective_rpc("execute_model", args=(scheduler_output,)) - result = output[0] - if non_block: - f = Future() - f.set_result(result) - return f - return result + return self.collective_rpc( + "execute_model", + args=(scheduler_output,), + non_block=non_block, + unique_reply_rank=0, + ) def sample_tokens( self, grammar_output: "GrammarOutput | None", non_block: bool = False ) -> "ModelRunnerOutput | None | Future[ModelRunnerOutput | None]": - output = self.collective_rpc("sample_tokens", args=(grammar_output,)) - result = output[0] - if non_block: - f = Future() - f.set_result(result) - return f - return result + return self.collective_rpc( + "sample_tokens", + args=(grammar_output,), + non_block=non_block, + unique_reply_rank=0, + ) def collective_rpc( self, @@ -144,8 +179,10 @@ def collective_rpc( timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict[str, Any]] = None, + non_block: bool = False, + unique_reply_rank: int | None = None, **kwargs_extra: Any, - ) -> list[Any]: + ) -> Any | list[Any] | Future[Any | list[Any]]: if isinstance(method, str): sent_method = method else: @@ -156,6 +193,9 @@ def collective_rpc( for socket in self.sockets: socket.send(message, zmq.DONTWAIT) + if non_block: + return _DeferredZmqFuture(self.sockets, unique_reply_rank=unique_reply_rank) + outputs = [] for socket in self.sockets: outputs.append(pickle.loads(socket.recv())) @@ -163,6 +203,9 @@ def collective_rpc( for output in outputs: if isinstance(output, Exception): raise output + + if unique_reply_rank is not None: + return outputs[unique_reply_rank] return outputs def check_health(self):