@@ -49,19 +49,34 @@ def with_default(value: Any, default: Any) -> Any:
4949
5050
5151async def _handle_exception (e : Exception , func , * args , ** kwargs ):
52- # We don't log AbortErrors since these correspond to gRPC errors
53- # intentionally raised during handling of requests.
52+ context = kwargs .get ("context" , None ) or args [- 1 ]
53+ is_generate_fn = "generate" in func .__name__ .lower ()
54+
55+ # First just try to replicate the TGIS-style log messages
56+ # for generate_* rpcs
57+ if is_generate_fn :
58+ if isinstance (e , AbortError ):
59+ # For things that we've already aborted, the relevant error
60+ # string is already in the grpc context.
61+ error_message = context .details ()
62+ else :
63+ error_message = str (e )
64+ request = kwargs .get ("request" , None ) or args [- 2 ]
65+ logs .log_error (request = request ,
66+ exception_str = error_message ,
67+ logger = logger )
68+
69+ # AbortErrors likely correspond to things we've already explicitly handled,
70+ # So we only add special handling for other types of errors
5471 if not isinstance (e , AbortError ):
5572 if type (e ).__name__ == "torch.cuda.OutOfMemoryError" : #TODO check
56- context = kwargs .get ("context" , None ) or args [- 1 ]
5773 logger .exception ("%s caused GPU OOM error" , func .__name__ )
5874 service_metrics .count_request_failure (FailureReasonLabel .OOM )
5975 await context .abort (StatusCode .RESOURCE_EXHAUSTED , str (e ))
76+ elif is_generate_fn :
77+ service_metrics .count_request_failure (FailureReasonLabel .GENERATE )
6078 else :
61- if "generate" in func .__name__ .lower ():
62- service_metrics .count_request_failure (FailureReasonLabel .GENERATE )
63- else :
64- service_metrics .count_request_failure (FailureReasonLabel .UNKNOWN )
79+ service_metrics .count_request_failure (FailureReasonLabel .UNKNOWN )
6580 logger .exception ("%s failed" , func .__name__ )
6681 raise e
6782
@@ -172,14 +187,9 @@ async def Generate(self, request: BatchedGenerationRequest,
172187 response = self ._convert_input_details (res , resp_options ,
173188 sampling_params ,
174189 response )
175- if request_count == 1 :
176- kind_log = "Request"
177- else :
178- kind_log = f"Sub-request { i } from batch of { request_count } "
179-
180- self ._log_unary_response (request = request , response = response ,
181- start_time = start_time , engine_response = res ,
182- kind_log = kind_log )
190+ logs .log_response (request = request , response = response ,
191+ start_time = start_time , engine_metrics = res .metrics ,
192+ sub_request_num = i , logger = logger )
183193 responses [i ] = response
184194
185195 return BatchedGenerationResponse (responses = responses )
@@ -253,9 +263,11 @@ async def GenerateStream(
253263 return
254264 first_response .text = full_output
255265 first_response .generated_token_count = last_token_count
256- self ._log_streaming_response (request = request , response = first_response ,
257- start_time = start_time ,
258- engine_response = last_engine_response )
266+ logs .log_response (request = request , response = first_response ,
267+ start_time = start_time ,
268+ engine_metrics = last_engine_response .metrics
269+ if last_engine_response else None ,
270+ logger = logger )
259271
260272 def _convert_input_details (
261273 self , result : RequestOutput , resp_options : ResponseOptions ,
@@ -537,30 +549,6 @@ async def _validate_prompt_and_tokenize(
537549
538550 return input_ids , max_is_token_limit
539551
540- @staticmethod
541- def _log_unary_response (request : BatchedGenerationRequest ,
542- response : GenerationResponse ,
543- engine_response : RequestOutput ,
544- start_time : float , kind_log : str ):
545- logs .log_response (inputs = [r .text for r in request .requests ],
546- response = response , params = request .params ,
547- prefix_id = request .prefix_id ,
548- engine_response = engine_response ,
549- start_time = start_time , kind_log = kind_log ,
550- method_str = "generate" , logger = logger )
551-
552- @staticmethod
553- def _log_streaming_response (request : SingleGenerationRequest ,
554- response : GenerationResponse ,
555- engine_response : RequestOutput ,
556- start_time : float ):
557- logs .log_response (inputs = [request .request .text ], response = response ,
558- params = request .params , prefix_id = request .prefix_id ,
559- engine_response = engine_response ,
560- start_time = start_time , kind_log = "Streaming response" ,
561- method_str = "generate_stream" , logger = logger )
562-
563-
564552 @log_rpc_handler_errors
565553 async def Tokenize (self , request : BatchedTokenizeRequest ,
566554 context : ServicerContext ) -> BatchedTokenizeResponse :
0 commit comments