Skip to content

[WIP] Integrate OmniCoordinator into stage engine pipeline#3569

Open
chickeyton wants to merge 10 commits into
vllm-project:mainfrom
chickeyton:omni_coord_itg_rebase2
Open

[WIP] Integrate OmniCoordinator into stage engine pipeline#3569
chickeyton wants to merge 10 commits into
vllm-project:mainfrom
chickeyton:omni_coord_itg_rebase2

Conversation

@chickeyton
Copy link
Copy Markdown
Contributor

@chickeyton chickeyton commented May 13, 2026

Integrate OmniCoordinator into the stage engine pipeline

Closes / relates to: #984

Motivation

OmniCoordinator, OmniCoordClientForStage, and OmniCoordClientForHub
already existed in vllm_omni/distributed/omni_coordinator/ but were not
wired into the running system. This PR is the integration work: it makes
the running pipeline actually use them so that

  1. The head runtime (vllm serve <model> --omni --stage-id N without
    --headless) hosts an OmniCoordinator alongside the existing
    OmniMasterServer.
  2. Every stage replica subprocess — head-local replicas and
    externally-launched headless replicas, AR or DiT — reports liveness
    to the coordinator from inside the engine subprocess via
    OmniCoordClientForStage.
  3. The head's Orchestrator discovers replicas and load-balances
    across them via a single OmniCoordClientForHub + a per-stage
    LoadBalancer injected into each StagePool.
  4. Headless replicas can be started and stopped independently of the
    head; the head attaches them dynamically through
    Orchestrator._attach_remote_replica / _detach_remote_replica.

Design constraints honoured

  • Only four new CLI flags: --omni-dp-size-local,
    --omni-lb-policy, --omni-heartbeat-timeout,
    --omni-replica-address. No new environment variables.
  • OmniCoordinator runs unconditionally when --stage-id is set
    and --headless is not — the user-stated invariant.
  • The coordinator address shares the host of --omni-master-address;
    its ROUTER/PUB ports are auto-picked and published to registrants
    through OmniMasterServer's registration reply.
  • --omni-dp-size-local is per-runtime, not per-cluster — each
    head and each headless invocation reads its own copy and launches
    that many replicas locally for its own --stage-id. Values may
    differ across processes; the master server's auto-assignment keeps
    replica ids globally unique within a stage.
  • minimum impact to YAML, mainly the stage "devices" settings

Architecture

flowchart LR
  classDef new stroke-width:6px
  subgraph HEAD["Head: <code>vllm serve --omni --stage-id 0 --omni-dp-size-local 2</code>"]
    direction TB
    AsyncOmni["AsyncOmni<br/>(EngineClient API)"]
    Orch["Orchestrator<br/>(asyncio loop)<br/>+ remote replica attach/detach<br/>+ OmniCoordClientForHub"]
    Pools["StagePools<br/>(addr-keyed clients)<br/>+ LoadBalancer + affinity<br/>+ pick()"]
    OMS["OmniMasterServer<br/>tcp://H:9000 ROUTER"]

    subgraph OCR["<b>OmniCoordinatorRuntime</b>"]
      direction TB
      Coord["OmniCoordinator<br/>ROUTER + PUB<br/>(auto ports)"]
    end

    LocalMgr["<b>OmniCoreEngineProcManager</b>"]
    LocalProc0["StageEngineCoreProc<br/>stage 0, replica 0<br/>OmniCoordClientForStage"]
    LocalProc1["StageEngineCoreProc<br/>stage 0, replica 1<br/>OmniCoordClientForStage"]

    AsyncOmni --> Orch
    Orch --> Pools
    Coord -.PUB.-> Orch
    LocalMgr --> LocalProc0
    LocalMgr --> LocalProc1
    LocalProc0 -.DEALER.-> Coord
    LocalProc1 -.DEALER.-> Coord
    Pools <-->|ZMQ ROUTER/PULL| LocalProc0
    Pools <-->|ZMQ ROUTER/PULL| LocalProc1
    OMS --on_register--> Orch
  end

  subgraph HL1["Headless A: <code>--headless --stage-id 1 --omni-dp-size-local 4</code>"]
    direction TB
    HL1Main["run_headless<br/>main process"]
    HL1Mgr["<b>OmniCoreEngineProcManager</b>"]
    HL1P0["StageEngineCoreProc<br/>stage 1, replica 0<br/>OmniCoordClientForStage"]
    HL1P1["StageEngineCoreProc<br/>stage 1, replica 1<br/>OmniCoordClientForStage"]
    HL1P2["StageEngineCoreProc<br/>stage 1, replica 2<br/>OmniCoordClientForStage"]
    HL1P3["StageEngineCoreProc<br/>stage 1, replica 3<br/>OmniCoordClientForStage"]
    HL1Main --> HL1Mgr
    HL1Mgr --> HL1P0
    HL1Mgr --> HL1P1
    HL1Mgr --> HL1P2
    HL1Mgr --> HL1P3
  end

  subgraph HL2["Headless B: <code>--headless --stage-id 1 --omni-dp-size-local 1</code>"]
    direction TB
    HL2Main["run_headless"]
    HL2Mgr["<b>OmniCoreEngineProcManager</b>"]
    HL2P0["StageEngineCoreProc<br/>stage 1, replica 4<br/>OmniCoordClientForStage"]
    HL2Main --> HL2Mgr
    HL2Mgr --> HL2P0
  end

  subgraph HL3["Headless C: <code>--headless --stage-id 2 --omni-dp-size-local 3</code>"]
    direction TB
    HL3Main["run_headless"]
    HL3P0["StageDiffusionProc<br/>stage 2, replica 0<br/>OmniCoordClientForStage"]
    HL3P1["StageDiffusionProc<br/>stage 2, replica 1<br/>OmniCoordClientForStage"]
    HL3P2["StageDiffusionProc<br/>stage 2, replica 2<br/>OmniCoordClientForStage"]
    HL3Main --> HL3P0
    HL3Main --> HL3P1
    HL3Main --> HL3P2
  end

  HL1Main -.register DEALER.-> OMS
  HL2Main -.register DEALER.-> OMS
  HL3Main -.register DEALER.-> OMS
  HL1P0 -.DEALER.-> Coord
  HL1P1 -.DEALER.-> Coord
  HL1P2 -.DEALER.-> Coord
  HL1P3 -.DEALER.-> Coord
  HL2P0 -.DEALER.-> Coord
  HL3P0 -.DEALER.-> Coord
  HL3P1 -.DEALER.-> Coord
  HL3P2 -.DEALER.-> Coord
  Pools <-->|ZMQ| HL1P0
  Pools <-->|ZMQ| HL1P1
  Pools <-->|ZMQ| HL1P2
  Pools <-->|ZMQ| HL1P3
  Pools <-->|ZMQ| HL2P0
  Pools <-->|ZMQ| HL3P0
  Pools <-->|ZMQ| HL3P1
  Pools <-->|ZMQ| HL3P2

  %% Thicken borders of newly added classes (no fill change).
  class LocalMgr,HL1Mgr,HL2Mgr new
  style OCR stroke-width:6px
