Skip to content

Commit 066041a

Browse files
authored
✨ log all errored requests (vllm-project#30)
This PR logs all errors during validation or generation for a request like TGIS does. Signed-off-by: Joe Runde <[email protected]>
1 parent f022464 commit 066041a

File tree

2 files changed

+115
-56
lines changed

2 files changed

+115
-56
lines changed

vllm/entrypoints/grpc/grpc_server.py

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,34 @@ def with_default(value: Any, default: Any) -> Any:
4949

5050

5151
async def _handle_exception(e: Exception, func, *args, **kwargs):
52-
# We don't log AbortErrors since these correspond to gRPC errors
53-
# intentionally raised during handling of requests.
52+
context = kwargs.get("context", None) or args[-1]
53+
is_generate_fn = "generate" in func.__name__.lower()
54+
55+
# First just try to replicate the TGIS-style log messages
56+
# for generate_* rpcs
57+
if is_generate_fn:
58+
if isinstance(e, AbortError):
59+
# For things that we've already aborted, the relevant error
60+
# string is already in the grpc context.
61+
error_message = context.details()
62+
else:
63+
error_message = str(e)
64+
request = kwargs.get("request", None) or args[-2]
65+
logs.log_error(request=request,
66+
exception_str=error_message,
67+
logger=logger)
68+
69+
# AbortErrors likely correspond to things we've already explicitly handled,
70+
# So we only add special handling for other types of errors
5471
if not isinstance(e, AbortError):
5572
if type(e).__name__ == "torch.cuda.OutOfMemoryError": #TODO check
56-
context = kwargs.get("context", None) or args[-1]
5773
logger.exception("%s caused GPU OOM error", func.__name__)
5874
service_metrics.count_request_failure(FailureReasonLabel.OOM)
5975
await context.abort(StatusCode.RESOURCE_EXHAUSTED, str(e))
76+
elif is_generate_fn:
77+
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
6078
else:
61-
if "generate" in func.__name__.lower():
62-
service_metrics.count_request_failure(FailureReasonLabel.GENERATE)
63-
else:
64-
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
79+
service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN)
6580
logger.exception("%s failed", func.__name__)
6681
raise e
6782

