Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 31 additions & 75 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,87 +464,43 @@ def send_kvcache_slice(
)
return -1

layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_current_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
for layer_id in range(layers_current_pp_stage)
]

def process_layer_tp_aware(layer_params):
(
src_ptr,
dst_ptr,
src_item_len,
dst_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
) = layer_params
src_addr_list = []
dst_addr_list = []
length_list = []

# Calculate strides for a single token slot
bytes_per_token_on_prefill = src_item_len // page_size
bytes_per_token_on_decode = dst_item_len // page_size

for i in range(len(prefill_kv_indices)):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])

# Get the starting addresses for the current src and dst pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len

# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate the start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill
)
dst_token_slot_start_addr = (
dst_page_start_addr
+ token_slot_in_page * bytes_per_token_on_decode
)

# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset

src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(heads_bytes_per_token_to_send)
prefill_kv_indices_reshaped = prefill_kv_indices.astype(np.int64).reshape(-1, 1)
dst_kv_indices_reshaped = dst_kv_indices.astype(np.int64).reshape(-1, 1)
token_offsets = np.arange(page_size, dtype=np.int64).reshape(1, -1)
bytes_per_token_on_prefill = src_kv_item_len // page_size
bytes_per_token_on_decode = dst_kv_item_len // page_size
src_token_offsets_base = (
token_offsets * bytes_per_token_on_prefill + src_head_slice_offset
)
dst_token_offsets_base = (
token_offsets * bytes_per_token_on_decode + dst_head_slice_offset
)

def process_layer_tp_aware(ptrs):
src_ptr, dst_ptr = ptrs
src_page_starts = src_ptr + prefill_kv_indices_reshaped * src_kv_item_len
dst_page_starts = dst_ptr + dst_kv_indices_reshaped * dst_kv_item_len
src_addrs = src_page_starts + src_token_offsets_base
dst_addrs = dst_page_starts + dst_token_offsets_base
src_addr_list = src_addrs.reshape(-1).tolist()
if not src_addr_list:
return 0
dst_addr_list = dst_addrs.reshape(-1).tolist()
total_chunks = len(src_addr_list)
length_list = [heads_bytes_per_token_to_send] * total_chunks
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)

futures = [
executor.submit(
process_layer_tp_aware,
layer_params,
futures = []
for i in range(layers_current_pp_stage):
futures.append(
executor.submit(process_layer_tp_aware, (src_k_ptrs[i], dst_k_ptrs[i]))
)
for i in range(layers_current_pp_stage):
futures.append(
executor.submit(process_layer_tp_aware, (src_v_ptrs[i], dst_v_ptrs[i]))
)
for layer_params in layers_params
]

for future in concurrent.futures.as_completed(futures):
status = future.result()
Expand Down
Loading