Skip to content

Commit e832516

Browse files
authored
【Hackathon 9th No.63】add test_draft_model_postprocess.py (#3757)
* add test_draft_model_postprocess.py * fix * fix
1 parent ac46ef4 commit e832516

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
16+
import numpy as np
17+
import paddle
18+
19+
from fastdeploy.model_executor.ops.gpu import draft_model_postprocess
20+
21+
22+
def draft_model_postprocess_cpu(
23+
base_model_draft_tokens,
24+
base_model_seq_lens_encoder,
25+
base_model_stop_flags,
26+
):
27+
bsz = base_model_draft_tokens.shape[0]
28+
base_model_draft_token_len = base_model_draft_tokens.shape[1]
29+
base_model_seq_lens_this_time = paddle.ones((bsz), dtype=paddle.int32)
30+
for tid in range(bsz):
31+
if (not base_model_stop_flags[tid]) and (base_model_seq_lens_encoder[tid] == 0):
32+
base_model_draft_tokens_now = base_model_draft_tokens[tid]
33+
token_num = 0
34+
for i in range(base_model_draft_token_len):
35+
if base_model_draft_tokens_now[i] != -1:
36+
token_num += 1
37+
38+
base_model_seq_lens_this_time[tid] = token_num
39+
elif base_model_stop_flags[tid]:
40+
base_model_seq_lens_this_time[tid] = 0
41+
42+
return base_model_seq_lens_this_time
43+
44+
45+
class TestDraftModelPostProcess(unittest.TestCase):
46+
def _test_draft_model_postprocess(self, batch_size=1, base_model_draft_token_len=8192):
47+
paddle.seed(66)
48+
base_model_draft_tokens = paddle.randint(
49+
low=-1,
50+
high=1,
51+
shape=[batch_size, base_model_draft_token_len],
52+
dtype="int64",
53+
)
54+
base_model_seq_lens_encoder = paddle.randint(low=0, high=2, shape=[batch_size], dtype="int32")
55+
random_floats = paddle.rand(shape=[batch_size])
56+
base_model_stop_flags = random_floats >= 0.5
57+
58+
base_model_seq_lens_this_time = draft_model_postprocess_cpu(
59+
base_model_draft_tokens,
60+
base_model_seq_lens_encoder,
61+
base_model_stop_flags,
62+
)
63+
base_model_seq_lens_this_time_gpu = paddle.ones((batch_size), dtype=paddle.int32)
64+
draft_model_postprocess(
65+
base_model_draft_tokens,
66+
base_model_seq_lens_this_time_gpu,
67+
base_model_seq_lens_encoder,
68+
base_model_stop_flags,
69+
)
70+
np.testing.assert_allclose(base_model_seq_lens_this_time.numpy(), base_model_seq_lens_this_time_gpu.numpy())
71+
72+
def test_enough_cases(self):
73+
self._test_draft_model_postprocess(100, 1024)
74+
self._test_draft_model_postprocess(1, 11)
75+
self._test_draft_model_postprocess(1, 8192)
76+
self._test_draft_model_postprocess(2, 2048)
77+
self._test_draft_model_postprocess(3, 1023)
78+
self._test_draft_model_postprocess(4, 2047)
79+
self._test_draft_model_postprocess(5, 4095)
80+
self._test_draft_model_postprocess(10, 9191)
81+
82+
83+
if __name__ == "__main__":
84+
unittest.main()

0 commit comments

Comments
 (0)