-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[P/D] layerwise connector support recompute scheduler #5900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1b89fd7
3e0d74a
bde95f9
5bfae63
df51879
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,9 +86,11 @@ | |
|
|
||
| import argparse | ||
| import asyncio | ||
| import copy | ||
| import functools | ||
| import heapq | ||
| import ipaddress | ||
| import json | ||
| import os | ||
| import sys | ||
| import uuid | ||
|
|
@@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request): | |
| request_length = len(req_body) | ||
| request_id = await proxy_state.next_req_id() | ||
| request_id_api = get_api_request_id(api, request_id) | ||
| proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api) | ||
| proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api) | ||
| req_data["kv_transfer_params"] = { | ||
| "do_remote_decode": False, | ||
| "do_remote_prefill": True, | ||
|
|
@@ -428,19 +430,91 @@ async def _handle_completions(api: str, request: Request): | |
| # Stream response from decoder | ||
| released_kv = False | ||
|
|
||
| # Record request info for recompute | ||
| stream_flag = bool(req_data.get("stream", False)) | ||
| chat_flag = "messages" in req_data | ||
| if "prompt" in req_data: | ||
| origin_prompt = req_data["prompt"] | ||
| elif chat_flag: | ||
| messages = req_data["messages"] | ||
| origin_prompt = messages[0].get("content", "") | ||
| if isinstance(origin_prompt, list): | ||
| origin_prompt = origin_prompt[0].get("text", "") | ||
| else: | ||
| origin_prompt = "" | ||
| # refer to vLLM sampling_params: max_token default value | ||
| origin_max_tokens = req_data.get("max_tokens", 16) | ||
|
|
||
| async def generate_stream(): | ||
| nonlocal released_kv | ||
| generated_token = "" | ||
| released_kv = False | ||
|
Comment on lines
449
to
+451
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| retry_count = 0 | ||
| retry = True | ||
| completion_tokens = 0 | ||
| # Only one await per chunk, minimal logic in loop | ||
| try: | ||
| async for chunk in stream_service_response_with_retry( | ||
| decoder.client, | ||
| api, | ||
| req_data, | ||
| request_id=request_id, | ||
| max_retries=global_args.max_retries, | ||
| base_delay=global_args.retry_delay, | ||
| ): | ||
| yield chunk | ||
| while retry: | ||
| retry = False | ||
| async for chunk in stream_service_response_with_retry( | ||
| decoder.client, | ||
| api, | ||
| req_data, | ||
| request_id=request_id, | ||
| max_retries=global_args.max_retries, | ||
| base_delay=global_args.retry_delay, | ||
| ): | ||
| try: | ||
| chunk_str = chunk.decode("utf-8").strip() | ||
| except UnicodeDecodeError: | ||
| logger.debug(f"Skipping chunk: {chunk}") | ||
| yield chunk | ||
| continue | ||
| if not chunk_str: | ||
| continue | ||
| if chunk_str.startswith("data: "): | ||
| chunk_str = chunk_str[len("data: ") :] | ||
| try: | ||
| chunk_json = json.loads(chunk_str) | ||
| except json.JSONDecodeError: | ||
| # if chunk is [done], skip it. | ||
| logger.debug(f"Skipping chunk: {chunk_str}") | ||
| yield chunk | ||
| continue | ||
| choices = chunk_json.get("choices", []) | ||
| if not choices: | ||
| yield chunk | ||
| continue | ||
|
|
||
| choice = choices[0] | ||
| delta = choice.get("delta") or {} | ||
| message = choice.get("message") or {} | ||
| content = delta.get("content") or message.get("content") or choice.get("text") or "" | ||
| generated_token += content | ||
|
|
||
| stop_reason = choice.get("stop_reason") | ||
| usage = chunk_json.get("usage", {}) | ||
| completion_tokens = ( | ||
| (completion_tokens + 1) | ||
| if stream_flag | ||
| else (completion_tokens + usage.get("completion_tokens")) | ||
| ) | ||
| if stop_reason == "recomputed": | ||
| retry = True | ||
| retry_count += 1 | ||
| if chat_flag: | ||
| messages[0]["content"] = origin_prompt + generated_token | ||
| else: | ||
| req_data["prompt"] = origin_prompt + generated_token | ||
| req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count | ||
| break | ||
| if retry_count > 0 and not stream_flag: | ||
| if chat_flag: | ||
| choice["message"]["content"] = generated_token | ||
| else: | ||
| choice["text"] = generated_token | ||
| chunk = json.dumps(chunk_json).encode("utf-8") | ||
| yield chunk | ||
| except Exception as e: | ||
| logger.error( | ||
| f"Error during streaming from decoder {decoder.url}: {str(e)} " | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以把request_id/retry_count加入日志中,方便问题定位
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for reviewing the code. The number of retries will be printed in stream_service_response_with_retry, and the request_id will be printed on the next line. |
||
|
|
@@ -451,7 +525,10 @@ async def generate_stream(): | |
| # After streaming done, release tokens | ||
| proxy_state.release_decoder(decoder_idx, decoder_score) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果抛出异常的时候,是不是也要调用这个来释放?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Running to this line of the function signifies the end of stream return, releasing the key-value cache records of node D. |
||
|
|
||
| return StreamingResponse(generate_stream(), media_type="application/json") | ||
| if stream_flag: | ||
| return StreamingResponse(generate_stream(), media_type="text/event-stream") | ||
| else: | ||
| return StreamingResponse(generate_stream(), media_type="application/json") | ||
| except Exception as e: | ||
| import traceback | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code at line 439 (
messages[0].get("content", "")) assumes that themessageslist is not empty. If an API request is sent with an emptymessageslist (e.g.,"messages": []), this will raise anIndexError, causing an unhandled exception. While the OpenAI API spec requires at least one message, it's best to code defensively.A similar issue exists on lines 507-508. You should add checks to ensure
messagesis not empty before accessing its elements.Here's a suggested way to fix this: