11
11
from lightllm .common .basemodel .triton_kernel .kv_cache_offload import offload_gpu_kv_to_cpu , load_cpu_kv_to_gpu
12
12
from lightllm .server .router .model_infer .infer_batch import g_infer_context
13
13
from lightllm .utils .log_utils import init_logger
14
+ from lightllm .utils .infer_utils import mark_start , mark_end
14
15
15
16
logger = init_logger (__name__ )
16
17
@@ -84,9 +85,15 @@ def handle_finished_reqs(self, finished_reqs: List[InferReq]) -> List[InferReq]:
84
85
else :
85
86
assert req .cpu_cache_task_status .is_not_started ()
86
87
# 发起将请求的 kv cache 卸载到 cpu cache 中的任务
88
+ # if self.backend.is_master_in_dp:
89
+ # mark_start("blueswhen offload_kv_to_cpu")
90
+ torch .cuda .synchronize ()
87
91
trans_task = self ._start_kv_cache_offload_task (
88
92
req = req , cpu_kv_cache_stream = g_infer_context .get_cpu_kv_cache_stream ()
89
93
)
94
+ torch .cuda .synchronize ()
95
+ # if self.backend.is_master_in_dp:
96
+ # mark_end("blueswhen offload_kv_to_cpu")
90
97
91
98
if trans_task is not None :
92
99
self .cpu_cache_handle_queue .append (trans_task )
@@ -101,44 +108,51 @@ def _start_kv_cache_offload_task(
101
108
self , req : InferReq , cpu_kv_cache_stream : torch .cuda .Stream
102
109
) -> Optional ["TransTask" ]:
103
110
with torch .cuda .stream (cpu_kv_cache_stream ):
104
- # 重新计算基于完整序列的hash值,而不是只基于输入
105
- all_token_hash_list = self ._compute_full_sequence_hash (req )
106
- block_size = req .cur_kv_len // self .args .cpu_cache_token_page_size
107
- move_block_size = min (block_size , len (all_token_hash_list ))
108
- if move_block_size == 0 :
109
- req .cpu_cache_task_status = InferReq ._CpuCacheTaskStatus .FINISHED
110
- return None
111
111
if self .backend .is_master_in_dp :
112
- self .cpu_cache_client .lock .acquire_sleep1ms ()
113
- page_list , ready_list = self .cpu_cache_client .allocate_pages (
114
- all_token_hash_list [:move_block_size ],
115
- disk_offload_enable = self .args .enable_disk_cache ,
116
- )
117
- self .cpu_cache_client .lock .release ()
112
+ all_token_hash_list = self ._compute_full_sequence_hash (req )
113
+ block_size = req .cur_kv_len // self .args .cpu_cache_token_page_size
114
+ move_block_size = min (block_size , len (all_token_hash_list ))
115
+
116
+ if move_block_size == 0 :
117
+ dist .broadcast_object_list ([0 ], group = self .gloo_group , group_src = 0 )
118
+ req .cpu_cache_task_status = InferReq ._CpuCacheTaskStatus .FINISHED
119
+ return None
120
+
121
+ try :
122
+ self .cpu_cache_client .lock .acquire_sleep1ms ()
123
+ page_list , ready_list = self .cpu_cache_client .allocate_pages (
124
+ all_token_hash_list [:move_block_size ],
125
+ disk_offload_enable = self .args .enable_disk_cache ,
126
+ )
127
+ finally :
128
+ self .cpu_cache_client .lock .release ()
129
+
118
130
item_size = len (page_list )
119
- dist .broadcast_object_list ([item_size ], group = self .gloo_group , group_src = 0 )
120
131
if item_size == 0 :
132
+ dist .broadcast_object_list ([0 ], group = self .gloo_group , group_src = 0 )
121
133
req .cpu_cache_task_status = InferReq ._CpuCacheTaskStatus .FINISHED
122
134
return None
123
- dist .broadcast_object_list (page_list , group = self .gloo_group , group_src = 0 )
124
- dist .broadcast_object_list (ready_list , group = self .gloo_group , group_src = 0 )
135
+
136
+ broadcast_data = {
137
+ 'item_size' : item_size ,
138
+ 'page_list' : page_list ,
139
+ 'ready_list' : ready_list
140
+ }
141
+ dist .broadcast_object_list ([broadcast_data ], group = self .gloo_group , group_src = 0 )
125
142
else :
126
143
recv_list = [None ]
127
144
dist .broadcast_object_list (recv_list , group = self .gloo_group , group_src = 0 )
128
- item_size = recv_list [0 ]
129
- if item_size == 0 :
145
+ if isinstance (recv_list [0 ], int ) and recv_list [0 ] == 0 :
130
146
req .cpu_cache_task_status = InferReq ._CpuCacheTaskStatus .FINISHED
131
147
return None
132
- page_list = [ None ] * item_size
133
- ready_list = [ None ] * item_size
134
- dist . broadcast_object_list ( page_list , group = self . gloo_group , group_src = 0 )
135
- dist . broadcast_object_list ( ready_list , group = self . gloo_group , group_src = 0 )
148
+ broadcast_data = recv_list [ 0 ]
149
+ item_size = broadcast_data [ ' item_size' ]
150
+ page_list = broadcast_data [ 'page_list' ]
151
+ ready_list = broadcast_data [ 'ready_list' ]
136
152
137
153
page_indexes = torch .tensor (page_list , dtype = torch .int32 , device = "cpu" , pin_memory = True )
138
154
page_readies = torch .tensor (ready_list , dtype = torch .bool , device = "cpu" , pin_memory = True )
139
-
140
155
token_indexes = self .backend .model .req_manager .req_to_token_indexs [req .req_idx , 0 : req .cur_kv_len ]
141
-
142
156
offload_gpu_kv_to_cpu (
143
157
token_indexes = token_indexes ,
144
158
gpu_kv_cache = self .backend .model .mem_manager .kv_buffer ,
@@ -147,8 +161,7 @@ def _start_kv_cache_offload_task(
147
161
page_readies = page_readies ,
148
162
)
149
163
150
- # 用一个allreduce 操作和 sync_event 来确保所有gpu worker都完成对cpu kv cache的写入。
151
- dist .all_reduce (tensor = self .sync_tensor , group = self .sync_group , async_op = False )
164
+ # dist.all_reduce(tensor=self.sync_tensor, group=self.sync_group, async_op=False)
152
165
sync_event = torch .cuda .Event ()
153
166
sync_event .record ()
154
167
req .cpu_cache_task_status = InferReq ._CpuCacheTaskStatus .RUNNING
0 commit comments