diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 97520499c310..41b480d947b8 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -743,8 +743,12 @@ def send_kvcache_slice( num_heads_to_send = dst_heads_per_rank dst_head_start_offset = 0 + # torch.int exceeds np.int64 range on Intel XPU (addresses have bit 63 set, e.g. + # 0xffff81ab54e01000). Use np.uint64 to prevent overflow on XPU. + kv_data_ptrs = np.array(self.kv_args.kv_data_ptrs, dtype=np.uint64) + dst_kv_ptrs = np.array(dst_kv_ptrs, dtype=np.uint64) src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( - self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) + self.get_mha_kv_ptrs_with_pp(kv_data_ptrs, dst_kv_ptrs) ) # Calculate precise byte offset and length for the sub-slice within the token src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send @@ -765,11 +769,11 @@ def send_kvcache_slice( for layer_id in range(layers_current_pp_stage) ] - prefill_indices = np.asarray(prefill_kv_indices, dtype=np.int64) - dst_indices = np.asarray(dst_kv_indices, dtype=np.int64) + prefill_indices = np.asarray(prefill_kv_indices, dtype=np.uint64) + dst_indices = np.asarray(dst_kv_indices, dtype=np.uint64) bytes_per_token_prefill = src_kv_item_len // page_size bytes_per_token_decode = dst_kv_item_len // page_size - token_offsets = np.arange(page_size, dtype=np.int64) + token_offsets = np.arange(page_size, dtype=np.uint64) src_addrs = [] dst_addrs = []