Skip to content
Merged
Changes from all commits
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
4730522
[Update] LMcache connector v1 implementation
ApostaC Apr 17, 2025
4162650
[Add] examples for disaggregated prefill
ApostaC Apr 17, 2025
3ccd34c
[add] extra information about evns
ApostaC Apr 18, 2025
161010c
Initial stubs for P/D scheduling changes
tlrmchlsmth Apr 18, 2025
38a2eb8
Merge branch 'main' into local-dev/lmcache-v1-connector-pr
tlrmchlsmth Apr 19, 2025
6c3191f
Merge branch 'local-dev/lmcache-v1-connector-pr' into pd_scheduling_l…
tlrmchlsmth Apr 19, 2025
1f708e9
Updates
tlrmchlsmth Apr 19, 2025
038f2f8
Rs branch (#3)
robertgshaw2-redhat Apr 20, 2025
5c4fc6f
Rs branch (#5)
robertgshaw2-redhat Apr 20, 2025
1800689
Remove Unneeded Arguments (#7)
robertgshaw2-redhat Apr 21, 2025
7a1f25f
Improve disagg-example.sh (#8)
tlrmchlsmth Apr 21, 2025
2385d8e
updated
robertgshaw2-redhat Apr 22, 2025
6eeb47c
updated
robertgshaw2-redhat Apr 22, 2025
266fcee
updated
robertgshaw2-redhat Apr 22, 2025
f7e16f1
updated
robertgshaw2-redhat Apr 22, 2025
f591b8e
added connector
robertgshaw2-redhat Apr 22, 2025
184d0b6
updated
robertgshaw2-redhat Apr 22, 2025
d4a9e5b
updated
robertgshaw2-redhat Apr 22, 2025
4b0d1dc
updated
robertgshaw2-redhat Apr 22, 2025
bfef039
updated
robertgshaw2-redhat Apr 22, 2025
54f4a43
updated
robertgshaw2-redhat Apr 22, 2025
e604b09
updated
robertgshaw2-redhat Apr 22, 2025
2fc00ad
updated
robertgshaw2-redhat Apr 22, 2025
e5967b6
updated
robertgshaw2-redhat Apr 22, 2025
f1bc0f7
updated
robertgshaw2-redhat Apr 22, 2025
1cea2bb
updated
robertgshaw2-redhat Apr 22, 2025
489e4c0
updated
robertgshaw2-redhat Apr 22, 2025
437ac91
updated
robertgshaw2-redhat Apr 22, 2025
ea47af7
updated
robertgshaw2-redhat Apr 22, 2025
554b27d
updated
robertgshaw2-redhat Apr 22, 2025
1aea5ba
updated
robertgshaw2-redhat Apr 22, 2025
e0c112b
updated
robertgshaw2-redhat Apr 22, 2025
c7717c1
update
robertgshaw2-redhat Apr 22, 2025
e0af1db
remove
robertgshaw2-redhat Apr 22, 2025
9533471
updated
robertgshaw2-redhat Apr 22, 2025
2eb068e
updated
robertgshaw2-redhat Apr 22, 2025
0f2b7e3
updated
robertgshaw2-redhat Apr 22, 2025
6127cb8
updated
robertgshaw2-redhat Apr 22, 2025
568249e
updated
robertgshaw2-redhat Apr 23, 2025
ccb44ea
seems to load properly
robertgshaw2-redhat Apr 23, 2025
3785905
updated
robertgshaw2-redhat Apr 23, 2025
8a94b2e
updated
robertgshaw2-redhat Apr 24, 2025
ac19437
updated
robertgshaw2-redhat Apr 24, 2025
6391ec9
updated
robertgshaw2-redhat Apr 24, 2025
7dd764b
updated
robertgshaw2-redhat Apr 24, 2025
97316d9
updated
robertgshaw2-redhat Apr 24, 2025
2771353
Revert "updated"
robertgshaw2-redhat Apr 24, 2025
baed1bf
updated
robertgshaw2-redhat Apr 24, 2025
d0ad6d9
updated
robertgshaw2-redhat Apr 24, 2025
055885e
updated
robertgshaw2-redhat Apr 24, 2025
5ed3806
updated
robertgshaw2-redhat Apr 24, 2025
58266b5
updated
robertgshaw2-redhat Apr 24, 2025
344d9da
stash
robertgshaw2-redhat Apr 24, 2025
2996638
added
robertgshaw2-redhat Apr 24, 2025
bcc88dc
diffs for local dev on macos
Apr 24, 2025
62205ae
updated
Apr 24, 2025
b4609a5
update
Apr 24, 2025
5d78ba6
updaed
Apr 25, 2025
c1f26b9
updated
Apr 25, 2025
9b9ef36
updated
Apr 25, 2025
c60639e
Checkpoint.
tlrmchlsmth Apr 25, 2025
006dda3
Merge branch 'pd_scheduling_nixl' of https://github.com/robertgshaw2-…
tlrmchlsmth Apr 25, 2025
c5e023e
updated
Apr 25, 2025
8b0c93c
Cleanup
tlrmchlsmth Apr 26, 2025
5e45d90
WIP
tlrmchlsmth Apr 26, 2025
20a5491
updated
Apr 27, 2025
cee3c61
updated
Apr 27, 2025
5972571
updated on scheduler side
Apr 27, 2025
1b69d33
updated
Apr 27, 2025
74e105a
Merge remote-tracking branch 'rs/pd_scheduling_rob_dev' into nixl_int…
tlrmchlsmth Apr 27, 2025
8adf1ad
updated
Apr 27, 2025
21ab3d9
updated
Apr 27, 2025
3a27bbc
updated
Apr 27, 2025
f252df9
updated
Apr 27, 2025
8104803
updated
Apr 27, 2025
10bbe21
Hacking away
tlrmchlsmth Apr 27, 2025
a14278c
Merge remote-tracking branch 'rs/pd_scheduling_rob_dev_2' into nixl_i…
tlrmchlsmth Apr 27, 2025
65ea91f
cleanup
Apr 27, 2025
f2550ef
ensure request removed from running list
Apr 27, 2025
985bac3
Runs E2E. Garbage output. Crashes on 2nd request
tlrmchlsmth Apr 27, 2025
bf37a7d
update
tlrmchlsmth Apr 27, 2025
ebe1263
updated
Apr 27, 2025
a008aa3
updated
Apr 27, 2025
195dceb
rename files
Apr 27, 2025
e2cc365
updated
Apr 27, 2025
2324a50
Merge remote-tracking branch 'rs/pd_scheduling_rob_dev_2' into nixl_i…
tlrmchlsmth Apr 27, 2025
b4b64fe
updated
Apr 27, 2025
6686397
updated
Apr 27, 2025
8736043
updated
Apr 27, 2025
dcbf6e5
updated
Apr 27, 2025
7c8e21a
update
Apr 27, 2025
a4855d2
Second request no longer crashes
tlrmchlsmth Apr 27, 2025
0914040
Merge remote-tracking branch 'rs/pd_scheduling_rob_dev_2' into nixl_i…
tlrmchlsmth Apr 27, 2025
c5b3053
Remove gpu_model_runner hacks
tlrmchlsmth Apr 27, 2025
7502819
Clean up Justfile
tlrmchlsmth Apr 28, 2025
7768b96
[Bugfix] Stale finished requests in EMPTY_MODEL_RUNNER_OUTPUT
tlrmchlsmth Apr 28, 2025
a5950b7
update
tlrmchlsmth Apr 28, 2025
610a357
justfile edits
tlrmchlsmth Apr 28, 2025
5b026ab
Update
tlrmchlsmth Apr 28, 2025
f2fadd6
Fixes - lm_eval gsm8k has correctness
tlrmchlsmth Apr 29, 2025
4060f86
"just delete the assert"
tlrmchlsmth Apr 29, 2025
bfe9d19
fixup precommit issues
tlrmchlsmth Apr 29, 2025
ced529a
Fixes
tlrmchlsmth Apr 29, 2025
83f2872
updated (#12)
robertgshaw2-redhat Apr 30, 2025
e853b3c
Add Accuracy Test (#13)
robertgshaw2-redhat Apr 30, 2025
1c45ed1
Preemption Bugfixes (#15)
robertgshaw2-redhat May 1, 2025
a45a694
updated (#16)
robertgshaw2-redhat May 1, 2025
f6d0ac5
Merge branch 'main' into nixl_integration
tlrmchlsmth May 1, 2025
2f9a3f3
Fix Bad Merge | Fix Memory Leak in Upstream (#18)
robertgshaw2-redhat May 2, 2025
90ba831
updated
robertgshaw2-redhat May 2, 2025
9378594
cleanup code
robertgshaw2-redhat May 2, 2025
790c1b2
cleanup code
robertgshaw2-redhat May 2, 2025
e4802fd
updated
robertgshaw2-redhat May 2, 2025
f4c2915
updated
robertgshaw2-redhat May 2, 2025
6346a64
updated
robertgshaw2-redhat May 2, 2025
a8832ec
stash
robertgshaw2-redhat May 2, 2025
dd0935a
complete merge
robertgshaw2-redhat May 3, 2025
42a28ff
updated
robertgshaw2-redhat May 3, 2025
422a9ac
updated
robertgshaw2-redhat May 3, 2025
836e76b
updatted
robertgshaw2-redhat May 3, 2025
0aafe4a
updated
robertgshaw2-redhat May 3, 2025
4fe1829
updated
robertgshaw2-redhat May 3, 2025
d6b2531
Merge remote-tracking branch 'nm-fork/disagg_pd_dev' into tp-gt-1
robertgshaw2-redhat May 3, 2025
1bbd623
revert
robertgshaw2-redhat May 3, 2025
afdcd2f
more spurious changes
robertgshaw2-redhat May 3, 2025
6790c00
updated
robertgshaw2-redhat May 3, 2025
87277d6
updated
robertgshaw2-redhat May 3, 2025
8ff421e
updated
robertgshaw2-redhat May 3, 2025
79af352
updated
robertgshaw2-redhat May 3, 2025
e21f5f9
updated
robertgshaw2-redhat May 3, 2025
39fee21
updated
robertgshaw2-redhat May 3, 2025
99a5afd
updated
robertgshaw2-redhat May 3, 2025
1e0db0b
updated
robertgshaw2-redhat May 3, 2025
9bdbe38
updated
robertgshaw2-redhat May 3, 2025
911e480
updated
robertgshaw2-redhat May 3, 2025
357bd03
updated
robertgshaw2-redhat May 3, 2025
93a32eb
updated
robertgshaw2-redhat May 3, 2025
181d68d
updated
robertgshaw2-redhat May 3, 2025
01e5864
updated
robertgshaw2-redhat May 3, 2025
04cba85
updated
robertgshaw2-redhat May 4, 2025
06c5c39
updated
robertgshaw2-redhat May 4, 2025
9a87c34
updated
robertgshaw2-redhat May 4, 2025
027689d
updated
robertgshaw2-redhat May 4, 2025
48add56
Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
robertgshaw2-redhat May 4, 2025
ed6fd4f
Update vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
robertgshaw2-redhat May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 160 additions & 125 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.sampling_params import KVTransferParams
from vllm.utils import round_down
Expand Down Expand Up @@ -47,8 +50,6 @@ class NixlAgentMetadata(
dict=True):
engine_id: str
agent_metadata: bytes
# Base addr for each layer for KVs
# NOTE: we will need another list for TP>1
kv_caches_base_addr: list[int]
num_blocks: int

Expand Down Expand Up @@ -222,47 +223,53 @@ def __init__(self, engine_id: str):

# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> list[agent_names] (1 per rank).
self._remote_agents: dict[str, list[str]] = {}
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}

# Metadata.
self.engine_id = engine_id
self.rank = 0
self.rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()

# KV Caches and nixl tracking data.
self.kv_caches: dict[str, torch.Tensor] = {}

# Map of engine_id -> kv_caches_base_addr
# For Local: base addr for *this* rank, each layer for K,V
# For Remote: base addr for *each* rank, each layer for K,V
# KV_CACHES_ADDR_TYPE = Union[list[tuple[int, int]],
# list[list[tuple[int, int]]]]
self.kv_caches_base_addr: dict[str, list[int]] = {}

# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0

# Map of tp_mult -> nixl_prepped_dlist_handle (int).
self.src_xfer_side_handles: dict[int, int] = {}
# Map of engine_id -> map[tp_mult -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: defaultdict[str,
dict[int,
int]] = defaultdict(dict)
# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[str, int] = {}

# Map of engine_id -> num_blocks.
self.dst_num_blocks: dict[str, int] = {}
self._registered_descs: list[Any] = []

# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers: dict[str, list[Any]] = defaultdict(list[Any])
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
list[Any])

# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self._done_recving_count: defaultdict[str,
int] = defaultdict(lambda: 0)
self._done_sending_count: defaultdict[str,
int] = defaultdict(lambda: 0)

# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None

@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event):
ready_event: threading.Event, rank: int):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
Expand All @@ -280,8 +287,13 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,

# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT
with zmq_ctx(zmq.ROUTER, f"tcp://{host}:{port}") as sock:
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
path = f"tcp://{host}:{port}"
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
while True:
identity, _, msg = sock.recv_multipart()
Expand All @@ -294,7 +306,12 @@ def _nixl_handshake(self, host: str, port: int):
"""Do a NIXL handshake with a remote instance."""

start_time = time.perf_counter()
with zmq_ctx(zmq.REQ, f"tcp://{host}:{port}") as sock:
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = f"tcp://{host}:{port + self.rank}"
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
Expand Down Expand Up @@ -364,90 +381,125 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event),
args=(metadata, ready_event, self.rank),
daemon=True,
name="nixl_handshake_listener")
import os
if os.getenv("SKIP", None) != "1":
self._nixl_handshake_listener_t.start()
ready_event.wait()
self._nixl_handshake_listener_t.start()
ready_event.wait()

def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, tp_idx=0):
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
engine_id = nixl_agent_meta.engine_id
if engine_id in self._remote_agents:
return

