Skip to content

Commit 4408dc7

Browse files
authored
【Hackathon 9th No.49】add test_pre_cache_len_concat (#3847)
* add test_pre_cache_len_concat * fix according review, add ref_pre_cache_len_concat
1 parent ef4a1aa commit 4408dc7

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import unittest
2+
3+
import numpy as np
4+
import paddle
5+
6+
from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat
7+
8+
9+
def ref_pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, block_size):
10+
"""
11+
Reference implementation.
12+
"""
13+
bsz = len(seq_lens_this_time)
14+
cu_seqlens_k = np.zeros(bsz + 1, dtype=np.int32)
15+
batch_ids = []
16+
tile_ids_per_batch = []
17+
total_tokens = 0
18+
gridx = 0
19+
20+
for bid in range(bsz):
21+
cache_len = int(seq_lens_decoder[bid])
22+
q_len = int(seq_lens_this_time[bid])
23+
if q_len <= 0:
24+
cache_len = 0
25+
loop_times = (cache_len + block_size - 1) // block_size # div_up
26+
for tile_id in range(loop_times):
27+
batch_ids.append(bid)
28+
tile_ids_per_batch.append(tile_id)
29+
gridx += loop_times
30+
total_tokens += cache_len + q_len
31+
cu_seqlens_k[bid + 1] = total_tokens
32+
33+
return (
34+
cu_seqlens_k,
35+
np.array(batch_ids, dtype=np.int32),
36+
np.array(tile_ids_per_batch, dtype=np.int32),
37+
np.array([gridx], dtype=np.int32),
38+
np.array([total_tokens], dtype=np.int32),
39+
)
40+
41+
42+
class TestPreCacheLenConcat(unittest.TestCase):
43+
def setUp(self):
44+
paddle.set_device("gpu")
45+
46+
def test_smoke_shapes(self):
47+
bsz = 3
48+
max_dec_len, block_size = 16, 4
49+
50+
seq_lens_decoder = np.array([8, 4, 2], dtype=np.int32)
51+
seq_lens_this_time = np.array([2, 3, 1], dtype=np.int32)
52+
53+
seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32")
54+
seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32")
55+
56+
outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size)
57+
cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs]
58+
59+
# Shape checks
60+
self.assertEqual(cu_seqlens_k.shape[0], bsz + 1)
61+
self.assertEqual(batch_ids.shape, tile_ids.shape)
62+
self.assertEqual(num_blocks.shape, (1,))
63+
self.assertEqual(kv_token_num.shape, (1,))
64+
65+
# Basic value sanity checks
66+
self.assertTrue(np.all(np.diff(cu_seqlens_k) >= 0)) # monotonic
67+
self.assertGreaterEqual(num_blocks[0], 0)
68+
self.assertGreaterEqual(kv_token_num[0], 0)
69+
70+
def test_strict_values_with_ref(self):
71+
max_dec_len, block_size = 16, 4
72+
73+
seq_lens_decoder = np.array([8, 4, 2], dtype=np.int32)
74+
seq_lens_this_time = np.array([2, 3, 1], dtype=np.int32)
75+
76+
seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32")
77+
seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32")
78+
79+
outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size)
80+
cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs]
81+
82+
# Reference implementation
83+
ref_outputs = ref_pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, block_size)
84+
ref_cu, ref_batch_ids, ref_tile_ids, ref_num_blocks, ref_kv_token_num = ref_outputs
85+
86+
# Compare all outputs against reference
87+
np.testing.assert_array_equal(cu_seqlens_k, ref_cu)
88+
np.testing.assert_array_equal(batch_ids[: len(ref_batch_ids)], ref_batch_ids)
89+
np.testing.assert_array_equal(tile_ids[: len(ref_tile_ids)], ref_tile_ids)
90+
self.assertEqual(num_blocks[0], ref_num_blocks[0])
91+
self.assertEqual(kv_token_num[0], ref_kv_token_num[0])
92+
93+
94+
if __name__ == "__main__":
95+
unittest.main()

0 commit comments

Comments
 (0)