Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tensorrt_llm.models.modeling_utils import QuantConfig

from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
from .models import ModelFactory, ModelFactoryRegistry
from .utils._config import DynamicYamlMixInForSettings
from .utils.logger import ad_logger
Expand Down Expand Up @@ -130,6 +130,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
"supported in AutoDeploy.",
)

sampler_type: Union[str, SamplerType] = Field(
default=SamplerType.TorchSampler,
description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.",
)

# NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet
# see https://github.com/NVIDIA/TensorRT-LLM/issues/7142
kv_cache_config: KvCacheConfig = Field(
Expand Down
81 changes: 70 additions & 11 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from torch._prims_common import DeviceLikeType

from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
from tensorrt_llm._torch.pyexecutor._util import _create_kv_cache_manager, get_kv_cache_manager_cls
from tensorrt_llm._torch.pyexecutor._util import (
_create_kv_cache_manager,
get_decoding_mode,
get_kv_cache_manager_cls,
)
from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder
from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
Expand All @@ -30,6 +34,7 @@
from tensorrt_llm.llmapi.llm_args import (
ContextChunkingPolicy,
LoadFormat,
SamplerType,
SpeculativeConfig,
TorchLlmArgs,
)
Expand All @@ -42,7 +47,7 @@
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
from ...pyexecutor.py_executor import PyExecutor
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType
from ...pyexecutor.sampler import TorchSampler
from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler
from ...pyexecutor.scheduler import (
BindCapacityScheduler,
BindMicroBatchScheduler,
Expand Down Expand Up @@ -283,9 +288,9 @@ def __init__(
self.llm_args.batch_wait_timeout_iters = 0
self.llm_args.batch_wait_max_tokens_ratio = 0.0
self.llm_args.max_num_tokens = seq_info.max_num_tokens
self.llm_args.max_seq_len = seq_info.max_seq_len
self.iter_counter = 0
self.iter_states = {}
self.llm_args.max_seq_len = seq_info.max_seq_len

# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
self.max_beam_width = max_beam_width
Expand Down Expand Up @@ -487,6 +492,9 @@ def _compute_logits(self) -> List[torch.Tensor]:
# run the model
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]

# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
logits = logits.float()

# return a list of tensors
return self.cache_seq_interface.info.unnest_sequences(logits)

Expand Down Expand Up @@ -574,6 +582,59 @@ def create_draft_model_engine_maybe(
return draft_model_engine


class TRTLLMSamplerModelConfig:
def __init__(self, vocab_size_padded: int):
self.config = SimpleNamespace()
self.config.vocab_size = vocab_size_padded

# Initialized to dummy values as they are not used in the C++ code underlying TRTLLMSampler.
self.config.num_hidden_layers = 42
self.config.hidden_size = 42
self.config.num_attention_heads = 42


def instantiate_sampler(
ad_config: LlmArgs,
max_num_sequences: int,
max_draft_len: int,
max_total_draft_tokens: int,
dist_mapping: Mapping,
engine: ADEngine,
):
if ad_config.sampler_type == SamplerType.TorchSampler:
# search sampler with speculative decoding
sampler_args = TorchSampler.Args(
max_seq_len=ad_config.max_seq_len,
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=ad_config.max_beam_width,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
)
sampler = TorchSampler(sampler_args)

elif ad_config.sampler_type == SamplerType.TRTLLMSampler:
vocab_size_padded: int = engine.cache_seq_interface.info.vocab_size_padded
sampler_model_config = TRTLLMSamplerModelConfig(vocab_size_padded)
decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width)
sampler = TRTLLMSampler(
model=sampler_model_config,
model_dtype=torch.bfloat16, # hardcoded as bfloat16; does not seem necessary in C++ code.
mapping=dist_mapping,
decoding_mode=decoding_mode,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
max_seq_len=ad_config.max_seq_len,
max_batch_size=ad_config.max_batch_size,
max_beam_width=ad_config.max_beam_width,
decoding_config=ad_config.decoding_config,
kv_cache_config=ad_config.kv_cache_config,
)
else:
raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.")

return sampler


def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None):
"""Create an AutoDeploy executor from the given configuration and tokenizer.
The tokenizer is required for guided decoding.
Expand Down Expand Up @@ -695,23 +756,21 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
)
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)

# search sampler with speculative decoding
sampler_args = TorchSampler.Args(
max_seq_len=ad_config.max_seq_len,
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
sampler = instantiate_sampler(
ad_config=ad_config,
max_num_sequences=max_num_sequences,
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
max_num_sequences=max_num_sequences,
max_beam_width=ad_config.max_beam_width,
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
dist_mapping=dist_mapping,
engine=engine,
)
sampler = TorchSampler(sampler_args)

# Guided (structured) decoding.
guided_decoder = None
if (
(guided_decoding_backend := ad_config.guided_decoding_backend) is not None
) and dist_mapping.is_last_pp_rank():
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
if vocab_size_padded is None:
raise RuntimeError(
"Could not determine the vocabulary size. Required for guided decoding."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from _model_test_utils import get_small_model_config
from build_and_run_ad import ExperimentConfig, main

from tensorrt_llm.llmapi.llm_args import SamplerType


def test_ad_trtllm_sampler_smoke():
"""Test TRTLLMSampler in AutoDeploy smoke test."""
# Get small model config
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
experiment_config = get_small_model_config(model_id)

# Configure for TRTLLMSampler
experiment_config["args"]["runtime"] = "trtllm"
experiment_config["args"]["world_size"] = 1
experiment_config["args"]["sampler_type"] = SamplerType.TRTLLMSampler

# Setup simple prompt
experiment_config["prompt"]["batch_size"] = 1
experiment_config["prompt"]["queries"] = {"prompt": "What is the capital of France?"}
experiment_config["prompt"]["sp_kwargs"] = {
"max_tokens": 10,
"temperature": 1.0,
"top_k": 1,
}

print(f"Experiment config: {experiment_config}")
cfg = ExperimentConfig(**experiment_config)

print("Running smoke test with TRTLLMSampler...")
results = main(cfg)

# Basic assertion that we got some output
prompts_and_outputs = results["prompts_and_outputs"]
assert len(prompts_and_outputs) == 1
assert len(prompts_and_outputs[0][1]) > 0