Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
import zmq
from quart import Quart, make_response, request

from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOConstants,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
prefill_instances: list[dict] = []
Expand Down Expand Up @@ -213,6 +217,8 @@ def extract_ip_port_fast(url):

dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])

transfer_id = f"{MoRIIOConstants.TRANSFER_PREFIX}-{str(uuid.uuid4())}"

req_data_to_prefill = copy.deepcopy(req_data)
req_data_to_prefill["kv_transfer_params"] = {}
req_data["kv_transfer_params"] = {}
Expand All @@ -222,6 +228,7 @@ def extract_ip_port_fast(url):
req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = (
decode_instance_endpoint["tp_size"]
)
req_data_to_prefill["kv_transfer_params"]["transfer_id"] = transfer_id

send_prefill_task = asyncio.create_task(
send_request_to_prefill(
Expand Down Expand Up @@ -267,6 +274,7 @@ def extract_ip_port_fast(url):

if selected_prefill_dp_rank is not None:
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
req_data["kv_transfer_params"]["transfer_id"] = transfer_id

decode_request_task = asyncio.create_task(
start_decode_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@
Transfer = tuple[int, float]
EngineId = str
ReqId = str
TransferId = str


@dataclass
class WriteTask:
request_id: str
request_id: ReqId
transfer_id: TransferId
dst_engine_id: str
local_block_ids: list[int]
remote_block_ids_hint: list[int] | None
Expand All @@ -59,7 +61,8 @@ class WriteTask:
class LayerTransferPlan:
"""Plan for transferring a single layer."""

request_id: str
request_id: ReqId
transfer_id: TransferId
layer_name: str
sess_idx: int
transfer_local_offsets: list[int]
Expand Down Expand Up @@ -234,6 +237,7 @@ class MoRIIOConstants:
POP_DONE_RECV = b"pop_done_recv"
OVER = b"OVER"
COMPLETION_PREFIX = "cmpl"
TRANSFER_PREFIX = "tx"

PING_INTERVAL = 5
MAX_PING_RETRIES = 100
Expand All @@ -247,6 +251,7 @@ class MoRIIOConstants:
class ReqMeta:
"""Metadata for a single request."""

transfer_id: TransferId
local_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
Expand All @@ -263,21 +268,15 @@ def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.transfer_id_to_request_id: dict[TransferId, ReqId] = {}

def __repr__(self):
return_str = ""
for req_id, req_meta in self.reqs_to_recv.items():
return_str += (
f"{req_id = },{req_meta.local_block_ids = },"
f"{req_meta.remote_host = },{req_meta.remote_port = }"
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
)
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"

for req_id, expiry in self.reqs_to_send.items():
return_str += f"{req_id = },{expiry = }"
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
return return_str
return (
f"MoRIIOConnectorMetadata: reqs_to_recv={self.reqs_to_recv}, "
f"reqs_to_save={self.reqs_to_save}, "
f"reqs_to_send={self.reqs_to_send}, "
f"transfer_id_to_request_id={self.transfer_id_to_request_id}"
)

def add_new_req(
self,
Expand All @@ -286,7 +285,9 @@ def add_new_req(
kv_transfer_params: dict[str, Any],
write_mode=False,
):
transfer_id = kv_transfer_params["transfer_id"]
_req = ReqMeta(
transfer_id=transfer_id,
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MoRIIOMode,
ReqId,
ReqMeta,
TransferId,
WriteTask,
get_moriio_mode,
get_port_offset,
Expand Down Expand Up @@ -277,6 +278,30 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self.paths: dict[str, zmq.Socket] = {}
self.transfer_id_to_request_id: dict[TransferId, ReqId] = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @rasmith a lot! Would it be better to define two mapping functions here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the function name, this could save some horizontal space, but what other benefit could be had?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For versatility and maintainability, if it's only used within the class, I think using a dictionary is also fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you meant actually, but I added a map and unmap functions for request-id-transfer-id pairs. If you mean to swap the dictionaries for functions, its also possible to implement a class with __get_item__ and use that instead. It should be pretty easy, so if the time arises it shouldn't be too hard, but for the time being I don't see a need for it.

self.request_id_to_transfer_id: dict[ReqId, TransferId] = {}

def map_request_id(self, request_id: ReqId, transfer_id: TransferId):
self.transfer_id_to_request_id[transfer_id] = request_id
self.request_id_to_transfer_id[request_id] = transfer_id

def unmap_request_id(self, request_id: ReqId):
if request_id in self.request_id_to_transfer_id:
transfer_id = self.request_id_to_transfer_id[request_id]
del self.request_id_to_transfer_id[request_id]
if transfer_id in self.transfer_id_to_request_id:
del self.transfer_id_to_request_id[transfer_id]
else:
logger.warning(
"transfer id not in transfer_id_to_request_id lookup"
"table. there is likely a bug!"
)
else:
logger.warning(
"Could not find %s in transfer_id_to_request_id"
"lookup table. This could lead to a possible hang.",
request_id,
)

def get_num_new_matched_tokens(
self,
Expand Down Expand Up @@ -309,7 +334,12 @@ def get_num_new_matched_tokens(
return len(token_ids) - 1 - num_computed_tokens, False

def send_notify_block(
self, req_id: str, block_notify_list: list[int], host=None, port=None
self,
req_id: ReqId,
transfer_id: TransferId,
block_notify_list: list[int],
host=None,
port=None,
):
path = make_zmq_path("tcp", host, port)
if path not in self.paths:
Expand All @@ -321,6 +351,7 @@ def send_notify_block(

data = {
"req_id": req_id,
"transfer_id": transfer_id,
"block_notify_list": block_notify_list or [],
"decode_rank": self.dp_rank,
"type": "remote_blocks",
Expand All @@ -338,6 +369,9 @@ def update_state_after_alloc(
params = request.kv_transfer_params
if not params:
return
transfer_id = params["transfer_id"]
request_id = request.request_id
self.map_request_id(request_id, transfer_id)
if params.get("do_remote_decode"):
local_block_ids = blocks.get_block_ids()[0]
self._reqs_need_save[request.request_id] = (request, local_block_ids)
Expand Down Expand Up @@ -386,6 +420,7 @@ def update_state_after_alloc(

self.send_notify_block(
req_id=request.request_id,
transfer_id=request.kv_transfer_params["transfer_id"],
block_notify_list=blocks.get_block_ids()[0],
host=params.get("remote_host"),
port=target_port,
Expand All @@ -400,6 +435,7 @@ def build_connector_meta(
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = MoRIIOConnectorMetadata()
meta.transfer_id_to_request_id = self.transfer_id_to_request_id

if self.mode == MoRIIOMode.WRITE:
# when async_load_kv finished,
Expand Down Expand Up @@ -506,6 +542,9 @@ def request_finished(
should be freed now or will be sent asynchronously and freed later.
"""

request_id = request.request_id
self.unmap_request_id(request_id)

params = request.kv_transfer_params
logger.debug(
"MoriioConnector request_finished, request_status=%s, "
Expand Down Expand Up @@ -728,14 +767,16 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.cache_config.cache_dtype,
use_mla=self.use_mla,
)
self.transfer_id_to_request_id: dict[TransferId, ReqId] = {}

# TODO: consider the integration of flashinfer or other backends.
self.backend_name = backend.get_name()
logger.debug("Detected attention backend %s", self.backend_name)

def schedule_write_blocks(
self,
request_id: str,
request_id: ReqId,
transfer_id: TransferId,
dst_engine_id: str,
local_block_ids: list[int],
remote_block_ids: list[int] | None,
Expand All @@ -748,6 +789,7 @@ def schedule_write_blocks(

Args:
request_id: Unique identifier for the request
transfer_id: Unique identifier for the transfer
dst_engine_id: Destination engine ID
local_block_ids: Local block IDs to transfer
remote_block_ids: Hint for remote block IDs
Expand All @@ -768,6 +810,7 @@ def schedule_write_blocks(

task = WriteTask(
request_id=request_id,
transfer_id=transfer_id,
dst_engine_id=dst_engine_id,
local_block_ids=local_block_ids,
remote_block_ids_hint=remote_block_ids,
Expand Down Expand Up @@ -1010,7 +1053,7 @@ def _moriio_handshake(
return {remote_agent_name}

def _background_moriio_handshake(
self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta
self, req_id: ReqId, remote_engine_id: EngineId, meta: ReqMeta
):
# Do MoRIIO handshake in background and add to _ready_requests when done.
fut = None
Expand Down Expand Up @@ -1189,6 +1232,13 @@ def get_finished(self) -> tuple[set[str], set[str]]:
else:
done_recving = self._pop_done_transfers()

done_recving = {
self.transfer_id_to_request_id[id]
for id in filter(
lambda id: id in self.transfer_id_to_request_id, done_recving
)
}

return done_sending, done_recving

def _pop_done_transfers(self) -> set[str]:
Expand Down Expand Up @@ -1269,6 +1319,7 @@ def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
Start loading by triggering non-blocking moriio_xfer.
We check for these trnxs to complete in each step().
"""
self.transfer_id_to_request_id = metadata.transfer_id_to_request_id
if self.is_producer:
self.moriio_wrapper.async_wait_reqid()
return
Expand Down Expand Up @@ -1332,9 +1383,10 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
remote_notify_port=meta.remote_notify_port,
)

def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer):
def _write_blocks_for_req(self, req_id: ReqId, meta: ReqMeta, layer_name, kv_layer):
self.schedule_write_blocks(
request_id=req_id,
transfer_id=meta.transfer_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
Expand Down
Loading
Loading