@@ -16,35 +16,33 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
1616 assert len (func .body .blocks [0 ].bindings ) == 2
1717 assert isinstance (func .body .blocks [0 ].bindings [0 ], relax .VarBinding )
1818 assert isinstance (func .body .blocks [0 ].bindings [0 ].value , relax .Call )
19- assert isinstance (func .body .blocks [0 ].bindings [0 ].value .op , relax .ExternFunc )
20- assert (
21- func .body .blocks [0 ].bindings [0 ].value .op .global_symbol
22- == "mlc.create_paged_kv_cache_generic"
23- )
24-
19+ assert func .body .blocks [0 ].bindings [0 ].value .op == tvm .ir .Op .get ("relax.call_pure_packed" )
2520 args = func .body .blocks [0 ].bindings [0 ].value .args
26- assert len (args ) == 10
27- assert isinstance (args [0 ], relax .ShapeExpr )
28- assert len (args [0 ].values ) == 4
29- for i in range (1 , 9 ):
21+ assert isinstance (args [0 ], relax .ExternFunc )
22+ assert args [0 ].global_symbol == "mlc.create_paged_kv_cache_generic"
23+
24+ assert len (args ) == 11
25+ assert isinstance (args [1 ], relax .ShapeExpr )
26+ assert len (args [1 ].values ) == 4
27+ for i in range (2 , 10 ):
3028 assert isinstance (args [i ], relax .PrimValue )
3129 assert isinstance (args [i ].value , (tvm .tir .IntImm , tvm .tir .FloatImm ))
32- assert isinstance (args [9 ], relax .DataTypeImm )
30+ assert isinstance (args [10 ], relax .DataTypeImm )
3331
3432 return {
35- "max_batch_size" : args [0 ].values [0 ],
36- "max_total_seq_len" : args [0 ].values [1 ],
37- "prefill_chunk_size" : args [0 ].values [2 ],
38- "page_size" : args [0 ].values [3 ],
39- "num_hidden_layers" : args [1 ].value .value ,
40- "num_attention_heads" : args [2 ].value .value ,
41- "num_key_value_heads" : args [3 ].value .value ,
42- "head_dim" : args [4 ].value .value ,
43- "rope_mode" : args [5 ].value .value ,
44- "rope_scale" : args [6 ].value .value ,
45- "rope_theta" : args [7 ].value .value ,
46- "rotary_dim" : args [8 ].value .value ,
47- "dtype" : args [9 ].value ,
33+ "max_batch_size" : args [1 ].values [0 ],
34+ "max_total_seq_len" : args [1 ].values [1 ],
35+ "prefill_chunk_size" : args [1 ].values [2 ],
36+ "page_size" : args [1 ].values [3 ],
37+ "num_hidden_layers" : args [2 ].value .value ,
38+ "num_attention_heads" : args [3 ].value .value ,
39+ "num_key_value_heads" : args [4 ].value .value ,
40+ "head_dim" : args [5 ].value .value ,
41+ "rope_mode" : args [6 ].value .value ,
42+ "rope_scale" : args [7 ].value .value ,
43+ "rope_theta" : args [8 ].value .value ,
44+ "rotary_dim" : args [9 ].value .value ,
45+ "dtype" : args [10 ].value ,
4846 }
4947
5048
0 commit comments