-
Notifications
You must be signed in to change notification settings - Fork 6.4k
NixlKVManager: async multi-threaded KV transfer #20680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
76640b2
563b805
4c7f7c1
7c0f891
da64ed0
89467c9
9299bba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,10 @@ | |
| CommonKVReceiver, | ||
| CommonKVSender, | ||
| ) | ||
| from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous | ||
| from sglang.srt.disaggregation.common.utils import ( | ||
| FastQueue, | ||
| group_concurrent_contiguous, | ||
| ) | ||
| from sglang.srt.disaggregation.utils import ( | ||
| DisaggregationMode, | ||
| filter_kv_indices_for_cp_rank, | ||
|
|
@@ -68,6 +71,17 @@ def from_zmq(cls, msg: List[bytes]): | |
| ) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TransferKVChunk: | ||
| room: int | ||
| prefill_kv_indices: npt.NDArray[np.int32] | ||
| index_slice: slice | ||
| is_last: bool | ||
| chunk_id: int | ||
| prefill_aux_index: Optional[int] | ||
| state_indices: Optional[List[int]] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class KVArgsRegisterInfo: | ||
| """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread.""" | ||
|
|
@@ -203,6 +217,14 @@ def __init__( | |
| self.register_buffer_to_engine() | ||
|
|
||
| if self.disaggregation_mode == DisaggregationMode.PREFILL: | ||
| transfer_queue_size = envs.SGLANG_DISAGGREGATION_QUEUE_SIZE.get() | ||
| self.transfer_queues: List[FastQueue] = [ | ||
| FastQueue() for _ in range(transfer_queue_size) | ||
| ] | ||
| for queue in self.transfer_queues: | ||
| threading.Thread( | ||
| target=self.transfer_worker, args=(queue,), daemon=True | ||
| ).start() | ||
| self._start_bootstrap_thread() | ||
| elif self.disaggregation_mode == DisaggregationMode.DECODE: | ||
| self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( | ||
|
|
@@ -300,6 +322,118 @@ def _handle_node_failure(self, failed_bootstrap_addr): | |
| logger.error(f"Let room {room} be failed due to prefill down") | ||
| self.update_status(room, KVPoll.Failed) | ||
|
|
||
| def check_status(self, bootstrap_room: int): | ||
| return self.request_status.get(bootstrap_room, KVPoll.Bootstrapping) | ||
|
|
||
| def transfer_worker(self, queue: FastQueue): | ||
| while True: | ||
| kv_chunk: TransferKVChunk = queue.get() | ||
| room = kv_chunk.room | ||
| try: | ||
| if self.check_status(room) == KVPoll.Failed: | ||
| continue | ||
|
|
||
| assert room in self.transfer_infos | ||
|
|
||
| self.update_status(room, KVPoll.Transferring) | ||
|
|
||
| reqs_to_be_processed = list(self.transfer_infos[room].values()) | ||
| handles: List = [] | ||
|
|
||
| for req in reqs_to_be_processed: | ||
| assert room == req.room | ||
| if req.is_dummy(): | ||
| continue | ||
|
|
||
| chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] | ||
| if len(chunked_dst_kv_indice) < len(kv_chunk.prefill_kv_indices): | ||
| kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[ | ||
| : len(chunked_dst_kv_indice) | ||
| ] | ||
| assert req.agent_name in self.decode_kv_args_table | ||
|
|
||
| notif = f"{req.room}_kv_{kv_chunk.chunk_id}_{int(kv_chunk.is_last)}_{self.kv_args.engine_rank}" | ||
| decode_tp_size = self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].decode_tp_size | ||
|
|
||
| if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): | ||
| kv_xfer_handle = self.send_kvcache( | ||
| req.agent_name, | ||
| kv_chunk.prefill_kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| ) | ||
| else: | ||
| kv_xfer_handle = self.send_kvcache_slice( | ||
| req.agent_name, | ||
| kv_chunk.prefill_kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| prefill_tp_size=self.attn_tp_size, | ||
| decode_tp_size=decode_tp_size, | ||
| decode_tp_rank=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].decode_tp_rank, | ||
| dst_kv_item_len=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].dst_kv_item_len, | ||
| ) | ||
| handles.append(kv_xfer_handle) | ||
|
|
||
| if kv_chunk.is_last: | ||
| if kv_chunk.state_indices is not None: | ||
| dst_info = self.decode_kv_args_table[req.agent_name] | ||
| state_xfer_handle = self.maybe_send_extra( | ||
| req.agent_name, | ||
| kv_chunk.state_indices, | ||
| dst_info.dst_state_data_ptrs, | ||
| req.dst_state_indices, | ||
| dst_info.gpu_id, | ||
| f"{req.room}_state_{self.kv_args.engine_rank}", | ||
| decode_tp_size, | ||
| decode_tp_rank=dst_info.decode_tp_rank, | ||
| dst_state_item_lens=dst_info.dst_state_item_lens, | ||
| dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, | ||
| ) | ||
| if state_xfer_handle is not None: | ||
| handles.append(state_xfer_handle) | ||
|
|
||
| if kv_chunk.prefill_aux_index is None: | ||
| raise RuntimeError("Missing aux index for last chunk") | ||
| aux_xfer_handle = self.send_aux( | ||
| req.agent_name, | ||
| kv_chunk.prefill_aux_index, | ||
| self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, | ||
| req.dst_aux_index, | ||
| f"{req.room}_aux", | ||
| ) | ||
| handles.append(aux_xfer_handle) | ||
|
|
||
| while handles: | ||
| states = [self.agent.check_xfer_state(h) for h in handles] | ||
| if any(s == "ERR" for s in states): | ||
| raise RuntimeError(f"NIXL transfer encountered ERR room={room}") | ||
| if all(s == "DONE" for s in states): | ||
| break | ||
| time.sleep(0) | ||
|
|
||
| if kv_chunk.is_last: | ||
| if room in self.transfer_infos: | ||
| del self.transfer_infos[room] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see multiple threads modifying and reading self.transfer_infos. Do we need locking? Or pass the data to the workers in a safer way? |
||
| self.update_status(room, KVPoll.Success) | ||
| else: | ||
| self.update_status(room, KVPoll.Transferring) | ||
| except Exception as e: | ||
| reason = f"Prefill transfer worker error room={room}: {e}" | ||
| logger.exception(reason) | ||
| self.record_failure(room, reason) | ||
| self.update_status(room, KVPoll.Failed) | ||
|
|
||
| def register_buffer_to_engine(self): | ||
| kv_addrs = [] | ||
| for kv_data_ptr, kv_data_len in zip( | ||
|
|
@@ -872,81 +1006,22 @@ def add_transfer_request( | |
| assert self.disaggregation_mode == DisaggregationMode.PREFILL | ||
| assert not is_last or (is_last and aux_index is not None) | ||
|
|
||
| reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() | ||
| handles = [] | ||
| for req in reqs_to_be_processed: | ||
| assert bootstrap_room == req.room | ||
| if req.is_dummy(): | ||
| continue | ||
|
|
||
| chunked_dst_kv_indice = req.dst_kv_indices[index_slice] | ||
| assert len(chunked_dst_kv_indice) == len(kv_indices) | ||
| assert req.agent_name in self.decode_kv_args_table | ||
|
|
||
| notif = ( | ||
| f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.engine_rank}" | ||
| if bootstrap_room not in self.request_status: | ||
| self.update_status(bootstrap_room, KVPoll.Bootstrapping) | ||
|
|
||
| shard_idx = bootstrap_room % len(self.transfer_queues) | ||
| self.transfer_queues[shard_idx].put( | ||
| TransferKVChunk( | ||
| room=bootstrap_room, | ||
| prefill_kv_indices=kv_indices, | ||
| index_slice=index_slice, | ||
| is_last=is_last, | ||
| chunk_id=chunk_id, | ||
| prefill_aux_index=aux_index, | ||
| state_indices=state_indices, | ||
| ) | ||
| decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size | ||
|
|
||
| if self.is_mla_backend or (decode_tp_size == self.attn_tp_size): | ||
| kv_xfer_handle = self.send_kvcache( | ||
| req.agent_name, | ||
| kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| ) | ||
| else: | ||
| kv_xfer_handle = self.send_kvcache_slice( | ||
| req.agent_name, | ||
| kv_indices, | ||
| self.decode_kv_args_table[req.agent_name].dst_kv_ptrs, | ||
| chunked_dst_kv_indice, | ||
| self.decode_kv_args_table[req.agent_name].gpu_id, | ||
| notif, | ||
| prefill_tp_size=self.attn_tp_size, | ||
| decode_tp_size=decode_tp_size, | ||
| decode_tp_rank=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].decode_tp_rank, | ||
| dst_kv_item_len=self.decode_kv_args_table[ | ||
| req.agent_name | ||
| ].dst_kv_item_len, | ||
| ) | ||
|
|
||
| handles.append(kv_xfer_handle) | ||
| # Only the last chunk we need to send the aux data. | ||
| if is_last: | ||
| if state_indices is not None: | ||
| dst_info = self.decode_kv_args_table[req.agent_name] | ||
| state_xfer_handle = self.maybe_send_extra( | ||
| req.agent_name, | ||
| state_indices, | ||
| dst_info.dst_state_data_ptrs, | ||
| req.dst_state_indices, | ||
| dst_info.gpu_id, | ||
| f"{req.room}_state_{self.kv_args.engine_rank}", | ||
| decode_tp_size, | ||
| decode_tp_rank=dst_info.decode_tp_rank, | ||
| dst_state_item_lens=dst_info.dst_state_item_lens, | ||
| dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor, | ||
| ) | ||
| if state_xfer_handle is not None: | ||
| handles.append(state_xfer_handle) | ||
|
|
||
| assert aux_index is not None | ||
| aux_xfer_handle = self.send_aux( | ||
| req.agent_name, | ||
| aux_index, | ||
| self.decode_kv_args_table[req.agent_name].dst_aux_ptrs, | ||
| req.dst_aux_index, | ||
| f"{req.room}_aux", | ||
| ) | ||
| handles.append(aux_xfer_handle) | ||
| if is_last: | ||
| del self.transfer_infos[bootstrap_room] | ||
| return handles | ||
| ) | ||
| return None | ||
|
|
||
| def update_transfer_status(self): | ||
| # Process notifications from received transfers. | ||
|
|
@@ -1035,7 +1110,6 @@ def __init__( | |
| pp_rank: int, | ||
| ): | ||
| super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank) | ||
| self.xfer_handles = [] | ||
| self.has_sent = False | ||
| self.chunk_id = 0 | ||
|
|
||
|
|
@@ -1062,7 +1136,7 @@ def send( | |
| self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Success) | ||
| return | ||
|
|
||
| new_xfer_handles = self.kv_mgr.add_transfer_request( | ||
| self.kv_mgr.add_transfer_request( | ||
| self.bootstrap_room, | ||
| kv_indices, | ||
| index_slice, | ||
|
|
@@ -1071,21 +1145,21 @@ def send( | |
| self.aux_index, | ||
| state_indices, | ||
| ) | ||
| self.xfer_handles.extend(new_xfer_handles) | ||
| self.chunk_id += 1 | ||
| if is_last: | ||
| self.has_sent = True | ||
| del self.kv_mgr.request_status[self.bootstrap_room] | ||
|
|
||
| def poll(self) -> KVPoll: | ||
| status = self.kv_mgr.check_status(self.bootstrap_room) | ||
| if not self.has_sent: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we return status here in all cases, so we can just return directly and remove the ifs |
||
| return self.kv_mgr.check_status(self.bootstrap_room) | ||
| states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] | ||
| if all([x == "DONE" for x in states]): | ||
| return KVPoll.Success # type: ignore | ||
| if any([x == "ERR" for x in states]): | ||
| raise Exception("KVSender transfer encountered an error.") | ||
| return KVPoll.WaitingForInput # type: ignore | ||
| return status | ||
| if status in (KVPoll.Success, KVPoll.Failed): | ||
| return status | ||
| return status | ||
|
|
||
| def clear(self): | ||
| if self.bootstrap_room in self.kv_mgr.request_status: | ||
| self.kv_mgr.request_status.pop(self.bootstrap_room) | ||
|
|
||
| def failure_exception(self): | ||
| raise RuntimeError("NIXL KVSender Exception") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old code has:
assert len(chunked_dst_kv_indice) == len(kv_indices)Why was this changed?