@@ -1197,16 +1197,16 @@ def _prepare_tp_inputs(
11971197 new_tokens_lens_device = new_tensors_device .new_tokens_lens # [batch]
11981198 next_draft_tokens_device = new_tensors_device .next_draft_tokens # [batch, draft_len]
11991199
1200- # Requests with draft tokens are treated like extend requests. Dummy extend requests should be
1201- # at the end of extend_requests.
1200+ # Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
1201+ # requests should be at the end of extend_requests.
12021202 extend_requests = []
1203- extend_dummy_requests = []
1203+ extend_cuda_graph_dummy_requests = []
12041204 generation_requests = []
12051205 for request in scheduled_requests .generation_requests :
12061206 if len (request .py_draft_tokens
12071207 ) > 0 or next_draft_tokens_device is not None :
1208- if request .is_dummy :
1209- extend_dummy_requests .append (request )
1208+ if request .is_cuda_graph_dummy :
1209+ extend_cuda_graph_dummy_requests .append (request )
12101210 else :
12111211 extend_requests .append (request )
12121212 else :
@@ -1219,8 +1219,8 @@ def _prepare_tp_inputs(
12191219 pin_memory = True )
12201220 mrope_config ['mrope_position_deltas' ].append (
12211221 mrope_position_deltas .to ('cuda' , non_blocking = True ))
1222- extend_requests += extend_dummy_requests
12231222
1223+ extend_requests = extend_cuda_graph_dummy_requests + extend_requests
12241224 if not self ._disable_overlap_scheduler and self .is_spec_decode :
12251225 spec_dec_mode = self .spec_config .spec_dec_mode
12261226 assert spec_dec_mode .support_overlap_scheduler (
@@ -1229,18 +1229,18 @@ def _prepare_tp_inputs(
12291229 # will contain previous batch incices of generation requests
12301230 previous_batch_indices = []
12311231 previous_pos_indices = []
1232+ request_ids_with_previous_batch = []
1233+ num_extend_reqs_wo_previous_batch = 0
12321234 for request in extend_requests :
12331235 # the request has no previous tensor:
12341236 # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
12351237 # (2) a dummy request; or
12361238 # (3) the first step in the generation server of disaggregated serving
12371239 if next_draft_tokens_device is None or request .is_dummy or request .py_batch_idx is None :
1238- # get token ids, including input token ids and draft token ids. For these dummy requests,
1239- # no need to copy the token ids.
1240- if not request .is_dummy :
1241- input_ids .append (request .get_last_tokens (0 ))
1242- input_ids .extend (request .py_draft_tokens )
1243- draft_tokens .extend (request .py_draft_tokens )
1240+ # get token ids, including input token ids and draft token ids
1241+ input_ids .append (request .get_last_tokens (0 ))
1242+ input_ids .extend (request .py_draft_tokens )
1243+ draft_tokens .extend (request .py_draft_tokens )
12441244 # get other ids and lengths
12451245 num_draft_tokens = len (request .py_draft_tokens )
12461246 past_seen_token_num = request .max_beam_num_tokens - 1
@@ -1268,6 +1268,7 @@ def _prepare_tp_inputs(
12681268 # update batch index
12691269 request .py_batch_idx = batch_idx
12701270 batch_idx += 1
1271+ num_extend_reqs_wo_previous_batch += 1
12711272 else :
12721273 # update batch index
12731274 previous_batch_idx = request .py_batch_idx
@@ -1294,7 +1295,10 @@ def _prepare_tp_inputs(
12941295 num_cached_tokens_per_seq .append (past_seen_token_num +
12951296 self .max_draft_len + 1 )
12961297 prompt_lengths .append (request .py_prompt_len )
1297- request_ids .append (request .py_request_id )
1298+ request_ids_with_previous_batch .append (request .py_request_id )
1299+
1300+ # move requests with previous batch to the end of the list
1301+ request_ids .extend (request_ids_with_previous_batch )
12981302
12991303 sequence_lengths .extend ([1 ] * len (generation_requests ))
13001304 gather_ids .extend (
@@ -1329,7 +1333,6 @@ def _prepare_tp_inputs(
13291333 num_tokens = len (input_ids )
13301334 num_draft_tokens = len (draft_tokens )
13311335 previous_batchs = len (previous_batch_indices )
1332- num_requests = len (request_ids )
13331336 total_num_tokens = len (position_ids )
13341337 assert total_num_tokens <= self .max_num_tokens , (
13351338 "total_num_tokens should be less than or equal to max_num_tokens" )
@@ -1371,27 +1374,31 @@ def _prepare_tp_inputs(
13711374 non_blocking = True )
13721375 # prepare data for the preprocess inputs
13731376 kv_len_offsets_device = new_tokens_lens_device - self .max_draft_len - 1
1377+ pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1378+ 1 + self .max_draft_len )
1379+ pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
1380+ pre_batch_start_idx = num_extend_reqs_wo_previous_batch
1381+ pre_batch_end_idx = pre_batch_start_idx + previous_batchs
13741382 previous_pos_indices = torch .tensor (previous_pos_indices ,
13751383 dtype = torch .int ,
13761384 pin_memory = True )
1377- self .previous_pos_indices_cuda [0 :previous_batch_tokens ].copy_ (
1378- previous_pos_indices , non_blocking = True )
1385+ self .previous_pos_indices_cuda [
1386+ pre_tokens_start_idx :pre_tokens_end_idx ].copy_ (
1387+ previous_pos_indices , non_blocking = True )
13791388 self .previous_pos_id_offsets_cuda [
1380- 0 : previous_batch_tokens ].copy_ (
1389+ pre_tokens_start_idx : pre_tokens_end_idx ].copy_ (
13811390 new_tokens_lens_device [self .previous_pos_indices_cuda [
1382- 0 :previous_batch_tokens ]],
1391+ pre_tokens_start_idx :pre_tokens_end_idx ]],
1392+ non_blocking = True )
1393+ self .previous_kv_lens_offsets_cuda [
1394+ pre_batch_start_idx :pre_batch_end_idx ].copy_ (
1395+ kv_len_offsets_device [
1396+ self .previous_batch_indices_cuda [:previous_batchs ]],
13831397 non_blocking = True )
1384- self .previous_kv_lens_offsets_cuda [0 :previous_batchs ].copy_ (
1385- kv_len_offsets_device [
1386- self .previous_batch_indices_cuda [:previous_batchs ]],
1387- non_blocking = True )
13881398 # for the requests that do not have previous batch, set the previous_pos_id_offsets and
13891399 # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1390- self .previous_pos_id_offsets_cuda [
1391- previous_batch_tokens :num_requests *
1392- (1 + self .max_draft_len )] *= 0
1393- self .previous_kv_lens_offsets_cuda [
1394- previous_batchs :num_requests ] *= 0
1400+ self .previous_pos_id_offsets_cuda [:pre_tokens_start_idx ] *= 0
1401+ self .previous_kv_lens_offsets_cuda [:pre_batch_start_idx ] *= 0
13951402 else :
13961403 # change the data to zeros to skip the value changes in _preprocess_inputs
13971404 self .previous_pos_id_offsets_cuda *= 0
0 commit comments