Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cda6939
feat(omni_coordinator): add coordinator module and corresponding unit…
NumberWan Feb 25, 2026
026f35a
edited messages
NumberWan Feb 26, 2026
0381d5b
edited messages
NumberWan Feb 26, 2026
1e25124
fix(omni_coordinator): address PR review feedback - TOCTOU, license, …
NumberWan Mar 2, 2026
5aa32a3
fix(omni_coordinator): address PR review feedback - TOCTOU, license, …
NumberWan Mar 2, 2026
6226ca9
fix(omni_coordinator): address PR review feedback - TOCTOU, license, …
NumberWan Mar 2, 2026
f79c9cb
fix(omni_coordinator): address PR review feedback - dynamic ports
NumberWan Mar 3, 2026
ef42c8e
Add periodic loop for queue_length-only broadcasts
NumberWan Mar 3, 2026
ba0968d
fix(omni_coordinator): cleanup DOWN instances and use dedicated zmq c…
NumberWan Mar 3, 2026
39f80e8
fix(omni_coordinator): ClientForStage and ClientForHub Retry function…
NumberWan Mar 3, 2026
4eb65a4
fix(omni_coordinator): ClientForStage and ClientForHub Retry function…
NumberWan Mar 3, 2026
a1bbe0a
Merge branch 'vllm-project:main' into omni-coordinator
NumberWan Mar 3, 2026
530a4ae
Merge branch 'main' into omni-coordinator
NumberWan Mar 3, 2026
e440f30
Merge branch 'main' into omni-coordinator
NumberWan Mar 5, 2026
f42f19a
Merge branch 'vllm-project:main' into omni-coordinator
NumberWan Mar 6, 2026
1c4b3c1
Replace zmq_addr with input_addr/output_addr
NumberWan Mar 6, 2026
13eb1c5
ERROR/DOWN stage instance handling in coordinator
NumberWan Mar 6, 2026
dc2ecd4
Comment Edition
NumberWan Mar 6, 2026
2d3d744
Merge branch 'main' into omni-coordinator
NumberWan Mar 6, 2026
70bbeea
Merge branch 'main' into omni-coordinator
NumberWan Mar 10, 2026
4783e9a
chore: retrigger CI
NumberWan Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/distributed/omni_coordinator/test_load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from time import time

from vllm_omni.distributed.omni_coordinator import (
InstanceInfo,
RandomBalancer,
StageStatus,
)


def test_load_balancer_select_returns_valid_index():
"""Verify RandomBalancer.select() returns a valid index for instances."""
# Task structure mirrors async_omni; RandomBalancer ignores task contents.
task: dict = {
"request_id": "test",
"engine_inputs": None,
"sampling_params": None,
}

now = time()
instances = [
InstanceInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
input_addr="tcp://host:10002",
output_addr="tcp://host:10002-out",
stage_id=0,
status=StageStatus.UP,
queue_length=1,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
input_addr="tcp://host:10003",
output_addr="tcp://host:10003-out",
stage_id=1,
status=StageStatus.UP,
queue_length=2,
last_heartbeat=now,
registered_at=now,
),
]

balancer = RandomBalancer()

index = balancer.select(task, instances)

assert isinstance(index, int)
assert 0 <= index < len(instances)
119 changes: 119 additions & 0 deletions tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import time

import pytest
import zmq

from vllm_omni.distributed.omni_coordinator import (
InstanceList,
OmniCoordClientForHub,
)


def _bind_pub() -> tuple[zmq.Context, zmq.Socket, str]:
ctx = zmq.Context.instance()
pub = ctx.socket(zmq.PUB)
pub.bind("tcp://127.0.0.1:*")
endpoint = pub.getsockopt(zmq.LAST_ENDPOINT).decode("ascii")
return ctx, pub, endpoint


def _wait_for_condition(cond, timeout: float = 2.0, interval: float = 0.01) -> bool:
start = time.time()
while time.time() - start < timeout:
if cond():
return True
time.sleep(interval)
return False


def test_hub_client_caches_instance_list_from_pub():
"""Verify OmniCoordClientForHub receives instance list updates from OmniCoordinator and caches for get_instance_list()."""
ctx, pub, endpoint = _bind_pub()

client = OmniCoordClientForHub(endpoint)
# ZMQ PUB/SUB slow-joiner: allow SUB to finish connecting before first send
time.sleep(0.2)

now = time.time()
instances_payload = [
{
"input_addr": "tcp://stage:10001",
"output_addr": "tcp://stage:10001-out",
"stage_id": 0,
"status": "up",
"queue_length": 0,
"last_heartbeat": now,
"registered_at": now,
},
{
"input_addr": "tcp://stage:10002",
"output_addr": "tcp://stage:10002-out",
"stage_id": 0,
"status": "up",
"queue_length": 1,
"last_heartbeat": now,
"registered_at": now,
},
{
"input_addr": "tcp://stage:10003",
"output_addr": "tcp://stage:10003-out",
"stage_id": 1,
"status": "error",
"queue_length": 5,
"last_heartbeat": now,
"registered_at": now,
},
]

