Skip to content

Commit 0a5a93b

Browse files
committed
add test_eagle_get_hidden_states.py
1 parent 3790505 commit 0a5a93b

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import eagle_get_hidden_states
21+
22+
23+
class TestEagleGetHiddenStates(unittest.TestCase):
24+
def test_eagle_get_hidden_states(self):
25+
np.random.seed(2023)
26+
paddle.seed(2023)
27+
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
28+
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
29+
dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32)
30+
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)
31+
32+
seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
33+
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
34+
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
35+
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
36+
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
37+
# don't care
38+
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
39+
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)
40+
41+
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
42+
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
43+
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
44+
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
45+
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
46+
# don't care
47+
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
48+
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)
49+
50+
# fp32 test
51+
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
52+
input_tensor = paddle.to_tensor(input, dtype=paddle.float32)
53+
gpu_out = eagle_get_hidden_states(
54+
input_tensor,
55+
seq_lens_this_time_tensor,
56+
seq_lens_encoder_tensor,
57+
seq_lens_decoder_tensor,
58+
stop_flags_tensor,
59+
accept_nums_tensor,
60+
base_model_seq_lens_this_time_tensor,
61+
base_model_seq_lens_encoder_tensor,
62+
actual_draft_token_num,
63+
)
64+
print(gpu_out.numpy())
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)