Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import aiohttp
import msgpack
import regex as re
import zmq
from quart import Quart, Request, make_response, request

Expand All @@ -25,32 +24,10 @@
request_nums = 0
app = Quart(__name__)

IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")


TRANSFER_TYPE = None


def _append_whole_dict_unique(target_list, data_dict):
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
for existed in target_list:
existed_filtered = {k: v for k, v in existed.items() if k != "index"}
if existed_filtered == new_filtered:
return False
print("!!APPEND!!", data_dict)
target_list.append(data_dict)
transfer_mode = data_dict.get("transfer_mode", "unknown")
global TRANSFER_TYPE

if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif transfer_mode != TRANSFER_TYPE:
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")

return True


_list_lock = threading.RLock()


Expand All @@ -68,23 +45,81 @@ def _listen_for_register(hostname, port):
if router_socket in socks:
remote_addr, msg = router_socket.recv_multipart()
data = msgpack.loads(msg)
if data["type"] == "HELLO":
if data.get("type") == "HELLO":
pass
elif (
data["type"] == "register"
and data["role"] == "P"
and data["request_address"] not in prefill_instances
):
elif data.get("type") in ("P", "D"):
role = data["type"]
required_keys = {
"http_address",
"zmq_address",
"dp_size",
"tp_size",
"transfer_mode",
}
missing = required_keys - data.keys()
if missing:
logger.error(
"Registration message missing required keys %s; skipping",
missing,
)
continue
# Derive request_address from http_address
# api path suffix is appended at request time
instance = {
"role": role,
"request_address": f"http://{data['http_address']}/v1",
"http_address": data["http_address"],
"zmq_address": data["zmq_address"],
"dp_size": data["dp_size"],
"tp_size": data["tp_size"],
"transfer_mode": data["transfer_mode"],
}
# zmq_address format: "host:IP,handshake:PORT,notify:PORT"
# Stored verbatim; embedded into the request_id by handle_request.

global TRANSFER_TYPE
transfer_mode = instance["transfer_mode"]
target_list = prefill_instances if role == "P" else decode_instances
with _list_lock:
_append_whole_dict_unique(prefill_instances, data)

elif (
data["type"] == "register"
and data["role"] == "D"
and data["request_address"] not in decode_instances
):
with _list_lock:
_append_whole_dict_unique(decode_instances, data)
if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif transfer_mode != TRANSFER_TYPE:
logger.error(
"Mismatched transfer mode: expected %s, got %s;"
" skipping registration of %s",
TRANSFER_TYPE,
transfer_mode,
data["http_address"],
)
continue
existing_idx = next(
(
idx
for idx, i in enumerate(target_list)
if i.get("http_address") == data["http_address"]
),
None,
)
if existing_idx is not None:
target_list[existing_idx] = instance
logger.info(
"Updated existing %s instance: %s",
"Prefill" if role == "P" else "Decode",
instance,
)
Comment thread
simondanielsson marked this conversation as resolved.
else:
target_list.append(instance)
logger.info(
"Registered %s instance: %s",
"Prefill" if role == "P" else "Decode",
instance,
)
else:
logger.warning(
"Received message with unrecognized type %r; ignoring",
data.get("type"),
)


def start_service_discovery(hostname, port):
Expand All @@ -101,20 +136,16 @@ def start_service_discovery(hostname, port):


async def send_request_to_prefill(
endpoint, req_data, request_id, d_endpoint, dip, dport, selected_prefill_dp_rank
endpoint, req_data, request_id, selected_prefill_dp_rank
):
req_data_copy = req_data

req_data_copy["kv_transfer_params"].update(
{
"do_remote_decode": True,
"do_remote_prefill": False,

@simondanielsson simondanielsson Apr 22, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation: These were MoRI-specific fields. For uniformity we instead we embed them into the zmq_address which is then injected into the request id, similar to P2pNccl

"remote_handshake_port": d_endpoint["handshake_port"],
"remote_notify_port": d_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": dip,
"remote_port": dport,
}
)
req_data_copy["stream"] = False
Expand Down Expand Up @@ -197,14 +228,7 @@ async def handle_request(api: str, request: Request):
global request_nums
request_nums += 1

def extract_ip_port_fast(url):
match = IP_PORT_PATTERN.search(url)
if not match:
raise ValueError(f"Invalid URL format: {url}")
return match.groups()

req_data = await request.get_json()
request_id = str(uuid.uuid4())

prefill_instance_endpoint = None
decode_instance_endpoint = None
Expand All @@ -230,7 +254,14 @@ def extract_ip_port_fast(url):
prefill_instance_endpoint["dp_size"],
)

dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
# Embed both zmq_addresses in the request_id so the connector can parse
# the peer's host/ports from it, similar to P2P-NCCL
uid = str(uuid.uuid4()).replace("-", "")
request_id = (
f"___prefill_addr_{prefill_instance_endpoint['zmq_address']}"
f"___decode_addr_{decode_instance_endpoint['zmq_address']}"
f"_{uid}"
)

