Skip to content

Commit ffbbd36

Browse files
committed
add test_eagle_get_hidden_states.py
1 parent 3790505 commit ffbbd36

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import eagle_get_hidden_states
21+
22+
23+
class TestEagleGetHiddenStates(unittest.TestCase):
24+
def test_eagle_get_hidden_states(self):
25+
np.random.seed(2023)
26+
paddle.seed(2023)
27+
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
28+
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
29+
dim_embed = np.array(1024, dtype=np.int32)
30+
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)
31+
32+
seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
33+
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
34+
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
35+
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
36+
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
37+
38+
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
39+
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)
40+
41+
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
42+
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
43+
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
44+
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
45+
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
46+
47+
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
48+
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)
49+
50+
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
51+
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
52+
gpu_out = eagle_get_hidden_states(
53+
input_tensor,
54+
seq_lens_this_time_tensor,
55+
seq_lens_encoder_tensor,
56+
seq_lens_decoder_tensor,
57+
stop_flags_tensor,
58+
accept_nums_tensor,
59+
base_model_seq_lens_this_time_tensor,
60+
base_model_seq_lens_encoder_tensor,
61+
actual_draft_token_num,
62+
)
63+
out_ref = np.array([6, 4, 3, 3], dtype=np.float16)
64+
np.testing.assert_allclose(gpu_out.numpy()[0][0:4], out_ref)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import eagle_get_self_hidden_states
21+
22+
23+
class TestEagleGetSelfHiddenStates(unittest.TestCase):
24+
def test_eagle_get_self_hidden_states(self):
25+
paddle.seed(2023)
26+
np.random.seed(2023)
27+
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
28+
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
29+
dim_embed = np.array(1024, dtype=np.int32)
30+
31+
last_seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32)
32+
seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32)
33+
step_idx = np.arange(0, bs, dtype=np.int32)
34+
35+
last_seq_lens_this_time_tensor = paddle.to_tensor(last_seq_lens_this_time, dtype=paddle.int32)
36+
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
37+
step_idx_tensor = paddle.to_tensor(step_idx, dtype=paddle.int64)
38+
39+
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
40+
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
41+
gpu_out = eagle_get_self_hidden_states(
42+
input_tensor,
43+
last_seq_lens_this_time_tensor,
44+
seq_lens_this_time_tensor,
45+
step_idx_tensor,
46+
)
47+
out_ref = np.array([5, 4, 2, 8], dtype=np.float16)
48+
np.testing.assert_allclose(gpu_out.numpy()[0][0:4], out_ref)
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()

0 commit comments

Comments
 (0)