diff --git a/tests/operators/test_pre_cache_len_concat.py b/tests/operators/test_pre_cache_len_concat.py new file mode 100644 index 00000000000..dd2fc8dcc49 --- /dev/null +++ b/tests/operators/test_pre_cache_len_concat.py @@ -0,0 +1,95 @@ +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import pre_cache_len_concat + + +def ref_pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, block_size): + """ + Reference implementation. + """ + bsz = len(seq_lens_this_time) + cu_seqlens_k = np.zeros(bsz + 1, dtype=np.int32) + batch_ids = [] + tile_ids_per_batch = [] + total_tokens = 0 + gridx = 0 + + for bid in range(bsz): + cache_len = int(seq_lens_decoder[bid]) + q_len = int(seq_lens_this_time[bid]) + if q_len <= 0: + cache_len = 0 + loop_times = (cache_len + block_size - 1) // block_size # div_up + for tile_id in range(loop_times): + batch_ids.append(bid) + tile_ids_per_batch.append(tile_id) + gridx += loop_times + total_tokens += cache_len + q_len + cu_seqlens_k[bid + 1] = total_tokens + + return ( + cu_seqlens_k, + np.array(batch_ids, dtype=np.int32), + np.array(tile_ids_per_batch, dtype=np.int32), + np.array([gridx], dtype=np.int32), + np.array([total_tokens], dtype=np.int32), + ) + + +class TestPreCacheLenConcat(unittest.TestCase): + def setUp(self): + paddle.set_device("gpu") + + def test_smoke_shapes(self): + bsz = 3 + max_dec_len, block_size = 16, 4 + + seq_lens_decoder = np.array([8, 4, 2], dtype=np.int32) + seq_lens_this_time = np.array([2, 3, 1], dtype=np.int32) + + seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32") + seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32") + + outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size) + cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs] + + # Shape checks + self.assertEqual(cu_seqlens_k.shape[0], bsz + 1) + self.assertEqual(batch_ids.shape, tile_ids.shape) + self.assertEqual(num_blocks.shape, (1,)) + self.assertEqual(kv_token_num.shape, (1,)) + + # Basic value sanity checks + self.assertTrue(np.all(np.diff(cu_seqlens_k) >= 0)) # monotonic + self.assertGreaterEqual(num_blocks[0], 0) + self.assertGreaterEqual(kv_token_num[0], 0) + + def test_strict_values_with_ref(self): + max_dec_len, block_size = 16, 4 + + seq_lens_decoder = np.array([8, 4, 2], dtype=np.int32) + seq_lens_this_time = np.array([2, 3, 1], dtype=np.int32) + + seq_lens_decoder_t = paddle.to_tensor(seq_lens_decoder, dtype="int32") + seq_lens_this_time_t = paddle.to_tensor(seq_lens_this_time, dtype="int32") + + outputs = pre_cache_len_concat(seq_lens_decoder_t, seq_lens_this_time_t, max_dec_len, block_size) + cu_seqlens_k, batch_ids, tile_ids, num_blocks, kv_token_num = [out.numpy() for out in outputs] + + # Reference implementation + ref_outputs = ref_pre_cache_len_concat(seq_lens_decoder, seq_lens_this_time, block_size) + ref_cu, ref_batch_ids, ref_tile_ids, ref_num_blocks, ref_kv_token_num = ref_outputs + + # Compare all outputs against reference + np.testing.assert_array_equal(cu_seqlens_k, ref_cu) + np.testing.assert_array_equal(batch_ids[: len(ref_batch_ids)], ref_batch_ids) + np.testing.assert_array_equal(tile_ids[: len(ref_tile_ids)], ref_tile_ids) + self.assertEqual(num_blocks[0], ref_num_blocks[0]) + self.assertEqual(kv_token_num[0], ref_kv_token_num[0]) + + +if __name__ == "__main__": + unittest.main()