num_blocks = nixl_agent_meta.num_blocks
logger.debug("Adding remote agent %s %s", engine_id, str(num_blocks))

agent_names = [
self.nixl_wrapper.add_remote_agent(nixl_agent_meta.agent_metadata)
]

self._remote_agents[engine_id] = agent_names
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr

# NOTE: once we support heterogeneous TP, we will need maintain the
# src for each TP multiplier.
# NOTE(rob): Dynamo only supports D TP size > P TP size.
# https://github.com/vllm-project/vllm/pull/16124/files#diff-876efa5533f5dcff3fba850e8684a47d53c112e287988957c115b11691374f4bR331 # noqa: E501
# Create descs and xfer side handles.
tp_multiplier = 1
dst_block_len = self.block_len // tp_multiplier
if tp_multiplier not in self.src_xfer_side_handles:
# Create descs and xfer side handles.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
for i in range(tp_multiplier):
tp_multiplier_offset = tp_idx * dst_block_len
blocks_data.append(
(base_addr + block_offset + tp_multiplier_offset,
dst_block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handles[tp_multiplier] = (
self.nixl_wrapper.prep_xfer_dlist("", descs))

# create dst xfer side handles
self.dst_num_blocks[engine_id] = num_blocks
# Create src descs and xfer side handles.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)

# Create dst descs and xfer side handles.
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
blocks_data = []
for base_addr in self.kv_caches_base_addr[engine_id]:
for block_id in range(num_blocks):
block_offset = block_id * dst_block_len
blocks_data.append((base_addr + block_offset, dst_block_len,
self.rank * tp_multiplier))
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s and rank %s",
len(blocks_data), engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[engine_id][tp_idx] = (
self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id][self.rank * tp_multiplier +
tp_idx], descs))
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id], descs)

