@@ -122,6 +122,7 @@ class BatchState:
122122@dataclasses .dataclass  
123123class  BatchStatePP (BatchState ):
124124    microbatch_id : int  =  - 1 
125+     scheduled_ctx_reqs : list [LlmRequest ] =  None 
125126
126127
127128class  PyExecutor :
@@ -656,6 +657,9 @@ def _executor_loop_pp(self):
656657                if  self .should_stop_processing :
657658                    break 
658659
660+                 if  self .kv_cache_transceiver :
661+                     self ._check_disagg_gen_transfer_status ()
662+ 
659663                if  self .enable_iter_perf_stats :
660664                    iter_stats  =  self ._get_init_iter_stats (
661665                        len (new_requests ),
@@ -664,9 +668,28 @@ def _executor_loop_pp(self):
664668
665669                self ._pad_attention_dp_dummy_request ()
666670
667-                 scheduled_batch , _ , _  =  self ._schedule ()
671+                 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs  =  self ._schedule (
672+                 )
673+ 
674+                 if  self .kv_cache_transceiver :
675+ 
676+                     # For requests that are fitting disagg gen init, also prepare resources for KV cache manager 
677+                     self ._prepare_disagg_gen_init (
678+                         fitting_disagg_gen_init_requests )
679+ 
680+                     if  num_fitting_reqs  ==  0  and  not  fitting_disagg_gen_init_requests :
681+                         logger .warning (
682+                             "num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache" 
683+                         )
684+                         self .kv_cache_transceiver .check_context_transfer_status (
685+                             1 )
686+                 else :
687+                     assert  scheduled_batch .batch_size  >  0 , (
688+                         "fail to schedule any pending request, " 
689+                         "probably run out of resource." )
668690
669691                self .num_scheduled_requests  =  scheduled_batch .batch_size 
692+ 
670693                logger .debug (
671694                    f'has { len (self .active_requests )}  
672695                    f'scheduled { len (scheduled_batch .context_requests )}  
@@ -688,8 +711,28 @@ def _executor_loop_pp(self):
688711                    self .micro_batches [microbatch_id ] =  None 
689712                else :
690713                    self ._add_inflight_ids (scheduled_batch )
714+ 
715+                     if  self .kv_cache_transceiver :
716+                         # For generation requests which have completed KV cache transfer 
717+                         self ._prepare_disagg_gen_transmission_complete (
718+                             scheduled_batch )
719+ 
691720                    self .resource_manager .prepare_resources (scheduled_batch )
692721
722+                     # The generation requests that are do not have batch_idx, 
723+                     # needs to be in front of the batch due to the assumptions 
724+                     # made in model_engine.py::_forward_step. This is only important 
725+                     # for disaggregated serving. For non-disaggregated serving, 
726+                     # the generation requests always have batch_idx. 
727+                     scheduled_batch .generation_requests  =  sorted (  # stable sort 
728+                         scheduled_batch .generation_requests ,
729+                         key = lambda  req : int (req .py_batch_idx  is  not None ),
730+                     )
731+ 
732+                     if  self .kv_cache_transceiver :
733+                         # Return the first token to the client 
734+                         self ._handle_first_token_response (scheduled_batch )
735+ 
693736                    # Stage 1: Async forward (all ranks) and decoding pass (last rank only) 
694737                    if  not  self .dist .is_last_pp_rank :
695738                        sample_state  =  self ._forward_step_inter_pp (
@@ -720,6 +763,7 @@ def _executor_loop_pp(self):
720763                        iter_start_time = iter_start_time ,
721764                        iter_stats = iter_stats ,
722765                        microbatch_id = microbatch_id ,
766+                         scheduled_ctx_reqs = scheduled_batch .context_requests ,
723767                    )
724768
725769                    self .micro_batches [microbatch_id ] =  batch_state 
@@ -784,6 +828,12 @@ def _executor_loop_pp(self):
784828                if  previous_batch  is  not None :
785829                    with  torch .cuda .nvtx .range ("_handle_previous_batch_pp" ):
786830                        self ._update_requests (previous_batch .sample_state )
831+ 
832+                         if  self .kv_cache_transceiver  and  previous_batch .scheduled_ctx_reqs :
833+                             ctx_transmission_reqs  =  self ._send_disagg_ctx_cache (
834+                                 previous_batch .scheduled_ctx_reqs 
835+                             ) if  self .kv_cache_transceiver  else  []
836+ 
787837                        self ._handle_canceled_requests ()
788838                        finished_requests  =  self ._handle_responses ()
789839                        previous_scheduled_batch  =  previous_batch .sample_state .scheduled_requests 
@@ -792,6 +842,9 @@ def _executor_loop_pp(self):
792842                        self ._remove_inflight_ids (previous_scheduled_batch )
793843                    self .micro_batches [prev_microbatch_id ] =  None 
794844
845+                 if  self .kv_cache_transceiver  and  self .ctx_in_transmission_requests :
846+                     self ._terminate_ctx_finished_requests ()
847+ 
795848                # march forward in microbatch slots 
796849                microbatch_id  =  (microbatch_id  +  1 ) %  self .num_micro_batches 
797850
0 commit comments