Loading

Legend

  • Solid arrows are in-process Python references.
  • Dashed arrows are ZMQ wire connections; the verb on the arrow names the
    socket pattern.
  • Pools <--> proc is the head-side StageEngineCoreClient /
    StageDiffusionClient ROUTER+PULL bound to ports allocated by
    OmniMasterServer; the engine subprocesses connect via DEALER+PUSH.
  • Boxes with a thicker border are newly added: OmniCoordinatorRuntime
    and OmniCoreEngineProcManager. Orchestrator and StagePool are
    existing classes that absorb new responsibilities (dispatch and
    remote-attach), so they keep the default thin border.

Where dispatch and attach logic live

To keep the class graph small, the routing and attach concerns are folded
into the two classes that already own those neighborhoods:

Concern Class Methods added
Subscribe to OmniCoordinator's PUB; cache the cluster ReplicaList Orchestrator constructs/owns one OmniCoordClientForHub
Turn OmniMasterServer register events into head-side stage clients Orchestrator _attach_remote_replica, _detach_remote_replica
LB + affinity + bounded-wait pick for one stage StagePool pick(request_id, task), bind(request_id, addr), invalidate_addr(addr)

Per-runtime --omni-dp-size-local

Every box that carries a --stage-id carries its own
--omni-dp-size-local. The flag is process-local: each invocation
reads its own copy and launches that many replicas locally. Values are
independent.

In the diagram above:

Runtime --stage-id --omni-dp-size-local Replicas it owns
Head 0 2 stage 0, replicas 0–1
Headless A 1 4 stage 1, replicas 0–3
Headless B 1 1 stage 1, replica 4
Headless C 2 3 stage 2, replicas 0–2

Cluster total: 2 + 4 + 1 + 3 = 10 replicas across 3 stages. The head's
OmniCoordClientForHub (owned by Orchestrator) sees them all
uniformly; each StagePool picks among its own stage's UP replicas.
Replica IDs within a stage are auto-assigned by OmniMasterServer so
they stay unique even when contributors run with different local sizes.


Head and headless startup order

There is no hard ordering requirement between head and headless processes
— ZMQ DEALER queues registration messages client-side until the head's
OmniMasterServer ROUTER binds, and the headless waits up to 300 s
(_DEFAULT_STARTUP_TIMEOUT_S) for the reply. Two soft constraints apply:

  1. Headless registration timeout — the head's master server must be
    listening within 300 s of the headless's
    register_stage_with_omni_master call.
  2. Pre-allocated stages — when the head's YAML declares
    num_replicas: N for a remote stage, the head's
    connect_remote_engine_cores blocks at bring-up until N concrete
    --replica-id registrations arrive (bounded by
    --stage-init-timeout).

One additional inter-headless rule applies only when two headless
processes target the same stage with mixed concrete + auto-assign
(one uses --replica-id 0, another relies on --omni-dp-size-local):
the concrete registrant must finish registration before the
auto-assigner sends its first request, otherwise auto-assign can steal
slot 0 (bug #9). Pure-concrete and pure-auto-assign clusters are
unaffected.

Recommended pattern. Start the head first; wait for
OmniMasterServer] Listening on tcp in its log; then launch headless
processes in any order (subject to the mixed-mode rule above if
applicable). This is what every smoke-test scenario does and it is the
only pattern that has been validated end-to-end.


New CLI flags

Four flags are added to the OmniConfig argument group of vllm serve.
All three are consumed by both head and headless invocations (where
applicable); none introduces a new environment variable.

--omni-dp-size-local <int> (default 1)

