Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,49 @@
logger.setLevel(logging.INFO)


class _DeferredZmqFuture(Future):
"""Future that defers ZMQ recv() until result() is called.

This allows vLLM's EngineCore to overlap work (e.g., grammar bitmask computation)
with remote model execution when using non_block=True.
"""

def __init__(self, sockets: list["zmq.Socket"], unique_reply_rank: int | None = None):
super().__init__()
self._sockets = sockets
self._unique_reply_rank = unique_reply_rank

def result(self, timeout=None):
if not self.done():
try:
outputs = []
for socket in self._sockets:
outputs.append(pickle.loads(socket.recv()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

The use of pickle.loads() on data received from a network socket introduces a critical security vulnerability. Deserializing data with pickle can lead to arbitrary code execution if the data is crafted maliciously. While this communication is likely between trusted internal workers, it's a significant security risk if the network is not completely isolated and secure. An attacker who can intercept or inject traffic on this ZMQ channel could compromise the worker process.

It is strongly recommended to replace pickle with a safer serialization format, such as JSON. If complex Python objects must be transferred, consider using a library that provides cryptographically signed serialization to ensure data integrity and authenticity.

Copy link
Contributor Author

@jreiml jreiml Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current change follows the existing logic. This would have to be done in another PR.

for output in outputs:
if isinstance(output, Exception):
raise output
if self._unique_reply_rank is not None:
self.set_result(outputs[self._unique_reply_rank])
else:
self.set_result(outputs)
except Exception as e:
self.set_exception(e)
raise
return super().result(timeout)


class ExternalZeroMQDistributedExecutor(Executor):
"""An executor that engines are launched by external ray actors."""

uses_ray: bool = False

def _init_executor(self) -> None:
# _DeferredZmqFuture relies on ZMQ REQ/REP send-recv ordering (one request in-flight).
# This is only safe when max_concurrent_batches == 1 (which is the default in Executor).
assert self.max_concurrent_batches == 1, (
f"ExternalZeroMQDistributedExecutor requires max_concurrent_batches=1, got {self.max_concurrent_batches}"
)

dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local
tp_size = self.vllm_config.parallel_config.tensor_parallel_size

Expand Down Expand Up @@ -119,33 +156,33 @@ def _init_executor(self) -> None:
def execute_model(
self, scheduler_output: "SchedulerOutput", non_block: bool = False
) -> "ModelRunnerOutput | None | Future[ModelRunnerOutput | None]":
output = self.collective_rpc("execute_model", args=(scheduler_output,))
result = output[0]
if non_block:
f = Future()
f.set_result(result)
return f
return result
return self.collective_rpc(
"execute_model",
args=(scheduler_output,),
non_block=non_block,
unique_reply_rank=0,
)

def sample_tokens(
self, grammar_output: "GrammarOutput | None", non_block: bool = False
) -> "ModelRunnerOutput | None | Future[ModelRunnerOutput | None]":
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
result = output[0]
if non_block:
f = Future()
f.set_result(result)
return f
return result
return self.collective_rpc(
"sample_tokens",
args=(grammar_output,),
non_block=non_block,
unique_reply_rank=0,
)

def collective_rpc(
self,
method: str | Callable,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
**kwargs_extra: Any,
) -> list[Any]:
) -> Any | list[Any] | Future[Any | list[Any]]:
if isinstance(method, str):
sent_method = method
else:
Expand All @@ -156,13 +193,19 @@ def collective_rpc(
for socket in self.sockets:
socket.send(message, zmq.DONTWAIT)

if non_block:
return _DeferredZmqFuture(self.sockets, unique_reply_rank=unique_reply_rank)

outputs = []
for socket in self.sockets:
outputs.append(pickle.loads(socket.recv()))

for output in outputs:
if isinstance(output, Exception):
raise output

if unique_reply_rank is not None:
return outputs[unique_reply_rank]
return outputs

def check_health(self):
Expand Down