diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 497e4dd5b9..0380d51c82 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -308,6 +308,7 @@ class CompletionOutput: decode_type: int = 0 logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -322,9 +323,9 @@ def to_dict(self): "index": self.index, "send_idx": self.send_idx, "token_ids": self.token_ids, - "decode_type": self.decode_type, "logprob": self.logprob, "top_logprobs": self.top_logprobs, + "draft_top_logprobs": self.draft_top_logprobs, "logprobs": self.logprobs, "draft_token_ids": self.draft_token_ids, "text": self.text, @@ -350,6 +351,8 @@ def __repr__(self) -> str: f"draft_token_ids={self.draft_token_ids}, " f"reasoning_content={self.reasoning_content!r}, " f"logprobs={self.logprobs}, " + f"top_logprobs={self.top_logprobs}, " + f"draft_top_logprobs={self.draft_top_logprobs}, " ) @@ -434,6 +437,7 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -444,6 +448,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.output_type = output_type self.outputs = outputs self.finished = finished self.metrics = metrics @@ -472,12 +477,21 @@ def add(self, next_output: RequestOutput) -> None: self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids) self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs) self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks) + if next_output.outputs.draft_top_logprobs is not None: + self.outputs.draft_top_logprobs.logprob_token_ids.extend( + next_output.outputs.draft_top_logprobs.logprob_token_ids + ) + self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs) + self.outputs.draft_top_logprobs.sampled_token_ranks.extend( + next_output.outputs.draft_top_logprobs.sampled_token_ranks + ) def __repr__(self) -> str: return ( f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -498,6 +512,7 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index d6c0c74bb5..97f52f5797 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -205,6 +205,7 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -265,6 +266,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -295,6 +297,7 @@ class CompletionResponseChoice(BaseModel): completion_tokens: Optional[str] = None arrival_time: Optional[float] = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None reasoning_content: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None @@ -333,6 +336,7 @@ class CompletionResponseStreamChoice(BaseModel): text: str arrival_time: float = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None prompt_tokens: Optional[str] = None @@ -420,6 +424,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2) logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -555,6 +560,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = Field(None, le=2, ge=-2) logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 7363572905..aaa534228f 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -316,12 +316,18 @@ async def chat_completion_stream_generator( output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] previous_num_tokens[idx] += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -348,6 +354,7 @@ async def chat_completion_stream_generator( index=idx, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -444,7 +451,9 @@ async def chat_completion_full_generator( dealer.write([b"", rid.encode("utf-8")]) previous_num_tokens = [0] * num_choices current_waiting_time = 0 + logprob_contents = [[] for _ in range(num_choices)] + draft_logprob_contents = [[] for _ in range(num_choices)] completion_token_ids = [[] for _ in range(num_choices)] num_cached_tokens = [0] * num_choices response_processor = ChatResponseProcessor( @@ -492,12 +501,23 @@ async def chat_completion_full_generator( # The logprob for handling the response output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents[idx].extend(logprobs_res.content) + + # draft_logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprob_contents[idx].extend(draft_logprobs_res.content) + if data["finished"]: num_choices -= 1 choice = await self._create_chat_completion_choice( diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 94fcc57abb..fcf31f9e5c 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -234,6 +234,7 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -266,12 +267,19 @@ async def completion_full_generator( raise ValueError("{}".format(data["error_msg"])) output = data["outputs"] - output_top_logprobs = output["top_logprobs"] + output_top_logprobs = output.get("top_logprobs") or None + output_draft_top_logprobs = output.get("draft_top_logprobs") or None if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -282,6 +290,7 @@ async def completion_full_generator( if data.get("finished", False): data["output_token_ids"] = output_tokens[rid] data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] + data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] data["outputs"]["token_ids"] = aggregated_token_ids[rid] valid_results[rid] = data num_choices -= 1 @@ -437,10 +446,17 @@ async def completion_stream_generator( await self._process_echo_logic(request, idx, res["outputs"]) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -452,6 +468,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] @@ -541,15 +558,23 @@ def request_output_to_completion_response( final_res = final_res_batch[idx] prompt_token_ids = prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] assert prompt_token_ids is not None + prompt_text = request.prompt completion_token_ids = completion_batched_token_ids[idx] output = final_res["outputs"] - output_top_logprobs = output["top_logprobs"] + output_top_logprobs = output.get("top_logprobs") or None + output_draft_top_logprobs = output.get("draft_top_logprobs") or None aggregated_logprobs: Optional[CompletionLogprobs] = None if output_top_logprobs is not None: aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + aggregated_draft_logprobs: Optional[CompletionLogprobs] = None + if output_draft_top_logprobs is not None: + aggregated_draft_logprobs = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) + if request.echo: prompt_text = self._echo_back_prompt(request, idx // (1 if request.n is None else request.n)) token_ids = [*prompt_token_ids, *output["token_ids"]] @@ -574,6 +599,7 @@ def request_output_to_completion_response( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, + draft_logprobs=aggregated_draft_logprobs, finish_reason=finish_reason, ) choices.append(choice_data) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 7d87033e88..a0894a6f85 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -22,6 +22,7 @@ import weakref from collections import Counter from concurrent.futures import ThreadPoolExecutor +from typing import List import numpy as np import paddle @@ -67,11 +68,20 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.use_logprobs = self.cfg.model_config.enable_logprob if self.speculative_decoding: - self.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) + if self.use_logprobs: + self.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64" + ) + self.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32" + ) + self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64") + else: + self.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) elif self.use_logprobs: self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") @@ -107,6 +117,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -312,6 +323,7 @@ def process_sampling_results(self): get_output_ep, get_output_topk, speculate_get_output, + speculate_get_output_topk, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -319,15 +331,27 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - if ( - self.cfg.parallel_config.enable_expert_parallel - and self.cfg.parallel_config.data_parallel_size > 1 - ): - speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + if self.use_logprobs: + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + is_blocking, + ) + if self.output_tokens[0, 0] == -2: + continue else: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) - if self.output_tokens[0] == -2: - continue + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if self.output_tokens[0] == -2: + continue else: if self.use_logprobs: get_output_topk( @@ -372,7 +396,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result: List[RequestOutput], mtype=3): """ single post-processing function @@ -380,7 +404,28 @@ def postprocess(self, batch_result): batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.use_logprobs: + if mtype == 3: # target + finished_batch_result, unfinished_batch_result = [], [] + for r in batch_result: + (finished_batch_result if r.finished else unfinished_batch_result).append(r) + if finished_batch_result: + self.cached_generated_tokens.put_results(batch_result) + else: + self._batch_result_buffer = unfinished_batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + if self._batch_result_buffer is not None: + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -471,9 +516,25 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = int(self.output_tokens[1, 0].item()) + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( + [batch, MAX_DRAFT_TOKENS, K + 1] + ) + scores = ( + self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] + .numpy() + .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + ) + ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -501,6 +562,8 @@ def _process_batch_output(self): if recovery_stop: llm_logger.info(f"recovery stop signal found at task {task_id}") token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] else: token_ids = tokens[ 2 @@ -556,6 +619,7 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -575,29 +639,54 @@ def _process_batch_output(self): if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) - for token_id in token_ids: + for batch_token_index in range(len(token_ids)): + token_id = token_ids[batch_token_index] self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids): result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: - result.outputs.logprob = float(scores[i, 0]) - # Construct top_logprobs - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - if token_id in task.eos_token_ids or is_prefill or recovery_stop: + if self.cfg.speculative_config.method: + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() + else: + result.outputs.logprob = float(scores[i, 0]) + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop): result.finished = True if recovery_stop: result.error_msg = "Recover is not supported, the result is incomplete!" llm_logger.info( - f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}." + f"Request: {task_id} finished, number of " + f"generated tokens: {self.tokens_counter[task_id]}, token_id:{token_id},is_prefill:{is_prefill},recovery_stop:{recovery_stop}" ) llm_logger.info( f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}" @@ -616,7 +705,7 @@ def _process_batch_output(self): ): batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 7311f358f4..3454c83407 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -94,43 +94,43 @@ async def test_integration_with_chat_stream_generator(self, mock_processor_class response_data = [ { "request_id": "test_request_id_0", - "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None}, + "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1}, "finished": False, }, { "request_id": "test_request_id_0", - "outputs": {"token_ids": [2], "text": "b", "top_logprobs": None}, + "outputs": {"token_ids": [2], "text": "b", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.2, "first_token_time": None}, "finished": False, }, { "request_id": "test_request_id_0", - "outputs": {"token_ids": [3], "text": "c", "top_logprobs": None}, + "outputs": {"token_ids": [3], "text": "c", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.3, "first_token_time": None}, "finished": False, }, { "request_id": "test_request_id_0", - "outputs": {"token_ids": [4], "text": "d", "top_logprobs": None}, + "outputs": {"token_ids": [4], "text": "d", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.4, "first_token_time": None}, "finished": False, }, { "request_id": "test_request_id_0", - "outputs": {"token_ids": [5], "text": "e", "top_logprobs": None}, + "outputs": {"token_ids": [5], "text": "e", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.5, "first_token_time": None}, "finished": False, }, { "request_id": "test_request_id_0", - "outputs": {"token_ids": [6], "text": "f", "top_logprobs": None}, + "outputs": {"token_ids": [6], "text": "f", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.6, "first_token_time": None}, "finished": False, }, { "request_id": "test_request_id_0", - "outputs": {"token_ids": [7], "text": "g", "top_logprobs": None}, + "outputs": {"token_ids": [7], "text": "g", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1}, "finished": True, }, @@ -190,9 +190,9 @@ async def mock_process_response_chat_single(response, stream, enable_thinking, i chunk_dict = json.loads(json_part) parsed_chunks.append(chunk_dict) except json.JSONDecodeError as e: - self.fail(f"Cannot parser {i+1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}") + self.fail(f"Cannot parser {i + 1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}") else: - self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") + self.fail(f"{i + 1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") for chunk_dict in parsed_chunks: choices_list = chunk_dict["choices"] if choices_list[-1].get("finish_reason") is not None: @@ -209,13 +209,13 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): [ { "request_id": "test-request-id_0", - "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None}, + "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1}, "finished": False, }, { "request_id": "test-request-id_0", - "outputs": {"token_ids": [2], "text": "b", "top_logprobs": None}, + "outputs": {"token_ids": [2], "text": "b", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.2, "first_token_time": None}, "finished": False, }, @@ -223,7 +223,7 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): [ { "request_id": "test-request-id_0", - "outputs": {"token_ids": [7], "text": "g", "top_logprobs": None}, + "outputs": {"token_ids": [7], "text": "g", "top_logprobs": None, "draft_top_logprobs": None}, "metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1}, "finished": True, } @@ -269,11 +269,12 @@ async def test_integration_with_completion_stream_generator(self, mock_logger): chunk_dict = json.loads(json_part) parsed_chunks.append(chunk_dict) except json.JSONDecodeError as e: - self.fail(f"Cannot parser {i+1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}") + self.fail(f"Cannot parser {i + 1} chunk, JSON: {e}\n origin string: {repr(chunk_str)}") else: - self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") + self.fail(f"{i + 1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") self.assertEqual(len(parsed_chunks), 1) for chunk_dict in parsed_chunks: + print(f"======>{chunk_dict}") choices_list = chunk_dict["choices"] self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices") self.assertEqual( diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py new file mode 100644 index 0000000000..4fe689e07d --- /dev/null +++ b/tests/output/test_process_batch_output.py @@ -0,0 +1,217 @@ +import random +import time +import unittest +from unittest.mock import Mock + +import paddle + +from fastdeploy.engine.request import RequestOutput +from fastdeploy.output.token_processor import TokenProcessor + +paddle.set_device("cpu") + + +# Mock classes and constants needed for the test +class MockConfig: + class ParallelConfig: + local_data_parallel_id = 0 + + class SpeculativeConfig: + method = None + + class ModelConfig: + enable_logprob = False + + class SchedulerConfig: + name = "default" + + parallel_config = ParallelConfig() + speculative_config = SpeculativeConfig() + model_config = ModelConfig() + scheduler_config = SchedulerConfig() + + +class MockTask: + def __init__(self): + self.request_id = "test_request_1" + self.arrival_time = time.time() + self.inference_start_time = time.time() + self.schedule_start_time = time.time() + self.preprocess_end_time = time.time() - 0.1 + self.preprocess_start_time = time.time() - 0.2 + self.eos_token_ids = [2] + self.output_token_ids = [] + self.messages = "Test prompt" + self.num_cached_tokens = 0 + self.disaggregate_info = None + self.prefill_chunk_info = None + self.prefill_chunk_num = 0 + + def get(self, key: str, default_value=None): + if hasattr(self, key): + return getattr(self, key) + elif hasattr(self.sampling_params, key): + return getattr(self.sampling_params, key) + else: + return default_value + + +class MockResourceManager: + def __init__(self): + self.stop_flags = [False] + self.tasks_list = [MockTask()] + self.to_be_rescheduled_request_id_set = set() + + def info(self): + return "Mock resource manager info" + + def reschedule_preempt_task(self, task_id): + pass + + +class MockCachedGeneratedTokens: + def __init__(self): + self.cache = [] + + def put_results(self, results): + self.cache.extend(results) + + +# Constants +RECOVERY_STOP_SIGNAL = -3 +MAX_BSZ = 512 +K = 20 +MAX_DRAFT_TOKENS = 6 +SPECULATE_MAX_BSZ = 256 + + +class TestTokenProcessorProcessBatchOutput(unittest.TestCase): + def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): + """Helper method to setup TokenProcessor with different configurations""" + cfg = MockConfig() + cfg.speculative_config.method = "mtp" if speculative_decoding else None + cfg.speculative_config.num_speculative_tokens = 1 + cfg.model_config.enable_logprob = use_logprobs + + processor = TokenProcessor.__new__(TokenProcessor) + processor.cfg = cfg + processor.cached_generated_tokens: MockCachedGeneratedTokens = MockCachedGeneratedTokens() + processor.executor = Mock() + processor.engine_worker_queue = Mock() + processor.split_connector = Mock() + processor.resource_manager = MockResourceManager() + task1 = MockTask() + task2 = MockTask() + processor.resource_manager.tasks_list = [task1, task2] + processor.resource_manager.stop_flags = [False, False] + processor.tokens_counter = {task1.request_id: 0, task2.request_id: 0} + processor.total_step = 0 + processor.number_of_output_tokens = 0 + processor.prefill_result_status = {} + processor.use_logprobs = use_logprobs + processor.num_draft_tokens = 0 + processor.num_accepted_tokens = 0 + processor.num_emitted_tokens = 0 + processor.max_num_emitted_tokens = 0 + processor.num_rest_requests_per_head = [ + 0, + ] * MAX_DRAFT_TOKENS + processor.num_accept_requests_per_head = [ + 0, + ] * MAX_DRAFT_TOKENS + processor.speculative_stats_step = 0 + + # processor._recycle_resources = Mock() + + if speculative_decoding: + if use_logprobs: + processor.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], + fill_value=2, + dtype="int64", + ) + processor.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], + fill_value=0.0, + dtype="float32", + ) + processor.output_ranks = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS], + fill_value=0, + dtype="int64", + ) + else: + processor.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) + elif use_logprobs: + processor.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") + processor.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") + processor.output_ranks = paddle.full(shape=[MAX_BSZ], fill_value=0, dtype="int64") + else: + processor.output_tokens = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + + return processor + + def test_speculative_decoding_use_logprobs(self): + """Test basic speculative decoding scenario""" + processor = self.setup_token_processor(speculative_decoding=True, use_logprobs=True) + + # stop_flag + processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) + # mtype target = 3, decode = 4 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # batch + processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) + # accept_num + processor.output_tokens[3, 0].set_tensor(paddle.to_tensor(3)) + processor.output_tokens[4, 0].set_tensor(paddle.to_tensor(3)) + + batch = processor.output_tokens[2, 0] + mtype = processor.output_tokens[3, 0] + accept_num = [int(num[0]) for num in processor.output_tokens[3 : batch + 3]] + + # init + print(f"batch:{batch}, mtype:{mtype} accept_num: {accept_num}") + for i in range(batch): + for j in range(accept_num[i]): + token_index = 3 + MAX_BSZ + i * MAX_DRAFT_TOKENS * (K + 1) + j * (K + 1) + score_index = i * MAX_DRAFT_TOKENS * (K + 1) + j * (K + 1) + print(f"batch:{i}, accept:{j} token_index: {token_index} score_index: {score_index}") + for k in range(K + 1): + processor.output_tokens[token_index + k].set_tensor(paddle.to_tensor(random.randint(100, 100000))) + processor.output_scores[score_index + k].set_tensor(paddle.to_tensor(random.random())) + processor.output_ranks[j].set_tensor(paddle.to_tensor(1)) + + processor._process_batch_output() + + batch_result_buffer: list[RequestOutput] = processor._batch_result_buffer + + for i, request_output in enumerate(batch_result_buffer): + assert isinstance(request_output, RequestOutput) + assert len(request_output.outputs.token_ids) == accept_num[i] + assert len(request_output.outputs.top_logprobs) == 3 + # tokens, scores, ranks + assert len(request_output.outputs.top_logprobs[0][0]) == K + 1 + assert len(request_output.outputs.top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.top_logprobs[2]) == accept_num[i] + + # mtype = 4 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4)) + processor._process_batch_output() + cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens + for c in cached_generated_tokens.cache: + assert isinstance(request_output, RequestOutput) + assert len(request_output.outputs.token_ids) == accept_num[i] + assert len(request_output.outputs.top_logprobs) == 3 + assert len(request_output.outputs.draft_top_logprobs) == 3 + # tokens, scores, ranks + assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1 + assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] + + +if __name__ == "__main__": + unittest.main(verbosity=2, buffer=False)