transfer_id = f"{MoRIIOConstants.TRANSFER_PREFIX}-{str(uuid.uuid4())}"

Expand All @@ -251,35 +282,30 @@ def extract_ip_port_fast(url):
prefill_request_url,
req_data_to_prefill,
request_id,
decode_instance_endpoint,
dip,
dport,
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_request_url)

req_data["max_tokens"] -= 1

req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
"remote_notify_port": prefill_instance_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": ip,
"remote_port": port,
"transfer_id": transfer_id,
}
Comment thread
simondanielsson marked this conversation as resolved.
if TRANSFER_TYPE == "READ":
# In read mode, prefill and decode are executed serially.
prefill_response = await send_prefill_task
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
"kv_transfer_params"
]["remote_engine_id"]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
"kv_transfer_params"
]["remote_block_ids"]
prefill_kv = prefill_response["kv_transfer_params"]
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_kv[
"remote_engine_id"
]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_kv[
"remote_block_ids"
]
req_data["kv_transfer_params"]["transfer_id"] = prefill_kv["transfer_id"]

req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
"dp_size"
Expand All @@ -290,7 +316,6 @@ def extract_ip_port_fast(url):

if selected_prefill_dp_rank is not None:
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
req_data["kv_transfer_params"]["transfer_id"] = transfer_id

decode_request_url = decode_instance_endpoint["request_address"] + api
decode_request_task = asyncio.create_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any

import msgspec
import regex as re
import torch
import zmq

Expand Down Expand Up @@ -239,14 +240,72 @@ class MoRIIOConstants:
COMPLETION_PREFIX = "cmpl"
TRANSFER_PREFIX = "tx"

PING_INTERVAL = 5
PING_INTERVAL = 3
MAX_PING_RETRIES = 100
DEFAULT_HANDSHAKE_PORT = "6301"
DEFAULT_NOTIFY_PORT = "61005"

VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600


# The router embeds both zmq_addresses in the request_id (similar to P2pNcclConnector):
# "___prefill_addr_{zmq}___decode_addr_{zmq}_{32-hex-uuid}"
# MoRIIO zmq_address format: "host:IP,handshake:PORT,notify:PORT"
#
# This lets each connector side parse the peer's connection info without
# requiring the router to pass it explicitly in kv_transfer_params.
_PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_")
# vLLM wraps the router's X-Request-Id as "cmpl-<id>-<seq>-<hex>" so there may
# be a trailing "-<seq>-<hex>" suffix after the 32-char UUID. Allow it.
_DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$")


def parse_moriio_zmq_address(
zmq_address: str,
) -> tuple[str, int, int]:
"""Parse the MoRI-IO zmq address into its components.

Parses ``"host:IP,handshake:PORT,notify:PORT"`` into
(host, handshake_port, notify_port).

Each key-value pair is split on the *first* colon so that IPv6 addresses
(e.g. ``host:::1``) are handled correctly. Raises ``ValueError`` if any
of ``host``, ``handshake``, or ``notify`` keys are absent or if the port
values are non-numeric.
"""
parts: dict[str, str] = {}
for segment in zmq_address.split(","):
key, _, val = segment.partition(":")
parts[key.strip()] = val.strip()
try:
host = parts["host"]
handshake_port = int(parts["handshake"])
notify_port = int(parts["notify"])
except (KeyError, ValueError) as e:
raise ValueError(
f"Malformed zmq_address {zmq_address!r}: expected "
f"'host:IP,handshake:PORT,notify:PORT' format"
) from e
return host, handshake_port, notify_port


def get_peer_zmq_from_request_id(request_id: str, is_producer: bool) -> str:
"""Extract the *peer's* zmq_address from the vLLM router request_id.

The producer (prefill) needs the decode's address; the consumer (decode)
needs the prefill's address.
"""
if is_producer:
m = _DECODE_ZMQ_RE.search(request_id)
else:
m = _PREFILL_ZMQ_RE.search(request_id)
if m is None:
raise ValueError(
f"Cannot parse peer zmq_address from request_id: {request_id!r}"
)
return m.group(1)


@dataclass
class ReqMeta:
"""Metadata for a single request."""
Expand Down Expand Up @@ -286,15 +345,23 @@ def add_new_req(
write_mode=False,
):
transfer_id = kv_transfer_params["transfer_id"]

# Parse host/ports from the request_id. The router embeds both zmq_addresses
# in the request_id
peer_zmq = get_peer_zmq_from_request_id(request_id, is_producer=write_mode)
remote_host, remote_handshake_port, remote_notify_port = (
parse_moriio_zmq_address(peer_zmq)
)

_req = ReqMeta(
transfer_id=transfer_id,
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
remote_notify_port=kv_transfer_params["remote_notify_port"],
remote_host=remote_host,
remote_port=remote_handshake_port,
remote_handshake_port=remote_handshake_port,
remote_notify_port=remote_notify_port,
tp_size=kv_transfer_params.get("tp_size", 1),
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
)
Expand Down
Loading
Loading