@@ -58,50 +58,41 @@ def parse_args():
5858 parser .add_argument ("--rank" , type = int , default = 0 , help = "current rank" )
5959 parser .add_argument ("--device_id" , type = int , default = 0 , help = "device id" )
6060 parser .add_argument ("--num_layers" , type = int , default = 1 , help = "model num layers" )
61- parser .add_argument ("--head_dim" , type = int , default = 1 , help = "model head dim" )
62- parser .add_argument ("--kv_num_head" , type = int , default = 1 , help = "model kv num head" )
63- parser .add_argument ("--rdma_port" , type = str , default = "" , help = "rmda port" )
6461 parser .add_argument ("--mp_num" , type = int , default = 1 , help = "number of model parallel" )
6562 parser .add_argument (
66- "--protocol " ,
63+ "--cache_dtype " ,
6764 type = str ,
68- default = "ipc" ,
69- help = "cache transfer protocol, only support ipc now" ,
65+ default = "bfloat16" ,
66+ choices = ["uint8" , "bfloat16" ],
67+ help = "cache dtype" ,
7068 )
71- parser .add_argument ("--enable_splitwise" , type = int , default = 0 , help = "enable splitwise " )
69+ parser .add_argument ("--key_cache_shape" , type = str , default = "" , help = "key cache shape" )
70+ parser .add_argument ("--value_cache_shape" , type = str , default = "" , help = "value cache shape" )
7271 parser .add_argument ("--cache_queue_port" , type = int , default = 9923 , help = "cache queue port" )
72+ parser .add_argument ("--enable_splitwise" , type = int , default = 0 , help = "enable splitwise " )
7373 parser .add_argument ("--pod_ip" , type = str , default = "0.0.0.0" , help = "pod ip" )
7474 parser .add_argument (
7575 "--engine_worker_queue_port" ,
7676 type = int ,
7777 default = 9923 ,
7878 help = "engine worker queue port" ,
7979 )
80- parser .add_argument ("--engine_pid" , type = str , default = None , help = "engine pid" )
81-
82- parser .add_argument ("--num_gpu_blocks" , type = int , default = 1 , help = "gpu cache block number" )
8380 parser .add_argument ("--num_cpu_blocks" , type = int , default = 4 , help = "cpu cache block number" )
84- parser .add_argument ("--block_size" , type = int , default = 64 , help = "cache block size(tokens)" )
85- parser .add_argument (
86- "--bytes_per_layer_per_block" ,
87- type = int ,
88- default = 1024 ,
89- help = "per layer per block bytes" ,
90- )
81+ parser .add_argument ("--engine_pid" , type = str , default = None , help = "engine pid" )
9182 parser .add_argument (
92- "--cache_dtype " ,
83+ "--protocol " ,
9384 type = str ,
94- default = "bfloat16" ,
95- choices = ["uint8" , "bfloat16" ],
96- help = "cache dtype" ,
85+ default = "ipc" ,
86+ help = "cache transfer protocol, only support ipc now" ,
9787 )
88+ parser .add_argument ("--local_data_parallel_id" , type = int , default = 0 )
89+ parser .add_argument ("--rdma_port" , type = str , default = "" , help = "rmda port" )
9890 parser .add_argument (
9991 "--speculative_config" ,
10092 type = json .loads ,
10193 default = "{}" ,
10294 help = "speculative config" ,
10395 )
104- parser .add_argument ("--local_data_parallel_id" , type = int , default = 0 )
10596 parser .add_argument ("--create_cache_tensor" , action = "store_true" )
10697
10798 args = parser .parse_args ()
@@ -124,8 +115,13 @@ def __init__(self, args):
124115 self .gpu_cache_k_tensors = []
125116 self .gpu_cache_v_tensors = []
126117 self .speculative_config = SpeculativeConfig (args .speculative_config )
118+ self .key_cache_shape = [int (i ) for i in args .key_cache_shape .split ("," )]
119+ self .value_cache_shape = []
120+ if args .value_cache_shape :
121+ self .value_cache_shape = [int (i ) for i in args .value_cache_shape .split ("," )]
122+ self .num_gpu_blocks = self .key_cache_shape [0 ]
127123 self .num_extra_layers = self .speculative_config .num_extra_cache_layer
128- self .num_extra_layer_gpu_blocks = int (args .num_gpu_blocks * self .speculative_config .num_gpu_block_expand_ratio )
124+ self .num_extra_layer_gpu_blocks = int (self .num_gpu_blocks * self .speculative_config .num_gpu_block_expand_ratio )
129125
130126 self .swap_to_cpu_thread_pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
131127 self .swap_to_gpu_thread_pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
@@ -164,8 +160,9 @@ def __init__(self, args):
164160
165161 self .num_cpu_blocks = args .num_cpu_blocks
166162
167- self ._init_cpu_cache (args )
168163 self ._init_gpu_cache (args )
164+ if self .num_cpu_blocks > 0 :
165+ self ._init_cpu_cache (args )
169166
170167 cache_task_broadcast_data = np .zeros (shape = [1 ], dtype = np .int32 )
171168 self .cache_task_broadcast_signal = IPCSignal (
@@ -209,28 +206,47 @@ def _init_gpu_cache(self, args):
209206 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] Initializing kv cache for all layers." )
210207 set_device (self .device )
211208 for i in range (args .num_layers + self .num_extra_layers ):
212- num_gpu_blocks = args .num_gpu_blocks if i < args .num_layers else self .num_extra_layer_gpu_blocks
213- cache_shape = [num_gpu_blocks , args .kv_num_head , args .block_size , args .head_dim ]
209+ num_gpu_blocks = self .num_gpu_blocks if i < args .num_layers else self .num_extra_layer_gpu_blocks
214210 key_name = f"key_caches_{ i } _rank{ self .rank } .device{ self .device } "
215211 val_name = f"value_caches_{ i } _rank{ self .rank } .device{ self .device } "
216-
212+ key_cache_shape = [
213+ num_gpu_blocks ,
214+ self .key_cache_shape [1 ],
215+ self .key_cache_shape [2 ],
216+ self .key_cache_shape [3 ],
217+ ]
218+ value_cache_shape = []
219+ if self .value_cache_shape :
220+ value_cache_shape = [
221+ num_gpu_blocks ,
222+ self .value_cache_shape [1 ],
223+ self .value_cache_shape [2 ],
224+ self .value_cache_shape [3 ],
225+ ]
217226 if args .create_cache_tensor :
218- logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ..creating kv cache for layer { i } : { cache_shape } " )
219- key_cache = paddle .full (shape = cache_shape , fill_value = 0 , dtype = args .cache_dtype )
220- val_cache = paddle .full (shape = cache_shape , fill_value = 0 , dtype = args .cache_dtype )
227+ logger .info (
228+ f"[rank { self .rank } /{ self .n_ranks } ] ..creating kv cache for layer { i } : { key_cache_shape } { value_cache_shape } "
229+ )
230+ key_cache = paddle .full (shape = key_cache_shape , fill_value = 0 , dtype = args .cache_dtype )
221231 set_data_ipc (key_cache , key_name )
222- set_data_ipc (val_cache , val_name )
232+ if self .value_cache_shape :
233+ val_cache = paddle .full (shape = value_cache_shape , fill_value = 0 , dtype = args .cache_dtype )
234+ set_data_ipc (val_cache , val_name )
223235 else :
224- logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ..attaching kv cache for layer { i } : { cache_shape } " )
236+ logger .info (
237+ f"[rank { self .rank } /{ self .n_ranks } ] ..attaching kv cache for layer { i } : { key_cache_shape } { value_cache_shape } "
238+ )
225239 key_cache = paddle .empty (shape = [], dtype = args .cache_dtype )
226240 val_cache = paddle .empty (shape = [], dtype = args .cache_dtype )
227- key_cache = share_external_data_ (key_cache , key_name , cache_shape , True )
228- val_cache = share_external_data_ (val_cache , val_name , cache_shape , True )
241+ key_cache = share_external_data_ (key_cache , key_name , key_cache_shape , True )
242+ if self .value_cache_shape :
243+ val_cache = share_external_data_ (val_cache , val_name , value_cache_shape , True )
229244
230245 self .gpu_cache_kvs [key_name ] = key_cache
231- self .gpu_cache_kvs [val_name ] = val_cache
232246 self .gpu_cache_k_tensors .append (self .gpu_cache_kvs [key_name ])
233- self .gpu_cache_v_tensors .append (self .gpu_cache_kvs [val_name ])
247+ if args .value_cache_shape :
248+ self .gpu_cache_kvs [val_name ] = val_cache
249+ self .gpu_cache_v_tensors .append (self .gpu_cache_kvs [val_name ])
234250
235251 if args .create_cache_tensor :
236252 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ✅ kv cache is ready!" )
@@ -242,6 +258,22 @@ def _init_gpu_cache(self, args):
242258 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] done init cache (full) gmem alloc : { memory_allocated ()} " )
243259
244260 def _init_cpu_cache (self , args ):
261+ key_cache_size = self .key_cache_shape [1 ] * self .key_cache_shape [2 ] * self .key_cache_shape [3 ]
262+ if args .value_cache_shape :
263+ value_cache_size = self .value_cache_shape [1 ] * self .value_cache_shape [2 ] * self .value_cache_shape [3 ]
264+ else :
265+ value_cache_size = 0
266+ if args .cache_dtype == "bfloat16" :
267+ cache_bytes = 2
268+ elif args .cache_dtype == "uint8" :
269+ cache_bytes = 1
270+ else :
271+ raise ValueError (f"Unsupported cache dtype: { args .cache_dtype } " )
272+ key_need_to_allocate_bytes = args .num_cpu_blocks * cache_bytes * key_cache_size
273+ value_need_to_allocate_bytes = args .num_cpu_blocks * cache_bytes * value_cache_size
274+ logger .info (
275+ f"[rank { self .rank } /{ self .n_ranks } ] ..swap space size : { (key_need_to_allocate_bytes + value_need_to_allocate_bytes ) / 1024 ** 3 :.2f} GB"
276+ )
245277 if args .num_cpu_blocks == 0 :
246278 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] 💡 no swap space (cpu cache) is specified." )
247279 self .swap_space_ready_signal .value [self .rank ] = 1
@@ -253,14 +285,14 @@ def _init_cpu_cache(self, args):
253285 for i in range (args .num_layers + self .num_extra_layers ):
254286 key_name = f"key_caches_{ i } _rank{ self .rank } "
255287 val_name = f"value_caches_{ i } _rank{ self .rank } "
256- need_to_allocate_bytes = args .num_cpu_blocks * args .bytes_per_layer_per_block
257288 logger .info (
258- f"[rank { self .rank } /{ self .n_ranks } ] ..creating cpu cache for layer { i } : { 2 * need_to_allocate_bytes / 1024 ** 3 :.2f} GB"
289+ f"[rank { self .rank } /{ self .n_ranks } ] ..creating cpu cache for layer { i } : { ( key_need_to_allocate_bytes + value_need_to_allocate_bytes ) / 1024 ** 3 :.2f} GB"
259290 )
260- self .cpu_cache_kvs [key_name ] = cuda_host_alloc (need_to_allocate_bytes )
291+ self .cpu_cache_kvs [key_name ] = cuda_host_alloc (key_need_to_allocate_bytes )
261292 self .k_dst_ptrs .append (self .cpu_cache_kvs [key_name ])
262- self .cpu_cache_kvs [val_name ] = cuda_host_alloc (need_to_allocate_bytes )
263- self .v_dst_ptrs .append (self .cpu_cache_kvs [val_name ])
293+ if value_need_to_allocate_bytes > 0 :
294+ self .cpu_cache_kvs [val_name ] = cuda_host_alloc (value_need_to_allocate_bytes )
295+ self .v_dst_ptrs .append (self .cpu_cache_kvs [val_name ])
264296 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ✅ swap space (cpu cache) is ready!" )
265297 self .swap_space_ready_signal .value [self .rank ] = 1
266298
0 commit comments