Number of stage replicas this runtime launches locally for its own
--stage-id. Mapped onto that stage's runtime_cfg.num_replicas for
this process only.

  • Process-local: every head and every headless invocation reads its
    own copy. Values may differ across invocations — e.g. one headless
    may run --omni-dp-size-local 4 while another runs
    --omni-dp-size-local 1 on the same stage; the master server's
    auto-assigned replica ids keep the cluster's per-stage namespace
    globally unique.
  • Requires --stage-id when the value is != 1 (validated in
    OmniServeCommand.validate; bare vllm serve --omni without a
    stage id is unaffected).
  • When > 1 on a headless invocation, the runtime narrows
    CUDA_VISIBLE_DEVICES per spawned replica (so two replicas do not
    stack on cuda:0) and gives each DiT replica a unique
    torch.distributed MASTER_PORT (bugs init main repo structure and demonstrate the AR + DiT demo for omni models #6 and vllm-omni framework and support for qwen2.5-omni [WIP] #7 above).

--omni-lb-policy <random|round-robin|least-queue-length> (default random)

Per-stage load-balancing policy used by the head's StagePool to
route incoming requests across UP replicas. Validated against the
LoadBalancingPolicy enum. Only consulted on the head runtime
the orchestrator wires a fresh LoadBalancer of this kind into each
StagePool it owns; pools on headless processes never see the flag.

  • random — uniform random pick over UP replicas.
  • round-robin — cycle through UP replicas in registration order.
  • least-queue-length — pick the UP replica with the smallest
    queue_length reported via heartbeat (get_num_unfinished_requests()
    for LLM stages, in-flight task count for DiT stages).

Request affinity (CFG companions, multi-step requests) takes
precedence over the policy: once a request is bound to a replica, the
same replica is reused as long as it stays UP.

--omni-heartbeat-timeout <float> (default 30.0)

Seconds before an unreporting replica is marked ERROR in the
OmniCoordinator. Only consulted on the head runtime — it is the
parameter the coordinator's periodic loop uses to decide a replica is
stale. Headless processes always heartbeat on a fixed interval
(heartbeat_interval, ~5 s for the 30 s default) regardless of this
flag.

When a replica is flipped to ERROR, the orchestrator's
_watch_replica_list task enqueues an unregister_remote_replica
control message; pinned requests are aborted with a clear error and
the head-side client is torn down.

--omni-replica-address <ip> / -ora (default auto-detect)

Local bind address (this host's IP) that the headless stage
advertises to the Omni master for its per-stage handshake / input /
output ZMQ sockets. Only consulted on headless runtimes
validated to require --headless in OmniServeCommand.validate; the
head ignores the flag (its own stages are co-located with the master
and already use the master-rooted addresses).

  • Default = auto-detect. When unset, register_stage_with_omni_master
    runs a UDP-connect routing probe against
    --omni-master-address:--omni-master-port: it opens a SOCK_DGRAM
    socket, calls connect() (no packets actually sent — just forces a
    route lookup), and reads getsockname()[0] to learn the source IP
    the kernel would use to reach the master. That IP is what the
    headless's ZMQ sockets must bind on. On a single-host setup this
    returns the loopback / eth0 IP and the resulting registration is a
    behaviour-preserving no-op; on a cross-host pod it returns the
    headless's own routable interface IP.
  • Override only when auto-detect is wrong. Multi-NIC hosts where
    the master is reachable on the wrong interface, or environments where
    the source-route lookup picks an address that the head cannot reach
    back, are the cases that need the explicit override.
  • Cross-host wire format. The headless picks 3 locally-free ports
    via get_open_ports_list(count=3) and includes
    replica_bind_address + replica_{handshake,input,output}_port in
    the registration payload. OmniMasterServer._handle_registration
    recognises the fields and rewrites _stage_routes[(stage_id, replica_id)] so subsequent head-side lookups via
    get_engine_zmq_addresses(stage_id) return the new, headless-rooted
    addresses. Without this rewrite the head would hand back the
    master's own IP — which the headless cannot bind on a different host
    (zmq.error.ZMQError: Cannot assign requested address).

This flag is what makes the cross-server topology (head on one pod,
headless DiT on another) work without operator-supplied per-host
configuration in the common case, while still leaving an escape hatch
for unusual NIC layouts.

--replica-id <int> (deprecated, ignored)

The pre-existing --replica-id flag is now deprecated and ignored.
Replica ids are auto-assigned by OmniMasterServer so headless
processes carry no knowledge of their per-replica id at launch time —
the master is the sole authority on the per-stage namespace and a
launching headless cannot race or collide with any other registrant
(bug #9 is now structurally impossible). When --replica-id is
supplied on the CLI, run_headless emits a single warning log line
identifying the supplied value and continues with auto-assignment;
the flag itself stays in the parser only so existing launchers
(scripts, CI configs) keep working without modification.


CLI usage examples — BAGEL scenarios A, B, C

The bundled vllm_omni/deploy/bagel.yaml puts both stages on GPU 0.
The scenarios below use the test harness's bagel_two_gpu.yaml /
bagel_dp_local_2.yaml deploy configs to spread the AR thinker and
the DiT onto separate cards; you can use any deploy YAML that
declares runtime.devices consistently with the GPU mapping passed
via CUDA_VISIBLE_DEVICES.

Common variables used across the examples:

MODEL_ID="ByteDance-Seed/BAGEL-7B-MoT"
HOST=127.0.0.1
API_PORT=8000
MASTER_PORT=9000

# DEPLOY_TWO_GPU (scenario A) — download from the dev branch:
#   https://github.com/chickeyton/vllm-omni/blob/omni_coord_intg_2_base/coord_intg_tests_diffusion/configs/bagel_two_gpu.yaml
DEPLOY_TWO_GPU="path/to/bagel_two_gpu.yaml"

# DEPLOY_DP_LOCAL_2 (scenarios B, C) — download from the dev branch:
#   https://github.com/chickeyton/vllm-omni/blob/omni_coord_intg_2_base/coord_intg_tests_diffusion/configs/bagel_dp_local_2.yaml
DEPLOY_DP_LOCAL_2="path/to/bagel_dp_local_2.yaml"

Scenario A — single in-process runtime (AR + DiT)

One vllm serve process; no coordinator, no headless, no --stage-id.
The AR thinker and the DiT both run inside one engine, on GPUs 0 and 1.

CUDA_VISIBLE_DEVICES=0,1 \
vllm serve "$MODEL_ID" \
    --omni \
    --host "$HOST" \
    --port "$API_PORT" \
    --deploy-config "$DEPLOY_TWO_GPU"

Then issue a request against http://$HOST:$API_PORT/v1/chat/completions
with modalities=["image"].

Scenario B — head dp_local=2 (AR) + single headless dp_local=2 (DiT)

Two vllm serve processes, four replicas total. The head hosts the
AR thinker with two local replicas on GPUs 0,1 and runs the
coordinator + master server. A single headless hosts the DiT stage
with two replicas, exercising the single-headless multi-replica DiT
path (bugs #6 and #7 above).

# Head: stage 0 (AR thinker), 2 local replicas on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 \
vllm serve "$MODEL_ID" \
    --omni \
    --host "$HOST" \
    --port "$API_PORT" \
    --stage-id 0 \
    --omni-dp-size-local 2 \
    --omni-master-address "$HOST" \
    --omni-master-port "$MASTER_PORT" \
    --deploy-config "$DEPLOY_DP_LOCAL_2"

# Headless: stage 1 (DiT), 2 local replicas on GPUs 2,3
CUDA_VISIBLE_DEVICES=2,3 \
vllm serve "$MODEL_ID" \
    --omni \
    --headless \
    --stage-id 1 \
    --omni-dp-size-local 2 \
    --omni-master-address "$HOST" \
    --omni-master-port "$MASTER_PORT" \
    --deploy-config "$DEPLOY_DP_LOCAL_2"

Wait for two Diffusion replica id=N for stage 1 is up lines in the
headless log, then curl http://$HOST:$API_PORT/health before
issuing requests.

Scenario C — head + headless AR dp_local=2 + headless DiT

Three vllm serve processes, four replicas total — the AR analogue
of scenario B. The head runs 1 AR replica, a second headless runs 2
more AR replicas on the same stage (exercising
single-headless multi-replica AR at 7B scale, bugs #6 + #8),
and a third headless runs 1 DiT replica.

# Head: stage 0 (AR thinker), 1 local replica on GPU 0
CUDA_VISIBLE_DEVICES=0 \
vllm serve "$MODEL_ID" \
    --omni \
    --host "$HOST" \
    --port "$API_PORT" \
    --stage-id 0 \
    --omni-dp-size-local 1 \
    --omni-master-address "$HOST" \
    --omni-master-port "$MASTER_PORT" \
    --deploy-config "$DEPLOY_DP_LOCAL_2"

# Headless A: stage 0 (AR thinker), 2 more replicas on GPUs 1,2.
# These auto-assign replica ids above the head's pre-allocation and
# attach via the dynamic-attach LLM path.
CUDA_VISIBLE_DEVICES=1,2 \
vllm serve "$MODEL_ID" \
    --omni \
    --headless \
    --stage-id 0 \
    --omni-dp-size-local 2 \
    --omni-master-address "$HOST" \
    --omni-master-port "$MASTER_PORT" \
    --deploy-config "$DEPLOY_DP_LOCAL_2"

# Headless B: stage 1 (DiT), 1 replica on GPU 3
CUDA_VISIBLE_DEVICES=3 \
vllm serve "$MODEL_ID" \
    --omni \
    --headless \
    --stage-id 1 \
    --omni-dp-size-local 1 \
    --omni-master-address "$HOST" \
    --omni-master-port "$MASTER_PORT" \
    --deploy-config "$DEPLOY_DP_LOCAL_2"

Wait for two Stage 0 replica id=N up lines in headless A's log, one
Diffusion replica id=0 for stage 1 is up line in headless B's log,
and /health 200 on the head before issuing requests.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e751bcee5e

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +286 to 288
info = self._replicas.get(input_addr)
if info is not None:
info.last_heartbeat = time()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Update queue length from heartbeat events

When using LeastQueueLengthBalancer, the only periodic source of live load is the new _on_heartbeat hook in the stage/diffusion procs, which sends the refreshed queue_length on heartbeat messages. This heartbeat path updates only last_heartbeat, so the coordinator keeps publishing the initial queue length (usually 0) and the least-queue policy routes with stale load information, allowing busy replicas to be selected as if they were idle. Please copy event.queue_length into info.queue_length for heartbeat events and schedule a broadcast when it changes.

Useful? React with 👍 / 👎.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

resolve conflicts please

@herotai214
Copy link
Copy Markdown

I tested the PR codes with ByteDance-Seed/BAGEL-7B-MoT running cross nodes with multi-replicas.
Summary: 1 AR + 1 DiT works; Failure for case more than 1 replicas

Version: chickeyton:omni_coord_itg_rebase2 branch commit d553d95287784cd7f23b82bf0274f8f6501166a1

Customized a yaml file (attached at the end) to use MooncakeTransferEngineConnector rdma_connector to support KV cache transfer between nodes:

  • When running 2 AR (on node 0) + 2 headless DiT (on node 1), error due to repeated address port appears
    • I will fix it very soon and push to this PR codes.
Error log for 2AR 2DiT
I0514 07:31:03.679602 138654 rdma_context.cpp:140] RDMA device: mlx5_0, LID: 140, GID: (GID_Index 0) fe:80:00:00:00:00:00:00:a0:88:c2:03:00:3b:58:3a
E0514 07:31:03.679625 138654 transfer_metadata.cpp:863] Local segment descriptor not found
I0514 07:31:03.679651 138654 transfer_engine_impl.cpp:319] installTransport, type=rdma
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:03�[0m �[90m[mooncake_transfer_engine_connector.py:339]�[0m MooncakeTransferEngineConnector initialized at 10.248.12.160:15611
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:03�[0m �[90m[mooncake_transfer_engine_connector.py:342]�[0m Allocating RDMA Memory Pool: 4096.00 MB on cpu
�[0;36m(APIServer pid=138362)�[0;0m DEBUG 05-14 07:31:04 [v1/engine/utils.py:1168] Waiting for 1 local, 0 remote core engine proc(s) to start.
�[0;36m(APIServer pid=138362)�[0;0m DEBUG 05-14 07:31:04 [v1/engine/utils.py:1168] Waiting for 1 local, 0 remote core engine proc(s) to start.
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m MooncakeTransferEngineConnector config summary:
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m   Local: host=10.248.12.160, zmq_port=50151, rpc_port=16747
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m   Remote: sender_host=None, sender_zmq_port=None
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m   Role: can_put=True, configured_role=sender
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:389]�[0m MooncakeTransferEngineConnector started as SENDER (ZMQ listener on 10.248.12.160:50151)
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[factory.py:46]�[0m Created connector: MooncakeTransferEngineConnector
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:333]�[0m Sender connector eagerly initialized
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[33mWARNING�[0m �[90m05-14 07:31:05�[0m �[90m[base.py:188]�[0m [LLM Worker 0] Sleep Mode DISABLED.
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[33mWARNING�[0m �[90m05-14 07:31:05�[0m �[90m[base.py:188]�[0m [LLM Worker 0] Sleep Mode DISABLED.
�[0;36m(StageEngineCoreProc_stage0_replica0 pid=138649)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[v1/worker/gpu_model_runner.py:4777]�[0m Starting to load model /home/dyvm6xra/dyvm6xrauser68/hero/models/BAGEL-7B-MoT...
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m MooncakeTransferEngineConnector config summary:
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m   Local: host=10.248.12.160, zmq_port=50151, rpc_port=15611
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m   Remote: sender_host=None, sender_zmq_port=None
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:369]�[0m   Role: can_put=True, configured_role=sender
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[mooncake_transfer_engine_connector.py:1097]�[0m ZMQ bind failed on 10.248.12.160:50151: Address already in use (addr='tcp://10.248.12.160:50151') (errno=98)
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[factory.py:49]�[0m Failed to create connector MooncakeTransferEngineConnector: MooncakeTransferEngineConnector failed to bind ZMQ on 10.248.12.160:50151: Address already in use (addr='tcp://10.248.12.160:50151')
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m Failed to initialize OmniConnector
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m Traceback (most recent call last):
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py", line 1092, in _zmq_listener_loop
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     socket.bind(f"tcp://{self.host}:{self.zmq_port}")
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/env_54/lib/python3.12/site-packages/zmq/sugar/socket.py", line 320, in bind
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     super().bind(addr)
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "zmq/backend/cython/_zmq.py", line 1009, in zmq.backend.cython._zmq.Socket.bind
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     _check_rc(rc)
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     ^^^^^^^^^^^
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "zmq/backend/cython/_zmq.py", line 190, in zmq.backend.cython._zmq._check_rc
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     raise ZMQError(errno)
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     ^^^^^^^^^^^
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m zmq.error.ZMQError: Address already in use (addr='tcp://10.248.12.160:50151')
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m 
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m The above exception was the direct cause of the following exception:
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m 
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m Traceback (most recent call last):
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/vllm_omni/distributed/omni_connectors/factory.py", line 45, in create_connector
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     connector = constructor(spec.extra)
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m                 ^^^^^^^^^^^^^^^^^^^^^^^
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/vllm_omni/distributed/omni_connectors/factory.py", line 102, in _create_mooncake_transfer_engine_connector
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     return MooncakeTransferEngineConnector(config)
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py", line 385, in __init__
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     raise RuntimeError(
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m RuntimeError: MooncakeTransferEngineConnector failed to bind ZMQ on 10.248.12.160:50151: Address already in use (addr='tcp://10.248.12.160:50151')
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m 
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m During handling of the above exception, another exception occurred:
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m 
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m Traceback (most recent call last):
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py", line 433, in connector
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     self._connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra))
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m   File "/home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/vllm_omni/distributed/omni_connectors/factory.py", line 50, in create_connector
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m     raise ValueError(f"Failed to create connector {spec.name}: {e}")
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[31mERROR�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:435]�[0m ValueError: Failed to create connector MooncakeTransferEngineConnector: MooncakeTransferEngineConnector failed to bind ZMQ on 10.248.12.160:50151: Address already in use (addr='tcp://10.248.12.160:50151')
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[kv_transfer_manager.py:333]�[0m Sender connector eagerly initialized
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[33mWARNING�[0m �[90m05-14 07:31:05�[0m �[90m[base.py:188]�[0m [LLM Worker 0] Sleep Mode DISABLED.
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[33mWARNING�[0m �[90m05-14 07:31:05�[0m �[90m[base.py:188]�[0m [LLM Worker 0] Sleep Mode DISABLED.
�[0;36m(StageEngineCoreProc_stage0_replica1 pid=138654)�[0;0m �[32mINFO�[0m �[90m05-14 07:31:05�[0m �[90m[v1/worker/gpu_model_runner.py:4777]�[0m Starting to load model /home/dyvm6xra/dyvm6xrauser68/hero/models/BAGEL-7B-MoT...
�[0;36m(StageEng
  • Currently, 1 AR + 1 headless DiT is working well, and gives proper output:
python /home/dyvm6xra/dyvm6xrauser68/hero/vllm-omni/examples/online_serving/bagel/openai_chat_client.py \
    --prompt "Make the cat stand up" \
    --modality img2img \
    --image-url /home/dyvm6xra/dyvm6xrauser68/hero/test_omni/test_hunyuan/input_0.png \
    --output transformed_1a1d.png \
    --server http://10.248.12.160:9878
input_0.png transformed_1a1d.png:
image image

Relevant KV transfer log in stage 1:

�[37mDEBUG�[0m �[90m05-14 07:47:54�[0m �[90m[manager.py:222]�[0m No lora_request provided and adapters are already inactive
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[mooncake_transfer_engine_connector.py:442]�[0m Sender info updated (default): host='10.248.12.160', zmq_port=50177
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:713]�[0m Sender info updated: host=10.248.12.160, base_port=50177, adjusted_port=50177 (local_rank=0)
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1015]�[0m Wait for KV cache for request chatcmpl-8e30c878bd2e3567 from stage 0 to 1 via 1 key(s)...
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[mooncake_transfer_engine_connector.py:904]�[0m [RDMA GET] omni_0_to_1_kv_cache_chatcmpl-8e30c878bd2e3567@0_1: query=4.0ms, alloc=0.0ms, rdma=28.3ms, sync=0.0ms, total=32.3ms, 12963.1 MB/s (fast_path, zero-copy)
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1105]�[0m Successfully received KV cache for chatcmpl-8e30c878bd2e3567, 439491236 bytes across 1 key(s), wait=0.056s, link=56.2ms
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1015]�[0m Wait for KV cache for request chatcmpl-8e30c878bd2e3567__cfg_text from stage 0 to 1 via 1 key(s)...
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[mooncake_transfer_engine_connector.py:904]�[0m [RDMA GET] omni_0_to_1_kv_cache_chatcmpl-8e30c878bd2e3567__cfg_text@0_1: query=1.5ms, alloc=0.0ms, rdma=12.1ms, sync=0.0ms, total=13.6ms, 30679.1 MB/s (fast_path, zero-copy)
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1105]�[0m Successfully received KV cache for chatcmpl-8e30c878bd2e3567__cfg_text, 439089946 bytes across 1 key(s), wait=0.041s, link=41.0ms
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[bagel.py:265]�[0m Collected CFG KV cache for role=cfg_text, rid=chatcmpl-8e30c878bd2e3567__cfg_text, size=439089946 bytes
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1015]�[0m Wait for KV cache for request chatcmpl-8e30c878bd2e3567__cfg_img from stage 0 to 1 via 1 key(s)...
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[mooncake_transfer_engine_connector.py:904]�[0m [RDMA GET] omni_0_to_1_kv_cache_chatcmpl-8e30c878bd2e3567__cfg_img@0_1: query=1.8ms, alloc=0.0ms, rdma=1.6ms, sync=0.0ms, total=3.4ms, 114.0 MB/s (fast_path, zero-copy)
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1105]�[0m Successfully received KV cache for chatcmpl-8e30c878bd2e3567__cfg_img, 405873 bytes across 1 key(s), wait=0.004s, link=4.2ms
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[bagel.py:265]�[0m Collected CFG KV cache for role=cfg_img, rid=chatcmpl-8e30c878bd2e3567__cfg_img, size=405873 bytes
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[kv_transfer_manager.py:1225]�[0m Applied CFG KV caches: ['cfg_text_past_key_values', 'cfg_text_kv_metadata', 'cfg_img_past_key_values', 'cfg_img_kv_metadata']
�[37mDEBUG�[0m �[90m05-14 07:47:54�[0m �[90m[config/kernel.py:85]�[0m Setting IR op priority for rms_norm to ['vllm_c', 'native']
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[pipeline_bagel.py:376]�[0m Using injected KV Cache (direct)
�[32mINFO�[0m �[90m05-14 07:47:54�[0m �[90m[pipeline_bagel.py:406]�[0m CFG enabled with injected branch KV context roles=[] active=None

The customized yaml config:

async_chunk: false

stages:
  - stage_id: 0
    max_num_batched_tokens: 32768
    max_num_seqs: 3
    gpu_memory_utilization: 0.7
    trust_remote_code: true
    enable_prefix_caching: false
    devices: "0,1"
    output_connectors:
      to_stage_1: rdma_connector           # <------ important
    default_sampling_params:
      temperature: 0.4
      top_p: 0.9
      top_k: 1
      max_tokens: 2048
      seed: 52
      detokenize: true
      repetition_penalty: 1.05

  - stage_id: 1
    max_num_batched_tokens: 32768
    max_num_seqs: 1
    enforce_eager: true
    trust_remote_code: true
    enable_prefix_caching: false
    gpu_memory_utilization: 0.7
    devices: "0,1"
    input_connectors:
      from_stage_0: rdma_connector           # <------ important
    default_sampling_params:
      seed: 52

connectors:
  rdma_connector:           # <------ important
    name: MooncakeTransferEngineConnector           # <------ important
    extra:
      host: "auto"
      zmq_port: 50077
      protocol: "rdma"
      device_name: "mlx5_0"           # <------ important, suggest to manually write mlx5_0 or any specify device here; auto selection may got problem.
      memory_pool_size: 4294967296
      memory_pool_device: "cpu"
  shared_memory_connector:
    name: SharedMemoryConnector
    extra:
      shm_threshold_bytes: 65536

@herotai214
Copy link
Copy Markdown

herotai214 commented May 15, 2026

I tested the PR codes with ByteDance-Seed/BAGEL-7B-MoT running cross nodes with multi-replicas. Summary: 1 AR + 1 DiT works; Failure for case more than 1 replicas

Version: chickeyton:omni_coord_itg_rebase2 branch commit d553d95287784cd7f23b82bf0274f8f6501166a1

Customized a yaml file (attached at the end) to use MooncakeTransferEngineConnector rdma_connector to support KV cache transfer between nodes:

  • When running 2 AR (on node 0) + 2 headless DiT (on node 1), error due to repeated address port appears

    • I will fix it very soon and push to this PR codes.

Error log for 2AR 2DiT

  • Currently, 1 AR + 1 headless DiT is working well, and gives proper output:
    ....

As a follow up, raised a chickeyton#4 into this PR code
Verified with cross-node 1A1D, 1A2D, 2A2D cases with Bagel; All gives proper output and error-free.

It is an ad-hoc fix that:

  • avoids repeated connector port between replicas by saving omni_replica_id of the replica each individual process to its os.environ["VLLM_OMNI_REPLICA_ID"] under
    vllm_omni/engine/stage_engine_core_proc.py
  • Then retrieved later for port number omputation with kv_zmq_port under vllm_omni/distributed/omni_connectors/utils/kv_utils.py

A more elegant way to assign port number would be a future work to discuss about.

@herotai214
Copy link
Copy Markdown

herotai214 commented May 15, 2026

A similar deploy config to test with HunyuanImage-3.0-Instruct with multi-replicas:

# HunyuanImage-3.0-Instruct deploy: AR (stage 0) + DiT (stage 1)
# with AR-to-DiT KV reuse. The base CUDA layout uses 2 GPUs for AR
# and 2 GPUs for DiT.
pipeline: hunyuan_image3
async_chunk: false
trust_remote_code: true

connectors:
  rdma_connector:
    name: MooncakeTransferEngineConnector
    extra:
      host: "auto"
      zmq_port: 50051
      protocol: "rdma"
      device_name: "mlx5_0"
      memory_pool_size: 4294967296
      memory_pool_device: "cpu"
  shared_memory_connector:
    name: rdma_connector
    extra:
      shm_threshold_bytes: 65536

stages:
  - stage_id: 0
    is_comprehension: false
    final_output: true
    final_output_type: text
    max_num_seqs: 1
    gpu_memory_utilization: 0.95
    enforce_eager: true
    max_num_batched_tokens: 32768
    devices: "0,1,2,3"
    tensor_parallel_size: 4
    hf_overrides:
      rope_parameters:
        mrope_section: [0, 32, 32]
        rope_type: default
    omni_kv_config:
      need_send_cache: true
    output_connectors:
      to_stage_1: rdma_connector
    default_sampling_params:
      temperature: 0.0
      top_p: 1
      top_k: -1
      max_tokens: 8192
      detokenize: true
      skip_special_tokens: false

  - stage_id: 1
    max_num_seqs: 1
    enforce_eager: true
    devices: "0,1,2,3"
    distributed_executor_backend: "mp"
    omni_kv_config:
      need_recv_cache: true
    parallel_config:
      tensor_parallel_size: 4
      enable_expert_parallel: true
    input_connectors:
      from_stage_0: rdma_connector
    default_sampling_params:
      num_inference_steps: 50
      guidance_scale: 0

edges:
  - from: 0
    to: 1
    window_size: -1
    max_inflight: 1

platforms:
  npu:
    stages:
      - stage_id: 0
        gpu_memory_utilization: 0.8
        devices: "0,1,2,3"
        tensor_parallel_size: 4
        max_num_batched_tokens: 8192
      - stage_id: 1
        gpu_memory_utilization: 0.65
        devices: "4,5,6,7"
        max_num_batched_tokens: 8192
        parallel_config:
          tensor_parallel_size: 4
          enable_expert_parallel: true

  xpu:
    stages:
      - stage_id: 0
        gpu_memory_utilization: 0.95
        devices: "0,1,2,3,4,5,6,7"
        tensor_parallel_size: 8
        max_num_batched_tokens: 32784
        quantization: fp8
        enable_expert_parallel: true
        worker_cls: vllm_omni.platforms.xpu.worker.xpu_ar_worker.XPUARWorker
        default_sampling_params:
          max_tokens: 2048
          seed: 42
          repetition_penalty: 1.1
      - stage_id: 1
        gpu_memory_utilization: 0.9
        devices: "0,1,2,3,4,5,6,7"
        quantization: fp8
        parallel_config:
          pipeline_parallel_size: 1
          data_parallel_size: 1
          tensor_parallel_size: 8
          enable_expert_parallel: true
          sequence_parallel_size: 1
          ulysses_degree: 1
          ring_degree: 1
          cfg_parallel_size: 1
          vae_patch_parallel_size: 1
          use_hsdp: false
          hsdp_shard_size: -1
          hsdp_replicate_size: 1

Signed-off-by: chickeyton <ngton2014@gmail.com>
LeastQueueLengthBalancer relies on heartbeats as the only periodic
source of live load, because StageEngineCoreProc/StageDiffusionProc
refresh ``queue_length`` just-in-time via the ``_on_heartbeat`` hook
before each heartbeat send. The coordinator's heartbeat handler was
only updating ``last_heartbeat`` though, so it kept publishing the
initial queue_length (usually 0) and the least-queue policy could
pick busy replicas as if they were idle.

Copy ``event.queue_length`` into ``info.queue_length`` on heartbeat
events and request a broadcast when it changes so subscribers see
fresh load promptly. Coalescing in the periodic loop keeps the wire
traffic bounded.

Also corrects the now-outdated docstring on ``_send_event`` that
claimed heartbeats sent ``queue_length=null``.

Signed-off-by: chickeyton <ngton2014@gmail.com>
- Drop unused ``vllm_config`` local in ``StageEngineCoreProc.run_stage_core``
  (F841); the comment about the removed hardcoded data_parallel_size is
  retained.
- Wrap the long ``[Headless] Launching ... OmniMasterServer`` log line
  in serve.py to keep it under the 120-char limit (E501).
- Reflow multi-line ``raise`` / ``logger`` calls that fit on one line
  per ``ruff format`` rules in stage_diffusion_proc, async_omni_engine,
  omni_coord_client_for_hub, omni_core_engine_proc_manager, orchestrator,
  stage_engine_core_proc and serve.

Signed-off-by: chickeyton <ngton2014@gmail.com>
Signed-off-by: chickeyton <ngton2014@gmail.com>
Signed-off-by: chickeyton <ngton2014@gmail.com>
Signed-off-by: chickeyton <ngton2014@gmail.com>
Signed-off-by: chickeyton <ngton2014@gmail.com>
Signed-off-by: herotai214 <herotai214@gmail.com>
@chickeyton chickeyton force-pushed the omni_coord_itg_rebase2 branch from affb36a to 73301a8 Compare May 15, 2026 06:18
Signed-off-by: chickeyton <ngton2014@gmail.com>
@Gaohan123 Gaohan123 added the high priority high priority issue, needs to be done asap label May 15, 2026
@herotai214
Copy link
Copy Markdown

Based on commit ba30955, I verified that in multi-replicas case:

  • only 1 replica is working for a single request
  • all replicas are involved and working properly

The replicas are working well.
(*The input images are different every time)

I added some codes for loggings, see if the logging codes benefit this PR.
chickeyton#5

@chickeyton
Copy link
Copy Markdown
Contributor Author

chickeyton commented May 15, 2026

Test

Setup

model: tencent/HunyuanImage-3.0-Instruct
server1: Head cli with stage 1 (DiT) 2 x Replicas

vllm serve tencent/HunyuanImage-3.0-Instruct \
  --omni \
  --host $HEAD_HOST \
  --port  $API_PORT \
  --stage-id 1 \
  --quantization fp8 \
  --trust-remote-code \
  --omni-dp-size-local 2 \
  --omni-master-address $HEAD_HOST \
  --omni-master-port $MASTER_PORT \
  --deploy-config "config.yaml" \
  --stage-init-timeout 1800 \
  --init-timeout 2400

server2: Headless cli with stage 0 (AR) 1 x Replicas

vllm serve tencent/HunyuanImage-3.0-Instruct \
  --omni \
  --headless \
  --stage-id 0 \
  --quantization fp8 \
  --trust-remote-code \
  --omni-master-address $HEAD_HOST \
  --omni-master-port $MASTER_PORT \
  --deploy-config "config.yaml" \
  --stage-init-timeout 1800

connector: RDMA Mooncake

Input

image:
input_0_0

curl -X POST http://localhost:${API_PORT}/v1/images/edits \
  -F "image=@./input.png" \
  -F "prompt=新年宠物海报,Q版圆润的可爱标题\"新年快乐汪\",副标题\"HAPPY NEW YEAR\"。 鱼眼镜头,背景是房间门口,近景,上传的主体歪头笑,围着红色围巾,戴着红色毛线帽,高清,绒毛细节,面部特写。 宝丽莱相纸,超现实主义,写实主义,胶片摄影,打印颗粒感肌理。肌理,超写实,复古感。" \
  -F "bot_task=it2i_think" \
  -F "n=1" \
  -F "num_inference_steps=50" \
  -F "guidance_scale=1" \
  -F "seed=42" \
  | jq -r '.data[0].b64_json' \
  | base64 -d > result.png

Output

result

@Gaohan123 Gaohan123 added the ready label to trigger buildkite CI label May 16, 2026
@yinpeiqi
Copy link
Copy Markdown
Contributor

This PR adds useful functionality, but the implementation feels too tightly coupled. Right now topology planning, headless registration/handshake, replica lifecycle, routing, and backend-specific behavior are spread across serve.py, async_omni_engine.py, stage_engine_startup.py, orchestrator.py, and stage_pool.py. That makes the control flow hard to follow and will make future changes risky.

Upstream vLLM handles this more cleanly by separating concerns at the boundaries: config/parallel arguments are normalized first, then a small number of factory/launcher paths choose the execution mode, and launch/handshake logic is concentrated in dedicated utilities. The frontend mostly talks to abstract client/executor interfaces, and the DP coordinator stays a control-plane component rather than being mixed into request orchestration. I think vLLM-Omni should move in the same direction: introduce a clear topology/plan layer, isolate registration + bootstrap + handshake into dedicated runtime helpers, keep LLM/diffusion differences behind small runtime/client abstractions, let Orchestrator focus on request flow only, and let StagePool focus on routing rather than membership/lifecycle management. This would improve readability, reduce cross-module coupling, and make the multi-node path much easier to maintain.

@zengchuang-hw
Copy link
Copy Markdown
Contributor

@fake0fan @Gaohan123 plz review this PR asap

@yinpeiqi
Copy link
Copy Markdown
Contributor

@fake0fan @Gaohan123 plz review this PR asap

I will refactor it after merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority high priority issue, needs to be done asap ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants