Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
86c5ed7
Integrate OmniCoordinator into stage engine pipeline
chickeyton May 13, 2026
d9fabcd
Propagate queue_length on heartbeat events
chickeyton May 13, 2026
62b4791
Fix pre-commit lint and format violations
chickeyton May 13, 2026
d3aff5f
fix cannot assign address bug
chickeyton May 14, 2026
d423020
fix dp=1 bug
chickeyton May 14, 2026
4e242db
bugfix
chickeyton May 14, 2026
7044642
bugfix
chickeyton May 15, 2026
73301a8
adhoc fix replicas connector port with env variable
herotai214 May 15, 2026
8a637ab
Merge branch 'main' into omni_coord_itg_rebase2
chickeyton May 15, 2026
ba30955
bugfix
chickeyton May 15, 2026
7b754b0
Merge remote-tracking branch 'origin/main' into omni_coord_itg_rebase2
chickeyton May 18, 2026
478c5c0
Port PR#2631 zmq.asyncio.Poller; fix Instance->Replica test renames
chickeyton May 18, 2026
982c4d4
fix tests
chickeyton May 18, 2026
f62b6b7
Merge branch 'main' into omni_coord_itg_rebase2
chickeyton May 18, 2026
5f8b7a5
Close partial-death gap in StageDiffusionProc; rename StageStatus to …
chickeyton May 18, 2026
16c24f6
Merge branch 'main' into omni_coord_itg_rebase2
chickeyton May 18, 2026
78148a1
fix CI issues
chickeyton May 18, 2026
172e4e0
Merge branch 'main' into omni_coord_itg_rebase2
hsliuustc0106 May 18, 2026
963e67d
fix review comments
chickeyton May 19, 2026
5f4d943
Merge branch 'omni_coord_itg_rebase2' of https://github.com/chickeyto…
chickeyton May 19, 2026
6b794b9
Merge branch 'main' into omni_coord_itg_rebase2
chickeyton May 19, 2026
1a3d08c
Merge branch 'omni_coord_itg_rebase2' of https://github.com/chickeyto…
chickeyton May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/diffusion/test_diffusion_engine_rpc_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _make_engine_with_loop(
calls and a real ``RequestScheduler``.
"""
engine = DiffusionEngine.__new__(DiffusionEngine)
engine._closed = False
engine.executor = _ConcurrencyTrackingExecutor(rpc_delay=rpc_delay)

sched = RequestScheduler()
Expand Down Expand Up @@ -371,6 +372,7 @@ def test_collective_rpc_before_loop_starts_calls_executor_directly():
on the caller's thread without enqueueing.
"""
engine = DiffusionEngine.__new__(DiffusionEngine)
engine._closed = False
engine._loop_started = False
engine._rpc_lock = threading.RLock()
engine._cv = threading.Condition(engine._rpc_lock)
Expand All @@ -391,6 +393,7 @@ def test_collective_rpc_before_loop_starts_serializes_concurrent_callers():
serialize them so they cannot race on the shared executor MQ pair.
"""
engine = DiffusionEngine.__new__(DiffusionEngine)
engine._closed = False
engine._loop_started = False
engine._rpc_lock = threading.RLock()
engine._cv = threading.Condition(engine._rpc_lock)
Expand Down
100 changes: 50 additions & 50 deletions tests/distributed/omni_coordinator/test_load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
import pytest

from vllm_omni.distributed.omni_coordinator import (
InstanceInfo,
LeastQueueLengthBalancer,
RandomBalancer,
ReplicaInfo,
ReplicaStatus,
RoundRobinBalancer,
StageStatus,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def test_load_balancer_select_returns_valid_index():
"""Verify RandomBalancer.select() returns a valid index for instances."""
"""Verify RandomBalancer.select() returns a valid index for replicas."""
# Task structure mirrors async_omni; RandomBalancer ignores task contents.
task: dict = {
"request_id": "test",
Expand All @@ -26,30 +26,30 @@ def test_load_balancer_select_returns_valid_index():
}

now = time()
instances = [
InstanceInfo(
replicas = [
ReplicaInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10002",
output_addr="tcp://host:10002-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=1,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10003",
output_addr="tcp://host:10003-out",
stage_id=1,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=2,
last_heartbeat=now,
registered_at=now,
Expand All @@ -58,74 +58,74 @@ def test_load_balancer_select_returns_valid_index():

balancer = RandomBalancer()

index = balancer.select(task, instances)
index = balancer.select(task, replicas)

assert isinstance(index, int)
assert 0 <= index < len(instances)
assert 0 <= index < len(replicas)


def test_round_robin_balancer_cycles_instances():
def test_round_robin_balancer_cycles_replicas():
now = time()
instances = [
InstanceInfo(
replicas = [
ReplicaInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=2,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10002",
output_addr="tcp://host:10002-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=1,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10003",
output_addr="tcp://host:10003-out",
stage_id=1,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
),
]

balancer = RoundRobinBalancer()
results = [balancer.select({}, instances) for _ in range(5)]
results = [balancer.select({}, replicas) for _ in range(5)]

# Default start_index=0 => 0,1,2,0,1
assert results == [0, 1, 2, 0, 1]


def test_round_robin_balancer_empty_instances_raises():
with pytest.raises(ValueError, match="instances must not be empty"):
def test_round_robin_balancer_empty_replicas_raises():
with pytest.raises(ValueError, match="replicas must not be empty"):
RoundRobinBalancer().select({}, [])


def test_round_robin_balancer_after_large_index_and_shorter_list():
"""Large start_index % len(instances) then counter wraps with shorter list."""
"""Large start_index % len(replicas) then counter wraps with shorter list."""
now = time()
two = [
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10002",
output_addr="tcp://host:10002-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
Expand All @@ -138,72 +138,72 @@ def test_round_robin_balancer_after_large_index_and_shorter_list():

def test_least_queue_length_balancer_picks_min_queue():
now = time()
instances = [
InstanceInfo(
replicas = [
ReplicaInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=2,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10002",
output_addr="tcp://host:10002-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=0,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10003",
output_addr="tcp://host:10003-out",
stage_id=1,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=5,
last_heartbeat=now,
registered_at=now,
),
]

balancer = LeastQueueLengthBalancer()
index = balancer.select({}, instances)
index = balancer.select({}, replicas)
assert index == 1


def test_least_queue_length_balancer_empty_instances_raises():
with pytest.raises(ValueError, match="instances must not be empty"):
def test_least_queue_length_balancer_empty_replicas_raises():
with pytest.raises(ValueError, match="replicas must not be empty"):
LeastQueueLengthBalancer().select({}, [])


def test_least_queue_length_balancer_equal_queues_uses_choice(mocker):
now = time()
instances = [
InstanceInfo(
replicas = [
ReplicaInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=3,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10002",
output_addr="tcp://host:10002-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=3,
last_heartbeat=now,
registered_at=now,
),
InstanceInfo(
ReplicaInfo(
input_addr="tcp://host:10003",
output_addr="tcp://host:10003-out",
stage_id=1,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=3,
last_heartbeat=now,
registered_at=now,
Expand All @@ -214,21 +214,21 @@ def test_least_queue_length_balancer_equal_queues_uses_choice(mocker):
"vllm_omni.distributed.omni_coordinator.load_balancer.random.choice",
return_value=2,
)
assert balancer.select({}, instances) == 2
assert balancer.select({}, replicas) == 2


def test_least_queue_length_balancer_negative_queue_raises():
now = time()
instances = [
InstanceInfo(
replicas = [
ReplicaInfo(
input_addr="tcp://host:10001",
output_addr="tcp://host:10001-out",
stage_id=0,
status=StageStatus.UP,
status=ReplicaStatus.UP,
queue_length=-1,
last_heartbeat=now,
registered_at=now,
),
]
with pytest.raises(ValueError, match="queue_length must be non-negative"):
LeastQueueLengthBalancer().select({}, instances)
LeastQueueLengthBalancer().select({}, replicas)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import zmq

from vllm_omni.distributed.omni_coordinator import (
InstanceList,
OmniCoordClientForHub,
ReplicaList,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
Expand All @@ -32,16 +32,16 @@ def _wait_for_condition(cond, timeout: float = 2.0, interval: float = 0.01) -> b
return False


def test_hub_client_caches_instance_list_from_pub():
"""Verify OmniCoordClientForHub receives instance list updates from OmniCoordinator and caches for get_instance_list()."""
def test_hub_client_caches_replica_list_from_pub():
"""Verify OmniCoordClientForHub receives replica list updates from OmniCoordinator and caches for get_replica_list()."""
ctx, pub, endpoint = _bind_pub()

client = OmniCoordClientForHub(endpoint)
# ZMQ PUB/SUB slow-joiner: allow SUB to finish connecting before first send
time.sleep(0.2)

now = time.time()
instances_payload = [
replicas_payload = [
{
"input_addr": "tcp://stage:10001",
"output_addr": "tcp://stage:10001-out",
Expand Down Expand Up @@ -71,37 +71,37 @@ def test_hub_client_caches_instance_list_from_pub():
},
]

payload = {"instances": instances_payload, "timestamp": now}
payload = {"replicas": replicas_payload, "timestamp": now}
pub.send(json.dumps(payload).encode("utf-8"))

assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 3)
assert _wait_for_condition(lambda: len(client.get_replica_list().replicas) == 3)

inst_list = client.get_instance_list()
assert isinstance(inst_list, InstanceList)
assert len(inst_list.instances) == 3
rep_list = client.get_replica_list()
assert isinstance(rep_list, ReplicaList)
assert len(rep_list.replicas) == 3

for src, inst in zip(instances_payload, inst_list.instances, strict=True):
assert inst.input_addr == src["input_addr"]
assert inst.output_addr == src["output_addr"]
assert inst.stage_id == src["stage_id"]
assert inst.status.value == src["status"]
for src, rep in zip(replicas_payload, rep_list.replicas, strict=True):
assert rep.input_addr == src["input_addr"]
assert rep.output_addr == src["output_addr"]
assert rep.stage_id == src["stage_id"]
assert rep.status.value == src["status"]

stage0 = client.get_instances_for_stage(0)
stage1 = client.get_instances_for_stage(1)
stage0 = client.get_replicas_for_stage(0)
stage1 = client.get_replicas_for_stage(1)

assert all(inst.stage_id == 0 for inst in stage0.instances)
assert all(inst.stage_id == 1 for inst in stage1.instances)
assert all(rep.stage_id == 0 for rep in stage0.replicas)
assert all(rep.stage_id == 1 for rep in stage1.replicas)

# Send an updated list with fewer instances and verify cache refresh.
# Send an updated list with fewer replicas and verify cache refresh.
updated_payload = {
"instances": instances_payload[:2],
"replicas": replicas_payload[:2],
"timestamp": now + 1.0,
}
pub.send(json.dumps(updated_payload).encode("utf-8"))

assert _wait_for_condition(lambda: len(client.get_instance_list().instances) == 2)
updated_list = client.get_instance_list()
assert len(updated_list.instances) == 2
assert _wait_for_condition(lambda: len(client.get_replica_list().replicas) == 2)
updated_list = client.get_replica_list()
assert len(updated_list.replicas) == 2

client.close()
pub.close(0)
Expand Down
Loading
Loading