Skip to content

Commit f323276

Browse files
authored
[UnitTest][MTP]add test_eagle_get_hidden_states (#3876)
1 parent 976aa88 commit f323276

File tree

2 files changed

+274
-0
lines changed

2 files changed

+274
-0
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License")
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import eagle_get_hidden_states
21+
22+
23+
def ComputeOrderKernel(
24+
seq_lens_this_time,
25+
seq_lens_encoder,
26+
base_model_seq_lens_this_time,
27+
base_model_seq_lens_encoder,
28+
accept_nums,
29+
position_map,
30+
output_token_num,
31+
bsz,
32+
actual_draft_token_num,
33+
input_token_num,
34+
):
35+
in_offset = 0
36+
out_offset = 0
37+
for i in range(bsz):
38+
cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]
39+
# cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i]
40+
cur_seq_lens_this_time = seq_lens_this_time[i]
41+
accept_num = accept_nums[i]
42+
cur_seq_lens_encoder = seq_lens_encoder[i]
43+
# 1. eagle encoder. Base step=1
44+
if cur_seq_lens_encoder > 0:
45+
for j in range(cur_seq_lens_encoder):
46+
position_map[in_offset] = out_offset
47+
in_offset += 1
48+
out_offset += 1
49+
# 2. Base model stop at last verify-step.
50+
elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0:
51+
in_offset += cur_base_model_seq_lens_this_time
52+
# 4. stopped
53+
elif cur_base_model_seq_lens_this_time == 0 and cur_seq_lens_this_time == 0: # end
54+
pass
55+
else:
56+
for i in range(accept_num):
57+
position_map[in_offset] = out_offset
58+
in_offset += 1
59+
out_offset += 1
60+
in_offset += cur_base_model_seq_lens_this_time - accept_num
61+
output_token_num[0] = out_offset
62+
63+
64+
def rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt):
65+
for elem_idx in range(elem_cnt):
66+
ori_token_idx = int(elem_idx / dim_embed)
67+
token_idx = position_map[ori_token_idx]
68+
if token_idx >= 0:
69+
offset = elem_idx % dim_embed
70+
out[token_idx][offset] = input[ori_token_idx][offset]
71+
72+
73+
def eagle_get_hidden_states_ref(
74+
input,
75+
seq_lens_this_time,
76+
seq_lens_encoder,
77+
seq_lens_decoder,
78+
stop_flags,
79+
accept_nums,
80+
base_model_seq_lens_this_time,
81+
base_model_seq_lens_encoder,
82+
actual_draft_token_num,
83+
):
84+
input_token_num = input.shape[0]
85+
dim_embed = input.shape[1]
86+
bsz = seq_lens_this_time.shape[0]
87+
position_map = paddle.full([input_token_num], 0xFFFFFFFF, seq_lens_this_time.dtype)
88+
output_token_num = paddle.empty([1], seq_lens_this_time.dtype)
89+
ComputeOrderKernel(
90+
seq_lens_this_time,
91+
seq_lens_encoder,
92+
base_model_seq_lens_this_time,
93+
base_model_seq_lens_encoder,
94+
accept_nums,
95+
position_map,
96+
output_token_num,
97+
bsz,
98+
actual_draft_token_num,
99+
input_token_num,
100+
)
101+
102+
output_token_num_cpu = output_token_num[0]
103+
out = paddle.empty([output_token_num_cpu, dim_embed], input.dtype)
104+
elem_cnt = input_token_num * dim_embed
105+
rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt)
106+
return out
107+
108+
109+
class TestEagleGetHiddenStates(unittest.TestCase):
110+
def test_eagle_get_hidden_states(self):
111+
np.random.seed(2023)
112+
paddle.seed(2023)
113+
bs = 2
114+
input_token_num = 10
115+
dim_embed = 512
116+
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)
117+
118+
seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
119+
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
120+
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
121+
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
122+
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
123+
124+
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
125+
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)
126+
127+
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
128+
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
129+
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
130+
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
131+
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
132+
133+
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
134+
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)
135+
136+
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
137+
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
138+
out = eagle_get_hidden_states(
139+
input_tensor,
140+
seq_lens_this_time_tensor,
141+
seq_lens_encoder_tensor,
142+
seq_lens_decoder_tensor,
143+
stop_flags_tensor,
144+
accept_nums_tensor,
145+
base_model_seq_lens_this_time_tensor,
146+
base_model_seq_lens_encoder_tensor,
147+
actual_draft_token_num,
148+
)
149+
out_ref = eagle_get_hidden_states_ref(
150+
input_tensor,
151+
seq_lens_this_time_tensor,
152+
seq_lens_encoder_tensor,
153+
seq_lens_decoder_tensor,
154+
stop_flags_tensor,
155+
accept_nums_tensor,
156+
base_model_seq_lens_this_time_tensor,
157+
base_model_seq_lens_encoder_tensor,
158+
actual_draft_token_num,
159+
)
160+
np.testing.assert_allclose(out.numpy(), out_ref.numpy())
161+
162+
163+
if __name__ == "__main__":
164+
unittest.main()
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License")
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import eagle_get_self_hidden_states
21+
22+
23+
def computeOrderKernel(last_seq_lens_this_time, seq_lens_this_time, step_idx, src_map, output_token_num, bsz):
24+
in_offset = 0
25+
out_offset = 0
26+
for i in range(bsz):
27+
cur_seq_lens_this_time = seq_lens_this_time[i]
28+
cur_last_seq_lens_this_time = last_seq_lens_this_time[i]
29+
# 1. encoder
30+
if step_idx[i] == 1 and cur_seq_lens_this_time > 0:
31+
in_offset += 1
32+
src_map[out_offset] = in_offset - 1
33+
out_offset += 1
34+
# 2. decoder
35+
elif cur_seq_lens_this_time > 0: # =1
36+
in_offset += cur_last_seq_lens_this_time
37+
src_map[out_offset] = in_offset - 1
38+
out_offset += 1
39+
# 3. stop
40+
else:
41+
# first token end
42+
if step_idx[i] == 1:
43+
in_offset += 1 if cur_last_seq_lens_this_time > 0 else 0
44+
# normal end
45+
else:
46+
in_offset += cur_last_seq_lens_this_time
47+
output_token_num[0] = out_offset
48+
49+
50+
def rebuildSelfHiddenStatesKernel(input, src_map, output, dim_embed, elem_cnt):
51+
for elem_id in range(elem_cnt):
52+
output_token_idx = int(elem_id / dim_embed)
53+
input_token_idx = src_map[output_token_idx]
54+
offset = elem_id % dim_embed
55+
output[output_token_idx][offset] = input[input_token_idx][offset]
56+
57+
58+
def eagle_get_self_hidden_states_ref(input, last_seq_lens_this_time, seq_lens_this_time, step_idx):
59+
input_token_num = input.shape[0]
60+
dim_embed = input.shape[1]
61+
bsz = seq_lens_this_time.shape[0]
62+
src_map = paddle.full([input_token_num], -1, seq_lens_this_time.dtype)
63+
output_token_num = paddle.full([1], 0, seq_lens_this_time.dtype)
64+
65+
computeOrderKernel(last_seq_lens_this_time, seq_lens_this_time, step_idx, src_map, output_token_num, bsz)
66+
67+
output_token_num_cpu = output_token_num[0]
68+
out = paddle.full([output_token_num_cpu, dim_embed], -1, input.dtype)
69+
70+
elem_cnt = output_token_num_cpu * dim_embed
71+
rebuildSelfHiddenStatesKernel(input, src_map, out, dim_embed, elem_cnt)
72+
73+
return out
74+
75+
76+
class TestEagleGetSelfHiddenStates(unittest.TestCase):
77+
def test_eagle_get_self_hidden_states(self):
78+
paddle.seed(2023)
79+
np.random.seed(2023)
80+
bs = 2
81+
input_token_num = 10
82+
dim_embed = 512
83+
84+
last_seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32)
85+
seq_lens_this_time = np.random.randint(0, input_token_num // bs, bs, dtype=np.int32)
86+
step_idx = np.arange(0, bs, dtype=np.int32)
87+
88+
last_seq_lens_this_time_tensor = paddle.to_tensor(last_seq_lens_this_time, dtype=paddle.int32)
89+
seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
90+
step_idx_tensor = paddle.to_tensor(step_idx, dtype=paddle.int64)
91+
92+
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
93+
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
94+
out = eagle_get_self_hidden_states(
95+
input_tensor,
96+
last_seq_lens_this_time_tensor,
97+
seq_lens_this_time_tensor,
98+
step_idx_tensor,
99+
)
100+
out_ref = eagle_get_self_hidden_states_ref(
101+
input_tensor,
102+
last_seq_lens_this_time_tensor,
103+
seq_lens_this_time_tensor,
104+
step_idx_tensor,
105+
)
106+
np.testing.assert_allclose(out.numpy(), out_ref.numpy())
107+
108+
109+
if __name__ == "__main__":
110+
unittest.main()

0 commit comments

Comments
 (0)