Skip to content

Commit b6cd3ae

Browse files
[Feature] support fd return decode response (#4407)
* [Feature] support fd return decode response * Resolving conflicts * fix * fix * fix * fix * fix --------- Co-authored-by: YuBaoku <[email protected]>
1 parent cd9195d commit b6cd3ae

File tree

4 files changed

+64
-22
lines changed

4 files changed

+64
-22
lines changed

fastdeploy/engine/common_engine.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fastdeploy.engine.request import Request, RequestOutput, RequestType
3434
from fastdeploy.engine.resource_manager import ResourceManager
3535
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
36+
from fastdeploy.input.preprocess import InputPreprocessor
3637
from fastdeploy.inter_communicator import (
3738
EngineCacheQueue,
3839
EngineWorkerQueue,
@@ -149,6 +150,16 @@ def start(self):
149150
if self.cfg.scheduler_config.splitwise_role != "mixed":
150151
self.split_mode_get_tasks()
151152

153+
def create_data_processor(self):
154+
self.input_processor = InputPreprocessor(
155+
self.cfg.model_config,
156+
self.cfg.structured_outputs_config.reasoning_parser,
157+
self.cfg.limit_mm_per_prompt,
158+
self.cfg.mm_processor_kwargs,
159+
self.cfg.tool_parser,
160+
)
161+
self.data_processor = self.input_processor.create_processor()
162+
152163
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
153164
current_suffix = int(
154165
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
@@ -831,9 +842,23 @@ def _insert_zmq_task_to_scheduler(self):
831842
f"traceback={traceback.format_exc()}"
832843
)
833844

845+
def _decode_token(self, token_ids, req_id, is_end):
846+
delta_text = ""
847+
if envs.FD_ENABLE_RETURN_TEXT:
848+
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
849+
if delta_text != "":
850+
prefix_offset = self.data_processor.decode_status[req_id][0]
851+
read_offset = self.data_processor.decode_status[req_id][1]
852+
token_ids = cum_tokens[prefix_offset:read_offset]
853+
else:
854+
token_ids = []
855+
if is_end:
856+
del self.data_processor.decode_status[req_id]
857+
return delta_text, token_ids
858+
834859
def _zmq_send_generated_tokens(self):
835860
"""
836-
Receive output for zmq
861+
Recieve output for zmq
837862
"""
838863
while self.running:
839864
try:
@@ -842,10 +867,31 @@ def _zmq_send_generated_tokens(self):
842867
time.sleep(0.005)
843868
continue
844869
for request_id, contents in results.items():
845-
self.send_response_server.send_response(request_id, contents)
846-
870+
new_contents = []
871+
for content in contents:
872+
decode_type = content.outputs.decode_type
873+
delta_text = ""
874+
if decode_type == 0:
875+
delta_text, token_ids = self._decode_token(
876+
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
877+
)
878+
else:
879+
token_ids = content.outputs.token_ids
880+
if len(token_ids):
881+
content.outputs.token_ids = token_ids
882+
content.outputs.text = delta_text
883+
new_contents.append(content)
884+
elif content.finished:
885+
new_contents.append(content)
886+
else:
887+
llm_logger.warning(
888+
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
889+
)
890+
if len(new_contents):
891+
llm_logger.info(f"Send response for request id: {request_id}")
892+
self.send_response_server.send_response(request_id, new_contents)
847893
except Exception as e:
848-
self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
894+
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
849895

850896
def split_mode_get_tasks(self):
851897
"""

fastdeploy/engine/engine.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from fastdeploy.engine.common_engine import EngineService
3939
from fastdeploy.engine.expert_service import start_data_parallel_service
4040
from fastdeploy.engine.request import Request
41-
from fastdeploy.input.preprocess import InputPreprocessor
4241
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
4342
from fastdeploy.metrics.metrics import main_process_metrics
4443
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
@@ -87,13 +86,6 @@ def __init__(self, cfg):
8786
self.running = True
8887
self.is_started = False
8988

90-
self.input_processor = InputPreprocessor(
91-
cfg.model_config,
92-
cfg.structured_outputs_config.reasoning_parser,
93-
cfg.limit_mm_per_prompt,
94-
cfg.mm_processor_kwargs,
95-
cfg.tool_parser,
96-
)
9789
self.engine = EngineService(cfg)
9890

9991
if self.cfg.cache_config.num_gpu_blocks_override is None:
@@ -117,12 +109,12 @@ def start(self, api_server_pid=None):
117109
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
118110
self._init_worker_signals()
119111

120-
self.data_processor = self.input_processor.create_processor()
121-
self.engine.data_processor = self.data_processor
122112
# Launch components: scheduler, cache_manager, expert_service et.al.
123113
self.launch_components()
124114

125115
self.engine.start()
116+
self.engine.create_data_processor()
117+
self.data_processor = self.engine.data_processor
126118

127119
# If block numer is specified and model is deployed in mixed mode, start cache manager first
128120
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
@@ -246,7 +238,7 @@ def add_requests(self, task, sampling_params=None, **kwargs):
246238
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
247239
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
248240
kwargs["chat_template_kwargs"] = chat_template_kwargs
249-
request = self.data_processor.process_request(request, self.cfg.model_config.max_model_len, **kwargs)
241+
request = self.engine.data_processor.process_request(request, self.cfg.model_config.max_model_len, **kwargs)
250242
request.prompt_token_ids_len = len(request.prompt_token_ids)
251243
request.need_prefill_tokens = request.prompt_token_ids_len
252244
input_ids_len = request.prompt_token_ids_len
@@ -482,9 +474,9 @@ def _start_worker_service(self):
482474
py_script = os.path.join(current_dir_path, worker_path)
483475

484476
ori_vocab_size = (
485-
len(self.data_processor.tokenizer.sp_model)
486-
if hasattr(self.data_processor.tokenizer, "sp_model")
487-
else len(self.data_processor.tokenizer.vocab)
477+
len(self.engine.data_processor.tokenizer.sp_model)
478+
if hasattr(self.engine.data_processor.tokenizer, "sp_model")
479+
else len(self.engine.data_processor.tokenizer.vocab)
488480
)
489481

490482
think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
@@ -511,8 +503,8 @@ def _start_worker_service(self):
511503
f" --total_block_num {self.cfg.cache_config.total_block_num}"
512504
f" --block_size {self.cfg.cache_config.block_size}"
513505
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
514-
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
515-
f" --pad_token_id {self.data_processor.pad_token_id}"
506+
f" --eos_tokens_lens {self.engine.data_processor.eos_token_id_len}"
507+
f" --pad_token_id {self.engine.data_processor.pad_token_id}"
516508
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
517509
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
518510
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
@@ -611,15 +603,15 @@ def generate(self, prompts, stream):
611603
for result in self._get_generated_tokens(req_id):
612604
is_end = result.finished
613605
if stream and not is_end:
614-
processed = self.data_processor.process_response(result)
606+
processed = self.engine.data_processor.process_response(result)
615607
if processed is None:
616608
continue
617609
output = processed.to_dict()
618610
yield output
619611

620612
# Exit loop if termination condition is met
621613
if is_end:
622-
processed = self.data_processor.process_response(result)
614+
processed = self.engine.data_processor.process_response(result)
623615
output = processed.to_dict()
624616
llm_logger.debug(f"Generate result: {output}")
625617
if not stream:

fastdeploy/engine/expert_service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def start(
9090

9191
start_time = time.time()
9292
self.engine.start()
93+
if envs.FD_ENABLE_RETURN_TEXT:
94+
self.engine.create_data_processor()
9395
if self.cfg.scheduler_config.name == "dp":
9496
self.cfg.init_cache_info()
9597
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@
118118
"FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))),
119119
# Whether to clear cpu cache when clearing model weights.
120120
"FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")),
121+
# enable return text, used when FD_ENABLE_INTERNAL_ADAPTER=1
122+
"FD_ENABLE_RETURN_TEXT": lambda: bool(int(os.getenv("FD_ENABLE_RETURN_TEXT", "0"))),
121123
# Used to truncate the string inserted during thinking when reasoning in a model. (</think> for ernie4_5_vl, \n</think>\n\n for ernie_x1)
122124
"FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", "</think>"),
123125
# Timeout for cache_transfer_manager process exit

0 commit comments

Comments
 (0)