Skip to content

Commit 0fe1d62

Browse files
authored
[MTP] add test_draft_model_set_value_by_flags.py (#3741)
1 parent 18e5d35 commit 0fe1d62

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import draft_model_set_value_by_flags
21+
22+
23+
class TestDraftModelSetValueByFlags(unittest.TestCase):
24+
def setUp(self):
25+
paddle.set_device("gpu")
26+
np.random.seed(42)
27+
28+
def test_basic_update(self):
29+
"""
30+
Test normal update behavior:
31+
batch0 performs a decoder step, batch1 performs an encoder step
32+
"""
33+
bs = 2
34+
pre_id_length = 5
35+
draft_tokens = paddle.to_tensor([[10, 11, 12], [20, 21, 22]], dtype="int64")
36+
pre_ids_all = paddle.zeros([bs, pre_id_length], dtype="int64")
37+
stop_flags = paddle.to_tensor([False, False], dtype="bool")
38+
seq_lens_this_time = paddle.to_tensor([3, 1], dtype="int32")
39+
seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32")
40+
seq_lens_decoder = paddle.to_tensor([0, 0], dtype="int32")
41+
step_idx = paddle.to_tensor([3, 1], dtype="int64") # batch0 decoder, batch1 encoder
42+
43+
""" Call custom op """
44+
draft_model_set_value_by_flags(
45+
draft_tokens, pre_ids_all, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx
46+
)
47+
48+
"""
49+
batch0: 3 tokens updated at decoder step
50+
batch1: 1 token updated at encoder step
51+
"""
52+
expected = np.array([[0, 10, 11, 12, 0], [0, 20, 0, 0, 0]], dtype=np.int64)
53+
54+
np.testing.assert_array_equal(pre_ids_all.numpy(), expected)
55+
np.testing.assert_array_equal(seq_lens_this_time.numpy(), [1, 1])
56+
57+
def test_stop_flags(self):
58+
"""
59+
batch0 is skipped (stop_flags=True), batch1 updates normally
60+
"""
61+
bs = 2
62+
pre_id_length = 4
63+
draft_tokens = paddle.to_tensor([[5, 6], [7, 8]], dtype="int64")
64+
pre_ids_all = paddle.zeros([bs, pre_id_length], dtype="int64")
65+
stop_flags = paddle.to_tensor([True, False], dtype="bool")
66+
seq_lens_this_time = paddle.to_tensor([2, 2], dtype="int32")
67+
seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32")
68+
seq_lens_decoder = paddle.to_tensor([0, 0], dtype="int32")
69+
step_idx = paddle.to_tensor([1, 2], dtype="int64")
70+
71+
draft_model_set_value_by_flags(
72+
draft_tokens, pre_ids_all, stop_flags, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, step_idx
73+
)
74+
75+
"""
76+
batch0: no update due to stop flag
77+
batch1: 2 tokens updated at decoder step
78+
"""
79+
expected = np.array([[0, 0, 0, 0], [0, 7, 8, 0]], dtype=np.int64)
80+
81+
np.testing.assert_array_equal(pre_ids_all.numpy(), expected)
82+
np.testing.assert_array_equal(seq_lens_this_time.numpy(), [2, 1])
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)