Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch._prims_common import DeviceLikeType

from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
from tensorrt_llm._utils import nvtx_range

from ...._utils import mpi_rank, mpi_world_size
Expand Down Expand Up @@ -256,6 +257,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
ad_config: LlmArgs = executor_config.pytorch_backend_config

max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
# some derivative properties
max_draft_tokens = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
Expand All @@ -272,7 +274,13 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
)
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
resource_manager = ResourceManager(
{
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
}
)
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)

# scheduling
Expand All @@ -287,10 +295,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
# correctly for models as needed.
sampler = TorchSampler(
sampler_args = TorchSampler.Args(
max_seq_len=ad_config.max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=ad_config.mixed_sampler,
)
sampler = TorchSampler(sampler_args)

# creating the executor object
py_executor = PyExecutor(
Expand All @@ -299,6 +311,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
model_engine=engine,
sampler=sampler,
dist=mpi_dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,
Expand Down
45 changes: 31 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
PeftCacheManager, ResourceManager,
ResourceManagerType)
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
TRTLLMSampler)
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
from .seq_slot_manager import SeqSlotManager
Expand Down Expand Up @@ -514,6 +513,7 @@ def create_py_executor_instance(
sampler=sampler,
drafter=drafter,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
Expand All @@ -525,27 +525,44 @@ def create_py_executor_instance(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)


def instantiate_sampler(model_engine: PyTorchModelEngine,
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, mixed_sampler: bool):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_tokens = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_tokens)
return TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_tokens=max_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
mixed_sampler=mixed_sampler,
)


def instantiate_sampler(engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping):
sampler_args = create_torch_sampler_args(
executor_config,
mapping,
max_seq_len=engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
if mapping.cp_config.get('cp_type') == 'star_attention':
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchStarAttentionSampler(max_seq_len=model_engine.max_seq_len)
spec_config = model_engine.spec_config
if spec_config is not None and spec_config.spec_dec_mode.has_spec_decoder():
return get_spec_decoder(max_seq_len=model_engine.max_seq_len,
spec_config=spec_config)
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
return get_spec_decoder(sampler_args, engine.spec_config)
if pytorch_backend_config.enable_trtllm_sampler:
return TRTLLMSampler(executor_config, model_engine.model,
model_engine.dtype, mapping,
get_decoding_mode(executor_config),
decoding_mode = get_decoding_mode(executor_config)
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
elif not model_engine.model.model_config.is_generation:
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
return TorchSampler(max_seq_len=model_engine.max_seq_len,
mixed_sampler=pytorch_backend_config.mixed_sampler)
return TorchSampler(sampler_args)


def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
Expand Down
8 changes: 2 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import math
from typing import List, Optional

Expand Down Expand Up @@ -52,8 +51,7 @@ def bitmask_size(self) -> int:

def build(self, scheduled_requests: ScheduledRequests,
resource_manager: SeqSlotManager) -> None:
for llm_req in itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
for llm_req in scheduled_requests.all_requests():
if llm_req.guided_decoding_params is None:
continue
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
Expand Down Expand Up @@ -84,9 +82,7 @@ def execute(self, scheduled_requests: ScheduledRequests,
torch.cuda.current_stream().wait_stream(self._stream)

batched_logits, batched_bitmask = [], []
for i, llm_req in enumerate(
itertools.chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)):
for i, llm_req in enumerate(scheduled_requests.all_requests()):
if llm_req.guided_decoding_params is None:
continue
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
exclude_last_generation_logits: bool = False,
return_perf_metrics: bool = False,
stop_words_list: list[list[int]] | None = None,
is_draft: bool = False,
**kwargs):
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
None)
Expand Down Expand Up @@ -288,6 +289,7 @@ def __init__(
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft

# TODO: remove this when use DynamicDecodeOp in pytorch flow.
# currently, keep py_stop_words_list as python list, rather than tensor.
Expand Down
Loading