Skip to content

Commit f14d1ae

Browse files
netanel-haberdominicshanshan
authored andcommitted
Reintroduce with perf fixes: feature: unify new_tokens format sample state to trtllm samper tokens format (NVIDIA#5513)
58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
1 parent e44bb56 commit f14d1ae

File tree

12 files changed

+427
-422
lines changed

12 files changed

+427
-422
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch._prims_common import DeviceLikeType
77

8+
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
89
from tensorrt_llm._utils import nvtx_range
910

1011
from ...._utils import mpi_rank, mpi_world_size
@@ -256,6 +257,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
256257
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
257258
ad_config: LlmArgs = executor_config.pytorch_backend_config
258259

260+
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
259261
# some derivative properties
260262
max_draft_tokens = (
261263
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
@@ -272,7 +274,13 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
272274
max_seq_len=ad_config.max_seq_len,
273275
max_batch_size=ad_config.max_batch_size,
274276
)
275-
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
277+
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
278+
resource_manager = ResourceManager(
279+
{
280+
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
281+
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
282+
}
283+
)
276284
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
277285

278286
# scheduling
@@ -287,10 +295,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
287295
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
288296
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
289297
# correctly for models as needed.
290-
sampler = TorchSampler(
298+
sampler_args = TorchSampler.Args(
291299
max_seq_len=ad_config.max_seq_len,
300+
max_draft_tokens=max_draft_tokens,
301+
max_num_sequences=max_num_sequences,
302+
max_beam_width=executor_config.max_beam_width,
292303
mixed_sampler=ad_config.mixed_sampler,
293304
)
305+
sampler = TorchSampler(sampler_args)
294306

295307
# creating the executor object
296308
py_executor = PyExecutor(
@@ -299,6 +311,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
299311
model_engine=engine,
300312
sampler=sampler,
301313
dist=mpi_dist,
314+
max_num_sequences=max_num_sequences,
302315
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
303316
max_input_len=ad_config.max_input_len,
304317
max_batch_size=ad_config.max_batch_size,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
2727
PeftCacheManager, ResourceManager,
2828
ResourceManagerType)
29-
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
30-
TRTLLMSampler)
29+
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
3130
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
3231
SimpleScheduler)
3332
from .seq_slot_manager import SeqSlotManager
@@ -514,6 +513,7 @@ def create_py_executor_instance(
514513
sampler=sampler,
515514
drafter=drafter,
516515
dist=dist,
516+
max_num_sequences=max_num_sequences,
517517
disable_overlap_scheduler=pytorch_backend_config.
518518
disable_overlap_scheduler,
519519
max_batch_size=executor_config.max_batch_size,
@@ -525,27 +525,44 @@ def create_py_executor_instance(
525525
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
526526

527527

528-
def instantiate_sampler(model_engine: PyTorchModelEngine,
528+
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
529+
*, max_seq_len: int, mixed_sampler: bool):
530+
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
531+
max_draft_tokens = (0 if executor_config.speculative_config is None else
532+
executor_config.speculative_config.max_draft_tokens)
533+
return TorchSampler.Args(
534+
max_seq_len=max_seq_len,
535+
max_draft_tokens=max_draft_tokens,
536+
max_num_sequences=max_num_sequences,
537+
max_beam_width=executor_config.max_beam_width,
538+
mixed_sampler=mixed_sampler,
539+
)
540+
541+
542+
def instantiate_sampler(engine: PyTorchModelEngine,
529543
executor_config: ExecutorConfig,
530544
pytorch_backend_config: PyTorchConfig,
531545
mapping: Mapping):
546+
sampler_args = create_torch_sampler_args(
547+
executor_config,
548+
mapping,
549+
max_seq_len=engine.max_seq_len,
550+
mixed_sampler=pytorch_backend_config.mixed_sampler)
532551
if mapping.cp_config.get('cp_type') == 'star_attention':
533552
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
534-
return TorchStarAttentionSampler(max_seq_len=model_engine.max_seq_len)
535-
spec_config = model_engine.spec_config
536-
if spec_config is not None and spec_config.spec_dec_mode.has_spec_decoder():
537-
return get_spec_decoder(max_seq_len=model_engine.max_seq_len,
538-
spec_config=spec_config)
553+
return TorchSampler(sampler_args)
554+
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
555+
):
556+
return get_spec_decoder(sampler_args, engine.spec_config)
539557
if pytorch_backend_config.enable_trtllm_sampler:
540-
return TRTLLMSampler(executor_config, model_engine.model,
541-
model_engine.dtype, mapping,
542-
get_decoding_mode(executor_config),
558+
decoding_mode = get_decoding_mode(executor_config)
559+
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
560+
mapping, decoding_mode,
543561
pytorch_backend_config.disable_overlap_scheduler)
544-
elif not model_engine.model.model_config.is_generation:
562+
if not engine.model.model_config.is_generation:
545563
# NOTE: choose sampler based on model type
546564
return EarlyStopSampler()
547-
return TorchSampler(max_seq_len=model_engine.max_seq_len,
548-
mixed_sampler=pytorch_backend_config.mixed_sampler)
565+
return TorchSampler(sampler_args)
549566

550567

551568
def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import math
32
from typing import List, Optional
43

@@ -52,8 +51,7 @@ def bitmask_size(self) -> int:
5251

5352
def build(self, scheduled_requests: ScheduledRequests,
5453
resource_manager: SeqSlotManager) -> None:
55-
for llm_req in itertools.chain(scheduled_requests.context_requests,
56-
scheduled_requests.generation_requests):
54+
for llm_req in scheduled_requests.all_requests():
5755
if llm_req.guided_decoding_params is None:
5856
continue
5957
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
@@ -84,9 +82,7 @@ def execute(self, scheduled_requests: ScheduledRequests,
8482
torch.cuda.current_stream().wait_stream(self._stream)
8583

8684
batched_logits, batched_bitmask = [], []
87-
for i, llm_req in enumerate(
88-
itertools.chain(scheduled_requests.context_requests,
89-
scheduled_requests.generation_requests)):
85+
for i, llm_req in enumerate(scheduled_requests.all_requests()):
9086
if llm_req.guided_decoding_params is None:
9187
continue
9288
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def __init__(
254254
exclude_last_generation_logits: bool = False,
255255
return_perf_metrics: bool = False,
256256
stop_words_list: list[list[int]] | None = None,
257+
is_draft: bool = False,
257258
**kwargs):
258259
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
259260
None)
@@ -288,6 +289,7 @@ def __init__(
288289
self.py_return_context_logits = return_context_logits
289290
self.py_return_generation_logits = return_generation_logits
290291
self.py_return_logits_device_memory = return_logits_device_memory
292+
self.py_is_draft = is_draft
291293

292294
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
293295
# currently, keep py_stop_words_list as python list, rather than tensor.

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import gc
55
import glob
66
import inspect
7-
import itertools
87
import math
98
import multiprocessing
109
import os
@@ -21,6 +20,7 @@
2120
import torch._dynamo.config
2221

2322
import tensorrt_llm.bindings.internal.userbuffers as ub
23+
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
2424
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
2525
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
2626
from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank,
@@ -319,6 +319,7 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
319319

320320

321321
class PyTorchModelEngine(ModelEngine):
322+
BEAM_WIDTH = 1
322323

323324
def __init__(
324325
self,
@@ -659,13 +660,12 @@ def get_autotune_warmup_request():
659660
return result
660661

661662
@contextlib.contextmanager
662-
def release_batch(result):
663+
def release_batch(result: ScheduledRequests | None):
663664
try:
664665
yield result
665666
finally:
666667
if result is not None:
667-
for req in itertools.chain(result.generation_requests,
668-
result.context_requests):
668+
for req in result.all_requests():
669669
kv_cache_manager.free_resources(req)
670670
if spec_resource_manager is not None:
671671
spec_resource_manager.free_resources(req)
@@ -1153,7 +1153,15 @@ def _prepare_tp_inputs(
11531153
draft_lens = []
11541154
mrope_config = defaultdict(list)
11551155

1156-
batch_idx = 0
1156+
mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially
1157+
1158+
def py_batch_idx(request: LlmRequest) -> int:
1159+
if not self.without_logits:
1160+
return request.seq_slot
1161+
nonlocal mtp_batch_idx
1162+
batch_idx = mtp_batch_idx
1163+
mtp_batch_idx += 1
1164+
return batch_idx
11571165

11581166
for request in scheduled_requests.context_requests:
11591167
request_ids.append(request.py_request_id)
@@ -1184,10 +1192,9 @@ def _prepare_tp_inputs(
11841192
) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin
11851193
mrope_config['mrope_rotary_cos_sin'].append(
11861194
mrope_rotary_cos_sin.to('cuda', non_blocking=True))
1187-
request.py_batch_idx = batch_idx
1188-
batch_idx += 1
1195+
request.py_batch_idx = py_batch_idx(request)
11891196

1190-
num_ctx_requests = batch_idx
1197+
num_ctx_requests = len(scheduled_requests.context_requests)
11911198
num_ctx_tokens = len(input_ids)
11921199
new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None
11931200
if new_tensors_device is not None:
@@ -1227,7 +1234,7 @@ def _prepare_tp_inputs(
12271234
assert spec_dec_mode.support_overlap_scheduler(
12281235
), f"{self.spec_config.spec_dec_name} does not support overlap scheduler"
12291236

1230-
# will contain previous batch incices of generation requests
1237+
# will contain previous batch indices of generation requests
12311238
previous_batch_indices = []
12321239
previous_pos_indices = []
12331240
for request in extend_requests:
@@ -1272,8 +1279,7 @@ def _prepare_tp_inputs(
12721279
else:
12731280
# update batch index
12741281
previous_batch_idx = request.py_batch_idx
1275-
request.py_batch_idx = batch_idx
1276-
batch_idx += 1
1282+
request.py_batch_idx = py_batch_idx(request)
12771283
# inputs
12781284
# overlap scheduler can only support the speculative decoding
12791285
# methods with a fixed number of draft tokens
@@ -1324,8 +1330,18 @@ def _prepare_tp_inputs(
13241330
prompt_lengths.append(request.py_prompt_len)
13251331
draft_lens.append(0)
13261332

1327-
request.py_batch_idx = batch_idx
1328-
batch_idx += 1
1333+
request.py_batch_idx = py_batch_idx(request)
1334+
1335+
previous_batch_len = len(previous_batch_indices)
1336+
1337+
def previous_seq_slots_device():
1338+
previous_batch_indices_host = torch.tensor(previous_batch_indices,
1339+
dtype=torch.int,
1340+
pin_memory=True)
1341+
previous_slots = self.previous_batch_indices_cuda[:
1342+
previous_batch_len]
1343+
previous_slots.copy_(previous_batch_indices_host, non_blocking=True)
1344+
return previous_slots
13291345

13301346
num_tokens = len(input_ids)
13311347
num_draft_tokens = len(draft_tokens)
@@ -1347,29 +1363,22 @@ def _prepare_tp_inputs(
13471363
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
13481364
non_blocking=True)
13491365
if next_draft_tokens_device is not None:
1350-
if len(previous_batch_indices) > 0:
1351-
previous_batch_indices = torch.tensor(previous_batch_indices,
1352-
dtype=torch.int,
1353-
pin_memory=True)
1354-
self.previous_batch_indices_cuda[:previous_batchs].copy_(
1355-
previous_batch_indices, non_blocking=True)
1366+
if previous_batch_len > 0:
1367+
previous_slots = previous_seq_slots_device()
13561368
# previous input ids
1357-
previous_batch_tokens = previous_batchs * (1 +
1358-
self.max_draft_len)
1359-
self.input_ids_cuda[
1360-
num_tokens:num_tokens +
1361-
previous_batch_tokens].copy_(new_tokens_device[
1362-
self.previous_batch_indices_cuda[:previous_batchs], :].
1363-
flatten(),
1364-
non_blocking=True)
1369+
previous_batch_tokens = previous_batch_len * (
1370+
1 + self.max_draft_len)
1371+
new_tokens = new_tokens_device[previous_slots, :].flatten()
1372+
self.input_ids_cuda[num_tokens:num_tokens +
1373+
previous_batch_tokens].copy_(
1374+
new_tokens, non_blocking=True)
13651375
# previous draft tokens
1366-
previous_batch_draft_tokens = previous_batchs * self.max_draft_len
1367-
self.draft_tokens_cuda[
1368-
num_draft_tokens:num_draft_tokens +
1369-
previous_batch_draft_tokens].copy_(next_draft_tokens_device[
1370-
self.previous_batch_indices_cuda[:previous_batchs], :].
1371-
flatten(),
1372-
non_blocking=True)
1376+
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
1377+
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
1378+
previous_batch_draft_tokens].copy_(
1379+
next_draft_tokens_device[
1380+
previous_slots, :].flatten(),
1381+
non_blocking=True)
13731382
# prepare data for the preprocess inputs
13741383
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
13751384
previous_pos_indices = torch.tensor(previous_pos_indices,
@@ -1398,16 +1407,13 @@ def _prepare_tp_inputs(
13981407
self.previous_pos_id_offsets_cuda *= 0
13991408
self.previous_kv_lens_offsets_cuda *= 0
14001409
elif new_tokens_device is not None:
1401-
previous_batch_tokens = len(previous_batch_indices)
1402-
previous_batch_indices = torch.tensor(previous_batch_indices,
1403-
dtype=torch.int,
1404-
pin_memory=True)
1405-
self.previous_batch_indices_cuda[:previous_batch_tokens].copy_(
1406-
previous_batch_indices, non_blocking=True)
1407-
self.input_ids_cuda[num_tokens:num_tokens + previous_batchs].copy_(
1408-
new_tokens_device[
1409-
self.previous_batch_indices_cuda[:previous_batchs]],
1410-
non_blocking=True)
1410+
seq_slots_device = previous_seq_slots_device()
1411+
max_draft_len = max(draft_lens)
1412+
new_tokens = new_tokens_device[:max_draft_len + 1,
1413+
seq_slots_device, :self.BEAM_WIDTH]
1414+
self.input_ids_cuda[num_tokens:num_tokens +
1415+
previous_batch_len].copy_(new_tokens.flatten(),
1416+
non_blocking=True)
14111417

14121418
position_ids = torch.tensor(position_ids,
14131419
dtype=torch.int,
@@ -1645,7 +1651,6 @@ def _prepare_star_attention_inputs(self,
16451651
# for star attention, we need customized block ids
16461652
block_ids_per_seq = []
16471653
num_cached_tokens_per_seq = []
1648-
output_token_idx = 0
16491654
for request in scheduled_requests.context_requests:
16501655
request_ids.append(request.py_request_id)
16511656
prompt_lengths.append(request.py_prompt_len)
@@ -1702,8 +1707,6 @@ def _prepare_star_attention_inputs(self,
17021707
sequence_lengths.append(len(input_id))
17031708
block_ids_per_seq.extend([all_cache_indices])
17041709
num_cached_tokens_per_seq.append(past_seen_token_num)
1705-
request.output_token_idx = output_token_idx
1706-
output_token_idx += 1
17071710
num_contexts = len(sequence_lengths)
17081711
for request in scheduled_requests.context_requests:
17091712
ctx_iter = request.ctx_iters
@@ -1743,8 +1746,6 @@ def _prepare_star_attention_inputs(self,
17431746
sequence_lengths.append(len(input_id))
17441747
block_ids_per_seq.extend([all_cache_indices])
17451748
num_cached_tokens_per_seq.append(past_seen_token_num)
1746-
request.output_token_idx = output_token_idx
1747-
output_token_idx += 1
17481749
num_queries = len(sequence_lengths) - num_contexts
17491750

17501751
# Requests with draft tokens are treated like extend requests.
@@ -1802,8 +1803,6 @@ def _prepare_star_attention_inputs(self,
18021803
position_ids.append(last_query_pos_id + request.gen_iters + 1)
18031804
block_ids_per_seq.extend([all_cache_indices])
18041805
num_cached_tokens_per_seq.append(past_seen_token_num)
1805-
request.output_token_idx = output_token_idx
1806-
output_token_idx += 1
18071806

18081807
num_tokens = len(input_ids)
18091808
assert num_tokens <= self.max_num_tokens, (
@@ -2171,9 +2170,7 @@ def _execute_logit_post_processors(self,
21712170
num_ctx_req = len(scheduled_requests.context_requests)
21722171
logits_tensor = outputs["logits"]
21732172

2174-
for idx, request in enumerate(
2175-
itertools.chain(scheduled_requests.context_requests,
2176-
scheduled_requests.generation_requests)):
2173+
for idx, request in enumerate(scheduled_requests.all_requests()):
21772174
logits_processors = getattr(request, "py_logits_post_processors",
21782175
None)
21792176
if not logits_processors:

0 commit comments

Comments
 (0)