Skip to content

Commit 1f8fde2

Browse files
committed
add unit test for fastdeploy/inputs/
1 parent b6cd3ae commit 1f8fde2

File tree

4 files changed

+301
-0
lines changed

4 files changed

+301
-0
lines changed

tests/input/test_ernie_processor.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,70 @@ def test_process_request_dict(self):
7474
result = self.processor.process_request_dict(request_dict, 100)
7575
self.assertEqual(result["prompt_token_ids"], [1])
7676

77+
def test_process_response_dict(self):
78+
"""测试 process_response_dict 根据 stream 参数调用正确的子方法"""
79+
response_dict = {"finished": True, "request_id": "req2", "outputs": {"token_ids": [4, 5]}}
80+
81+
# 模拟两个子方法
82+
self.processor.process_response_dict_streaming = MagicMock(return_value={"result": "stream"})
83+
self.processor.process_response_dict_normal = MagicMock(return_value={"result": "normal"})
84+
85+
# 情况1:stream=True
86+
result_stream = self.processor.process_response_dict(response_dict, stream=True)
87+
self.processor.process_response_dict_streaming.assert_called_once_with(response_dict)
88+
self.assertEqual(result_stream["result"], "stream")
89+
90+
# 情况2:stream=False
91+
result_normal = self.processor.process_response_dict(response_dict, stream=False)
92+
self.processor.process_response_dict_normal.assert_called_once_with(response_dict)
93+
self.assertEqual(result_normal["result"], "normal")
94+
95+
def test_process_response(self):
96+
"""测试 process_response 对完整响应的处理逻辑"""
97+
# 构造 mock response_dict 对象
98+
mock_outputs = MagicMock()
99+
mock_outputs.token_ids = [10, 20, self.processor.tokenizer.eos_token_id]
100+
mock_outputs.index = 2
101+
mock_response_dict = MagicMock()
102+
mock_response_dict.request_id = "req3"
103+
mock_response_dict.outputs = mock_outputs
104+
105+
# 模拟 tokenizer.decode
106+
self.processor.tokenizer.decode = MagicMock(return_value="decoded_text")
107+
108+
# 模拟 reasoning_parser
109+
mock_reasoning_parser = MagicMock()
110+
mock_reasoning_parser.extract_reasoning_content.return_value = ("reasoning_content", "pure_text")
111+
self.processor.reasoning_parser = mock_reasoning_parser
112+
113+
# 模拟 tool_parser
114+
mock_tool_parser = MagicMock()
115+
mock_tool_parser.extract_tool_calls.return_value = MagicMock(
116+
tools_called=False, tool_calls=None, content="tool_text"
117+
)
118+
self.processor.tool_parser_obj = MagicMock(return_value=mock_tool_parser)
119+
120+
# 调用方法
121+
result = self.processor.process_response(mock_response_dict)
122+
123+
# 验证 tokenizer.decode 被正确调用(去掉 eos_token)
124+
self.processor.tokenizer.decode.assert_called_once_with([10, 20])
125+
126+
# 验证 reasoning_parser 被调用并正确赋值
127+
mock_reasoning_parser.extract_reasoning_content.assert_called_once()
128+
self.assertEqual(result.outputs.text, "pure_text")
129+
self.assertEqual(result.outputs.reasoning_content, "reasoning_content")
130+
131+
# 验证 usage 被正确赋值
132+
self.assertIn("completion_tokens", result.usage)
133+
self.assertEqual(result.usage["completion_tokens"], 3)
134+
135+
# 验证 tool_parser 被正确调用
136+
mock_tool_parser.extract_tool_calls.assert_called_once()
137+
138+
# 验证返回结果不为 None
139+
self.assertIsNotNone(result)
140+
77141

