Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,13 @@ def _connect(self, endpoint: str, is_ipv6: bool = False):
return socket

def get_mha_kv_ptrs_with_pp(
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int], dst_non_draft_kv_data_lens: int,
) -> Tuple[List[int], List[int], List[int], List[int], int]:
start_layer = self.kv_args.prefill_start_layer
num_kv_layers = len(src_kv_ptrs) // 2
end_layer = start_layer + num_kv_layers
dst_num_total_layers = len(dst_kv_ptrs) // 2
dst_non_draft_num_total_layers = dst_non_draft_kv_data_lens // 2
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
if num_kv_layers == dst_num_total_layers:
Expand All @@ -318,6 +319,7 @@ def get_mha_kv_ptrs_with_pp(
elif (
num_kv_layers < dst_num_total_layers
and dst_num_total_layers % num_kv_layers != 0
and self.kv_args.prefill_pp_size == 1
):
# Case: Decode has draft model KV while Prefill is deployed without speculative decoding
# dst_kv_ptrs layout: [K_main..., V_main..., draft_K..., draft_V...]
Expand All @@ -331,7 +333,7 @@ def get_mha_kv_ptrs_with_pp(
# Decode pp size should be equal to prefill pp size or 1
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
dst_non_draft_num_total_layers + start_layer : dst_non_draft_num_total_layers + end_layer
]
layers_current_pp_stage = len(src_k_ptrs)
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _init_kv_manager(self) -> CommonKVManager:
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
kv_args.non_draft_kv_data_lens = len(kv_data_ptrs)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
Expand Down
81 changes: 72 additions & 9 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class KVArgsRegisterInfo:
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_non_draft_kv_data_lens: int
dst_aux_ptrs: list[int]
dst_state_data_ptrs: list[int]
dst_tp_rank: int
Expand Down Expand Up @@ -139,6 +140,7 @@ def from_zmq(cls, msg: List[bytes]):
if len(msg) > 11 and len(msg[11]) > 0
else []
),
dst_non_draft_kv_data_lens=int(msg[12].decode("ascii")),
)


