Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,35 @@ def from_zmq(cls, msg: List[bytes]):
class TransferStatus:
"""Used by KV Receiver to know when a transfer is done."""

# KV chunk IDs that have been received.
received_kvs: Set[int] = dataclasses.field(default_factory=set)
# Number of kv chunks to expect, will know this after last chunk is received.
num_kvs_expected: Optional[int] = None
# KV chunks received per pp_rank: {pp_rank: set of chunk_ids}
received_kvs_per_pp: Dict[int, Set[int]] = dataclasses.field(
default_factory=lambda: defaultdict(set)
)
# Expected chunk count per pp_rank (set when is_last=True): {pp_rank: expected_count}
expected_kvs_per_pp: Dict[int, int] = dataclasses.field(default_factory=dict)
# Number of PP ranks expected to send data.
num_pp_ranks_expected: Optional[int] = None
# Whether aux data has been received.
received_aux: bool = False
# Mark as failed
is_failure: bool = False

def is_done(self):
if self.num_kvs_expected is None:
if self.is_failure:
return True
if self.num_pp_ranks_expected is None or not self.received_aux:
return False
# Check for failure state
if self.num_kvs_expected == -1:
return True # Failed transfers are considered "done"
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
# All PP ranks must have reported their expected count
if len(self.expected_kvs_per_pp) < self.num_pp_ranks_expected:
return False
# Each PP rank must have received all expected chunks
for pp_rank, expected in self.expected_kvs_per_pp.items():
if len(self.received_kvs_per_pp[pp_rank]) != expected:
return False
return True

def is_failed(self):
return self.num_kvs_expected == -1
return self.is_failure


class NixlKVManager(CommonKVManager):
Expand Down Expand Up @@ -244,8 +256,8 @@ def _handle_node_failure(self, failed_bootstrap_addr):
room in self.transfer_statuses
and not self.transfer_statuses[room].is_done()
):
# Mark the transfer as failed by setting a special state
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
# Mark the transfer as failed
self.transfer_statuses[room].is_failure = True
affected_rooms.append(room)

logger.error(
Expand Down Expand Up @@ -573,7 +585,7 @@ def add_transfer_request(
assert len(chunked_dst_kv_indice) == len(kv_indices)
assert req.agent_name in self.decode_kv_args_table

notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
notif = f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}"
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size

if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
Expand Down Expand Up @@ -627,14 +639,26 @@ def update_transfer_status(self):
# the message sender. But the bootstrap room alone should be
# sufficient to map the status.
for msg in messages:
components = msg.decode("ascii").split("_")
components = msg.decode("ascii").split("_", 4)
room = int(components[0])
if components[1] == "kv":
chunk_id = int(components[2])
is_last = bool(int(components[3]))
self.transfer_statuses[room].received_kvs.add(chunk_id)
pp_rank = int(components[4]) if len(components) > 4 else 0
# Track received chunks per pp_rank
self.transfer_statuses[room].received_kvs_per_pp[pp_rank].add(
chunk_id
)
if is_last:
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
# Record expected chunk count for this pp_rank
self.transfer_statuses[room].expected_kvs_per_pp[pp_rank] = (
chunk_id + 1
)
# Set num_pp_ranks_expected from table (or default to 1)
if self.transfer_statuses[room].num_pp_ranks_expected is None:
self.transfer_statuses[room].num_pp_ranks_expected = (
self.required_prefill_response_num_table.get(room, 1)
)
elif components[1] == "aux":
self.transfer_statuses[room].received_aux = True

Expand Down
Loading