From e04d0563faef32272f9f1ec8927409ed3a8b6f8e Mon Sep 17 00:00:00 2001 From: shikpate Date: Fri, 29 May 2026 15:49:13 +0000 Subject: [PATCH 1/8] [Bugfix][KVConnector][MoRI-IO] Synthesize transfer_id when llm-d sidecar omits it The llm-d routing-sidecar (`--kv-connector=nixlv2`) splits an incoming chat completion into separate prefill and decode requests, forwarding NIXL-shaped `kv_transfer_params` on each leg. Those params include `do_remote_decode` / `do_remote_prefill` plus the `remote_engine_id`/`host`/`port`/`block_ids` triplet for the NIXL READ flow, but do not include MoRI-IO's own `transfer_id` field -- that is a MoRI-IO concept the sidecar has no knowledge of. Without this patch, `MoRIIOConnectorScheduler.update_state_after_alloc` unconditionally dereferences `params["transfer_id"]` and the prefill engine crashes on first traffic with `KeyError: 'transfer_id'`. Synthesize a stable `transfer_id` from `request.request_id` so both producer and consumer (which see the same `request_id` through the sidecar fan-out) end up with the same `transfer_id`, without requiring any wire-protocol change in the sidecar. This is the MoRI-IO-side counterpart to PR #39276's "Fix C", which only fixed the analogous `KeyError` for `remote_handshake_port` / `remote_notify_port`. See-also: vllm-project/vllm#39276 (Fix C, by @raviguptaamd: same defensive .get()-with-default pattern, applied there to remote_handshake_port and remote_notify_port; this commit extends it to transfer_id). Signed-off-by: shikpate --- .../kv_connector/v1/moriio/moriio_connector.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 167eef6e1ca8..a7d86511f364 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -387,7 +387,14 @@ def update_state_after_alloc( params = request.kv_transfer_params if not params: return - transfer_id = params["transfer_id"] + # LLM-D sidecar compat: the routing-sidecar (--kv-connector=nixlv2) + # emits NIXL-shaped kv_transfer_params that do not include MoRI-IO's + # transfer_id. Synthesize one deterministically from request_id so the + # producer and consumer (both downstream of the same sidecar fan-out) + # observe the same transfer_id without requiring a wire-protocol + # change in the sidecar. + transfer_id = params.get("transfer_id") or f"sidecar-{request.request_id}" + params.setdefault("transfer_id", transfer_id) request_id = request.request_id self.map_request_id(request_id, transfer_id) if params.get("do_remote_decode"): From 99ce0f29f8ae26498961694d6ece40f566e35cd5 Mon Sep 17 00:00:00 2001 From: shikpate Date: Fri, 29 May 2026 15:49:49 +0000 Subject: [PATCH 2/8] [Bugfix][V1][Scheduler] Tolerate finished_sending for already-removed requests `Scheduler._update_from_kv_xfer_finished` currently asserts that every `req_id` reported in `kv_connector_output.finished_sending` is still present in `self.requests`. That invariant holds for connectors that report completion synchronously inside the same scheduler step, but it is violated by WRITE-mode connectors (e.g. MoRI-IO in WRITE mode) which report `finished_sending` out-of-band from a deferred-write task that can complete one or more steps after the scheduler already removed the request via the normal finish path. Symptom: `AssertionError` in `_update_from_kv_xfer_finished` randomly under sustained disagg P/D traffic, killing the engine. Fix: replace the `assert req_id in self.requests` with a skip-if-missing guard, so the late completion is silently ignored. The block-free already happened on the synchronous finish; nothing else needs to be done. The matching `finished_recving` branch above already keeps the "missing" req_id in `finished_recving_kv_req_ids` for one more step and is handled separately, so this change is asymmetric on purpose. Related: vllm-project/vllm#39276 (engine-side timeouts for the same async-MoRI-IO completion races this commit handles scheduler-side; independent fixes, complementary surface). Signed-off-by: shikpate --- vllm/v1/core/sched/scheduler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 73d3dcb4b65e..e5f6f4ed62e3 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2175,8 +2175,14 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): self._free_blocks(self.requests[req_id]) for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) - assert req_id in self.requests - self._free_blocks(self.requests[req_id]) + # Skip-if-missing: the worker-side KV connector can report a + # completion for a request that has already been removed from + # ``self.requests`` (e.g. WRITE-mode connectors like MoRI-IO + # report finished_sending out-of-band, racing with the + # scheduler's normal lifecycle removal). Asserting here causes + # spurious crashes; tolerate the race by skipping the free. + if req_id in self.requests: + self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( self, From e8a42ce151732848a2ab35b9c0fec650b22b2645 Mon Sep 17 00:00:00 2001 From: shikpate Date: Fri, 29 May 2026 15:50:49 +0000 Subject: [PATCH 3/8] [Bugfix][V1][DP] Wake other DP engines on first request of first wave `DPEngineCoreProc.add_request` currently gates the ``engines_running`` flip and the ``start_wave`` broadcast on ``request_wave != self.current_wave``. Both ``current_wave`` and ``request_wave`` default to ``0``, so on the very first request after engine init the gate is False and the broadcast never happens. Consequence on collectives-heavy models (Wide-EP, large DP): * The DP rank that received the first request enters its forward pass and blocks on a collective (e.g. EP all2all, MoE all2all, ``has_unfinished_dp`` all-reduce). * The other DP ranks observe ``engines_running == False`` and ``local_unfinished_reqs == False`` in ``run_busy_loop``, take the ``continue`` path, and never call ``execute_dummy_batch`` -- so they never enter the collective. * The busy rank hangs forever on the collective until the ``multiproc_executor`` 1800 s timeout fires: ``RPC call to sample_tokens timed out``. Warm requests work fine because ``current_wave`` has already advanced past 0, so subsequent first-of-wave requests do trigger the broadcast; only the very first cold request after engine init hangs. This makes the bug invisible in CI but reliably reproduces in production startup of large DP topologies. Reproduced on DeepSeek-V3, DP=8, DP=16, TP=1, EP=8/16, on a fresh engine -- 100% deterministic hang on the first request, never on subsequent ones. Fix: drop the ``request_wave != self.current_wave`` outer gate. We still ``return`` early in steady state because ``engines_running`` is already True. When engines are idle and the scheduler is unpaused, we wake them up via the same code path that previously only fired for ``request_wave > current_wave``. Related: #36594, #36608, #37024, #38009 all touch the same region but for a different pause/resume race. This first-wave race is distinct and not fixed by any of them. Signed-off-by: shikpate --- vllm/v1/engine/core.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c21a4de5d309..4d9968df1cb8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1760,15 +1760,25 @@ def _pause_complete(self) -> bool: def add_request(self, request: Request, request_wave: int = 0): super().add_request(request, request_wave) - if self.has_coordinator and request_wave != self.current_wave: + if self.has_coordinator: if request_wave > self.current_wave: self.current_wave = request_wave - elif ( + # NB: don't gate this on ``request_wave != self.current_wave``. + # Both default to 0, so the very first request after engine init + # would otherwise hit ``0 != 0 == False``, ``engines_running`` + # would stay False, and no ``start_wave`` would be broadcast to + # the DP coordinator. The rank that received the first request + # enters its forward pass and blocks forever on the EP all2all + # collective because the other DP ranks (seeing + # ``engines_running=False`` and no local work) skip + # ``execute_dummy_batch`` and never participate. + # We re-broadcast whenever engines are idle and the scheduler + # is unpaused, which is harmless in steady-state (engines are + # already running) and required for the cold first wave. + if ( not self.engines_running and self.scheduler.pause_state == PauseState.UNPAUSED ): - # Request received for an already-completed wave, notify - # front-end that we need to start the next one. self.engines_running = True self.output_queue.put_nowait( (-1, EngineCoreOutputs(start_wave=self.current_wave)) From 52b634f901627d9bd9fc5f73f216e390d0cf258d Mon Sep 17 00:00:00 2001 From: shikpate Date: Fri, 29 May 2026 15:53:33 +0000 Subject: [PATCH 4/8] [Bugfix][V1][DP] AsyncLLM: stable per-request data_parallel_rank fallback When the OpenAI server is run with ``--api-server-count N`` (N > 1), Linux SO_REUSEPORT shuffles incoming connections across ApiServer processes. Two legs of a disaggregated prefill/decode pair (which share a ``request_id``) can land on different ApiServers and be load-balanced to different DP ranks. KV-transfer protocols that pin source/target by DP rank (MoRI-IO, NIXL WRITE-mode, ...) then end up exchanging handshakes with the wrong peer and the request deadlocks at the connector level. The result is rank-asymmetric: requests that happen to land on a ``(prefill DP=H, decode DP=H)`` pair succeed, all others time out after ``VLLM_MORIIO_DEFERRED_TIMEOUT_S`` (300 s by default). This patch adds a ``_pick_dp_rank_for_request`` helper that ``AsyncLLM.add_request`` consults when the caller did not supply a ``data_parallel_rank``. The helper synthesizes a stable rank in this order: 1. ``params.extra_args["kv_transfer_params"]["dp_rank_hint"]`` if the caller (or an upstream routing sidecar) already picked the rank. 2. Otherwise a stable ``blake2s(request_id) % effective_dp_size`` hash. Because the disagg sidecar uses the same ``request_id`` on the prefill and decode legs, both sides hash to the same rank H and the SO_REUSEPORT shuffle is neutralised. When ``data_parallel_size_local`` is set and smaller than ``data_parallel_size`` (multi-pod DP, "Wide-EP"), the modulus is capped to the local pod size so that both legs route to the same pod -- cross-pod handshake requires a coordinator that may not exist in the disagg orchestrator. The helper returns ``None`` when there is no DP fan-out to disambiguate, leaving the existing dispatch path unchanged. Callers that already pass an explicit ``data_parallel_rank`` (e.g. via the ``X-data-parallel-rank`` header) are untouched. Once ``data_parallel_rank`` is set, ``DPLBAsyncMPClient.get_core_engine_for_request`` already honours the hint and dispatches the request to ``EngineCore_DPH`` instead of load-balancing -- no changes are needed in the dispatch core. Related: vllm-project/vllm#39276 (multi-node engine_id collision fix: same theme of deterministic DP routing in P/D pairs; that PR handles --headless multi-node DP, this commit handles --api-server-count > 1 ApiServer fan-out). Signed-off-by: shikpate --- vllm/v1/engine/async_llm.py | 73 +++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 419e15163a9f..2abd5eacbc26 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import hashlib import os import socket import time @@ -277,6 +278,69 @@ async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self._supported_tasks + def _pick_dp_rank_for_request( + self, + request_id: str, + params: SamplingParams | PoolingParams, + ) -> int | None: + """Pick a stable ``data_parallel_rank`` for a request. + + When the OpenAI server is run with ``--api-server-count N`` (N > 1), + Linux SO_REUSEPORT shuffles incoming connections across ApiServer + processes. Two legs of a disaggregated prefill/decode pair (which + share a ``request_id``) can land on different ApiServers and be + load-balanced to different DP ranks. KV-transfer protocols that + pin source/target by DP rank (MoRI-IO, NIXL WRITE-mode, ...) then + end up exchanging handshakes with the wrong peer and the request + deadlocks at the connector level. + + To work around this we synthesize a per-request DP rank that both + legs will independently agree on, in this order: + + 1. ``params.extra_args["kv_transfer_params"]["dp_rank_hint"]`` + if the caller (or an upstream routing sidecar) has already + picked the rank. + 2. Otherwise a stable ``blake2s(request_id) % effective_dp_size`` + hash. + + When ``data_parallel_size_local`` is set and smaller than + ``data_parallel_size`` (multi-pod DP, "Wide-EP"), the modulus is + capped to the local pod size so that both legs route to a rank in + the same pod -- cross-pod handshake requires a coordinator that + may not exist in the disagg orchestrator. + + Returns ``None`` when there is no DP fan-out to disambiguate + (``effective_dp_size <= 1``); callers should leave + ``data_parallel_rank`` unset in that case. + """ + pc = self.vllm_config.parallel_config + try: + dp_size = int(pc.data_parallel_size) + except Exception: + dp_size = 1 + if dp_size <= 1: + return None + try: + dp_local_raw = getattr(pc, "data_parallel_size_local", None) + dp_local = int(dp_local_raw) if dp_local_raw else dp_size + if dp_local > 0: + dp_size = min(dp_size, dp_local) + except Exception: + pass + if dp_size <= 1: + return None + extra = getattr(params, "extra_args", None) or {} + if isinstance(extra, dict): + ktp = extra.get("kv_transfer_params") or {} + if isinstance(ktp, dict): + hint = ktp.get("dp_rank_hint") + if isinstance(hint, int) and 0 <= hint < dp_size: + return hint + digest = hashlib.blake2s( + str(request_id).encode("utf-8"), digest_size=8 + ).digest() + return int.from_bytes(digest, "big") % dp_size + async def add_request( self, request_id: str, @@ -297,6 +361,15 @@ async def add_request( ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" + if data_parallel_rank is None: + data_parallel_rank = self._pick_dp_rank_for_request(request_id, params) + if data_parallel_rank is not None: + logger.debug( + "Auto-routed request %s to data_parallel_rank=%d", + request_id, + data_parallel_rank, + ) + if self.errored: raise EngineDeadError() From 6863d2f83dc2dcc25c02d816bc42dae6ad1dea45 Mon Sep 17 00:00:00 2001 From: shikpate Date: Fri, 29 May 2026 16:02:08 +0000 Subject: [PATCH 5/8] [Bugfix][KVConnector][MoRI-IO] Hash-route decode->prefill notify in DP>1 In a DP>1 disaggregated deployment, ``request_finished`` on the decode-side scheduler reads the prefill DP rank from ``request.kv_transfer_params.get("remote_dp_rank", 0)``. That field is a **static** value injected by the routing sidecar (e.g. via ``--moriio-prefill-dp-rank``, default 0), so every decode->prefill notify lands on a single prefill DP rank. When the prefill dispatcher actually routes requests round-robin / hash / least-loaded across multiple DP ranks, every request whose prefill leg ran on rank N>0 never gets its ``done_remote_allocate`` notify -- those workers spin in ``save_kv_layer`` until the deferred write task expires after ``VLLM_MORIIO_DEFERRED_TIMEOUT_S`` (300 s default), and the request fails. In practice, **only requests that happen to land on the pinned prefill rank succeed**. Counts collected from a 138 minute run with DP=8, TP=1, EP=8, DeepSeek-V3, sidecar pinned to prefill DP0: Worker_DP0: 0 EXPIRED (works) Worker_DP1: 183 Worker_DP2: 61 Worker_DP3: 183 Worker_DP4: 427 Worker_DP5: 122 Worker_DP6: 61 Worker_DP7: 61 Fix: when ``remote_dp_size > 1`` and the caller did not explicitly set ``remote_dp_rank_override``, compute the prefill DP rank from a stable hash of ``request_id``: remote_dp_rank = int.from_bytes( hashlib.blake2s(request_id, digest_size=8).digest(), "big" ) % remote_dp_size The dispatcher-side helper (``AsyncLLM._pick_dp_rank_for_request``) uses the same blake2s scheme so both legs (prefill dispatch + decode notify) agree on the rank, neutralising the SO_REUSEPORT shuffle that the disagg sidecar can otherwise induce. By the time ``request_finished`` runs, MoRI-IO has appended a per-transfer suffix ``-<8 hex>`` to ``request.request_id`` (it isn't on the AsyncLLM rid the dispatcher hashes). Strip the suffix so both legs hash the same canonical base id. For multi-pod DP topologies (Wide-EP-16: 8 ranks per pod on master+child), cap the modulus to ``remote_dp_size_local`` when set so the notify lands on the same pod the dispatcher routed to. Without the cap, hash mod 16 can pick a rank on the other pod and the notify goes to an engine that never serviced the request. Single-DP and unspecified-DP deployments are unchanged (``_dp_size <= 1`` short-circuits to the previous behaviour). Companion to the dispatcher-side ``AsyncLLM._pick_dp_rank_for_request`` (separate PR). The two patches are independently useful but together provide end-to-end stable DP routing for disagg prefill/decode pairs. Related: vllm-project/vllm#39276 (Fix E + engine_id collision: same deterministic-DP-routing problem space; companion to the dispatcher-side AsyncLLM blake2s fix in the previous commit). Signed-off-by: shikpate --- .../v1/moriio/moriio_connector.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index a7d86511f364..cf8ec5227d1a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib import logging import math import queue +import re import threading import time from collections import defaultdict @@ -438,6 +440,67 @@ def update_state_after_alloc( remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) + # Wide-EP DP>1 fix: the disagg routing sidecar injects a + # STATIC ``remote_dp_rank`` (e.g. always 0) into + # ``kv_transfer_params``. With DP>1, that pins every + # decode->prefill notify to a single prefill DP rank, so + # all prefill ranks other than that one never receive + # their ``done_remote_allocate`` notify and their + # deferred-write tasks expire after + # ``VLLM_MORIIO_DEFERRED_TIMEOUT_S``. Most requests hang. + # + # Compute a per-request prefill DP rank from a stable + # hash of ``request_id``. The matching helper on + # ``AsyncLLM.add_request`` uses the same blake2s scheme, + # so both legs (prefill dispatch + decode notify) agree. + # + # When ``remote_dp_size_local`` is set and smaller than + # ``remote_dp_size`` (multi-pod DP, "Wide-EP"), cap the + # modulus to the per-pod size so the notify lands on the + # same pod that the producer dispatch routes to. + # + # By the time we reach ``request_finished``, MoRI-IO has + # appended a per-transfer suffix ``-<8 hex>`` to + # ``request.request_id`` (it isn't on the AsyncLLM rid + # that the dispatcher hashes). Strip that suffix so both + # legs hash the same canonical base id. + _dp_size = int( + request.kv_transfer_params.get("remote_dp_size", 1) or 1 + ) + try: + _dp_local = int( + request.kv_transfer_params.get("remote_dp_size_local", 0) + or 0 + ) + if _dp_local > 0: + _dp_size = min(_dp_size, _dp_local) + except (TypeError, ValueError): + pass + # Defense-in-depth handshake with the llm-d routing + # sidecar (patch 0013, shipped in + # pd-sidecar-moriio-write-widep-v0.8.0+): when the + # sidecar is in path it already pins the prefill DP rank + # via its own pickDPRank(uuid, dpSize) and stamps both + # ``remote_dp_rank`` and ``remote_dp_rank_override=True`` + # on the kv_transfer_params. Honouring that sentinel + # makes this branch dormant in production while still + # acting as a fail-safe for sidecar-less debug runs and + # for any future sidecar regression that drops the + # override stamp. Avoids cross-language hash divergence + # (Go blake2s-256 vs Python blake2s-8) when both sides + # would otherwise hash independently. + if ( + _dp_size > 1 + and "remote_dp_rank_override" not in request.kv_transfer_params + ): + _base_rid = re.sub( + r"-[0-9a-f]{8}$", "", str(request.request_id) + ) + _digest = hashlib.blake2s( + _base_rid.encode("utf-8"), digest_size=8 + ).digest() + remote_dp_rank = int.from_bytes(_digest, "big") % _dp_size + peer_zmq = get_peer_zmq_from_request_id( request.request_id, is_producer=False ) From ca48ff49cafd5f484a73db5ad73c1213900ab0af Mon Sep 17 00:00:00 2001 From: shikpate Date: Fri, 29 May 2026 16:04:15 +0000 Subject: [PATCH 6/8] [Bugfix][KVConnector][MoRI-IO] Tolerate per-transfer rid suffix in unmap ``MoRIIOConnectorScheduler.unmap_request_id`` (decode-side, run from ``request_finished``) currently does an exact-match lookup on ``self.request_id_to_transfer_id`` and warns on miss with:: Could not find in transfer_id_to_request_id lookup table. This could lead to a possible hang. In multi-pod disagg routing we observe in production that MoRI-IO appends a "-[0-9a-f]{8}" per-transfer suffix to ``request.request_id`` between the call that populated the map (``update_state_after_alloc``, alloc-time) and the call that drains it (``request_finished``, finish-time). The lookup is exact-match, so the suffix mutation produces a spurious warning, leaks the dict entry, and ships stale state to the worker via ``meta.transfer_id_to_request_id`` -- which manifests as rank-asymmetric MoRI-IO transfer-id lookup failures in worker logs on the pod where the suffix gets appended (decode-master in Wide-EP DP=16, ranks 0..7). Concretely:: map request_id = "cmpl-bda091899755d21b-0" (no suffix) unmap request_id = "cmpl-bda091899755d21b-0-956053a4" (8-hex suffix) Same canonical request, different keys -> dict lookup misses, dict entry leaks. This patch makes ``unmap_request_id`` robust to the suffix mutation: 1. Try exact match first. This is the existing fast path and is zero-overhead / bit-identical to the pre-patch behaviour for callers that pass the canonical rid (decode-child, ranks 8..15). 2. If the exact-match misses, strip a trailing ``-[0-9a-f]{8}`` suffix and retry. If the canonical base id is present, log a ``debug``-level note and proceed with the canonical id. 3. If both miss, log a more informative warning (table size + canonical base id) so a real "never mapped" bug is still easy to tell apart from the suffix mutation. The regex is declared as a private class-level constant (``_MORIIO_RID_SUFFIX_RE``) so it is compiled once at import time, not per-call. This is the matching scheduler-side fix to the ``request_finished`` hash-routing patch (separate PR) and uses the same suffix shape that patch already understands. The two are independently useful but together provide end-to-end rid normalisation for the sidecar-fronted decode path. Signed-off-by: shikpate --- .../v1/moriio/moriio_connector.py | 56 ++++++++++++++++--- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index cf8ec5227d1a..ceca0d5ccf32 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -305,22 +305,64 @@ def map_request_id(self, request_id: ReqId, transfer_id: TransferId): self.transfer_id_to_request_id[transfer_id] = request_id self.request_id_to_transfer_id[request_id] = transfer_id + # Per-transfer suffix that MoRI-IO appends to ``request.request_id`` + # between ``update_state_after_alloc`` (alloc-time) and + # ``request_finished`` (finish-time) on the sidecar-fronted decode + # path. Used by ``unmap_request_id`` to strip the suffix when the + # exact-match lookup misses. + _MORIIO_RID_SUFFIX_RE = re.compile(r"-[0-9a-f]{8}$") + def unmap_request_id(self, request_id: ReqId): - if request_id in self.request_id_to_transfer_id: - transfer_id = self.request_id_to_transfer_id[request_id] - del self.request_id_to_transfer_id[request_id] + # In multi-pod disagg routing, MoRI-IO can append a + # "-[0-9a-f]{8}" per-transfer suffix to ``request.request_id`` + # between the call that populated ``request_id_to_transfer_id`` + # (``update_state_after_alloc``) and the call that drains it + # (``request_finished``). The dict lookup is exact-match, so the + # suffix mutation produces a spurious + # + # "Could not find in transfer_id_to_request_id lookup + # table. This could lead to a possible hang." + # + # warning, leaks the dict entry, and ships stale state to the + # worker via ``meta.transfer_id_to_request_id`` -- causing + # rank-asymmetric MoRI-IO transfer-id lookup failures in worker + # logs on the pod where the suffix gets appended (decode-master + # in Wide-EP DP=16, ranks 0..7). + # + # Resolution: try exact match first (preserves the no-suffix + # fast path -- zero overhead and bit-identical to the + # pre-patch behaviour for callers that already pass the + # canonical rid), then fall back to stripping a trailing + # "-[0-9a-f]{8}" suffix and retrying. + lookup_id = request_id + if lookup_id not in self.request_id_to_transfer_id: + base = self._MORIIO_RID_SUFFIX_RE.sub("", str(request_id)) + if base != request_id and base in self.request_id_to_transfer_id: + logger.debug( + "MoRI-IO unmap suffix-strip: %r -> %r", + request_id, + base, + ) + lookup_id = base + if lookup_id in self.request_id_to_transfer_id: + transfer_id = self.request_id_to_transfer_id[lookup_id] + del self.request_id_to_transfer_id[lookup_id] if transfer_id in self.transfer_id_to_request_id: del self.transfer_id_to_request_id[transfer_id] else: logger.warning( - "transfer id not in transfer_id_to_request_id lookup" - "table. there is likely a bug!" + "MoRI-IO unmap: transfer_id %s not in " + "transfer_id_to_request_id for rid %s", + transfer_id, + lookup_id, ) else: logger.warning( - "Could not find %s in transfer_id_to_request_id" - "lookup table. This could lead to a possible hang.", + "MoRI-IO unmap MISS: rid=%r table_size=%d " + "(suffix-strip fallback also missed; map_request_id " + "likely never fired for this request)", request_id, + len(self.request_id_to_transfer_id), ) def get_num_new_matched_tokens( From 015e2c5dfc3811986c9ab4a47101bb5524b0f14f Mon Sep 17 00:00:00 2001 From: Shiksha Patel Date: Mon, 1 Jun 2026 21:02:29 +0000 Subject: [PATCH 7/8] [Bugfix][KVConnector][MoRI-IO] Skip non-disagg requests in build_connector_meta A request that arrives without ``kv_transfer_params`` (smoke test, mis-routed gateway request, kubelet probe POST, ...) is scheduled like any other request and shows up in ``scheduled_cached_reqs`` on the next tick. ``MoRIIOConnectorScheduler.build_connector_meta`` then unconditionally indexes ``self._reqs_need_pending_save[req_id]`` -- a dict that is only populated for true disagg requests -- and raises ``KeyError``, which crashes the EngineCore and cascades the whole producer pod down. This is the small drive-by hot-fix from the Wide-EP multi-pod patch series, extracted standalone because it applies equally to single-pod deployments running behind any gateway / EPP / health-probe path. Skip the loop body silently when ``req_id`` is not in the pending-save dict; preserves all behaviour for true disagg requests. Signed-off-by: shikpate --- .../kv_connector/v1/moriio/moriio_connector.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index ceca0d5ccf32..e1b1e18a09da 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -611,6 +611,15 @@ def build_connector_meta( if new_block_ids is not None: block_ids = new_block_ids[0] # TODO : hybrid attn, etc + # A request that arrived without ``kv_transfer_params`` + # (smoke test, mis-routed gateway request, ...) is + # scheduled normally but is never registered in + # ``_reqs_need_pending_save``. The unconditional dict + # access below would raise ``KeyError`` and crash the + # EngineCore, taking the whole replica down. Skip + # silently for non-disagg requests on a producer pod. + if req_id not in self._reqs_need_pending_save: + continue req, existing_blocks = self._reqs_need_pending_save[req_id] updated_blocks = list(existing_blocks) + (block_ids) self._reqs_need_pending_save[req_id] = (req, updated_blocks) From f4fd067ae0ef97fea83939b6e6896e9760cb3901 Mon Sep 17 00:00:00 2001 From: Shiksha Patel Date: Mon, 1 Jun 2026 21:03:21 +0000 Subject: [PATCH 8/8] [ROCm] Gate disagg-DP fixes behind current_platform.is_rocm() The three preceding fixes touch shared (non-MoRI-IO-specific) files: - vllm/v1/core/sched/scheduler.py (finished_sending skip-if-missing) - vllm/v1/engine/async_llm.py (DP-rank hash fallback) - vllm/v1/engine/core.py (DP first-wave wake) While each fix is correctness-positive for any disagg+DP user, gating them behind ``current_platform.is_rocm()`` keeps this branch a clean ROCm-only divergence from upstream main and avoids subtly changing default behaviour for CUDA / TPU / CPU users until each fix is upstreamed individually. CUDA path is bit-identical to upstream HEAD; ROCm path runs the new behaviour. The MoRI-IO connector module (``vllm/distributed/kv_transfer/kv_connector/v1/moriio/``) is already ROCm-only by virtue of needing the MoRI runtime library, so its edits do not need an explicit guard. Signed-off-by: shikpate --- vllm/v1/core/sched/scheduler.py | 18 ++++++---- vllm/v1/engine/async_llm.py | 8 ++++- vllm/v1/engine/core.py | 58 ++++++++++++++++++++------------- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e5f6f4ed62e3..ec45972594dd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -29,6 +29,7 @@ ) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.encoder_budget import MultiModalBudget +from vllm.platforms import current_platform from vllm.v1.core.encoder_cache_manager import ( EncoderCacheManager, EncoderDecoderCacheManager, @@ -2175,13 +2176,16 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): self._free_blocks(self.requests[req_id]) for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) - # Skip-if-missing: the worker-side KV connector can report a - # completion for a request that has already been removed from - # ``self.requests`` (e.g. WRITE-mode connectors like MoRI-IO - # report finished_sending out-of-band, racing with the - # scheduler's normal lifecycle removal). Asserting here causes - # spurious crashes; tolerate the race by skipping the free. - if req_id in self.requests: + if current_platform.is_rocm(): + # ROCm/MoRI-IO skip-if-missing: WRITE-mode connectors like + # MoRI-IO can report ``finished_sending`` out-of-band from a + # deferred-write task, racing with the scheduler's normal + # lifecycle removal. Tolerate the race by skipping the free + # if the request is already gone. + if req_id in self.requests: + self._free_blocks(self.requests[req_id]) + else: + assert req_id in self.requests self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 2abd5eacbc26..9a42565322f6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -27,6 +27,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.renderers import renderer_from_config from vllm.renderers.inputs.preprocess import extract_prompt_components @@ -361,7 +362,12 @@ async def add_request( ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" - if data_parallel_rank is None: + # ROCm-only: stable per-request DP-rank fallback to neutralise the + # SO_REUSEPORT shuffle on disagg P/D pairs (MoRI-IO, NIXL WRITE). + # Gated to ROCm because (a) MoRI-IO is the in-tree consumer that + # exercises this path and (b) we don't want to silently change the + # default DP load-balancing behaviour for CUDA users. + if current_platform.is_rocm() and data_parallel_rank is None: data_parallel_rank = self._pick_dp_rank_for_request(request_id, params) if data_parallel_rank is not None: logger.debug( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4d9968df1cb8..169904329e07 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -31,6 +31,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tracing import instrument, maybe_init_worker_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value @@ -1760,29 +1761,40 @@ def _pause_complete(self) -> bool: def add_request(self, request: Request, request_wave: int = 0): super().add_request(request, request_wave) - if self.has_coordinator: - if request_wave > self.current_wave: - self.current_wave = request_wave - # NB: don't gate this on ``request_wave != self.current_wave``. - # Both default to 0, so the very first request after engine init - # would otherwise hit ``0 != 0 == False``, ``engines_running`` - # would stay False, and no ``start_wave`` would be broadcast to - # the DP coordinator. The rank that received the first request - # enters its forward pass and blocks forever on the EP all2all - # collective because the other DP ranks (seeing - # ``engines_running=False`` and no local work) skip - # ``execute_dummy_batch`` and never participate. - # We re-broadcast whenever engines are idle and the scheduler - # is unpaused, which is harmless in steady-state (engines are - # already running) and required for the cold first wave. - if ( - not self.engines_running - and self.scheduler.pause_state == PauseState.UNPAUSED - ): - self.engines_running = True - self.output_queue.put_nowait( - (-1, EngineCoreOutputs(start_wave=self.current_wave)) - ) + if current_platform.is_rocm(): + # ROCm/Wide-EP first-wave wake fix: drop the + # ``request_wave != self.current_wave`` outer gate so the very + # first request after engine init also broadcasts + # ``start_wave`` (otherwise ``0 != 0`` skips the broadcast and + # the first DP rank hangs forever on the EP all2all collective + # because the other ranks never call ``execute_dummy_batch``). + # Steady-state remains correct because ``engines_running`` is + # already True so the inner branch short-circuits. + if self.has_coordinator: + if request_wave > self.current_wave: + self.current_wave = request_wave + if ( + not self.engines_running + and self.scheduler.pause_state == PauseState.UNPAUSED + ): + self.engines_running = True + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(start_wave=self.current_wave)) + ) + else: + if self.has_coordinator and request_wave != self.current_wave: + if request_wave > self.current_wave: + self.current_wave = request_wave + elif ( + not self.engines_running + and self.scheduler.pause_state == PauseState.UNPAUSED + ): + # Request received for an already-completed wave, notify + # front-end that we need to start the next one. + self.engines_running = True + self.output_queue.put_nowait( + (-1, EngineCoreOutputs(start_wave=self.current_wave)) + ) def resume_scheduler(self): if self.pending_pause or (self.engines_running and self.ignore_start_dp_wave):