Skip to content
Draft
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import logging
import math
import queue
import re
import threading
import time
from collections import defaultdict
Expand Down Expand Up @@ -303,22 +305,64 @@ def map_request_id(self, request_id: ReqId, transfer_id: TransferId):
self.transfer_id_to_request_id[transfer_id] = request_id
self.request_id_to_transfer_id[request_id] = transfer_id

# Per-transfer suffix that MoRI-IO appends to ``request.request_id``
# between ``update_state_after_alloc`` (alloc-time) and
# ``request_finished`` (finish-time) on the sidecar-fronted decode
# path. Used by ``unmap_request_id`` to strip the suffix when the
# exact-match lookup misses.
_MORIIO_RID_SUFFIX_RE = re.compile(r"-[0-9a-f]{8}$")

def unmap_request_id(self, request_id: ReqId):
if request_id in self.request_id_to_transfer_id:
transfer_id = self.request_id_to_transfer_id[request_id]
del self.request_id_to_transfer_id[request_id]
# In multi-pod disagg routing, MoRI-IO can append a
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The engine adds the suffix here here. It was introduced by this PR. The issue was addressed in this PR. Could you share an example of the request id's you were seeing?

# "-[0-9a-f]{8}" per-transfer suffix to ``request.request_id``
# between the call that populated ``request_id_to_transfer_id``
# (``update_state_after_alloc``) and the call that drains it
# (``request_finished``). The dict lookup is exact-match, so the
# suffix mutation produces a spurious
#
# "Could not find <rid> in transfer_id_to_request_id lookup
# table. This could lead to a possible hang."
#
# warning, leaks the dict entry, and ships stale state to the
# worker via ``meta.transfer_id_to_request_id`` -- causing
# rank-asymmetric MoRI-IO transfer-id lookup failures in worker
# logs on the pod where the suffix gets appended (decode-master
# in Wide-EP DP=16, ranks 0..7).
#
# Resolution: try exact match first (preserves the no-suffix
# fast path -- zero overhead and bit-identical to the
# pre-patch behaviour for callers that already pass the
# canonical rid), then fall back to stripping a trailing
# "-[0-9a-f]{8}" suffix and retrying.
lookup_id = request_id
if lookup_id not in self.request_id_to_transfer_id:
base = self._MORIIO_RID_SUFFIX_RE.sub("", str(request_id))
if base != request_id and base in self.request_id_to_transfer_id:
logger.debug(
"MoRI-IO unmap suffix-strip: %r -> %r",
request_id,
base,
)
lookup_id = base
if lookup_id in self.request_id_to_transfer_id:
transfer_id = self.request_id_to_transfer_id[lookup_id]
del self.request_id_to_transfer_id[lookup_id]
if transfer_id in self.transfer_id_to_request_id:
del self.transfer_id_to_request_id[transfer_id]
else:
logger.warning(
"transfer id not in transfer_id_to_request_id lookup"
"table. there is likely a bug!"
"MoRI-IO unmap: transfer_id %s not in "
"transfer_id_to_request_id for rid %s",
transfer_id,
lookup_id,
)
else:
logger.warning(
"Could not find %s in transfer_id_to_request_id"
"lookup table. This could lead to a possible hang.",
"MoRI-IO unmap MISS: rid=%r table_size=%d "
"(suffix-strip fallback also missed; map_request_id "
"likely never fired for this request)",
request_id,
len(self.request_id_to_transfer_id),
)

