Skip to content

Commit 8effa8e

Browse files
[Fix] KVCache creation with call_pure_packed (#1930)
With apache/tvm#16684 merged in, the KV cache creation will fail when compiling models. This PR fixes the problem by using `call_pure_packed`.
1 parent dd6a795 commit 8effa8e

File tree

2 files changed

+35
-41
lines changed

2 files changed

+35
-41
lines changed

python/mlc_chat/compiler_pass/dispatch_kv_cache_creation.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,33 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
1616
assert len(func.body.blocks[0].bindings) == 2
1717
assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding)
1818
assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call)
19-
assert isinstance(func.body.blocks[0].bindings[0].value.op, relax.ExternFunc)
20-
assert (
21-
func.body.blocks[0].bindings[0].value.op.global_symbol
22-
== "mlc.create_paged_kv_cache_generic"
23-
)
24-
19+
assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get("relax.call_pure_packed")
2520
args = func.body.blocks[0].bindings[0].value.args
26-
assert len(args) == 10
27-
assert isinstance(args[0], relax.ShapeExpr)
28-
assert len(args[0].values) == 4
29-
for i in range(1, 9):
21+
assert isinstance(args[0], relax.ExternFunc)
22+
assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
23+
24+
assert len(args) == 11
25+
assert isinstance(args[1], relax.ShapeExpr)
26+
assert len(args[1].values) == 4
27+
for i in range(2, 10):
3028
assert isinstance(args[i], relax.PrimValue)
3129
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
32-
assert isinstance(args[9], relax.DataTypeImm)
30+
assert isinstance(args[10], relax.DataTypeImm)
3331

3432
return {
35-
"max_batch_size": args[0].values[0],
36-
"max_total_seq_len": args[0].values[1],
37-
"prefill_chunk_size": args[0].values[2],
38-
"page_size": args[0].values[3],
39-
"num_hidden_layers": args[1].value.value,
40-
"num_attention_heads": args[2].value.value,
41-
"num_key_value_heads": args[3].value.value,
42-
"head_dim": args[4].value.value,
43-
"rope_mode": args[5].value.value,
44-
"rope_scale": args[6].value.value,
45-
"rope_theta": args[7].value.value,
46-
"rotary_dim": args[8].value.value,
47-
"dtype": args[9].value,
33+
"max_batch_size": args[1].values[0],
34+
"max_total_seq_len": args[1].values[1],
35+
"prefill_chunk_size": args[1].values[2],
36+
"page_size": args[1].values[3],
37+
"num_hidden_layers": args[2].value.value,
38+
"num_attention_heads": args[3].value.value,
39+
"num_key_value_heads": args[4].value.value,
40+
"head_dim": args[5].value.value,
41+
"rope_mode": args[6].value.value,
42+
"rope_scale": args[7].value.value,
43+
"rope_theta": args[8].value.value,
44+
"rotary_dim": args[9].value.value,
45+
"dtype": args[10].value,
4846
}
4947

5048

python/mlc_chat/nn/kv_cache.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,19 @@ def create_generic( # pylint: disable=too-many-arguments
6262
if rotary_dim is None:
6363
rotary_dim = head_dim
6464
return PagedKVCache(
65-
_expr=rx.Call(
66-
rx.extern("mlc.create_paged_kv_cache_generic"),
67-
args=[
68-
rx.ShapeExpr(
69-
[max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]
70-
),
71-
rx.PrimValue(num_hidden_layers),
72-
rx.PrimValue(num_attention_heads),
73-
rx.PrimValue(num_key_value_heads),
74-
rx.PrimValue(head_dim),
75-
rx.PrimValue(rope_mode),
76-
rx.PrimValue(rope_scale),
77-
rx.PrimValue(rope_theta),
78-
rx.PrimValue(rotary_dim),
79-
rx.DataTypeImm(dtype),
80-
],
81-
sinfo_args=[rx.ObjectStructInfo()],
65+
_expr=rx.call_pure_packed(
66+
"mlc.create_paged_kv_cache_generic",
67+
rx.ShapeExpr([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size]),
68+
rx.PrimValue(num_hidden_layers),
69+
rx.PrimValue(num_attention_heads),
70+
rx.PrimValue(num_key_value_heads),
71+
rx.PrimValue(head_dim),
72+
rx.PrimValue(rope_mode),
73+
rx.PrimValue(rope_scale),
74+
rx.PrimValue(rope_theta),
75+
rx.PrimValue(rotary_dim),
76+
rx.DataTypeImm(dtype),
77+
sinfo_args=rx.ObjectStructInfo(),
8278
),
8379
_name=name,
8480
)

0 commit comments

Comments
 (0)