@@ -473,10 +473,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
473473    finish_reasons : torch .Tensor 
474474    sequence_lengths : torch .Tensor 
475475    cum_log_probs : torch .Tensor  |  None  =  None 
476+     gathered_ids : torch .Tensor  |  None  =  None 
476477
477478
478479@dataclass (kw_only = True ) 
479480class  SampleStateTRTLLM (SampleState ):
481+     finalize_events : dict [str , CudaEvent ]
480482    host : SampleStateTensorsHostTRTLLM 
481483
482484
@@ -672,6 +674,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
672674            self .store ["decoder_state" ],
673675            self .store ["decoding_input" ][self .micro_batch_idx ])
674676
677+         finalize_events  =  {}
678+         gathered_ids  =  None 
679+         if  beam_width  >  1 :
680+             finished_sum_device  =  self .store ["decoder_state" ].finished_sum 
681+ 
682+             for  request  in  scheduled_requests .all_requests ():
683+                 if  request .is_context_init_state :
684+                     continue 
685+                 if  finished_sum_device [request .seq_slot ] ==  beam_width :
686+                     finalize_events [
687+                         request .request_id ] =  self ._finalize_request (
688+                             request , False )
689+                 elif  request .streaming :
690+                     finalize_events [
691+                         request .request_id ] =  self ._finalize_request (
692+                             request , True )
693+             gathered_ids  =  self .store ["decoder_state" ].gathered_ids .to (
694+                 'cpu' , non_blocking = True )
675695        new_output_tokens  =  self .store ["decoder_state" ].all_new_tokens .to (
676696            'cpu' , non_blocking = True )
677697        finished_sum  =  self .store ["decoder_state" ].finished_sum .to (
@@ -698,7 +718,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
698718                                            finish_reasons = finish_reasons ,
699719                                            sequence_lengths = sequence_lengths ,
700720                                            log_probs = log_probs ,
701-                                             cum_log_probs = cum_log_probs )
721+                                             cum_log_probs = cum_log_probs ,
722+                                             gathered_ids = gathered_ids )
702723
703724        sampler_event  =  torch .cuda .Event ()
704725        sampler_event .record ()
@@ -709,7 +730,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
709730        return  SampleStateTRTLLM (scheduled_requests = scheduled_requests ,
710731                                 device = device ,
711732                                 host = host ,
712-                                  sampler_event = sampler_event )
733+                                  sampler_event = sampler_event ,
734+                                  finalize_events = finalize_events )
713735
714736    @torch .inference_mode () 
715737    def  update_requests (self , state : SampleStateTRTLLM ):
@@ -797,7 +819,7 @@ def update_requests_multiple_beams_or_drafting(self,
797819        ) if  state .host .cum_log_probs  is  not None  else  None 
798820        log_probs_host  =  state .host .log_probs .tolist (
799821        ) if  state .host .log_probs  is  not None  else  None 
800-         finalize_events  =  {} 
822+         finalize_events  =  state . finalize_events 
801823
802824        reqs  =  [
803825            r  for  r  in  state .scheduled_requests .context_requests 
@@ -865,19 +887,9 @@ def update_requests_multiple_beams_or_drafting(self,
865887
866888            if  finished_sum_host [seq_slot ] ==  beam_width :
867889                request .state  =  LlmRequestState .GENERATION_COMPLETE 
868-                 if  beam_width  >  1 :
869-                     finalize_events [
870-                         request .request_id ] =  self ._finalize_request (
871-                             request , False )
872-             elif  request .streaming  and  beam_width  >  1 :
873-                 finalize_events [request .request_id ] =  self ._finalize_request (
874-                     request , True )
875-         # post process all requests if necessary 
876-         if  beam_width  >  1 :
877-             for  request  in  reqs :
878-                 if  request .request_id  in  finalize_events :
879-                     self ._post_process_request (
880-                         request , finalize_events [request .request_id ])
890+         for  request  in  reqs :
891+             if  request .request_id  in  finalize_events :
892+                 self ._post_process_request (request , state )
881893
882894    def  _finalize_request (self , request : LlmRequest , streaming : bool ):
883895        """ Finalizes the request. This is necessary for beam search. """ 
@@ -888,25 +900,24 @@ def _finalize_request(self, request: LlmRequest, streaming: bool):
888900        return  event 
889901
890902    def  _post_process_request (self , request : LlmRequest ,
891-                               finalize_event :  CudaEvent ):
903+                               state :  SampleStateTRTLLM ):
892904        """ Post Process the request. Updates the sequence according to the beam search results. 
893905        request: LlmRequest which shall be post processed 
894906        finalize_event: CudaEvent to wait for the finalize step to finish 
895907        """ 
896908        seq_slot  =  request .py_seq_slot 
897909        beam_width  =  request .sampling_config .beam_width 
898910        # synchronize on the finalize event before continuing the post processing. 
899-         finalize_event .synchronize ()
911+         # should be unnecessary, as already wait for the sampler event in update_requests 
912+         state .finalize_events [request .request_id ].synchronize ()
900913
901914        # Get these values again, as they might have changed during the finalize step 
902-         output_ids_host  =  self .store ["decoder_state" ].gathered_ids .to ('cpu' )
903-         sequence_lengths_host  =  self .store ["decoder_state" ].sequence_lengths .to (
904-             'cpu' )
915+         output_ids_host  =  state .host .gathered_ids 
916+         sequence_lengths_host  =  state .host .sequence_lengths 
905917
906918        if  request .py_return_log_probs :
907-             log_probs_host  =  self .store ["decoder_state" ].log_probs .to ('cpu' )
908-             cum_log_probs_host  =  self .store ["decoder_state" ].cum_log_probs .to (
909-                 'cpu' )
919+             log_probs_host  =  state .host .log_probs 
920+             cum_log_probs_host  =  state .host .cum_log_probs 
910921
911922        generated_tokens  =  [[0 ]] *  beam_width 
912923        log_probs  =  [[] for  _  in  range (beam_width )]
0 commit comments