Skip to content

Commit 5bf48de

Browse files
ltd0924ltd0924
andauthored
[KVCache] support unified cache backend (#4903)
* [Feature] support unified cache backend * fix * fix * fix * fix * Update metax_model_runner.py * fix * update * Update test_moba_attention_backend.py --------- Co-authored-by: ltd0924 <[email protected]>
1 parent 76e60e9 commit 5bf48de

19 files changed

+273
-194
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def parse_args():
5252
parser.add_argument("--rank", type=int, default=0, help="current rank")
5353
parser.add_argument("--device_id", type=int, default=0, help="device id")
5454
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
55-
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
56-
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
55+
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
56+
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
5757
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
5858
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
5959
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
@@ -71,8 +71,6 @@ def parse_args():
7171
default=9923,
7272
help="engine worker queue port",
7373
)
74-
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
75-
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
7674
parser.add_argument(
7775
"--cache_dtype",
7876
type=str,
@@ -764,38 +762,59 @@ def main():
764762
cache_type = args.cache_dtype
765763
speculative_config = SpeculativeConfig(args.speculative_config)
766764
num_extra_layers = speculative_config.num_extra_cache_layer
767-
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
765+
key_cache_shape_list = [int(i) for i in args.key_cache_shape.split(",")]
766+
value_cache_shape_list = []
767+
if args.value_cache_shape:
768+
value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")]
769+
total_gpu_blocks = key_cache_shape_list[0]
770+
num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
768771
gpu_cache_kvs = {}
769772
gpu_cache_k_tensors = []
770773
gpu_cache_v_tensors = []
771774

772775
logger.info(f"[rank {rank}/{args.mp_num}] Initializing kv cache for all layers.")
773776
for i in range(args.num_layers + num_extra_layers):
774-
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
775-
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
776-
logger.info(f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {cache_shape}")
777+
num_gpu_blocks = total_gpu_blocks if i < args.num_layers else num_extra_layer_gpu_blocks
778+
key_cache_shape = [
779+
num_gpu_blocks,
780+
key_cache_shape_list[1],
781+
key_cache_shape_list[2],
782+
key_cache_shape_list[3],
783+
]
784+
value_cache_shape = []
785+
if value_cache_shape_list:
786+
value_cache_shape = [
787+
num_gpu_blocks,
788+
value_cache_shape_list[1],
789+
value_cache_shape_list[2],
790+
value_cache_shape_list[3],
791+
]
792+
logger.info(
793+
f"[rank {rank}/{args.mp_num}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
794+
)
777795

778796
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
779-
shape=cache_shape,
797+
shape=key_cache_shape,
780798
fill_value=0,
781799
dtype=cache_type,
782800
)
783801
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
784-
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
785-
shape=cache_shape,
786-
fill_value=0,
787-
dtype=cache_type,
788-
)
789-
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
790-
791802
set_data_ipc(
792803
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
793804
f"key_caches_{i}_rank{rank}.device{device}",
794805
)
795-
set_data_ipc(
796-
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
797-
f"value_caches_{i}_rank{rank}.device{device}",
798-
)
806+
if value_cache_shape_list:
807+
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
808+
shape=value_cache_shape,
809+
fill_value=0,
810+
dtype=cache_type,
811+
)
812+
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
813+
814+
set_data_ipc(
815+
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
816+
f"value_caches_{i}_rank{rank}.device{device}",
817+
)
799818
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
800819
logger.info(f"device :{device}")
801820
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,50 +58,41 @@ def parse_args():
5858
parser.add_argument("--rank", type=int, default=0, help="current rank")
5959
parser.add_argument("--device_id", type=int, default=0, help="device id")
6060
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
61-
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
62-
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
63-
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
6461
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
6562
parser.add_argument(
66-
"--protocol",
63+
"--cache_dtype",
6764
type=str,
68-
default="ipc",
69-
help="cache transfer protocol, only support ipc now",
65+
default="bfloat16",
66+
choices=["uint8", "bfloat16"],
67+
help="cache dtype",
7068
)
71-
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
69+
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
70+
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
7271
parser.add_argument("--cache_queue_port", type=int, default=9923, help="cache queue port")
72+
parser.add_argument("--enable_splitwise", type=int, default=0, help="enable splitwise ")
7373
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
7474
parser.add_argument(
7575
"--engine_worker_queue_port",
7676
type=int,
7777
default=9923,
7878
help="engine worker queue port",
7979
)
80-
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
81-
82-
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
8380
parser.add_argument("--num_cpu_blocks", type=int, default=4, help="cpu cache block number")
84-
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
85-
parser.add_argument(
86-
"--bytes_per_layer_per_block",
87-
type=int,
88-
default=1024,
89-
help="per layer per block bytes",
90-
)
81+
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
9182
parser.add_argument(
92-
"--cache_dtype",
83+
"--protocol",
9384
type=str,
94-
default="bfloat16",
95-
choices=["uint8", "bfloat16"],
96-
help="cache dtype",
85+
default="ipc",
86+
help="cache transfer protocol, only support ipc now",
9787
)
88+
parser.add_argument("--local_data_parallel_id", type=int, default=0)
89+
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
9890
parser.add_argument(
9991
"--speculative_config",
10092
type=json.loads,
10193
default="{}",
10294
help="speculative config",
10395
)
104-
parser.add_argument("--local_data_parallel_id", type=int, default=0)
10596
parser.add_argument("--create_cache_tensor", action="store_true")
10697

