Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@

import argparse
import asyncio
import copy
import functools
import heapq
import ipaddress
import json
import os
import sys
import uuid
Expand Down Expand Up @@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request):
request_length = len(req_body)
request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api)
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
Expand All @@ -428,19 +430,91 @@ async def _handle_completions(api: str, request: Request):
# Stream response from decoder
released_kv = False

# Record request info for recompute
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
Comment on lines +439 to +440
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code at line 439 (messages[0].get("content", "")) assumes that the messages list is not empty. If an API request is sent with an empty messages list (e.g., "messages": []), this will raise an IndexError, causing an unhandled exception. While the OpenAI API spec requires at least one message, it's best to code defensively.

A similar issue exists on lines 507-508. You should add checks to ensure messages is not empty before accessing its elements.

Here's a suggested way to fix this:

# At lines 438-439
            messages = req_data["messages"]
            origin_prompt = messages[0].get("content", "") if messages else ""

# And at lines 506-508
                            if chat_flag and messages:
                                messages[0][
                                    "content"] = origin_prompt + generated_token
Suggested change
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "") if messages else ""

if isinstance(origin_prompt, list):
origin_prompt = origin_prompt[0].get("text", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)

async def generate_stream():
nonlocal released_kv
generated_token = ""
released_kv = False
Comment on lines 449 to +451
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The released_kv variable is declared nonlocal and then reassigned, but it's never read within the generate_stream function. This indicates dead code. It seems that logic for releasing the KV cache, which is present in load_balance_proxy_server_example.py, is missing here. This could lead to a resource leak if the KV cache is not freed. Please either add the KV cache release logic or remove the unused released_kv variable and its related declarations.

retry_count = 0
retry = True
completion_tokens = 0
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
yield chunk
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: ") :]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue

choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = delta.get("content") or message.get("content") or choice.get("text") or ""
generated_token += content

stop_reason = choice.get("stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (
(completion_tokens + 1)
if stream_flag
else (completion_tokens + usage.get("completion_tokens"))
)
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0]["content"] = origin_prompt + generated_token
else:
req_data["prompt"] = origin_prompt + generated_token
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"]["content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} "
Copy link
Copy Markdown

@winson-00178005 winson-00178005 Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以把request_id/retry_count加入日志中,方便问题定位

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing the code. The number of retries will be printed in stream_service_response_with_retry, and the request_id will be printed on the next line.

Expand All @@ -451,7 +525,10 @@ async def generate_stream():
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果抛出异常的时候,是不是也要调用这个来释放?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running to this line of the function signifies the end of stream return, releasing the key-value cache records of node D.


return StreamingResponse(generate_stream(), media_type="application/json")
if stream_flag:
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback

Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/core/recompute_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def update_from_output(
EngineCoreOutput(
request_id=req_info.request_id,
finish_reason=FinishReason.STOP,
new_token_ids=[req_info.output_token_ids[-1]],
new_token_ids=[],
stop_reason="recomputed",
))

Expand Down