@@ -172,14 +187,9 @@ async def Generate(self, request: BatchedGenerationRequest,
172187
response = self._convert_input_details(res, resp_options,
173188
sampling_params,
174189
response)
175-
if request_count == 1:
176-
kind_log = "Request"
177-
else:
178-
kind_log = f"Sub-request {i} from batch of {request_count}"
179-
180-
self._log_unary_response(request=request, response=response,
181-
start_time=start_time, engine_response=res,
182-
kind_log=kind_log)
190+
logs.log_response(request=request, response=response,
191+
start_time=start_time, engine_metrics=res.metrics,
192+
sub_request_num=i, logger=logger)
183193
responses[i] = response
184194

185195
return BatchedGenerationResponse(responses=responses)
@@ -253,9 +263,11 @@ async def GenerateStream(
253263
return
254264
first_response.text = full_output
255265
first_response.generated_token_count = last_token_count
256-
self._log_streaming_response(request=request, response=first_response,
257-
start_time=start_time,
258-
engine_response=last_engine_response)
266+
logs.log_response(request=request, response=first_response,
267+
start_time=start_time,
268+
engine_metrics=last_engine_response.metrics
269+
if last_engine_response else None,
270+
logger=logger)
259271

260272
def _convert_input_details(
261273
self, result: RequestOutput, resp_options: ResponseOptions,
@@ -537,30 +549,6 @@ async def _validate_prompt_and_tokenize(
537549

538550
return input_ids, max_is_token_limit
539551

540-
@staticmethod
541-
def _log_unary_response(request: BatchedGenerationRequest,
542-
response: GenerationResponse,
543-
engine_response: RequestOutput,
544-
start_time: float, kind_log: str):
545-
logs.log_response(inputs=[r.text for r in request.requests],
546-
response=response, params=request.params,
547-
prefix_id=request.prefix_id,
548-
engine_response=engine_response,
549-
start_time=start_time, kind_log=kind_log,
550-
method_str="generate", logger=logger)
551-
552-
@staticmethod
553-
def _log_streaming_response(request: SingleGenerationRequest,
554-
response: GenerationResponse,
555-
engine_response: RequestOutput,
556-
start_time: float):
557-
logs.log_response(inputs=[request.request.text], response=response,
558-
params=request.params, prefix_id=request.prefix_id,
559-
engine_response=engine_response,
560-
start_time=start_time, kind_log="Streaming response",
561-
method_str="generate_stream", logger=logger)
562-
563-
564552
@log_rpc_handler_errors
565553
async def Tokenize(self, request: BatchedTokenizeRequest,
566554
context: ServicerContext) -> BatchedTokenizeResponse:

vllm/tgis_utils/logs.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,97 @@
11
"""Some methods for producing logs similar to TGIS"""
22
import logging
3-
from typing import List
3+
from typing import List, Optional, Union
44

55
from google.protobuf import text_format
66

7-
from vllm import RequestOutput
8-
from vllm.entrypoints.grpc.pb.generation_pb2 import (GenerationResponse,
9-
Parameters, StopReason)
7+
from vllm.entrypoints.grpc.pb.generation_pb2 import (BatchedGenerationRequest,
8+
GenerationResponse,
9+
Parameters,
10+
SingleGenerationRequest,
11+
StopReason)
12+
from vllm.sequence import RequestMetrics
1013

1114

12-
def log_response(inputs: List[str], params: Parameters, prefix_id: str,
13-
response: GenerationResponse, engine_response: RequestOutput,
14-
start_time: float, kind_log: str, method_str: str,
15-
logger: logging.Logger):
15+
def log_response(
16+
request: Union[BatchedGenerationRequest, SingleGenerationRequest],
17+
response: GenerationResponse,
18+
engine_metrics: Optional[RequestMetrics],
19+
start_time: float,
20+
logger: logging.Logger,
21+
sub_request_num: int = 0,
22+
):
23+
if isinstance(request, BatchedGenerationRequest):
24+
# unary case
25+
request_count = len(request.requests)
26+
if request_count == 1:
27+
kind_log = "Request"
28+
else:
29+
kind_log = (f"Sub-request {sub_request_num} from batch of "
30+
f"{request_count}")
31+
inputs = [r.text for r in request.requests]
32+
method_str = "generate"
33+
else:
34+
# streaming case
35+
inputs = [request.request.text]
36+
kind_log = "Streaming response"
37+
method_str = "generate_stream"
38+
39+
_log_response(
40+
inputs=inputs,
41+
response=response,
42+
params=request.params,
43+
prefix_id=request.prefix_id,
44+
engine_metrics=engine_metrics,
45+
start_time=start_time,
46+
kind_log=kind_log,
47+
method_str=method_str,
48+
logger=logger,
49+
)
50+
51+
52+
def log_error(request: Union[BatchedGenerationRequest,
53+
SingleGenerationRequest], exception_str: str,
54+
logger: logging.Logger):
55+
"""Logs errors similar to how the TGIS server does"""
56+
# NB: We don't actually log the `Exception` here to match the TGIS behavior
57+
# of just logging the simple string representation of the error
58+
param_str = text_format.MessageToString(request.params, as_one_line=True)
59+
prefix_id = request.prefix_id
60+
61+
if isinstance(request, BatchedGenerationRequest):
62+
method_str = "generate"
63+
inputs = [r.text for r in request.requests]
64+
else:
65+
method_str = "generate_stream"
66+
inputs = [request.request.text]
67+
68+
short_input = [_truncate(input_, 32) for input_ in inputs]
69+
input_chars = sum(len(input_) for input_ in inputs)
70+
71+
span_str = (f"{method_str}{{input={short_input} prefix_id={prefix_id} "
72+
f"input_chars=[{input_chars}] params={param_str}")
73+
74+
logger.error("%s: %s", span_str, exception_str)
75+
76+
77+
def _log_response(inputs: List[str], params: Parameters, prefix_id: str,
78+
response: GenerationResponse,
79+
engine_metrics: Optional[RequestMetrics], start_time: float,
80+
kind_log: str, method_str: str, logger: logging.Logger):
1681
"""Logs responses similar to how the TGIS server does"""
1782
# This time contains both request validation and tokenization
18-
tokenization_time = engine_response.metrics.arrival_time - start_time
19-
inference_time = (engine_response.metrics.last_token_time -
20-
engine_response.metrics.first_scheduled_time)
21-
queue_time = engine_response.metrics.time_in_queue
22-
time_per_token = _safe_div(inference_time, response.generated_token_count)
23-
total_time = engine_response.metrics.last_token_time - start_time
83+
if engine_metrics is not None:
84+
tokenization_time = engine_metrics.arrival_time - start_time
85+
inference_time = (engine_metrics.last_token_time -
86+
engine_metrics.first_scheduled_time)
87+
queue_time = engine_metrics.time_in_queue
88+
time_per_token = _safe_div(inference_time,
89+
response.generated_token_count)
90+
total_time = engine_metrics.last_token_time - start_time
91+
else:
92+
logger.warning("No engine metrics for request, cannot log timing info")
93+
tokenization_time = inference_time = queue_time = time_per_token =\
94+
total_time = 0
2495
output_len = len(response.text)
2596
short_output = _truncate(response.text, 32)
2697
short_input = [_truncate(input_, 32) for input_ in inputs]

0 commit comments

Comments
 (0)