diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index cf5f64e146e3..c8c47aeba396 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -464,87 +464,43 @@ def send_kvcache_slice( ) return -1 - layers_params = [ - ( - src_k_ptrs[layer_id], - dst_k_ptrs[layer_id], - src_kv_item_len, - dst_kv_item_len, - src_head_slice_offset, - dst_head_slice_offset, - heads_bytes_per_token_to_send, - ) - for layer_id in range(layers_current_pp_stage) - ] + [ - ( - src_v_ptrs[layer_id], - dst_v_ptrs[layer_id], - src_kv_item_len, - dst_kv_item_len, - src_head_slice_offset, - dst_head_slice_offset, - heads_bytes_per_token_to_send, - ) - for layer_id in range(layers_current_pp_stage) - ] - - def process_layer_tp_aware(layer_params): - ( - src_ptr, - dst_ptr, - src_item_len, - dst_item_len, - src_head_slice_offset, - dst_head_slice_offset, - heads_bytes_per_token_to_send, - ) = layer_params - src_addr_list = [] - dst_addr_list = [] - length_list = [] - - # Calculate strides for a single token slot - bytes_per_token_on_prefill = src_item_len // page_size - bytes_per_token_on_decode = dst_item_len // page_size - - for i in range(len(prefill_kv_indices)): - prefill_page_idx = int(prefill_kv_indices[i]) - decode_page_idx = int(dst_kv_indices[i]) - - # Get the starting addresses for the current src and dst pages - src_page_start_addr = src_ptr + prefill_page_idx * src_item_len - dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len - - # Iterate through each valid token slot within the current page - for token_slot_in_page in range(page_size): - # Calculate the start address of the current token slot - src_token_slot_start_addr = ( - src_page_start_addr - + token_slot_in_page * bytes_per_token_on_prefill - ) - dst_token_slot_start_addr = ( - dst_page_start_addr - + token_slot_in_page * bytes_per_token_on_decode - ) - - # Calculate final src and dst addresses by applying head-slice offsets - src_slice_addr = src_token_slot_start_addr + src_head_slice_offset - dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset - - src_addr_list.append(src_slice_addr) - dst_addr_list.append(dst_slice_addr) - length_list.append(heads_bytes_per_token_to_send) + prefill_kv_indices_reshaped = prefill_kv_indices.astype(np.int64).reshape(-1, 1) + dst_kv_indices_reshaped = dst_kv_indices.astype(np.int64).reshape(-1, 1) + token_offsets = np.arange(page_size, dtype=np.int64).reshape(1, -1) + bytes_per_token_on_prefill = src_kv_item_len // page_size + bytes_per_token_on_decode = dst_kv_item_len // page_size + src_token_offsets_base = ( + token_offsets * bytes_per_token_on_prefill + src_head_slice_offset + ) + dst_token_offsets_base = ( + token_offsets * bytes_per_token_on_decode + dst_head_slice_offset + ) + def process_layer_tp_aware(ptrs): + src_ptr, dst_ptr = ptrs + src_page_starts = src_ptr + prefill_kv_indices_reshaped * src_kv_item_len + dst_page_starts = dst_ptr + dst_kv_indices_reshaped * dst_kv_item_len + src_addrs = src_page_starts + src_token_offsets_base + dst_addrs = dst_page_starts + dst_token_offsets_base + src_addr_list = src_addrs.reshape(-1).tolist() + if not src_addr_list: + return 0 + dst_addr_list = dst_addrs.reshape(-1).tolist() + total_chunks = len(src_addr_list) + length_list = [heads_bytes_per_token_to_send] * total_chunks return self.engine.batch_transfer_sync( mooncake_session_id, src_addr_list, dst_addr_list, length_list ) - futures = [ - executor.submit( - process_layer_tp_aware, - layer_params, + futures = [] + for i in range(layers_current_pp_stage): + futures.append( + executor.submit(process_layer_tp_aware, (src_k_ptrs[i], dst_k_ptrs[i])) + ) + for i in range(layers_current_pp_stage): + futures.append( + executor.submit(process_layer_tp_aware, (src_v_ptrs[i], dst_v_ptrs[i])) ) - for layer_params in layers_params - ] for future in concurrent.futures.as_completed(futures): status = future.result()