def get_num_new_matched_tokens(
Expand Down Expand Up @@ -387,7 +431,14 @@ def update_state_after_alloc(
params = request.kv_transfer_params
if not params:
return
transfer_id = params["transfer_id"]
# LLM-D sidecar compat: the routing-sidecar (--kv-connector=nixlv2)
# emits NIXL-shaped kv_transfer_params that do not include MoRI-IO's
# transfer_id. Synthesize one deterministically from request_id so the
# producer and consumer (both downstream of the same sidecar fan-out)
# observe the same transfer_id without requiring a wire-protocol
# change in the sidecar.
transfer_id = params.get("transfer_id") or f"sidecar-{request.request_id}"
params.setdefault("transfer_id", transfer_id)
request_id = request.request_id
self.map_request_id(request_id, transfer_id)
if params.get("do_remote_decode"):
Expand Down Expand Up @@ -431,6 +482,67 @@ def update_state_after_alloc(

remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)

# Wide-EP DP>1 fix: the disagg routing sidecar injects a
# STATIC ``remote_dp_rank`` (e.g. always 0) into
# ``kv_transfer_params``. With DP>1, that pins every
# decode->prefill notify to a single prefill DP rank, so
# all prefill ranks other than that one never receive
# their ``done_remote_allocate`` notify and their
# deferred-write tasks expire after
# ``VLLM_MORIIO_DEFERRED_TIMEOUT_S``. Most requests hang.
#
# Compute a per-request prefill DP rank from a stable
# hash of ``request_id``. The matching helper on
# ``AsyncLLM.add_request`` uses the same blake2s scheme,
# so both legs (prefill dispatch + decode notify) agree.
#
# When ``remote_dp_size_local`` is set and smaller than
# ``remote_dp_size`` (multi-pod DP, "Wide-EP"), cap the
# modulus to the per-pod size so the notify lands on the
# same pod that the producer dispatch routes to.
#
# By the time we reach ``request_finished``, MoRI-IO has
# appended a per-transfer suffix ``-<8 hex>`` to
# ``request.request_id`` (it isn't on the AsyncLLM rid
# that the dispatcher hashes). Strip that suffix so both
# legs hash the same canonical base id.
_dp_size = int(
request.kv_transfer_params.get("remote_dp_size", 1) or 1
)
try:
_dp_local = int(
request.kv_transfer_params.get("remote_dp_size_local", 0)
or 0
)
if _dp_local > 0:
_dp_size = min(_dp_size, _dp_local)
except (TypeError, ValueError):
pass
# Defense-in-depth handshake with the llm-d routing
# sidecar (patch 0013, shipped in
# pd-sidecar-moriio-write-widep-v0.8.0+): when the
# sidecar is in path it already pins the prefill DP rank
# via its own pickDPRank(uuid, dpSize) and stamps both
# ``remote_dp_rank`` and ``remote_dp_rank_override=True``
# on the kv_transfer_params. Honouring that sentinel
# makes this branch dormant in production while still
# acting as a fail-safe for sidecar-less debug runs and
# for any future sidecar regression that drops the
# override stamp. Avoids cross-language hash divergence
# (Go blake2s-256 vs Python blake2s-8) when both sides
# would otherwise hash independently.
if (
_dp_size > 1
and "remote_dp_rank_override" not in request.kv_transfer_params
):
_base_rid = re.sub(
r"-[0-9a-f]{8}$", "", str(request.request_id)
)
_digest = hashlib.blake2s(
_base_rid.encode("utf-8"), digest_size=8
).digest()
remote_dp_rank = int.from_bytes(_digest, "big") % _dp_size

peer_zmq = get_peer_zmq_from_request_id(
request.request_id, is_producer=False
)
Expand Down Expand Up @@ -499,6 +611,15 @@ def build_connector_meta(
if new_block_ids is not None:
block_ids = new_block_ids[0]
# TODO : hybrid attn, etc
# A request that arrived without ``kv_transfer_params``
# (smoke test, mis-routed gateway request, ...) is
# scheduled normally but is never registered in
# ``_reqs_need_pending_save``. The unconditional dict
# access below would raise ``KeyError`` and crash the
# EngineCore, taking the whole replica down. Skip
# silently for non-disagg requests on a producer pod.
if req_id not in self._reqs_need_pending_save:
continue
req, existing_blocks = self._reqs_need_pending_save[req_id]
updated_blocks = list(existing_blocks) + (block_ids)
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.platforms import current_platform
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
Expand Down Expand Up @@ -2175,8 +2176,17 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
self._free_blocks(self.requests[req_id])
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
assert req_id in self.requests
self._free_blocks(self.requests[req_id])
if current_platform.is_rocm():
Copy link
Copy Markdown
Contributor

@rasmith rasmith Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you can't pass some metadata and have the prefill worker or decode worker help to do this task?
You can get metadata in the worker with start_load_kv and send metadata by returning appropriate metadata from build_connector_meta. This scheduler and worker process communicate. You can then return finished requests in get_finished. Maybe you can modify update_connector_output to ensure requests are deleted when appropriate.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the feedback @rasmith . Checking this out

# ROCm/MoRI-IO skip-if-missing: WRITE-mode connectors like
# MoRI-IO can report ``finished_sending`` out-of-band from a
# deferred-write task, racing with the scheduler's normal
# lifecycle removal. Tolerate the race by skipping the free
# if the request is already gone.
if req_id in self.requests:
self._free_blocks(self.requests[req_id])
else:
assert req_id in self.requests
self._free_blocks(self.requests[req_id])

def _update_requests_with_invalid_blocks(
self,
Expand Down
79 changes: 79 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import hashlib
import os
import socket
import time
Expand All @@ -26,6 +27,7 @@
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs.preprocess import extract_prompt_components
Expand Down Expand Up @@ -277,6 +279,69 @@ async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:

return self._supported_tasks

def _pick_dp_rank_for_request(
self,
request_id: str,
params: SamplingParams | PoolingParams,
) -> int | None:
"""Pick a stable ``data_parallel_rank`` for a request.

When the OpenAI server is run with ``--api-server-count N`` (N > 1),
Linux SO_REUSEPORT shuffles incoming connections across ApiServer
processes. Two legs of a disaggregated prefill/decode pair (which
share a ``request_id``) can land on different ApiServers and be
load-balanced to different DP ranks. KV-transfer protocols that
pin source/target by DP rank (MoRI-IO, NIXL WRITE-mode, ...) then
end up exchanging handshakes with the wrong peer and the request
deadlocks at the connector level.

To work around this we synthesize a per-request DP rank that both
legs will independently agree on, in this order:

1. ``params.extra_args["kv_transfer_params"]["dp_rank_hint"]``
if the caller (or an upstream routing sidecar) has already
picked the rank.
2. Otherwise a stable ``blake2s(request_id) % effective_dp_size``
hash.

When ``data_parallel_size_local`` is set and smaller than
``data_parallel_size`` (multi-pod DP, "Wide-EP"), the modulus is
capped to the local pod size so that both legs route to a rank in
the same pod -- cross-pod handshake requires a coordinator that
may not exist in the disagg orchestrator.

Returns ``None`` when there is no DP fan-out to disambiguate
(``effective_dp_size <= 1``); callers should leave
``data_parallel_rank`` unset in that case.
"""
pc = self.vllm_config.parallel_config
try:
dp_size = int(pc.data_parallel_size)
except Exception:
dp_size = 1
if dp_size <= 1:
return None
try:
dp_local_raw = getattr(pc, "data_parallel_size_local", None)
dp_local = int(dp_local_raw) if dp_local_raw else dp_size
if dp_local > 0:
dp_size = min(dp_size, dp_local)
except Exception:
pass
if dp_size <= 1:
return None
extra = getattr(params, "extra_args", None) or {}
if isinstance(extra, dict):
ktp = extra.get("kv_transfer_params") or {}
if isinstance(ktp, dict):
hint = ktp.get("dp_rank_hint")
if isinstance(hint, int) and 0 <= hint < dp_size:
return hint
digest = hashlib.blake2s(
str(request_id).encode("utf-8"), digest_size=8
).digest()
return int.from_bytes(digest, "big") % dp_size

async def add_request(
self,
request_id: str,
Expand All @@ -297,6 +362,20 @@ async def add_request(
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""

# ROCm-only: stable per-request DP-rank fallback to neutralise the
# SO_REUSEPORT shuffle on disagg P/D pairs (MoRI-IO, NIXL WRITE).
# Gated to ROCm because (a) MoRI-IO is the in-tree consumer that
# exercises this path and (b) we don't want to silently change the
# default DP load-balancing behaviour for CUDA users.
if current_platform.is_rocm() and data_parallel_rank is None:
data_parallel_rank = self._pick_dp_rank_for_request(request_id, params)
if data_parallel_rank is not None:
logger.debug(
"Auto-routed request %s to data_parallel_rank=%d",
request_id,
data_parallel_rank,
)

if self.errored:
raise EngineDeadError()

Expand Down
48 changes: 35 additions & 13 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
Expand Down Expand Up @@ -1760,19 +1761,40 @@ def _pause_complete(self) -> bool:

def add_request(self, request: Request, request_wave: int = 0):
super().add_request(request, request_wave)
if self.has_coordinator and request_wave != self.current_wave:
if request_wave > self.current_wave:
self.current_wave = request_wave
elif (
not self.engines_running
and self.scheduler.pause_state == PauseState.UNPAUSED
):
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.engines_running = True
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave))
)
if current_platform.is_rocm():
# ROCm/Wide-EP first-wave wake fix: drop the
# ``request_wave != self.current_wave`` outer gate so the very
# first request after engine init also broadcasts
# ``start_wave`` (otherwise ``0 != 0`` skips the broadcast and
# the first DP rank hangs forever on the EP all2all collective
# because the other ranks never call ``execute_dummy_batch``).
# Steady-state remains correct because ``engines_running`` is
# already True so the inner branch short-circuits.
if self.has_coordinator:
if request_wave > self.current_wave:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like outer if current_platform.is_rocm() isn't necessary and this could be simplified.

self.current_wave = request_wave
if (
not self.engines_running
and self.scheduler.pause_state == PauseState.UNPAUSED
):
self.engines_running = True
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave))
)
else:
if self.has_coordinator and request_wave != self.current_wave:
if request_wave > self.current_wave:
self.current_wave = request_wave
elif (
not self.engines_running
and self.scheduler.pause_state == PauseState.UNPAUSED
):
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.engines_running = True
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave))
)

def resume_scheduler(self):
if self.pending_pause or (self.engines_running and self.ignore_start_dp_wave):
Expand Down
Loading