Skip to content

Commit a551183

Browse files
committed
Some code cleanup
Signed-off-by: Nick Hill <[email protected]>
1 parent 1ca3d15 commit a551183

File tree

1 file changed

+70
-58
lines changed

1 file changed

+70
-58
lines changed

vllm/v1/engine/core_client.py

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections.abc import Awaitable
1212
from concurrent.futures import Future
1313
from dataclasses import dataclass
14+
from enum import Enum, auto
1415
from threading import Thread
1516
from typing import Any, Callable, Optional, TypeVar, Union
1617

@@ -257,11 +258,20 @@ def collective_rpc(self,
257258
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
258259

259260

261+
class CoreEngineState(Enum):
262+
NEW = auto()
263+
CONNECTED = auto()
264+
READY = auto()
265+
266+
260267
class CoreEngine:
261268
"""One per data parallel rank."""
262269

263-
def __init__(self, index: int = 0):
270+
def __init__(self, index: int = 0, local: bool = True):
271+
self.local = local
264272
self.identity = index.to_bytes(length=2, byteorder="little")
273+
274+
self.state = CoreEngineState.NEW
265275
self.num_reqs_in_flight = 0
266276

267277

@@ -352,20 +362,12 @@ def sigusr1_handler(signum, frame):
352362
self.resources = BackgroundResources(ctx=sync_ctx)
353363
self._finalizer = weakref.finalize(self, self.resources)
354364

355-
# TODO move address setup to separate method
356365
parallel_config = vllm_config.parallel_config
357366
dp_size = parallel_config.data_parallel_size
358367
local_engine_count = parallel_config.data_parallel_size_local
359368

360-
if local_engine_count == dp_size:
361-
input_address = get_open_zmq_ipc_path()
362-
output_address = get_open_zmq_ipc_path()
363-
else:
364-
host = parallel_config.data_parallel_master_ip
365-
input_port = parallel_config.data_parallel_rpc_port
366-
output_port = get_open_port()
367-
input_address = f"tcp://{host}:{input_port}"
368-
output_address = f"tcp://{host}:{output_port}"
369+
input_address, output_address = self._get_zmq_addresses(
370+
parallel_config)
369371

370372
# Create input and output sockets.
371373
self.input_socket = self.resources.input_socket = make_zmq_socket(
@@ -386,14 +388,35 @@ def sigusr1_handler(signum, frame):
386388
local_engine_count=local_engine_count,
387389
start_index=0)
388390

389-
self.core_engines = [CoreEngine(i) for i in range(dp_size)]
391+
self.core_engines = [
392+
CoreEngine(index=i, local=(i < local_engine_count))
393+
for i in range(dp_size)
394+
]
390395
self.core_engine = self.core_engines[0]
391396

392397
# Wait for engine core process(es) to start.
393398
self._wait_for_engine_startup(output_address, parallel_config)
394399

395400
self.utility_results: dict[int, AnyFuture] = {}
396401

402+
@staticmethod
403+
def _get_zmq_addresses(parallel_config: ParallelConfig) -> tuple[str, str]:
404+
"""Returns (input_address, output_address)."""
405+
dp_size = parallel_config.data_parallel_size
406+
local_engine_count = parallel_config.data_parallel_size_local
407+
408+
if local_engine_count == dp_size:
409+
input_address = get_open_zmq_ipc_path()
410+
output_address = get_open_zmq_ipc_path()
411+
else:
412+
host = parallel_config.data_parallel_master_ip
413+
input_port = parallel_config.data_parallel_rpc_port
414+
output_port = get_open_port()
415+
input_address = f"tcp://{host}:{input_port}"
416+
output_address = f"tcp://{host}:{output_port}"
417+
418+
return input_address, output_address
419+
397420
def _wait_for_engine_startup(self, output_address: str,
398421
parallel_config: ParallelConfig):
399422
# Get a sync handle to the socket which can be sync or async.
@@ -402,60 +425,39 @@ def _wait_for_engine_startup(self, output_address: str,
402425
# TODO offline case compatibility
403426

404427
# Wait for engine core process(es) to send ready messages.
405-
local_engine_count = parallel_config.data_parallel_size_local
406-
remote_engine_count = len(self.core_engines) - local_engine_count
407-
408-
# TODO simplify the startup tracking logic below!
409-
pending_hello_local = set(range(local_engine_count))
410-
pending_hello_remote = set(
411-
range(local_engine_count, len(self.core_engines)))
412-
pending_ready_local = set(pending_hello_local)
413-
pending_ready_remote = set(pending_hello_remote)
414-
while pending_ready_local or pending_ready_remote:
428+
local_count = parallel_config.data_parallel_size_local
429+
remote_count = len(self.core_engines) - local_count
430+
# [local, remote] counts
431+
conn_pending, start_pending = [local_count, remote_count], [0, 0]
432+
433+
while any(conn_pending) or any(start_pending):
415434
while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS):
416-
local_conn = local_engine_count - len(pending_hello_local)
417-
local_ready = local_engine_count - len(pending_ready_local)
418-
if local_ready != local_engine_count:
435+
if any(conn_pending):
419436
logger.info(
420-
"Waiting for local core engine procs: "
421-
"%d/%d connected, %d/%d ready.", local_conn,
422-
local_engine_count, local_ready, local_engine_count)
423-
if remote_engine_count:
424-
remote_conn = remote_engine_count - len(
425-
pending_hello_remote)
426-
remote_ready = remote_engine_count - len(
427-
pending_ready_remote)
428-
if remote_ready != remote_engine_count:
429-
logger.info(
430-
"Waiting for remote core engine procs: "
431-
"%d/%d connected, %d/%d ready.", remote_conn,
432-
remote_engine_count, remote_ready,
433-
remote_engine_count)
437+
"Waiting for %d local, %d remote core engine proc(s) "
438+
"to connect.", *conn_pending)
439+
if any(start_pending):
440+
logger.info(
441+
"Waiting for %d local, %d remote core engine proc(s) "
442+
"to start.", *start_pending)
434443

