@@ -215,6 +215,7 @@ def _init_gpu_cache(self, args):
215215 self .key_cache_shape [2 ],
216216 self .key_cache_shape [3 ],
217217 ]
218+ value_cache_shape = []
218219 if self .value_cache_shape :
219220 value_cache_shape = [
220221 num_gpu_blocks ,
@@ -257,9 +258,9 @@ def _init_gpu_cache(self, args):
257258 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] done init cache (full) gmem alloc : { memory_allocated ()} " )
258259
259260 def _init_cpu_cache (self , args ):
260- key_cache_size = args .key_cache_shape [1 ] * args .key_cache_shape [2 ] * args .key_cache_shape [3 ]
261+ key_cache_size = self .key_cache_shape [1 ] * self .key_cache_shape [2 ] * self .key_cache_shape [3 ]
261262 if args .value_cache_shape :
262- value_cache_size = args .value_cache_shape [1 ] * args .value_cache_shape [2 ] * args .value_cache_shape [3 ]
263+ value_cache_size = self .value_cache_shape [1 ] * self .value_cache_shape [2 ] * self .value_cache_shape [3 ]
263264 else :
264265 value_cache_size = 0
265266 if args .cache_dtype == "bfloat16" :
@@ -270,7 +271,9 @@ def _init_cpu_cache(self, args):
270271 raise ValueError (f"Unsupported cache dtype: { args .cache_dtype } " )
271272 key_need_to_allocate_bytes = args .num_cpu_blocks * cache_bytes * key_cache_size
272273 value_need_to_allocate_bytes = args .num_cpu_blocks * cache_bytes * value_cache_size
273- # logger.info(f"[rank {self.rank}/{self.n_ranks}] ..swap space size : { / 1024 ** 3:.2f}GB")
274+ logger .info (
275+ f"[rank { self .rank } /{ self .n_ranks } ] ..swap space size : { (key_need_to_allocate_bytes + value_need_to_allocate_bytes ) / 1024 ** 3 :.2f} GB"
276+ )
274277 if args .num_cpu_blocks == 0 :
275278 logger .info (f"[rank { self .rank } /{ self .n_ranks } ] 💡 no swap space (cpu cache) is specified." )
276279 self .swap_space_ready_signal .value [self .rank ] = 1
0 commit comments