diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index cdb632c6cbe..40eb227f957 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -8,7 +8,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig -from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig +from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig from .models import ModelFactory, ModelFactoryRegistry from .utils._config import DynamicYamlMixInForSettings from .utils.logger import ad_logger @@ -130,6 +130,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): "supported in AutoDeploy.", ) + sampler_type: Union[str, SamplerType] = Field( + default=SamplerType.TorchSampler, + description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.", + ) + # NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet # see https://github.com/NVIDIA/TensorRT-LLM/issues/7142 kv_cache_config: KvCacheConfig = Field( diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 1dca0cc90c8..d4eab7131a8 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -20,7 +20,11 @@ from torch._prims_common import DeviceLikeType from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures -from tensorrt_llm._torch.pyexecutor._util import _create_kv_cache_manager, get_kv_cache_manager_cls +from tensorrt_llm._torch.pyexecutor._util import ( + _create_kv_cache_manager, + get_decoding_mode, + get_kv_cache_manager_cls, +) from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config @@ -30,6 +34,7 @@ from tensorrt_llm.llmapi.llm_args import ( ContextChunkingPolicy, LoadFormat, + SamplerType, SpeculativeConfig, TorchLlmArgs, ) @@ -42,7 +47,7 @@ from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine from ...pyexecutor.py_executor import PyExecutor from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType -from ...pyexecutor.sampler import TorchSampler +from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler from ...pyexecutor.scheduler import ( BindCapacityScheduler, BindMicroBatchScheduler, @@ -283,9 +288,9 @@ def __init__( self.llm_args.batch_wait_timeout_iters = 0 self.llm_args.batch_wait_max_tokens_ratio = 0.0 self.llm_args.max_num_tokens = seq_info.max_num_tokens + self.llm_args.max_seq_len = seq_info.max_seq_len self.iter_counter = 0 self.iter_states = {} - self.llm_args.max_seq_len = seq_info.max_seq_len # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... self.max_beam_width = max_beam_width @@ -487,6 +492,9 @@ def _compute_logits(self) -> List[torch.Tensor]: # run the model logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0] + # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless. + logits = logits.float() + # return a list of tensors return self.cache_seq_interface.info.unnest_sequences(logits) @@ -574,6 +582,59 @@ def create_draft_model_engine_maybe( return draft_model_engine +class TRTLLMSamplerModelConfig: + def __init__(self, vocab_size_padded: int): + self.config = SimpleNamespace() + self.config.vocab_size = vocab_size_padded + + # Initialized to dummy values as they are not used in the C++ code underlying TRTLLMSampler. + self.config.num_hidden_layers = 42 + self.config.hidden_size = 42 + self.config.num_attention_heads = 42 + + +def instantiate_sampler( + ad_config: LlmArgs, + max_num_sequences: int, + max_draft_len: int, + max_total_draft_tokens: int, + dist_mapping: Mapping, + engine: ADEngine, +): + if ad_config.sampler_type == SamplerType.TorchSampler: + # search sampler with speculative decoding + sampler_args = TorchSampler.Args( + max_seq_len=ad_config.max_seq_len, + max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, + max_num_sequences=max_num_sequences, + max_beam_width=ad_config.max_beam_width, + disable_overlap_scheduler=ad_config.disable_overlap_scheduler, + ) + sampler = TorchSampler(sampler_args) + + elif ad_config.sampler_type == SamplerType.TRTLLMSampler: + vocab_size_padded: int = engine.cache_seq_interface.info.vocab_size_padded + sampler_model_config = TRTLLMSamplerModelConfig(vocab_size_padded) + decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width) + sampler = TRTLLMSampler( + model=sampler_model_config, + model_dtype=torch.bfloat16, # hardcoded as bfloat16; does not seem necessary in C++ code. + mapping=dist_mapping, + decoding_mode=decoding_mode, + disable_overlap_scheduler=ad_config.disable_overlap_scheduler, + max_seq_len=ad_config.max_seq_len, + max_batch_size=ad_config.max_batch_size, + max_beam_width=ad_config.max_beam_width, + decoding_config=ad_config.decoding_config, + kv_cache_config=ad_config.kv_cache_config, + ) + else: + raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.") + + return sampler + + def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None): """Create an AutoDeploy executor from the given configuration and tokenizer. The tokenizer is required for guided decoding. @@ -695,23 +756,21 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer ) scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) - # search sampler with speculative decoding - sampler_args = TorchSampler.Args( - max_seq_len=ad_config.max_seq_len, + vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded + sampler = instantiate_sampler( + ad_config=ad_config, + max_num_sequences=max_num_sequences, max_draft_len=max_draft_len, max_total_draft_tokens=max_total_draft_tokens, - max_num_sequences=max_num_sequences, - max_beam_width=ad_config.max_beam_width, - disable_overlap_scheduler=ad_config.disable_overlap_scheduler, + dist_mapping=dist_mapping, + engine=engine, ) - sampler = TorchSampler(sampler_args) # Guided (structured) decoding. guided_decoder = None if ( (guided_decoding_backend := ad_config.guided_decoding_backend) is not None ) and dist_mapping.is_last_pp_rank(): - vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded if vocab_size_padded is None: raise RuntimeError( "Could not determine the vocabulary size. Required for guided decoding." diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_sampler.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_sampler.py new file mode 100644 index 00000000000..41e96ae1cf3 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_sampler.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _model_test_utils import get_small_model_config +from build_and_run_ad import ExperimentConfig, main + +from tensorrt_llm.llmapi.llm_args import SamplerType + + +def test_ad_trtllm_sampler_smoke(): + """Test TRTLLMSampler in AutoDeploy smoke test.""" + # Get small model config + model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" + experiment_config = get_small_model_config(model_id) + + # Configure for TRTLLMSampler + experiment_config["args"]["runtime"] = "trtllm" + experiment_config["args"]["world_size"] = 1 + experiment_config["args"]["sampler_type"] = SamplerType.TRTLLMSampler + + # Setup simple prompt + experiment_config["prompt"]["batch_size"] = 1 + experiment_config["prompt"]["queries"] = {"prompt": "What is the capital of France?"} + experiment_config["prompt"]["sp_kwargs"] = { + "max_tokens": 10, + "temperature": 1.0, + "top_k": 1, + } + + print(f"Experiment config: {experiment_config}") + cfg = ExperimentConfig(**experiment_config) + + print("Running smoke test with TRTLLMSampler...") + results = main(cfg) + + # Basic assertion that we got some output + prompts_and_outputs = results["prompts_and_outputs"] + assert len(prompts_and_outputs) == 1 + assert len(prompts_and_outputs[0][1]) > 0