10798
args = parser.parse_args()
@@ -124,8 +115,13 @@ def __init__(self, args):
124115
self.gpu_cache_k_tensors = []
125116
self.gpu_cache_v_tensors = []
126117
self.speculative_config = SpeculativeConfig(args.speculative_config)
118+
self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")]
119+
self.value_cache_shape = []
120+
if args.value_cache_shape:
121+
self.value_cache_shape = [int(i) for i in args.value_cache_shape.split(",")]
122+
self.num_gpu_blocks = self.key_cache_shape[0]
127123
self.num_extra_layers = self.speculative_config.num_extra_cache_layer
128-
self.num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
124+
self.num_extra_layer_gpu_blocks = int(self.num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
129125

130126
self.swap_to_cpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
131127
self.swap_to_gpu_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
@@ -164,8 +160,9 @@ def __init__(self, args):
164160

165161
self.num_cpu_blocks = args.num_cpu_blocks
166162

167-
self._init_cpu_cache(args)
168163
self._init_gpu_cache(args)
164+
if self.num_cpu_blocks > 0:
165+
self._init_cpu_cache(args)
169166

170167
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
171168
self.cache_task_broadcast_signal = IPCSignal(
@@ -209,28 +206,47 @@ def _init_gpu_cache(self, args):
209206
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
210207
set_device(self.device)
211208
for i in range(args.num_layers + self.num_extra_layers):
212-
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
213-
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
209+
num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
214210
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
215211
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
216-
212+
key_cache_shape = [
213+
num_gpu_blocks,
214+
self.key_cache_shape[1],
215+
self.key_cache_shape[2],
216+
self.key_cache_shape[3],
217+
]
218+
value_cache_shape = []
219+
if self.value_cache_shape:
220+
value_cache_shape = [
221+
num_gpu_blocks,
222+
self.value_cache_shape[1],
223+
self.value_cache_shape[2],
224+
self.value_cache_shape[3],
225+
]
217226
if args.create_cache_tensor:
218-
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
219-
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
220-
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
227+
logger.info(
228+
f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
229+
)
230+
key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype)
221231
set_data_ipc(key_cache, key_name)
222-
set_data_ipc(val_cache, val_name)
232+
if self.value_cache_shape:
233+
val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype)
234+
set_data_ipc(val_cache, val_name)
223235
else:
224-
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
236+
logger.info(
237+
f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}"
238+
)
225239
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
226240
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
227-
key_cache = share_external_data_(key_cache, key_name, cache_shape, True)
228-
val_cache = share_external_data_(val_cache, val_name, cache_shape, True)
241+
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
242+
if self.value_cache_shape:
243+
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
229244

230245
self.gpu_cache_kvs[key_name] = key_cache
231-
self.gpu_cache_kvs[val_name] = val_cache
232246
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
233-
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
247+
if args.value_cache_shape:
248+
self.gpu_cache_kvs[val_name] = val_cache
249+
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
234250

235251
if args.create_cache_tensor:
236252
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
@@ -242,6 +258,22 @@ def _init_gpu_cache(self, args):
242258
logger.info(f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {memory_allocated()}")
243259

244260
def _init_cpu_cache(self, args):
261+
key_cache_size = self.key_cache_shape[1] * self.key_cache_shape[2] * self.key_cache_shape[3]
262+
if args.value_cache_shape:
263+
value_cache_size = self.value_cache_shape[1] * self.value_cache_shape[2] * self.value_cache_shape[3]
264+
else:
265+
value_cache_size = 0
266+
if args.cache_dtype == "bfloat16":
267+
cache_bytes = 2
268+
elif args.cache_dtype == "uint8":
269+
cache_bytes = 1
270+
else:
271+
raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}")
272+
key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size
273+
value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size
274+
logger.info(
275+
f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
276+
)
245277
if args.num_cpu_blocks == 0:
246278
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
247279
self.swap_space_ready_signal.value[self.rank] = 1
@@ -253,14 +285,14 @@ def _init_cpu_cache(self, args):
253285
for i in range(args.num_layers + self.num_extra_layers):
254286
key_name = f"key_caches_{i}_rank{self.rank}"
255287
val_name = f"value_caches_{i}_rank{self.rank}"
256-
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
257288
logger.info(
258-
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
289+
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB"
259290
)
260-
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
291+
self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes)
261292
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
262-
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
263-
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
293+
if value_need_to_allocate_bytes > 0:
294+
self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes)
295+
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
264296
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
265297
self.swap_space_ready_signal.value[self.rank] = 1
266298

0 commit comments

Comments
 (0)