78142
if __name__ == "__main__":
79143
unittest.main()
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from fastdeploy.engine.request import Request
5+
6+
# 导入被测类
7+
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
8+
9+
10+
class TestErnie4_5_VLProcessor(unittest.TestCase):
11+
"""测试 Ernie4_5_VLProcessor 的主要功能"""
12+
13+
def setUp(self):
14+
"""初始化一个带有 mock 依赖的 Processor"""
15+
# patch DataProcessor,防止真实加载 tokenizer 或模型
16+
dp_patcher = patch("fastdeploy.input.ernie4_5_vl_processor.DataProcessor")
17+
self.addCleanup(dp_patcher.stop)
18+
MockDP = dp_patcher.start()
19+
20+
# 模拟 DataProcessor 行为
21+
self.mock_dp = MockDP.return_value
22+
self.mock_dp.eval.return_value = None
23+
self.mock_dp.text2ids.return_value = {
24+
"input_ids": [1, 2, 3],
25+
"token_type_ids": [0, 0, 0],
26+
"position_ids": [[0, 0, 0]],
27+
"images": [],
28+
"grid_thw": [],
29+
"image_type_ids": [],
30+
"cur_position": 3,
31+
}
32+
self.mock_dp.request2ids.return_value = self.mock_dp.text2ids.return_value
33+
self.mock_dp.image_patch_id = 999
34+
self.mock_dp.spatial_conv_size = 64
35+
self.mock_dp.tokenizer = MagicMock()
36+
self.mock_dp.tokenizer.pad_token_id = 0
37+
self.mock_dp.tokenizer.eos_token_id = 2
38+
39+
# patch GenerationConfig
40+
gen_patcher = patch("fastdeploy.input.ernie4_5_vl_processor.GenerationConfig.from_pretrained")
41+
self.addCleanup(gen_patcher.stop)
42+
gen_patcher.start()
43+
44+
# patch Request.from_dict 避免真实依赖
45+
req_patcher = patch("fastdeploy.input.ernie4_5_vl_processor.Request.from_dict")
46+
self.addCleanup(req_patcher.stop)
47+
self.mock_from_dict = req_patcher.start()
48+
self.mock_from_dict.side_effect = lambda d: Request(d)
49+
50+
# 创建 Processor 实例
51+
self.processor = Ernie4_5_VLProcessor(model_name_or_path="mock_path")
52+
53+
# mock 父类 tokenizer
54+
self.processor.tokenizer = MagicMock()
55+
self.processor.tokenizer.eos_token_id = 2
56+
self.processor.tokenizer.pad_token_id = 0
57+
self.processor.tokenizer.decode = MagicMock(return_value="decoded text")
58+
59+
# ----------------------------- #
60+
# 测试 process_request_dict
61+
# ----------------------------- #
62+
def test_process_request_dict_with_prompt(self):
63+
"""测试含 prompt 的请求"""
64+
req = {"prompt": "hello world"}
65+
result = self.processor.process_request_dict(req, max_model_len=10)
66+
67+
self.assertIsInstance(result, dict)
68+
self.assertIn("prompt_token_ids", result)
69+
self.assertIsInstance(result["prompt_token_ids"], list)
70+
self.assertIn("multimodal_inputs", result)
71+
self.assertIsInstance(result["multimodal_inputs"], dict)
72+
self.assertEqual(result["prompt_token_ids_len"], len(result["prompt_token_ids"]))
73+
74+
def test_process_request_dict_with_messages(self):
75+
"""测试含 messages 的请求"""
76+
req = {"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]}
77+
result = self.processor.process_request_dict(req)
78+
self.assertIn("prompt_token_ids", result)
79+
self.assertIn("multimodal_inputs", result)
80+
81+
# ----------------------------- #
82+
# 测试 process_request
83+
# ----------------------------- #
84+
def test_process_request(self):
85+
"""测试 process_request 能调用 process_request_dict 并返回正确的 Request"""
86+
# 模拟 Request 对象
87+
mock_request = MagicMock()
88+
mock_request.to_dict.return_value = {"prompt": "test prompt"}
89+
self.processor.process_request_dict = MagicMock(
90+
return_value={"prompt": "test prompt", "prompt_token_ids": [1, 2]}
91+
)
92+
self.processor._apply_default_parameters = MagicMock(
93+
return_value=Request({"prompt": "test prompt", "prompt_token_ids": [1, 2]})
94+
)
95+
96+
result = self.processor.process_request(mock_request, max_model_len=10)
97+
self.processor.process_request_dict.assert_called_once()
98+
self.processor._apply_default_parameters.assert_called_once()
99+
self.assertIsInstance(result, Request)
100+
self.assertEqual(result.data["prompt_token_ids"], [1, 2])
101+
102+
# ----------------------------- #
103+
# 测试 process_response
104+
# ----------------------------- #
105+
def test_process_response(self):
106+
"""测试继承自父类的 process_response"""
107+
response_dict = MagicMock()
108+
response_dict.request_id = "123"
109+
response_dict.outputs = MagicMock()
110+
response_dict.outputs.token_ids = [1, 2, 3]
111+
response_dict.outputs.index = 2
112+
113+
result = self.processor.process_response(response_dict)
114+
self.assertIsNotNone(result)
115+
self.assertEqual(result.outputs.text, "decoded text")
116+
self.processor.tokenizer.decode.assert_called_once()
117+
118+
# ----------------------------- #
119+
# 测试 process_response_dict
120+
# ----------------------------- #
121+
def test_process_response_dict_non_stream(self):
122+
"""测试非流式返回"""
123+
mock_normal = MagicMock(return_value={"text": "done"})
124+
self.processor.process_response_dict_normal = mock_normal
125+
126+
response = {"outputs": {"token_ids": [1, 2, 3]}, "finished": True, "request_id": "req_1"}
127+
result = self.processor.process_response_dict(response, stream=False)
128+
mock_normal.assert_called_once()
129+
self.assertEqual(result, {"text": "done"})
130+
131+
def test_process_response_dict_stream(self):
132+
"""测试流式返回"""
133+
mock_stream = MagicMock(return_value={"delta": "ok"})
134+
self.processor.process_response_dict_streaming = mock_stream
135+
136+
response = {"outputs": {"token_ids": [1, 2, 3]}, "finished": True, "request_id": "req_2"}
137+
result = self.processor.process_response_dict(response, stream=True)
138+
mock_stream.assert_called_once()
139+
self.assertEqual(result, {"delta": "ok"})
140+
141+
142+
if __name__ == "__main__":
143+
unittest.main()

tests/input/test_qwen_vl_processor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,62 @@ def test_process_request_dict(self):
207207
self.assertEqual(result["multimodal_inputs"]["pic_cnt"], 1)
208208
self.assertEqual(result["multimodal_inputs"]["video_cnt"], 1)
209209

210+
def test_process_response_dict(self):
211+
"""
212+
Test processing of a response dictionary through the processor.
213+
214+
Ensures:
215+
1. The returned dict contains expected keys
216+
2. Multimodal outputs are preserved
217+
3. Text output matches expected decoded tokens
218+
"""
219+
# Mock output from model
220+
response_dict = {
221+
"request_id": "12345",
222+
"responses": [{"text": "This is a test response."}],
223+
"multimodal_outputs": {"images": np.random.randint(0, 256, (1, 3, 224, 224))},
224+
}
225+
226+
# Patch processor's ids2tokens to return expected text
227+
with patch.object(
228+
self.processor.processor, "ids2tokens", return_value=["This", "is", "a", "test", "response", "."]
229+
):
230+
processed = self.processor.process_response_dict(response_dict)
231+
232+
# Assertions
233+
self.assertIn("request_id", processed)
234+
self.assertIn("responses", processed)
235+
self.assertIn("multimodal_outputs", processed)
236+
self.assertEqual(processed["responses"][0]["text"], "This is a test response.")
237+
238+
def test_process_response(self):
239+
"""
240+
Test processing of a Response object through the processor.
241+
242+
Ensures:
243+
1. Returns a Request object
244+
2. Response text is correctly decoded
245+
3. Multimodal outputs are preserved
246+
"""
247+
from fastdeploy.engine.response import Response
248+
249+
# Mock a Response object
250+
response = Response(
251+
request_id="12345",
252+
responses=[{"text": "Another test response"}],
253+
multimodal_outputs={"images": np.random.randint(0, 256, (1, 3, 224, 224))},
254+
)
255+
256+
# Patch ids2tokens
257+
with patch.object(self.processor.processor, "ids2tokens", return_value=["Another", "test", "response"]):
258+
processed = self.processor.process_response(response)
259+
260+
# Assertions
261+
self.assertIsInstance(processed, Request)
262+
self.assertEqual(processed.responses[0]["text"], "Another test response")
263+
self.assertTrue("multimodal_outputs" in processed.__dict__)
264+
self.assertEqual(processed.multimodal_outputs["images"].shape, response.multimodal_outputs["images"].shape)
265+
210266
def test_prompt(self):
211267
"""
212268
Test processing of prompt with image and video placeholders

tests/input/test_text_processor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,44 @@ def test_process_request_dict(self):
5858
result = self.processor.process_request_dict(request_dict, 100)
5959
self.assertEqual(result["prompt_token_ids"], [1])
6060

61+
def test_process_response_dict(self):
62+
# ===== 测试 streaming 分支 =====
63+
response_stream = {
64+
"request_id": "req_stream",
65+
"outputs": {"token_ids": [5, 6, 7]},
66+
"finished": False,
67+
}
68+
# mock ids2tokens 行为
69+
self.processor.ids2tokens = MagicMock(return_value=("delta", [5, 6], "prev"))
70+
# 确保 streaming 调用
71+
result_stream = self.processor.process_response_dict(response_stream, stream=True)
72+
self.assertIn("outputs", result_stream)
73+
self.assertEqual(result_stream["outputs"]["raw_prediction"], "delta")
74+
75+
# ===== 测试 normal 分支 =====
76+
response_normal = {
77+
"request_id": "req_normal",
78+
"outputs": {"token_ids": [8, 9, 1]}, # 含 eos_token_id
79+
"finished": True,
80+
}
81+
# mock ids2tokens 行为
82+
self.processor.ids2tokens = MagicMock(return_value=("delta", [8, 9], "prev"))
83+
self.processor.decode_status["req_normal"] = [0, 0, [], ""]
84+
result_normal = self.processor.process_response_dict(response_normal, stream=False)
85+
self.assertIn("text", result_normal["outputs"])
86+
self.assertEqual(result_normal["outputs"]["text"], "prevdelta")
87+
88+
def test_process_response(self):
89+
# 模拟 response_dict 结构
90+
response_mock = MagicMock()
91+
response_mock.request_id = "req1"
92+
response_mock.outputs = MagicMock()
93+
response_mock.outputs.token_ids = [2, 3, 1] # 含有 eos_token_id
94+
# decode 应该去掉 eos_token_id 并返回 "decoded text"
95+
result = self.processor.process_response(response_mock)
96+
self.processor.tokenizer.decode.assert_called_with([2, 3])
97+
self.assertEqual(result.outputs.text, "decoded text")
98+
6199

62100
if __name__ == "__main__":
63101
unittest.main()

0 commit comments

Comments
 (0)