Skip to content

Commit 0753757

Browse files
authored
Merge branch 'develop' into fix9
2 parents fd0af67 + fe5d09f commit 0753757

22 files changed

+785
-39
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
986986
py::arg("recv_expert_count"), py::arg("block_size"),
987987
"per token per block quant");
988988

989+
#ifdef ENABLE_MACHETE
989990
/*machete/machete_mm.cu
990991
* machete_mm
991992
*/
@@ -1004,6 +1005,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
10041005
* machete_supported_schedules
10051006
*/
10061007
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
1008+
#endif
10071009

10081010
/**
10091011
* moe/fused_moe/moe_topk_select.cu

custom_ops/setup_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def find_end_files(directory, end_str):
373373
if not os.listdir(json_dir):
374374
raise ValueError("Git clone nlohmann_json failed!")
375375

376+
cc_compile_args = []
376377
nvcc_compile_args = get_gencode_flags(archs)
377378
nvcc_compile_args += ["-DPADDLE_DEV"]
378379
nvcc_compile_args += ["-DPADDLE_ON_INFERENCE"]
@@ -519,12 +520,13 @@ def find_end_files(directory, end_str):
519520
sources += find_end_files("gpu_ops/wfp8afp8_sparse_gemm", ".cu")
520521
os.system("python gpu_ops/machete/generate.py")
521522
sources += find_end_files("gpu_ops/machete", ".cu")
523+
cc_compile_args += ["-DENABLE_MACHETE"]
522524

523525
setup(
524526
name="fastdeploy_ops",
525527
ext_modules=CUDAExtension(
526528
sources=sources,
527-
extra_compile_args={"nvcc": nvcc_compile_args},
529+
extra_compile_args={"cxx": cc_compile_args, "nvcc": nvcc_compile_args},
528530
libraries=["cublasLt"],
529531
extra_link_args=["-lcuda"],
530532
),
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import asyncio
18+
19+
from fastdeploy.input.tokenzier_client import (
20+
AsyncTokenizerClient,
21+
ImageDecodeRequest,
22+
ImageEncodeRequest,
23+
VideoEncodeRequest,
24+
)
25+
26+
27+
async def main():
28+
"""
29+
测试AsyncTokenizerClient类
30+
"""
31+
base_url = "http://example.com/"
32+
33+
client = AsyncTokenizerClient(base_url=base_url)
34+
35+
# # 测试图片编码请求
36+
image_encode_request = ImageEncodeRequest(
37+
version="v1", req_id="req_image_001", is_gen=False, resolution=512, image_url="http://example.com/image.jpg"
38+
)
39+
40+
image_encode_ret = await client.encode_image(image_encode_request)
41+
print(f"Image encode result:{image_encode_ret}")
42+
43+
# 测试视频编码请求
44+
video_encode_req = VideoEncodeRequest(
45+
version="v1",
46+
req_id="req_video_001",
47+
video_url="http://example.com/video.mp4",
48+
is_gen=False,
49+
resolution=1024,
50+
start_ts=0,
51+
end_ts=5,
52+
frames=1,
53+
)
54+
video_encode_result = await client.encode_video(video_encode_req)
55+
print(f"Video Encode Result:{video_encode_result}")
56+
# 测试图片解码请求
57+
with open("./image_decode_demo.json", "r", encoding="utf-8") as file:
58+
import json
59+
import time
60+
61+
start_time = time.time()
62+
start_process_time = time.process_time() # 记录开始时间
63+
json_data = json.load(file)
64+
image_decoding_request = ImageDecodeRequest(req_id="req_image_001", data=json_data.get("data"))
65+
# import pdb; pdb.set_trace()
66+
image_decode_result = await client.decode_image(image_decoding_request)
67+
print(f"Image decode result:{image_decode_result}")
68+
elapsed_time = time.time() - start_time
69+
elapsed_process_time = time.process_time() - start_process_time
70+
print(f"decode elapsed_time: {elapsed_time:.6f}s, elapsed_process_time: {elapsed_process_time:.6f}s")
71+
72+
73+
if __name__ == "__main__":
74+
asyncio.run(main())

fastdeploy/engine/args_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ class EngineArgs:
7171
"""
7272
The name or path of the tokenizer (defaults to model path if not provided).
7373
"""
74+
tokenizer_base_url: str = None
75+
"""
76+
The base URL of the remote tokenizer service (used instead of local tokenizer if provided).
77+
"""
7478
max_model_len: int = 2048
7579
"""
7680
Maximum context length supported by the model.
@@ -426,6 +430,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
426430
default=EngineArgs.tokenizer,
427431
help="Tokenizer name or path (defaults to model path if not specified).",
428432
)
433+
model_group.add_argument(
434+
"--tokenizer-base-url",
435+
type=nullable_str,
436+
default=EngineArgs.tokenizer_base_url,
437+
help="The base URL of the remote tokenizer service (used instead of local tokenizer if provided).",
438+
)
429439
model_group.add_argument(
430440
"--max-model-len",
431441
type=int,

fastdeploy/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _setting_environ_variables(self):
377377
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
378378
"FLAGS_use_append_attn": 1,
379379
"NCCL_ALGO": "Ring",
380-
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 32768)),
380+
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
381381
"FLAGS_hardamard_moe_block_size": int(os.getenv("FLAGS_hardamard_moe_block_size", 128)),
382382
"FLAGS_hardamard_use_diagonal_block_matrix": int(
383383
os.getenv("FLAGS_hardamard_use_diagonal_block_matrix", 0)

fastdeploy/entrypoints/chat_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
1718
import uuid
1819
from copy import deepcopy
1920
from pathlib import Path
@@ -162,9 +163,15 @@ def parse_chat_messages(messages):
162163

163164
def load_chat_template(
164165
chat_template: Union[Path, str],
166+
model_path: Path = None,
165167
is_literal: bool = False,
166168
) -> Optional[str]:
167169
if chat_template is None:
170+
if model_path:
171+
chat_template_file = os.path.join(model_path, "chat_template.jinja")
172+
if os.path.exists(chat_template_file):
173+
with open(chat_template_file) as f:
174+
return f.read()
168175
return None
169176
if is_literal:
170177
if isinstance(chat_template, Path):

fastdeploy/entrypoints/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102
self.master_node_ip = self.llm_engine.cfg.master_ip
103103
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
104104
self._receive_output_thread.start()
105-
self.chat_template = load_chat_template(chat_template)
105+
self.chat_template = load_chat_template(chat_template, model)
106106

107107
def _check_master(self):
108108
"""

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@
7777
help="max waiting time for connection, if set value -1 means no waiting time limit",
7878
)
7979
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
80+
parser.add_argument(
81+
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
82+
)
8083
parser = EngineArgs.add_cli_args(parser)
8184
args = parser.parse_args()
8285
args.model = retrive_model_from_server(args.model, args.revision)
83-
chat_template = load_chat_template(args.chat_template)
86+
chat_template = load_chat_template(args.chat_template, args.model)
8487
if args.tool_parser_plugin:
8588
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
8689
llm_engine = None
@@ -176,7 +179,14 @@ async def lifespan(app: FastAPI):
176179
)
177180
app.state.model_handler = model_handler
178181
chat_handler = OpenAIServingChat(
179-
engine_client, app.state.model_handler, pid, args.ips, args.max_waiting_time, chat_template
182+
engine_client,
183+
app.state.model_handler,
184+
pid,
185+
args.ips,
186+
args.max_waiting_time,
187+
chat_template,
188+
args.enable_mm_output,
189+
args.tokenizer_base_url,
180190
)
181191
completion_handler = OpenAIServingCompletion(
182192
engine_client,

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ class ChatMessage(BaseModel):
163163
Chat message.
164164
"""
165165

166-
role: str
167-
content: str
166+
role: Optional[str] = None
167+
content: Optional[str] = None
168+
multimodal_content: Optional[List[Any]] = None
168169
reasoning_content: Optional[str] = None
169170
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
170171
prompt_token_ids: Optional[List[int]] = None
@@ -226,6 +227,7 @@ class DeltaMessage(BaseModel):
226227

227228
role: Optional[str] = None
228229
content: Optional[str] = None
230+
multimodal_content: Optional[List[Any]] = None
229231
prompt_token_ids: Optional[List[int]] = None
230232
completion_token_ids: Optional[List[int]] = None
231233
reasoning_content: Optional[str] = None
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Any, List, Optional
18+
19+
from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest
20+
21+
22+
class ChatResponseProcessor:
23+
"""
24+
A decoder class to build multimodal content (text/image) from token_ids.
25+
26+
Attributes:
27+
eoi_token_id: Token ID indicating the end of an image (<eoi>).
28+
"""
29+
30+
def __init__(
31+
self,
32+
data_processor,
33+
enable_mm_output: Optional[bool] = False,
34+
eoi_token_id: Optional[int] = 101032,
35+
eos_token_id: Optional[int] = 2,
36+
decoder_base_url: Optional[str] = None,
37+
):
38+
self.data_processor = data_processor
39+
self.enable_mm_output = enable_mm_output
40+
self.eoi_token_id = eoi_token_id
41+
self.eos_token_id = eos_token_id
42+
if decoder_base_url is not None:
43+
self.decoder_client = AsyncTokenizerClient(base_url=decoder_base_url)
44+
self._mm_buffer: List[Any] = [] # Buffer for accumulating image token_ids
45+
self._end_image_code_request_output: Optional[Any] = None
46+
self._multipart_buffer = []
47+
48+
def enable_multimodal_content(self):
49+
return self.enable_mm_output
50+
51+
def accumulate_token_ids(self, request_output):
52+
decode_type = request_output["outputs"].get("decode_type", 0)
53+
54+
if not self._multipart_buffer:
55+
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
56+
else:
57+
last_part = self._multipart_buffer[-1]
58+
59+
if last_part["decode_type"] == decode_type:
60+
last_token_ids = last_part["request_output"]["outputs"]["token_ids"]
61+
last_token_ids.extend(request_output["outputs"]["token_ids"])
62+
request_output["outputs"]["token_ids"] = last_token_ids
63+
last_part["request_output"] = request_output
64+
else:
65+
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
66+
67+
async def process_response_chat(self, request_outputs, stream, enable_thinking, include_stop_str_in_output):
68+
"""
69+
Process a list of responses into a generator that yields each processed response as it's generated.
70+
Args:
71+
request_outputs: The list of outputs to be processed.
72+
stream: Whether or not to stream the output.
73+
enable_thinking: Whether or not to show thinking messages.
74+
include_stop_str_in_output: Whether or not to include stop strings in the output.
75+
"""
76+
for request_output in request_outputs:
77+
if not self.enable_mm_output:
78+
yield self.data_processor.process_response_dict(
79+
response_dict=request_output,
80+
stream=stream,
81+
enable_thinking=enable_thinking,
82+
include_stop_str_in_output=include_stop_str_in_output,
83+
)
84+
elif stream:
85+
decode_type = request_output["outputs"].get("decode_type", 0)
86+
token_ids = request_output["outputs"]["token_ids"]
87+
if decode_type == 0:
88+
if self.eoi_token_id and self.eoi_token_id in token_ids:
89+
if self._mm_buffer:
90+
all_tokens = self._mm_buffer
91+
self._mm_buffer = []
92+
image = {"type": "image"}
93+
if self.decoder_client:
94+
req_id = request_output["request_id"]
95+
image_ret = await self.decoder_client.decode_image(
96+
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
97+
)
98+
image["url"] = image_ret["http_url"]
99+
image_output = self._end_image_code_request_output
100+
image_output["outputs"]["multipart"] = [image]
101+
image_output["outputs"]["token_ids"] = all_tokens
102+
yield image_output
103+
104+
self.data_processor.process_response_dict(
105+
response_dict=request_output,
106+
stream=stream,
107+
enable_thinking=enable_thinking,
108+
include_stop_str_in_output=include_stop_str_in_output,
109+
)
110+
text = {"type": "text", "text": request_output["outputs"]["text"]}
111+
request_output["outputs"]["multipart"] = [text]
112+
yield request_output
113+
114+
elif decode_type == 1:
115+
self._mm_buffer.extend(token_ids)
116+
self._end_image_code_request_output = request_output
117+
else:
118+
self.accumulate_token_ids(request_output)
119+
token_ids = request_output["outputs"]["token_ids"]
120+
if token_ids[-1] == self.eos_token_id:
121+
multipart = []
122+
for part in self._multipart_buffer:
123+
if part["decode_type"] == 0:
124+
self.data_processor.process_response_dict(
125+
response_dict=part["request_output"],
126+
stream=False,
127+
enable_thinking=enable_thinking,
128+
include_stop_str_in_output=include_stop_str_in_output,
129+
)
130+
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
131+
multipart.append(text)
132+
elif part["decode_type"] == 1:
133+
image = {"type": "image"}
134+
if self.decoder_client:
135+
req_id = part["request_output"]["request_id"]
136+
all_tokens = part["request_output"]["outputs"]["token_ids"]
137+
image_ret = await self.decoder_client.decode_image(
138+
request=ImageDecodeRequest(req_id=req_id, data=all_tokens)
139+
)
140+
image["url"] = image_ret["http_url"]
141+
multipart.append(image)
142+
143+
lasrt_request_output = self._multipart_buffer[-1]["request_output"]
144+
lasrt_request_output["outputs"]["multipart"] = multipart
145+
yield lasrt_request_output

0 commit comments

Comments
 (0)