payload = {"instances": instances_payload, "timestamp": now}
pub.send(json.dumps(payload).encode("utf-8"))

assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 3)

inst_list = client.get_instance_list()
assert isinstance(inst_list, InstanceList)
assert len(inst_list.instances) == 3

for src, inst in zip(instances_payload, inst_list.instances, strict=True):
assert inst.input_addr == src["input_addr"]
assert inst.output_addr == src["output_addr"]
assert inst.stage_id == src["stage_id"]
assert inst.status.value == src["status"]

stage0 = client.get_instances_for_stage(0)
stage1 = client.get_instances_for_stage(1)

assert all(inst.stage_id == 0 for inst in stage0.instances)
assert all(inst.stage_id == 1 for inst in stage1.instances)

# Send an updated list with fewer instances and verify cache refresh.
updated_payload = {
"instances": instances_payload[:2],
"timestamp": now + 1.0,
}
pub.send(json.dumps(updated_payload).encode("utf-8"))

assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 2)
updated_list = client.get_instance_list()
assert len(updated_list.instances) == 2

client.close()
pub.close(0)
ctx.term()


def test_hub_client_close_closes_sub_socket():
"""Verify OmniCoordClientForHub.close() marks client as closed; second close raises."""
ctx, pub, endpoint = _bind_pub()
client = OmniCoordClientForHub(endpoint)
client.close()

with pytest.raises(RuntimeError, match="already closed"):
client.close()

pub.close(0)
ctx.term()
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

import zmq

from vllm_omni.distributed.omni_coordinator import (
OmniCoordClientForStage,
StageStatus,
)


def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]:
ctx = zmq.Context.instance()
router = ctx.socket(zmq.ROUTER)
router.bind("tcp://127.0.0.1:*")
endpoint = router.getsockopt(zmq.LAST_ENDPOINT).decode("ascii")
return ctx, router, endpoint


def _recv_event(router: zmq.Socket) -> dict:
frames = router.recv_multipart()
# ROUTER adds identity frame; the last frame is the payload.
payload = frames[-1]
return json.loads(payload.decode("utf-8"))


def test_stage_client_auto_register_on_init():
"""Verify OmniCoordClientForStage automatically sends initial registration/status-up event when created."""
ctx, router, endpoint = _bind_router()

input_addr = "tcp://stage:10001"
output_addr = "tcp://stage:10001-out"
stage_id = 0

client = OmniCoordClientForStage(endpoint, input_addr, output_addr, stage_id)

event = _recv_event(router)

assert event["event_type"] == "update"
assert event["status"] == StageStatus.UP.value
assert event["stage_id"] == stage_id
assert event["input_addr"] == input_addr
assert event["output_addr"] == output_addr

client.close()
router.close(0)
ctx.term()


def test_stage_client_update_info_sends_correct_event():
"""Verify OmniCoordClientForStage.update_info() sends status/load update events with expected fields."""
ctx, router, endpoint = _bind_router()

input_addr = "tcp://stage:10002"
output_addr = "tcp://stage:10002-out"
stage_id = 1

client = OmniCoordClientForStage(endpoint, input_addr, output_addr, stage_id)

# Discard initial registration event.
_recv_event(router)

client.update_info(status=StageStatus.ERROR)
client.update_info(queue_length=10)

first = _recv_event(router)
second = _recv_event(router)

assert first["status"] == StageStatus.ERROR.value
assert first["stage_id"] == stage_id
assert first["input_addr"] == input_addr
assert first["output_addr"] == output_addr

assert second["queue_length"] == 10
assert second["stage_id"] == stage_id
assert second["input_addr"] == input_addr
assert second["output_addr"] == output_addr

client.close()
router.close(0)
ctx.term()


def test_stage_client_close_sends_down_status():
"""Verify close() sends final status-down event before closing underlying socket."""
ctx, router, endpoint = _bind_router()

input_addr = "tcp://stage:10003"
output_addr = "tcp://stage:10003-out"
stage_id = 2

client = OmniCoordClientForStage(endpoint, input_addr, output_addr, stage_id)

# Discard initial registration event.
_recv_event(router)

client.close()

event = _recv_event(router)
assert event["status"] == StageStatus.DOWN.value
assert event["stage_id"] == stage_id
assert event["input_addr"] == input_addr
assert event["output_addr"] == output_addr

assert client._socket.closed # DEALER socket no longer usable after close

router.close(0)
ctx.term()
Loading
Loading