1515from  ..disaggregated_params  import  DisaggregatedParams 
1616from  ..llmapi .tracer  import  global_tracer 
1717from  ..llmapi .utils  import  AsyncQueue 
18+ from  ..metrics  import  MetricNames , MetricsCollector , RequestEventTiming 
1819from  ..sampling_params  import  LogprobParams , SamplingParams 
1920from  .utils  import  ErrorResponse , has_event_loop , is_llm_response 
2021
@@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):
5051
5152
5253class  ResponseWrapper :
53-     """Wrapper of runtime response with optional outputs computed post runtime. 
54+     """ 
55+     1. Wrapper of runtime response with optional outputs computed post runtime. 
56+     2. A workaround to pass around RequestPerfMetrics. 
5457    """ 
5558
5659    def  __init__ (self ,
5760                 response : Union ["PostprocWorker.Output" , tllm .Response ],
58-                  logprobs : Optional [LogProbsResult ] =  None ):
61+                  logprobs : Optional [LogProbsResult ] =  None ,
62+                  request_perf_metrics : Optional [dict [str , float ]] =  None ):
5963        self ._response  =  response 
6064        self .logprobs  =  logprobs 
65+         self .request_perf_metrics  =  request_perf_metrics 
6166
6267    @property  
6368    def  _is_llm_response (self ):
@@ -68,6 +73,14 @@ def __getattr__(self, name):
6873        response  =  object .__getattribute__ (self , '_response' )
6974        return  getattr (response , name )
7075
76+     def  __getstate__ (self ):
77+         return  (self ._response , self .logprobs , self .request_perf_metrics )
78+ 
79+     def  __setstate__ (self , state ):
80+         self ._response  =  state [0 ]
81+         self .logprobs  =  state [1 ]
82+         self .request_perf_metrics  =  state [2 ]
83+ 
7184
7285@dataclass (slots = True ) 
7386class  CompletionOutput :
@@ -146,6 +159,7 @@ def __init__(self,
146159        self .disaggregated_params  =  None 
147160        self .decoding_iter  =  0 
148161        self ._done  =  False 
162+         self .metrics_dict  =  {}
149163
150164        if  has_event_loop ():
151165            self .aqueue  =  AsyncQueue ()
@@ -201,7 +215,9 @@ def _handle_sequence(self,
201215                         finish_reasons ,
202216                         response_tensors ,
203217                         sequence_index ,
204-                          logprobs_result = None ):
218+                          logprobs_result = None ,
219+                          req_perf_metrics_dict : Optional [dict [str ,
220+                                                               float ]] =  None ):
205221        """ Handle a single sequence in the response. """ 
206222
207223        seq_idx  =  sequence_index 
@@ -271,14 +287,17 @@ def _handle_sequence(self,
271287            else :
272288                raise  ValueError (
273289                    f"Unknown finish reason: { finish_reasons [src_idx ]}  )
290+             self .record_stats (output , req_perf_metrics_dict )
274291
275292    @nvtx_range_debug ("handle_response" , 
276293                      color = "red" , 
277294                      category = "GenerationResultBase" ) 
278295    def  _handle_response (self ,
279296                         response : Union ["PostprocWorker.Output" , tllm .Response ,
280297                                         ResponseWrapper , ErrorResponse ]):
298+         req_perf_metrics_dict  =  None 
281299        if  isinstance (response , ResponseWrapper ):
300+             req_perf_metrics_dict  =  response .request_perf_metrics 
282301            logprobs_result  =  response .logprobs 
283302            response  =  response ._response 
284303        else :
@@ -291,6 +310,8 @@ def _handle_response(self,
291310                self ._outputs [0 ] =  response .res 
292311            else :
293312                self ._outputs [0 ]._postprocess_result  =  response .res 
313+             if  response .metrics :
314+                 self .metrics_dict  =  response .metrics 
294315
295316            if  response .error :
296317                if  self ._background_error_handler  is  not None  and  (
@@ -303,7 +324,8 @@ def _handle_response(self,
303324                    handler (response .error_msg )
304325
305326            response_result  =  response .result 
306-             if  hasattr (response_result , "_result" ):
327+             if  hasattr (response_result , "_result" ) and  isinstance (
328+                     response_result ._result , bytes ):
307329                response_result .deserialize ()
308330
309331            self ._done  =  response_result .is_final 
@@ -322,11 +344,12 @@ def _handle_response(self,
322344            if  self .sampling_params .use_beam_search :
323345                for  beam_idx , _  in  enumerate (response_result .output_token_ids ):
324346                    self ._handle_sequence (finish_reasons , response_result ,
325-                                           beam_idx , logprobs_result )
347+                                           beam_idx , logprobs_result ,
348+                                           req_perf_metrics_dict )
326349            else :
327350                self ._handle_sequence (finish_reasons , response_result ,
328351                                      response_result .sequence_index ,
329-                                       logprobs_result )
352+                                       logprobs_result ,  req_perf_metrics_dict )
330353
331354            if  response_result .context_logits  is  not None :
332355                self ._context_logits  =  response_result .context_logits 
@@ -342,6 +365,29 @@ def _handle_response(self,
342365        else :
343366            raise  ValueError (f"Unknown response type: { response }  )
344367
368+     def  record_stats (self ,
369+                      output : CompletionOutput ,
370+                      stats : Optional [dict [str , float ]] =  None ) ->  None :
371+         """Record the stats of the generation result. 
372+ 
373+         Args: 
374+             output (CompletionOutput): The output of the generation result. 
375+             stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None. 
376+         """ 
377+         if  not  stats :
378+             return 
379+         metrics_stats  =  {}
380+         if  output .finish_reason :
381+             metrics_stats .update ({
382+                 MetricsCollector .labelname_finish_reason :
383+                 output .finish_reason 
384+             })
385+         processed_metrics_stat  =  _process_req_perf_metrics (
386+             stats , len (output .token_ids ), self .sampling_params .n  >  1 )
387+         if  processed_metrics_stat :
388+             metrics_stats .update (processed_metrics_stat )
389+         self .metrics_dict  =  metrics_stats 
390+ 
345391
346392class  DetokenizedGenerationResultBase (GenerationResultBase ):
347393    ''' The base class for the generation result with detokenization support. ''' 
@@ -688,3 +734,30 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
688734
689735    return  LogProbsResult (prompt = prompt_logprobs ,
690736                          generation = generation_logprobs )
737+ 
738+ 
739+ def  _process_req_perf_metrics (
740+         req_perf_metrics_dict : Optional [dict [str , float ]],
741+         output_length : int ,
742+         is_multiple_response : bool  =  False ) ->  dict [MetricNames , float ]:
743+     stat  =  {}
744+     if  not  req_perf_metrics_dict :
745+         return  stat 
746+     ttft  =  req_perf_metrics_dict .get (RequestEventTiming .FIRST_TOKEN_TIME , 0 ) -  \
747+            req_perf_metrics_dict .get (RequestEventTiming .ARRIVAL_TIME , 0 )
748+     e2e  =  req_perf_metrics_dict .get (RequestEventTiming .LAST_TOKEN_TIME , 0 ) -  \
749+           req_perf_metrics_dict .get (RequestEventTiming .ARRIVAL_TIME , 0 )
750+     request_queue_time  =  req_perf_metrics_dict .get (RequestEventTiming .FIRST_SCHEDULED_TIME , 0 ) -  \
751+                          req_perf_metrics_dict .get (RequestEventTiming .ARRIVAL_TIME , 0 )
752+     stat  =  {
753+         MetricNames .TTFT : ttft ,
754+         MetricNames .E2E : e2e ,
755+         MetricNames .REQUEST_QUEUE_TIME : request_queue_time 
756+     }
757+     if  output_length  >  1  and  not  is_multiple_response :
758+         tpot  =  (req_perf_metrics_dict .get (
759+             RequestEventTiming .LAST_TOKEN_TIME , 0 ) -  req_perf_metrics_dict .get (
760+                 RequestEventTiming .FIRST_TOKEN_TIME , 0 )) /  (output_length  -  1 )
761+         stat .update ({MetricNames .TPOT : tpot })
762+     stat  =  dict (filter (lambda  item : item [1 ] >  0 , stat .items ()))
763+     return  stat 
0 commit comments