Skip to content

Commit d59ee42

Browse files
committed
fix typo bug and add test for vllm reconstruct_from_cache kernel
1 parent 45532d7 commit d59ee42

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/runtime/contrib/vllm/cache_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ __global__ void reconstruct_from_cache_kernel(
9494
block_offset;
9595

9696
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
97-
value[src_value_idx] = __ldg(&value_cache[tgt_value_idx]);
97+
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
9898
}
9999
}
100100

tests/python/relax/test_contrib_vllm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,5 +742,39 @@ def main(
742742
assert np.max(np.abs(out_value_cache - ref_value_cache)) == 0
743743

744744

745+
def test_reconstruct_from_cache():
746+
num_heads = 1
747+
head_dim = 8
748+
vec_size = 8
749+
block_size = 16
750+
num_tokens = 8
751+
num_blocks = 1
752+
753+
dev = tvm.device("cuda", 0)
754+
755+
key = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev)
756+
value = tvm.nd.array(np.random.randn(num_tokens, num_heads, head_dim).astype("float16"), dev)
757+
slot_mapping = tvm.nd.array(np.arange(num_tokens).astype("int32"), dev)
758+
759+
k_cache = tvm.nd.array(
760+
np.random.randn(num_blocks, num_heads, head_dim // vec_size, block_size, vec_size).astype(
761+
"float16"
762+
),
763+
dev,
764+
)
765+
v_cache = tvm.nd.array(
766+
np.random.randn(num_blocks, num_heads, head_dim, block_size).astype("float16"), dev
767+
)
768+
769+
reshape_and_cache_func = tvm.get_global_func("tvm.contrib.vllm.reshape_and_cache")
770+
reconstruct_from_cache_func = tvm.get_global_func("tvm.contrib.vllm.reconstruct_from_cache")
771+
772+
reshape_and_cache_func(key, value, k_cache, v_cache, slot_mapping)
773+
out = reconstruct_from_cache_func(k_cache, v_cache, slot_mapping)
774+
775+
np.testing.assert_equal(key.numpy(), out[0].numpy())
776+
np.testing.assert_equal(value.numpy(), out[1].numpy())
777+
778+
745779
if __name__ == "__main__":
746780
tvm.testing.main()

0 commit comments

Comments
 (0)