@@ -1197,16 +1197,16 @@ def _prepare_tp_inputs(
11971197 new_tokens_lens_device = new_tensors_device .new_tokens_lens # [batch]
11981198 next_draft_tokens_device = new_tensors_device .next_draft_tokens # [batch, draft_len]
11991199
1200- # Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
1201- # requests should be at the end of extend_requests.
1200+ # Requests with draft tokens are treated like extend requests. Dummy extend requests should be
1201+ # at the end of extend_requests.
12021202 extend_requests = []
1203- extend_cuda_graph_dummy_requests = []
1203+ extend_dummy_requests = []
12041204 generation_requests = []
12051205 for request in scheduled_requests .generation_requests :
12061206 if len (request .py_draft_tokens
12071207 ) > 0 or next_draft_tokens_device is not None :
1208- if request .is_cuda_graph_dummy :
1209- extend_cuda_graph_dummy_requests .append (request )
1208+ if request .is_dummy :
1209+ extend_dummy_requests .append (request )
12101210 else :
12111211 extend_requests .append (request )
12121212 else :
@@ -1219,8 +1219,8 @@ def _prepare_tp_inputs(
12191219 pin_memory = True )
12201220 mrope_config ['mrope_position_deltas' ].append (
12211221 mrope_position_deltas .to ('cuda' , non_blocking = True ))
1222+ extend_requests += extend_dummy_requests
12221223
1223- extend_requests = extend_cuda_graph_dummy_requests + extend_requests
12241224 if not self ._disable_overlap_scheduler and self .is_spec_decode :
12251225 spec_dec_mode = self .spec_config .spec_dec_mode
12261226 assert spec_dec_mode .support_overlap_scheduler (
@@ -1229,18 +1229,18 @@ def _prepare_tp_inputs(
12291229 # will contain previous batch incices of generation requests
12301230 previous_batch_indices = []
12311231 previous_pos_indices = []
1232- request_ids_with_previous_batch = []
1233- num_extend_reqs_wo_previous_batch = 0
12341232 for request in extend_requests :
12351233 # the request has no previous tensor:
12361234 # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
12371235 # (2) a dummy request; or
12381236 # (3) the first step in the generation server of disaggregated serving
12391237 if next_draft_tokens_device is None or request .is_dummy or request .py_batch_idx is None :
1240- # get token ids, including input token ids and draft token ids
1241- input_ids .append (request .get_last_tokens (0 ))
1242- input_ids .extend (request .py_draft_tokens )
1243- draft_tokens .extend (request .py_draft_tokens )
1238+ # get token ids, including input token ids and draft token ids. For these dummy requests,
1239+ # no need to copy the token ids.
1240+ if not request .is_dummy :
1241+ input_ids .append (request .get_last_tokens (0 ))
1242+ input_ids .extend (request .py_draft_tokens )
1243+ draft_tokens .extend (request .py_draft_tokens )
12441244 # get other ids and lengths
12451245 num_draft_tokens = len (request .py_draft_tokens )
12461246 past_seen_token_num = request .max_beam_num_tokens - 1
@@ -1268,7 +1268,6 @@ def _prepare_tp_inputs(
12681268 # update batch index
12691269 request .py_batch_idx = batch_idx
12701270 batch_idx += 1
1271- num_extend_reqs_wo_previous_batch += 1
12721271 else :
12731272 # update batch index
12741273 previous_batch_idx = request .py_batch_idx
@@ -1295,10 +1294,7 @@ def _prepare_tp_inputs(
12951294 num_cached_tokens_per_seq .append (past_seen_token_num +
12961295 self .max_draft_len + 1 )
12971296 prompt_lengths .append (request .py_prompt_len )
1298- request_ids_with_previous_batch .append (request .py_request_id )
1299-
1300- # move requests with previous batch to the end of the list
1301- request_ids .extend (request_ids_with_previous_batch )
1297+ request_ids .append (request .py_request_id )
13021298
13031299 sequence_lengths .extend ([1 ] * len (generation_requests ))
13041300 gather_ids .extend (
@@ -1333,6 +1329,7 @@ def _prepare_tp_inputs(
13331329 num_tokens = len (input_ids )
13341330 num_draft_tokens = len (draft_tokens )
13351331 previous_batchs = len (previous_batch_indices )
1332+ num_requests = len (request_ids )
13361333 total_num_tokens = len (position_ids )
13371334 assert total_num_tokens <= self .max_num_tokens , (
13381335 "total_num_tokens should be less than or equal to max_num_tokens" )
@@ -1374,31 +1371,27 @@ def _prepare_tp_inputs(
13741371 non_blocking = True )
13751372 # prepare data for the preprocess inputs
13761373 kv_len_offsets_device = new_tokens_lens_device - self .max_draft_len - 1
1377- pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1378- 1 + self .max_draft_len )
1379- pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
1380- pre_batch_start_idx = num_extend_reqs_wo_previous_batch
1381- pre_batch_end_idx = pre_batch_start_idx + previous_batchs
13821374 previous_pos_indices = torch .tensor (previous_pos_indices ,
13831375 dtype = torch .int ,
13841376 pin_memory = True )
1385- self .previous_pos_indices_cuda [
1386- pre_tokens_start_idx :pre_tokens_end_idx ].copy_ (
1387- previous_pos_indices , non_blocking = True )
1377+ self .previous_pos_indices_cuda [0 :previous_batch_tokens ].copy_ (
1378+ previous_pos_indices , non_blocking = True )
13881379 self .previous_pos_id_offsets_cuda [
1389- pre_tokens_start_idx : pre_tokens_end_idx ].copy_ (
1380+ 0 : previous_batch_tokens ].copy_ (
13901381 new_tokens_lens_device [self .previous_pos_indices_cuda [
1391- pre_tokens_start_idx :pre_tokens_end_idx ]],
1392- non_blocking = True )
1393- self .previous_kv_lens_offsets_cuda [
1394- pre_batch_start_idx :pre_batch_end_idx ].copy_ (
1395- kv_len_offsets_device [
1396- self .previous_batch_indices_cuda [:previous_batchs ]],
1382+ 0 :previous_batch_tokens ]],
13971383 non_blocking = True )
1384+ self .previous_kv_lens_offsets_cuda [0 :previous_batchs ].copy_ (
1385+ kv_len_offsets_device [
1386+ self .previous_batch_indices_cuda [:previous_batchs ]],
1387+ non_blocking = True )
13981388 # for the requests that do not have previous batch, set the previous_pos_id_offsets and
13991389 # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
1400- self .previous_pos_id_offsets_cuda [:pre_tokens_start_idx ] *= 0
1401- self .previous_kv_lens_offsets_cuda [:pre_batch_start_idx ] *= 0
1390+ self .previous_pos_id_offsets_cuda [
1391+ previous_batch_tokens :num_requests *
1392+ (1 + self .max_draft_len )] *= 0
1393+ self .previous_kv_lens_offsets_cuda [
1394+ previous_batchs :num_requests ] *= 0
14021395 else :
14031396 # change the data to zeros to skip the value changes in _preprocess_inputs
14041397 self .previous_pos_id_offsets_cuda *= 0
0 commit comments