@@ -742,5 +742,39 @@ def main(
742742 assert np .max (np .abs (out_value_cache - ref_value_cache )) == 0
743743
744744
745+ def test_reconstruct_from_cache ():
746+ num_heads = 1
747+ head_dim = 8
748+ vec_size = 8
749+ block_size = 16
750+ num_tokens = 8
751+ num_blocks = 1
752+
753+ dev = tvm .device ("cuda" , 0 )
754+
755+ key = tvm .nd .array (np .random .randn (num_tokens , num_heads , head_dim ).astype ("float16" ), dev )
756+ value = tvm .nd .array (np .random .randn (num_tokens , num_heads , head_dim ).astype ("float16" ), dev )
757+ slot_mapping = tvm .nd .array (np .arange (num_tokens ).astype ("int32" ), dev )
758+
759+ k_cache = tvm .nd .array (
760+ np .random .randn (num_blocks , num_heads , head_dim // vec_size , block_size , vec_size ).astype (
761+ "float16"
762+ ),
763+ dev ,
764+ )
765+ v_cache = tvm .nd .array (
766+ np .random .randn (num_blocks , num_heads , head_dim , block_size ).astype ("float16" ), dev
767+ )
768+
769+ reshape_and_cache_func = tvm .get_global_func ("tvm.contrib.vllm.reshape_and_cache" )
770+ reconstruct_from_cache_func = tvm .get_global_func ("tvm.contrib.vllm.reconstruct_from_cache" )
771+
772+ reshape_and_cache_func (key , value , k_cache , v_cache , slot_mapping )
773+ out = reconstruct_from_cache_func (k_cache , v_cache , slot_mapping )
774+
775+ np .testing .assert_equal (key .numpy (), out [0 ].numpy ())
776+ np .testing .assert_equal (value .numpy (), out [1 ].numpy ())
777+
778+
745779if __name__ == "__main__" :
746780 tvm .testing .main ()
0 commit comments