Expand Down Expand Up @@ -247,6 +249,7 @@ def _send_kvcache_generic(
mooncake_session_id: str,
src_data_ptrs: list[int],
dst_data_ptrs: list[int],
dst_non_draft_kv_data_lens : int,
item_lens: list[int],
prefill_data_indices: npt.NDArray[np.int32],
dst_data_indices: npt.NDArray[np.int32],
Expand Down Expand Up @@ -278,7 +281,7 @@ def _send_kvcache_generic(
]
else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs, dst_non_draft_kv_data_lens)
)
# item_lens structure: [k_layer0, k_layer1, ..., k_layerN, v_layer0, v_layer1, ..., v_layerN]
# Use correct item lengths for K and V separately
Expand Down Expand Up @@ -355,13 +358,15 @@ def send_kvcache(
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_non_draft_kv_data_lens : int,
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
return self._send_kvcache_generic(
mooncake_session_id=mooncake_session_id,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
dst_non_draft_kv_data_lens=dst_non_draft_kv_data_lens,
item_lens=self.kv_args.kv_item_lens,
prefill_data_indices=prefill_kv_indices,
dst_data_indices=dst_kv_indices,
Expand All @@ -373,6 +378,7 @@ def send_kvcache_slice(
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_non_draft_kv_data_lens : int,
dst_kv_indices: npt.NDArray[np.int32],
dst_tp_rank: int,
dst_attn_tp_size: int,
Expand Down Expand Up @@ -426,7 +432,7 @@ def send_kvcache_slice(
dst_head_start_offset = 0

src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs, dst_non_draft_kv_data_lens)
)

# Calculate precise byte offset and length for the sub-slice within the token
Expand Down Expand Up @@ -589,6 +595,7 @@ def maybe_send_extra(
req: TransferInfo,
prefill_state_indices: list[int],
dst_state_data_ptrs: list[int],
dst_non_draft_kv_data_lens: int,
executor: concurrent.futures.ThreadPoolExecutor,
target_rank_registration_info: Optional[KVArgsRegisterInfo] = None,
):
Expand All @@ -597,6 +604,41 @@ def maybe_send_extra(

if state_type == "mamba":
# Check if we need slice transfer for different TP sizes
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", [])
dst_state_item_lens = (
target_rank_registration_info.dst_state_item_lens
if target_rank_registration_info is not None
else []
)
dst_state_dim_per_tensor = (
target_rank_registration_info.dst_state_dim_per_tensor
if target_rank_registration_info is not None
else []
)
mamba_layer_ids = self.kv_args.total_mamba_layer_ids
layer_indices = self.kv_args.mamba_layer_ids
total_layers = len(mamba_layer_ids)
num_tensors = len(prefill_state_data_ptrs) // total_layers
if num_tensors * total_layers == len(prefill_state_data_ptrs):
indices = [
base + idx
for base in range(0, total_layers * num_tensors, total_layers)
for idx in layer_indices
]
def slice_list(values):
if not values:
return []
return [values[i] for i in indices]

prefill_state_data_ptrs = slice_list(prefill_state_data_ptrs)
prefill_state_item_lens = slice_list(prefill_state_item_lens)
src_state_dim_per_tensor = slice_list(src_state_dim_per_tensor)
dst_state_data_ptrs = slice_list(dst_state_data_ptrs)
dst_state_item_lens = slice_list(dst_state_item_lens)
dst_state_dim_per_tensor = slice_list(dst_state_dim_per_tensor)

if (
target_rank_registration_info is not None
and self.attn_tp_size != target_rank_registration_info.dst_attn_tp_size
Expand All @@ -605,16 +647,21 @@ def maybe_send_extra(
req,
prefill_state_indices,
dst_state_data_ptrs,
target_rank_registration_info.dst_state_item_lens,
target_rank_registration_info.dst_state_dim_per_tensor,
dst_state_item_lens,
dst_state_dim_per_tensor,
target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_attn_tp_size,
prefill_state_data_ptrs=prefill_state_data_ptrs,
prefill_state_item_lens=prefill_state_item_lens,
src_state_dim_per_tensor=src_state_dim_per_tensor,
)
else:
return self._send_mamba_state(
req,
prefill_state_indices,
dst_state_data_ptrs,
prefill_state_data_ptrs=prefill_state_data_ptrs,
prefill_state_item_lens=prefill_state_item_lens,
)
elif state_type in ["swa", "nsa"]:
# SWA and NSA hybrid models do not support different TP sizes yet
Expand All @@ -640,6 +687,7 @@ def maybe_send_extra(
mooncake_session_id=req.mooncake_session_id,
src_data_ptrs=self.kv_args.state_data_ptrs,
dst_data_ptrs=dst_state_data_ptrs,
dst_non_draft_kv_data_lens=dst_non_draft_kv_data_lens,
item_lens=self.kv_args.state_item_lens,
prefill_data_indices=prefill_state_indices,
dst_data_indices=dst_state_indices,
Expand All @@ -653,13 +701,17 @@ def _send_mamba_state(
req: TransferInfo,
prefill_mamba_index: list[int],
dst_state_data_ptrs: list[int],
prefill_state_data_ptrs: Optional[list[int]] = None,
prefill_state_item_lens: Optional[list[int]] = None,
):
"""Transfer Mamba states."""
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"

transfer_blocks = []
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
if prefill_state_data_ptrs is None:
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
if prefill_state_item_lens is None:
prefill_state_item_lens = self.kv_args.state_item_lens

for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
length = prefill_state_item_lens[i]
Expand All @@ -678,6 +730,9 @@ def _send_mamba_state_slice(
dst_state_dim_per_tensor: list[int],
dst_tp_rank: int,
dst_attn_tp_size: int,
prefill_state_data_ptrs: Optional[list[int]] = None,
prefill_state_item_lens: Optional[list[int]] = None,
src_state_dim_per_tensor: Optional[list[int]] = None,
):
"""Transfer Mamba states with TP slice support.

Expand All @@ -696,9 +751,12 @@ def _send_mamba_state_slice(
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"

transfer_blocks = []
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", [])
if prefill_state_data_ptrs is None:
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
if prefill_state_item_lens is None:
prefill_state_item_lens = self.kv_args.state_item_lens
if src_state_dim_per_tensor is None:
src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", [])

# If no dimension info available, fall back to regular transfer
if not src_state_dim_per_tensor or not dst_state_dim_per_tensor:
Expand Down Expand Up @@ -824,6 +882,7 @@ def transfer_worker(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
target_rank_registration_info.dst_non_draft_kv_data_lens,
chunked_dst_kv_indice,
executor,
)
Expand All @@ -832,6 +891,7 @@ def transfer_worker(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
target_rank_registration_info.dst_non_draft_kv_data_lens,
chunked_dst_kv_indice,
target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_attn_tp_size,
Expand Down Expand Up @@ -867,6 +927,7 @@ def transfer_worker(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
target_rank_registration_info.dst_non_draft_kv_data_lens,
executor,
target_rank_registration_info,
)
Expand Down Expand Up @@ -1256,6 +1317,7 @@ def _register_kv_args(self):
dst_tp_rank = str(tp_rank).encode("ascii")
dst_attn_tp_size = str(self.kv_mgr.attn_tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii")
non_draft_kv_data_lens = str(self.kv_mgr.kv_args.non_draft_kv_data_lens).encode("ascii")

sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
with lock:
Expand All @@ -1273,6 +1335,7 @@ def _register_kv_args(self):
dst_kv_item_len,
packed_state_item_lens,
packed_state_dim_per_tensor,
non_draft_kv_data_lens,
]
)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def _init_kv_manager(self) -> CommonKVManager:
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_args.total_mamba_layer_ids = self.token_to_kv_pool.total_mamba_layer_ids
kv_args.mamba_layer_ids = self.token_to_kv_pool.mamba_layer_ids
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
Expand Down
Loading
Loading