44import gc
55import glob
66import inspect
7- import itertools
87import math
98import multiprocessing
109import os
2120import torch ._dynamo .config
2221
2322import tensorrt_llm .bindings .internal .userbuffers as ub
23+ from tensorrt_llm ._torch .pyexecutor .llm_request import LlmRequest
2424from tensorrt_llm ._torch .pyexecutor .sampler import SampleStateTensors
2525from tensorrt_llm ._torch .speculative .mtp import SampleStateTensorsMTP
2626from tensorrt_llm ._utils import (is_trace_enabled , local_mpi_rank ,
@@ -319,6 +319,7 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
319319
320320
321321class PyTorchModelEngine (ModelEngine ):
322+ BEAM_WIDTH = 1
322323
323324 def __init__ (
324325 self ,
@@ -659,13 +660,12 @@ def get_autotune_warmup_request():
659660 return result
660661
661662 @contextlib .contextmanager
662- def release_batch (result ):
663+ def release_batch (result : ScheduledRequests | None ):
663664 try :
664665 yield result
665666 finally :
666667 if result is not None :
667- for req in itertools .chain (result .generation_requests ,
668- result .context_requests ):
668+ for req in result .all_requests ():
669669 kv_cache_manager .free_resources (req )
670670 if spec_resource_manager is not None :
671671 spec_resource_manager .free_resources (req )
@@ -1153,7 +1153,15 @@ def _prepare_tp_inputs(
11531153 draft_lens = []
11541154 mrope_config = defaultdict (list )
11551155
1156- batch_idx = 0
1156+ mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially
1157+
1158+ def py_batch_idx (request : LlmRequest ) -> int :
1159+ if not self .without_logits :
1160+ return request .seq_slot
1161+ nonlocal mtp_batch_idx
1162+ batch_idx = mtp_batch_idx
1163+ mtp_batch_idx += 1
1164+ return batch_idx
11571165
11581166 for request in scheduled_requests .context_requests :
11591167 request_ids .append (request .py_request_id )
@@ -1184,10 +1192,9 @@ def _prepare_tp_inputs(
11841192 ) if mrope_rotary_cos_sin .device == 'cpu' else mrope_rotary_cos_sin
11851193 mrope_config ['mrope_rotary_cos_sin' ].append (
11861194 mrope_rotary_cos_sin .to ('cuda' , non_blocking = True ))
1187- request .py_batch_idx = batch_idx
1188- batch_idx += 1
1195+ request .py_batch_idx = py_batch_idx (request )
11891196
1190- num_ctx_requests = batch_idx
1197+ num_ctx_requests = len ( scheduled_requests . context_requests )
11911198 num_ctx_tokens = len (input_ids )
11921199 new_tokens_device , new_tokens_lens_device , next_draft_tokens_device = None , None , None
11931200 if new_tensors_device is not None :
@@ -1227,7 +1234,7 @@ def _prepare_tp_inputs(
12271234 assert spec_dec_mode .support_overlap_scheduler (
12281235 ), f"{ self .spec_config .spec_dec_name } does not support overlap scheduler"
12291236
1230- # will contain previous batch incices of generation requests
1237+ # will contain previous batch indices of generation requests
12311238 previous_batch_indices = []
12321239 previous_pos_indices = []
12331240 for request in extend_requests :
@@ -1272,8 +1279,7 @@ def _prepare_tp_inputs(
12721279 else :
12731280 # update batch index
12741281 previous_batch_idx = request .py_batch_idx
1275- request .py_batch_idx = batch_idx
1276- batch_idx += 1
1282+ request .py_batch_idx = py_batch_idx (request )
12771283 # inputs
12781284 # overlap scheduler can only support the speculative decoding
12791285 # methods with a fixed number of draft tokens
@@ -1324,8 +1330,18 @@ def _prepare_tp_inputs(
13241330 prompt_lengths .append (request .py_prompt_len )
13251331 draft_lens .append (0 )
13261332
1327- request .py_batch_idx = batch_idx
1328- batch_idx += 1
1333+ request .py_batch_idx = py_batch_idx (request )
1334+
1335+ previous_batch_len = len (previous_batch_indices )
1336+
1337+ def previous_seq_slots_device ():
1338+ previous_batch_indices_host = torch .tensor (previous_batch_indices ,
1339+ dtype = torch .int ,
1340+ pin_memory = True )
1341+ previous_slots = self .previous_batch_indices_cuda [:
1342+ previous_batch_len ]
1343+ previous_slots .copy_ (previous_batch_indices_host , non_blocking = True )
1344+ return previous_slots
13291345
13301346 num_tokens = len (input_ids )
13311347 num_draft_tokens = len (draft_tokens )
@@ -1347,29 +1363,22 @@ def _prepare_tp_inputs(
13471363 self .draft_tokens_cuda [:len (draft_tokens )].copy_ (draft_tokens ,
13481364 non_blocking = True )
13491365 if next_draft_tokens_device is not None :
1350- if len (previous_batch_indices ) > 0 :
1351- previous_batch_indices = torch .tensor (previous_batch_indices ,
1352- dtype = torch .int ,
1353- pin_memory = True )
1354- self .previous_batch_indices_cuda [:previous_batchs ].copy_ (
1355- previous_batch_indices , non_blocking = True )
1366+ if previous_batch_len > 0 :
1367+ previous_slots = previous_seq_slots_device ()
13561368 # previous input ids
1357- previous_batch_tokens = previous_batchs * (1 +
1358- self .max_draft_len )
1359- self .input_ids_cuda [
1360- num_tokens :num_tokens +
1361- previous_batch_tokens ].copy_ (new_tokens_device [
1362- self .previous_batch_indices_cuda [:previous_batchs ], :].
1363- flatten (),
1364- non_blocking = True )
1369+ previous_batch_tokens = previous_batch_len * (
1370+ 1 + self .max_draft_len )
1371+ new_tokens = new_tokens_device [previous_slots , :].flatten ()
1372+ self .input_ids_cuda [num_tokens :num_tokens +
1373+ previous_batch_tokens ].copy_ (
1374+ new_tokens , non_blocking = True )
13651375 # previous draft tokens
1366- previous_batch_draft_tokens = previous_batchs * self .max_draft_len
1367- self .draft_tokens_cuda [
1368- num_draft_tokens :num_draft_tokens +
1369- previous_batch_draft_tokens ].copy_ (next_draft_tokens_device [
1370- self .previous_batch_indices_cuda [:previous_batchs ], :].
1371- flatten (),
1372- non_blocking = True )
1376+ previous_batch_draft_tokens = previous_batch_len * self .max_draft_len
1377+ self .draft_tokens_cuda [num_draft_tokens :num_draft_tokens +
1378+ previous_batch_draft_tokens ].copy_ (
1379+ next_draft_tokens_device [
1380+ previous_slots , :].flatten (),
1381+ non_blocking = True )
13731382 # prepare data for the preprocess inputs
13741383 kv_len_offsets_device = new_tokens_lens_device - self .max_draft_len - 1
13751384 previous_pos_indices = torch .tensor (previous_pos_indices ,
@@ -1398,16 +1407,13 @@ def _prepare_tp_inputs(
13981407 self .previous_pos_id_offsets_cuda *= 0
13991408 self .previous_kv_lens_offsets_cuda *= 0
14001409 elif new_tokens_device is not None :
1401- previous_batch_tokens = len (previous_batch_indices )
1402- previous_batch_indices = torch .tensor (previous_batch_indices ,
1403- dtype = torch .int ,
1404- pin_memory = True )
1405- self .previous_batch_indices_cuda [:previous_batch_tokens ].copy_ (
1406- previous_batch_indices , non_blocking = True )
1407- self .input_ids_cuda [num_tokens :num_tokens + previous_batchs ].copy_ (
1408- new_tokens_device [
1409- self .previous_batch_indices_cuda [:previous_batchs ]],
1410- non_blocking = True )
1410+ seq_slots_device = previous_seq_slots_device ()
1411+ max_draft_len = max (draft_lens )
1412+ new_tokens = new_tokens_device [:max_draft_len + 1 ,
1413+ seq_slots_device , :self .BEAM_WIDTH ]
1414+ self .input_ids_cuda [num_tokens :num_tokens +
1415+ previous_batch_len ].copy_ (new_tokens .flatten (),
1416+ non_blocking = True )
14111417
14121418 position_ids = torch .tensor (position_ids ,
14131419 dtype = torch .int ,
@@ -1645,7 +1651,6 @@ def _prepare_star_attention_inputs(self,
16451651 # for star attention, we need customized block ids
16461652 block_ids_per_seq = []
16471653 num_cached_tokens_per_seq = []
1648- output_token_idx = 0
16491654 for request in scheduled_requests .context_requests :
16501655 request_ids .append (request .py_request_id )
16511656 prompt_lengths .append (request .py_prompt_len )
@@ -1702,8 +1707,6 @@ def _prepare_star_attention_inputs(self,
17021707 sequence_lengths .append (len (input_id ))
17031708 block_ids_per_seq .extend ([all_cache_indices ])
17041709 num_cached_tokens_per_seq .append (past_seen_token_num )
1705- request .output_token_idx = output_token_idx
1706- output_token_idx += 1
17071710 num_contexts = len (sequence_lengths )
17081711 for request in scheduled_requests .context_requests :
17091712 ctx_iter = request .ctx_iters
@@ -1743,8 +1746,6 @@ def _prepare_star_attention_inputs(self,
17431746 sequence_lengths .append (len (input_id ))
17441747 block_ids_per_seq .extend ([all_cache_indices ])
17451748 num_cached_tokens_per_seq .append (past_seen_token_num )
1746- request .output_token_idx = output_token_idx
1747- output_token_idx += 1
17481749 num_queries = len (sequence_lengths ) - num_contexts
17491750
17501751 # Requests with draft tokens are treated like extend requests.
@@ -1802,8 +1803,6 @@ def _prepare_star_attention_inputs(self,
18021803 position_ids .append (last_query_pos_id + request .gen_iters + 1 )
18031804 block_ids_per_seq .extend ([all_cache_indices ])
18041805 num_cached_tokens_per_seq .append (past_seen_token_num )
1805- request .output_token_idx = output_token_idx
1806- output_token_idx += 1
18071806
18081807 num_tokens = len (input_ids )
18091808 assert num_tokens <= self .max_num_tokens , (
@@ -2171,9 +2170,7 @@ def _execute_logit_post_processors(self,
21712170 num_ctx_req = len (scheduled_requests .context_requests )
21722171 logits_tensor = outputs ["logits" ]
21732172
2174- for idx , request in enumerate (
2175- itertools .chain (scheduled_requests .context_requests ,
2176- scheduled_requests .generation_requests )):
2173+ for idx , request in enumerate (scheduled_requests .all_requests ()):
21772174 logits_processors = getattr (request , "py_logits_post_processors" ,
21782175 None )
21792176 if not logits_processors :
0 commit comments