Skip to content

Commit

Permalink
Fix exception handler for proxy server (#2901)
Browse files Browse the repository at this point in the history
* Fix exception handle for proxy server

* fix

* fix exception handler

* Add more logs

* fix typo

* logger.warn -> logger.warning

* logger info

* add error log for exception

* better log

* rename varaibles & reset timeout to

* update docstring
  • Loading branch information
AllentDan authored Dec 26, 2024
1 parent 9565505 commit f62b544
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 39 deletions.
1 change: 1 addition & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def add_parser_proxy():
help='the strategy to dispatch requests to nodes')
ArgumentHelper.api_keys(parser)
ArgumentHelper.ssl(parser)
ArgumentHelper.log_level(parser)

@staticmethod
def gradio(args):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,5 +1921,5 @@ def best_match_model(query: str) -> Optional[str]:
for name, model in MODELS.module_dict.items():
if model.match(query):
return model.match(query)
logger.warn(f'Did not find a chat template matching {query}.')
logger.warning(f'Did not find a chat template matching {query}.')
return 'base'
4 changes: 2 additions & 2 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ async def generate(
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
if not gen_config.do_sample:
logger.warn(f'GenerationConfig: {gen_config}')
logger.warn(
logger.warning(f'GenerationConfig: {gen_config}')
logger.warning(
'Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/serve/proxy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import enum

LATENCY_DEEQUE_LEN = 15
API_TIMEOUT_LEN = 100
LATENCY_DEQUE_LEN = 15
API_READ_TIMEOUT = 100


class Strategy(enum.Enum):
Expand Down
102 changes: 72 additions & 30 deletions lmdeploy/serve/proxy/proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import copy
import json
import os
Expand All @@ -18,14 +19,15 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from requests.exceptions import RequestException

from lmdeploy.serve.openai.api_server import (check_api_key,
create_error_response)
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, CompletionRequest, ModelCard, ModelList,
ModelPermission)
from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN,
LATENCY_DEEQUE_LEN, ErrorCodes,
from lmdeploy.serve.proxy.constants import (API_READ_TIMEOUT,
LATENCY_DEQUE_LEN, ErrorCodes,
Strategy, err_msg)
from lmdeploy.utils import get_logger

Expand All @@ -36,7 +38,7 @@ class Status(BaseModel):
"""Status protocol consists of models' information."""
models: Optional[List[str]] = Field(default=[], examples=[[]])
unfinished: int = 0
latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN),
latency: Deque = Field(default=deque(maxlen=LATENCY_DEQUE_LEN),
examples=[[]])
speed: Optional[int] = Field(default=None, examples=[None])

Expand Down Expand Up @@ -87,6 +89,9 @@ def __init__(self,
with open(self.config_path, 'r') as config_file:
self.nodes = yaml.safe_load(config_file)['nodes']
for url, status in self.nodes.items():
latency = deque(status.get('latency', []),
maxlen=LATENCY_DEQUE_LEN)
status['latency'] = latency
status = Status(**status)
self.nodes[url] = status
self.heart_beat_thread = threading.Thread(target=heart_beat_controller,
Expand All @@ -99,7 +104,7 @@ def update_config_file(self):
nodes = copy.deepcopy(self.nodes)
for url, status in nodes.items():
nodes[url] = status.model_dump()
nodes[url]['latency'] = list(status.latency)
nodes[url]['latency'] = list(status.latency)[-LATENCY_DEQUE_LEN:]
with open(self.config_path, 'w') as config_file: # update cfg yml
yaml.dump(dict(nodes=nodes), config_file)

Expand Down Expand Up @@ -149,7 +154,8 @@ def remove_stale_nodes_by_expiration(self):
to_be_deleted.append(node_url)
for node_url in to_be_deleted:
self.remove(node_url)
logger.info(f'Removed node_url: {node_url}')
logger.info(f'Removed node_url: {node_url} '
'due to heart beat expiration')

@property
def model_list(self):
Expand Down Expand Up @@ -251,7 +257,7 @@ def handle_unavailable_model(self, model_name):
Args:
model_name (str): the model in the request.
"""
logger.info(f'no model name: {model_name}')
logger.warning(f'no model name: {model_name}')
ret = {
'error_code': ErrorCodes.MODEL_NOT_FOUND,
'text': err_msg[ErrorCodes.MODEL_NOT_FOUND],
Expand All @@ -260,51 +266,54 @@ def handle_unavailable_model(self, model_name):

def handle_api_timeout(self, node_url):
"""Handle the api time out."""
logger.info(f'api timeout: {node_url}')
logger.warning(f'api timeout: {node_url}')
ret = {
'error_code': ErrorCodes.API_TIMEOUT,
'error_code': ErrorCodes.API_TIMEOUT.value,
'text': err_msg[ErrorCodes.API_TIMEOUT],
}
return json.dumps(ret).encode() + b'\n'

def stream_generate(self, request: Dict, node_url: str, node_path: str):
def stream_generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a generator to handle the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
node_path (str): the node path. Such as `/v1/chat/completions`.
endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
response = requests.post(
node_url + node_path,
node_url + endpoint,
json=request,
stream=request['stream'],
timeout=API_TIMEOUT_LEN,
stream=True,
timeout=(5, API_READ_TIMEOUT),
)
for chunk in response.iter_lines(decode_unicode=False,
delimiter=b'\n'):
if chunk:
yield chunk + b'\n\n'
except requests.exceptions.RequestException as e: # noqa
except (Exception, GeneratorExit, RequestException) as e: # noqa
logger.error(f'catched an exception: {e}')
# exception happened, reduce unfinished num
yield self.handle_api_timeout(node_url)

async def generate(self, request: Dict, node_url: str, node_path: str):
async def generate(self, request: Dict, node_url: str, endpoint: str):
"""Return a the response of the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
node_path (str): the node path. Such as `/v1/chat/completions`.
endpoint (str): the endpoint. Such as `/v1/chat/completions`.
"""
try:
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(node_url + node_path,
response = await client.post(node_url + endpoint,
json=request,
timeout=API_TIMEOUT_LEN)
timeout=API_READ_TIMEOUT)
return response.text
except requests.exceptions.RequestException as e: # noqa
except (Exception, GeneratorExit, RequestException, asyncio.CancelledError) as e: # noqa # yapf: disable
logger.error(f'catched an exception: {e}')
return self.handle_api_timeout(node_url)

def pre_call(self, node_url):
Expand Down Expand Up @@ -381,7 +390,11 @@ def add_node(node: Node, raw_request: Request = None):
RPM or other metric. All the values of nodes should be the same metric.
"""
try:
node_manager.add(node.url, node.status)
res = node_manager.add(node.url, node.status)
if res is not None:
logger.error(f'add node {node.url} failed, {res}')
return res
logger.info(f'add node {node.url} successfully')
return 'Added successfully'
except: # noqa
return 'Failed to add, please check the input url.'
Expand All @@ -392,8 +405,10 @@ def remove_node(node_url: str):
"""Show available models."""
try:
node_manager.remove(node_url)
logger.info(f'delete node {node_url} successfully')
return 'Deleted successfully'
except: # noqa
logger.error(f'delete node {node_url} failed.')
return 'Failed to delete, please check the input url.'


Expand All @@ -407,28 +422,50 @@ async def chat_completions_v1(request: ChatCompletionRequest,
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format. A example
for chat history is `[{"role": "user", "content":"knock knock"}]`.
- messages: string prompt or chat history in OpenAI format. Chat history
example: `[{"role": "user", "content": "hi"}]`.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
message. **Only support one here**.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- max_tokens (int | None): output token nums. Default to None.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
- response_format (Dict | None): Only pytorch backend support formatting
response. Examples: `{"type": "json_schema", "json_schema": {"name":
"test","schema": {"properties": {"name": {"type": "string"}},
"required": ["name"], "type": "object"}}}`
or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}`
- logit_bias (Dict): Bias to logits. Only supported in pytorch engine.
- tools (List): A list of tools the model may call. Currently, only
internlm2 functions are supported as a tool. Use this to specify a
list of functions for which the model can generate JSON inputs.
- tool_choice (str | object): Controls which (if any) tool is called by
the model. `none` means the model will not call any tool and instead
generates a message. Specifying a particular tool via {"type":
"function", "function": {"name": "my_function"}} forces the model to
call that tool. `auto` or `required` will put all the tools information
to the model.
Additional arguments supported by LMDeploy:
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
- min_new_tokens (int): To generate at least numbers of tokens.
- min_p (float): Minimum token probability, which will be scaled by the
probability of the most likely token. It must be a value between
0 and 1. Typical values are in the 0.01-0.2 range, comparably
selective as setting `top_p` in the 0.99-0.8 range (use the
opposite of normal `top_p` values)
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
Expand All @@ -439,6 +476,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
if not node_url:
return node_manager.handle_unavailable_model(request.model)

logger.info(f'A request is dispatched to {node_url}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
Expand All @@ -465,13 +503,13 @@ async def completions_v1(request: CompletionRequest,
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- max_tokens (int): output token nums. Default to 16.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
message. **Only support one here**.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
Expand All @@ -481,7 +519,8 @@ async def completions_v1(request: CompletionRequest,
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
Expand All @@ -497,6 +536,7 @@ async def completions_v1(request: CompletionRequest,
if not node_url:
return node_manager.handle_unavailable_model(request.model)

logger.info(f'A request is dispatched to {node_url}')
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
Expand All @@ -517,6 +557,7 @@ def proxy(server_name: str = '0.0.0.0',
'min_observed_latency'] = 'min_expected_latency',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
log_level: str = 'INFO',
**kwargs):
"""To launch the proxy server.
Expand All @@ -540,6 +581,7 @@ def proxy(server_name: str = '0.0.0.0',
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
ssl_certfile = os.environ['SSL_CERTFILE']
logger.setLevel(log_level)
uvicorn.run(app=app,
host=server_name,
port=server_port,
Expand Down
9 changes: 5 additions & 4 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,17 @@ def get_output_model_registered_name_and_config(model_path: str,
] else 'float16'
elif dtype in ['float16', 'bfloat16']:
if weight_type == 'int4':
logger.warn(f'The model {model_path} is a quantized model, so the '
f'specified data type {dtype} is ignored')
logger.warning(
f'The model {model_path} is a quantized model, so the '
f'specified data type {dtype} is ignored')
else:
weight_type = dtype
else:
assert 0, f'unsupported specified data type {dtype}'

if weight_type == 'bfloat16' and not is_bf16_supported():
logger.warn('data type fallback to float16 since '
'torch.cuda.is_bf16_supported is False')
logger.warning('data type fallback to float16 since '
'torch.cuda.is_bf16_supported is False')
weight_type = 'float16'
config.model_config.model_arch = model_arch
config.model_config.weight_type = weight_type
Expand Down

0 comments on commit f62b544

Please sign in to comment.