435444
# Receive HELLO and READY messages from the input socket.
436445
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
437446
eng_index = int.from_bytes(eng_identity, byteorder="little")
447+
if eng_index > len(self.core_engines):
448+
raise RuntimeError(
449+
f"Message from engine rank larger than "
450+
f"configured data parallel size: {eng_index}")
451+
engine = self.core_engines[eng_index]
438452
msg = msgspec.msgpack.decode(ready_msg_bytes)
439453
status, local = msg["status"], msg["local"]
440-
hello_set = pending_hello_local if local else pending_hello_remote
441-
ready_set = pending_ready_local if local else pending_ready_remote
442-
if status == "HELLO":
443-
index_set = hello_set
444-
elif status == "READY":
445-
index_set = ready_set
446-
else:
447-
raise RuntimeError(f"{'Local' if local else 'Remote'} engine "
448-
f"{eng_index} failed: {status}")
449-
if eng_index not in index_set:
450-
raise RuntimeError(
451-
f"Unexpected or duplicate {status} "
452-
f"{'local' if local else 'remote'} engine: {eng_index}")
453-
if status == "READY" and eng_index in hello_set:
454-
raise RuntimeError(
455-
f"Unexpected READY before HELLO for "
456-
f"{'local' if local else 'remote'} engine: {eng_index}")
454+
if local != engine.local:
455+
raise RuntimeError(f"{status} message from "
456+
f"{'local' if local else 'remote'} "
457+
f" engine {eng_index}, expected it to be "
458+
f"{'local' if engine.local else 'remote'}")
459+
if status == "HELLO" and engine.state == CoreEngineState.NEW:
457460

458-
if status == "HELLO":
459461
# Send init message with DP config info.
460462
init_message = self.encoder.encode({
461463
"output_socket_address": output_address,
@@ -470,10 +472,20 @@ def _wait_for_engine_startup(self, output_address: str,
470472
})
471473
sync_input_socket.send_multipart((eng_identity, init_message),
472474
copy=False)
475+
conn_pending[0 if local else 1] -= 1
476+
start_pending[0 if local else 1] += 1
477+
engine.state = CoreEngineState.CONNECTED
478+
elif status == "READY" and (engine.state
479+
== CoreEngineState.CONNECTED):
480+
start_pending[0 if local else 1] -= 1
481+
engine.state = CoreEngineState.READY
482+
else:
483+
raise RuntimeError(f"Unexpected {status} message for "
484+
f"{'local' if local else 'remote'} engine "
485+
f"{eng_index} in {engine.state} state.")
473486

474487
logger.debug("%s from %s core engine process %s.", status,
475488
"local" if local else "remote", eng_index)
476-
index_set.discard(eng_index)
477489

478490
# Double check that the process are running.
479491
engine_manager = self.resources.local_engine_manager

0 commit comments

Comments
 (0)