def get_finished(self) -> tuple[set[str], set[str]]:
"""Get requests that are done sending or recving."""
"""
Get requests that are done sending or recving.

In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To avoid race
ensure trnxs are done before adding to finished, Ranks 1 to
N-1 communicate to Rank 0 once their transaction is done.
Rank 0 only returns finished once all ranks are complete.
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"get_finished: %s requests done sending "
"and %s requests done recving", len(done_sending),
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.rank, len(done_sending),
len(done_recving))
return done_sending, done_recving

if self.world_size == 1:
return done_sending, done_recving

# Rank 0: get finished from all other ranks.
if self.rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1

# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is how Dyanmo does it (with the tp_group)

I wonder if there is a better way cc @njhill

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-redhat here is an alternative to consider robertgshaw2-redhat#7

Guess this might be preferable latency wise since we don't have additional gather collective, but not sure (since now scheduler needs to receive from all ranks .. though it was doing this anyhow until recently).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets just time things and see which one is faster

Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat May 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like for TP=2, the setup I have is taking <1ms, so I think this is good enough for now as I would prefer to keep the changes in this file if possible

for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1

# Return ids that finished on all ranks to the scheduler.
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
del self._done_recving_count[req_id]
all_done_recving.add(req_id)

all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)

return all_done_sending, all_done_recving

# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)

# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving

def _get_new_notifs(self) -> set[str]:
"""Get req_ids which got a remote xfer message."""

notified_req_ids: set[str] = set()
# TODO: handle the TP case (N notifies for TP=N).
# See: vllm/worker/worker_base.py L476 in DynamoPR.
for req_ids in self.nixl_wrapper.get_new_notifs().values():
for req_id in req_ids:
assert req_id not in notified_req_ids
Expand Down Expand Up @@ -539,61 +591,44 @@ def _read_blocks(
if len(local_block_ids) == 0:
return

# TODO: support TP multipliers.
tp_multiplier = 1
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, "all", remote_block_ids)
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]

# Read the data from the remote.
for i in range(tp_multiplier):
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id,
"all",
local_block_ids,
i=None, #TODO: Enable both tp_multiplier and staging_ranges.
tp_multiplier=tp_multiplier,
staging_ranges=None)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
remote_xfer_side_handle = self.dst_xfer_side_handles[
dst_engine_id][i]

# NOTE(rob): we use the request_id as the notify msg, so we
# must use the same request_id in both the p and d workers.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=request_id.encode("utf-8"),
)

# Call transfer to begin the async transfer
# We will check this is done in the next forward pass.
self.nixl_wrapper.transfer(handle)
self._recving_transfers[request_id].append(handle)
# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]

def _get_block_descs_ids(self,
engine_id,
region_ids,
block_ids,
i=None,
tp_multiplier=1,
staging_ranges=None):
# Get descs ids.
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)

# Prepare transfer with Nixl.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=request_id.encode("utf-8"),
)

if region_ids == "all":
region_ids = range(self.num_regions)
if block_ids == "all":
block_ids = range(self.num_blocks)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)

descs_ids = []
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)

if i is not None:
raise NotImplementedError("Prefill and Decode instances must have "
"the same TP size.")
def _get_block_descs_ids(self, engine_id: str,
block_ids: list[int]) -> list[int]:
"""Get the descs ids for a set of block ids."""
# TODO(rob): should we precompute this?

# range(1) for MLA, range(2) otherwise.
region_ids = range(self.num_regions)
num_blocks = self.dst_num_blocks[engine_id]

# Compute the desc ids for each block.
descs_ids: list[int] = []
for reg_id in region_ids:
for block_id in block_ids:
descs_ids.append(reg_id * num_blocks + block_id)
Expand Down