Skip to content

Commit 97d4a60

Browse files
committed
add speculate_update_v3 test
1 parent 3790505 commit 97d4a60

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import speculate_update
21+
22+
23+
def speculate_update_v3_np(
24+
seq_lens_encoder,
25+
seq_lens_decoder,
26+
not_need_stop,
27+
draft_tokens,
28+
actual_draft_token_nums,
29+
accept_tokens,
30+
accept_num,
31+
stop_flags,
32+
seq_lens_this_time,
33+
is_block_step,
34+
stop_nums,
35+
):
36+
stop_sum = 0
37+
real_bsz = seq_lens_this_time.shape[0]
38+
max_bsz = stop_flags.shape[0]
39+
max_draft_tokens = draft_tokens.shape[1]
40+
41+
for bid in range(max_bsz):
42+
stop_flag_now_int = 0
43+
inactive = bid >= real_bsz
44+
block_step = (not inactive) and is_block_step[bid]
45+
46+
if (not block_step) and (not inactive):
47+
48+
if stop_flags[bid]:
49+
stop_flag_now_int = 1
50+
51+
if seq_lens_encoder[bid] == 0:
52+
seq_lens_decoder[bid] += accept_num[bid]
53+
54+
if (seq_lens_encoder[bid] == 0) and (seq_lens_this_time[bid] > 1):
55+
cur_len = actual_draft_token_nums[bid]
56+
if accept_num[bid] - 1 == cur_len:
57+
if cur_len + 2 <= max_draft_tokens - 1:
58+
cur_len += 2
59+
elif cur_len + 1 <= max_draft_tokens - 1:
60+
cur_len += 1
61+
else:
62+
cur_len = max_draft_tokens - 1
63+
else:
64+
cur_len = max(1, cur_len - 1)
65+
actual_draft_token_nums[bid] = cur_len
66+
67+
if seq_lens_encoder[bid] != 0:
68+
seq_lens_decoder[bid] += seq_lens_encoder[bid]
69+
seq_lens_encoder[bid] = 0
70+
71+
draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1]
72+
73+
if stop_flag_now_int:
74+
seq_lens_decoder[bid] = 0
75+
76+
elif inactive:
77+
stop_flag_now_int = 1
78+
79+
stop_sum += stop_flag_now_int
80+
not_need_stop[0] = stop_sum < stop_nums[0]
81+
82+
return (
83+
seq_lens_encoder,
84+
seq_lens_decoder,
85+
not_need_stop,
86+
draft_tokens,
87+
actual_draft_token_nums,
88+
)
89+
90+
91+
def gen_inputs(
92+
max_bsz=512,
93+
max_draft_tokens=16,
94+
real_bsz=123,
95+
seed=2022,
96+
):
97+
rng = np.random.default_rng(seed)
98+
99+
seq_lens_encoder = rng.integers(0, 3, size=max_bsz, dtype=np.int32)
100+
seq_lens_decoder = rng.integers(0, 20, size=max_bsz, dtype=np.int32)
101+
not_need_stop = rng.integers(0, 1, size=1, dtype=np.bool_)
102+
draft_tokens = rng.integers(0, 1000, size=(max_bsz, max_draft_tokens), dtype=np.int64)
103+
actual_draft_nums = rng.integers(1, max_draft_tokens, size=max_bsz, dtype=np.int32)
104+
accept_tokens = rng.integers(0, 1000, size=(max_bsz, max_draft_tokens), dtype=np.int64)
105+
accept_num = rng.integers(1, max_draft_tokens, size=max_bsz, dtype=np.int32)
106+
stop_flags = rng.integers(0, 2, size=max_bsz, dtype=np.bool_)
107+
is_block_step = rng.integers(0, 2, size=max_bsz, dtype=np.bool_)
108+
stop_nums = np.array([5], dtype=np.int64)
109+
110+
seq_lens_this_time = rng.integers(1, max_draft_tokens, size=real_bsz, dtype=np.int32)
111+
112+
return {
113+
"seq_lens_encoder": seq_lens_encoder,
114+
"seq_lens_decoder": seq_lens_decoder,
115+
"not_need_stop": not_need_stop,
116+
"draft_tokens": draft_tokens,
117+
"actual_draft_token_nums": actual_draft_nums,
118+
"accept_tokens": accept_tokens,
119+
"accept_num": accept_num,
120+
"stop_flags": stop_flags,
121+
"seq_lens_this_time": seq_lens_this_time,
122+
"is_block_step": is_block_step,
123+
"stop_nums": stop_nums,
124+
}
125+
126+
127+
class TestSpeculateUpdateV3(unittest.TestCase):
128+
def test_speculate_update_v3(self):
129+
inputs = gen_inputs(max_bsz=512, max_draft_tokens=32, real_bsz=201)
130+
131+
paddle_inputs = {}
132+
for k, v in inputs.items():
133+
paddle_inputs[k] = paddle.to_tensor(v)
134+
135+
np_inputs = {
136+
k: (paddle_inputs[k].numpy().copy() if isinstance(paddle_inputs[k], paddle.Tensor) else paddle_inputs[k])
137+
for k in paddle_inputs
138+
}
139+
140+
out_pd = speculate_update(*(paddle_inputs.values()))
141+
(
142+
seq_lens_encoder_pd,
143+
seq_lens_decoder_pd,
144+
not_need_stop_pd,
145+
draft_tokens_pd,
146+
actual_draft_nums_pd,
147+
) = out_pd
148+
149+
out_np = speculate_update_v3_np(**np_inputs)
150+
151+
names = [
152+
"seq_lens_encoder",
153+
"seq_lens_decoder",
154+
"not_need_stop",
155+
"draft_tokens",
156+
"actual_draft_token_nums",
157+
]
158+
pd_tensors = [
159+
seq_lens_encoder_pd,
160+
seq_lens_decoder_pd,
161+
not_need_stop_pd,
162+
draft_tokens_pd,
163+
actual_draft_nums_pd,
164+
]
165+
166+
for name, pd_val, np_val in zip(names, pd_tensors, out_np):
167+
np.testing.assert_allclose(pd_val.numpy(), np_val)
168+
169+
170+
if __name__ == "__main__":
171+
unittest.main()

0 commit comments

Comments
 (0)