Skip to content
6 changes: 2 additions & 4 deletions examples/llm-eval/lm-eval-harness/lm_eval_tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
import tensorrt_llm
from tensorrt_llm import LLM as TORCH_LLM
from tensorrt_llm._tensorrt_engine import LLM as TRT_LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.bindings.executor import DecodingConfig
from tensorrt_llm.llmapi import KvCacheConfig as TRT_KvCacheConfig
from tensorrt_llm.llmapi import RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import MoeConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,10 +98,8 @@ def __init__(
pytorch_config_params = {
'cuda_graph_config': {} if use_cuda_graph else None,
"print_iter_log": False,
'moe_config': MoeConfig(backend=self.moe_backend)
}
if hasattr(PyTorchConfig, "moe_backend"):
pytorch_config_params["moe_backend"] = self.moe_backend
print(f"Info: moe_backend is set to {self.moe_backend}")

# stop words not currently supported by torch backend
self.use_stop_words = False
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(self, **kwargs):
self._executor = DemoGenerationExecutor(
world_size=self.args.world_size,
tokenizer=self.tokenizer,
ad_config=self.args.get_pytorch_backend_config(),
ad_config=self.args,
)

def __del__(self):
Expand Down
7 changes: 0 additions & 7 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,6 @@ def validate_and_init_tokenizer(self):
"""Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class."""
return self

### UTILITY METHODS ############################################################################
# TODO: Remove this after the PyTorch backend is fully migrated to LlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "LlmArgs":
"""Return the LlmArgs (self) object."""
# TODO: can we just pass through self directly??
return type(self)(**self.to_llm_kwargs())

def to_dict(self) -> Dict:
"""Convert model to a dictionary such that cls(**self.to_dict()) == self."""
self_dict = super().to_dict()
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,6 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
dist.initialize_or_skip(rank, world_size, port)

# some config
msg = "pytorch_backend_config must be an AD LlmArgs object"
assert isinstance(ad_config, LlmArgs), msg
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"

max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
Expand Down
44 changes: 22 additions & 22 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ..attention_backend import get_sparse_attn_kv_cache_manager
from ..model_config import ModelConfig
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
from .config import PyTorchConfig
from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(
max_seq_len: int,
max_batch_size: int,
kv_cache_config: KvCacheConfig,
pytorch_backend_config: PyTorchConfig,
llm_args: TorchLlmArgs,
speculative_config: SpeculativeConfig,
sparse_attention_config: SparseAttentionConfig,
profiling_stage_data: Optional[dict],
Expand All @@ -86,7 +85,7 @@ def __init__(
self._max_num_tokens = max_num_tokens
self._max_beam_width = max_beam_width
self._kv_connector_manager = kv_connector_manager
self._pytorch_backend_config = pytorch_backend_config
self._llm_args = llm_args
self._speculative_config = speculative_config
self._sparse_attention_config = sparse_attention_config
self._tokens_per_block = tokens_per_block
Expand Down Expand Up @@ -248,9 +247,8 @@ def _get_token_num_for_estimation(self) -> int:
# estimate_max_kv_cache_tokens submits self._dummy_reqs
num_cache_blocks = 0
num_extra_tokens_per_seq = 1 # account for generated tokens
pytorch_backend_config = self._pytorch_backend_config
spec_cfg = self._speculative_config
if not pytorch_backend_config.disable_overlap_scheduler:
if not self._llm_args.disable_overlap_scheduler:
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
Expand Down Expand Up @@ -653,7 +651,7 @@ def create_py_executor_instance(
dist,
resources,
mapping,
pytorch_backend_config,
llm_args,
ctx_chunk_config,
model_engine,
start_worker,
Expand All @@ -679,7 +677,7 @@ def create_py_executor_instance(
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
)

for key, value in pytorch_backend_config.extra_resource_managers.items():
for key, value in llm_args.extra_resource_managers.items():
if key in resources:
raise ValueError(
f"Cannot overwrite existing resource manager {key}.")
Expand Down Expand Up @@ -804,8 +802,7 @@ def create_py_executor_instance(
drafter=drafter,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_draft_len=spec_config.max_draft_len
Expand Down Expand Up @@ -840,13 +837,11 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
)


def instantiate_sampler(engine: PyTorchModelEngine,
pytorch_backend_config: PyTorchConfig, mapping: Mapping,
max_batch_size: int, max_beam_width: int,
max_seq_len: int, mm_encoder_only: bool,
speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig,
kv_cache_config: KvCacheConfig):
def instantiate_sampler(
engine: PyTorchModelEngine, llm_args: TorchLlmArgs, mapping: Mapping,
max_batch_size: int, max_beam_width: int, max_seq_len: int,
mm_encoder_only: bool, speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig, kv_cache_config: KvCacheConfig):
sampler_args = create_torch_sampler_args(
mapping,
max_seq_len=engine.max_seq_len,
Expand All @@ -856,7 +851,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
max_beam_width=max_beam_width)
if mapping.cp_config.get('cp_type') == CpType.STAR:
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
assert llm_args.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
Expand All @@ -865,15 +860,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
if mm_encoder_only:
# NOTE: handle model outputs specially for mm encoder executor/engine
return EarlyStopWithMMResult()
if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or (
pytorch_backend_config.sampler_type == SamplerType.auto
if llm_args.sampler_type == SamplerType.TRTLLMSampler or (
llm_args.sampler_type == SamplerType.auto
and decoding_mode.isBeamSearch()):
logger.debug(f"DecodingMode: {decoding_mode.name}")
return TRTLLMSampler(engine.model,
engine.dtype,
mapping,
decoding_mode,
pytorch_backend_config.disable_overlap_scheduler,
llm_args.disable_overlap_scheduler,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
Expand Down Expand Up @@ -935,7 +930,12 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
return num_experts


def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
def _adjust_torch_mem_fraction():
# If true, adjust PyTorch CUDA memory fraction to correspond to the
# total GPU memory minus the statically allocated engine memory.
# If false, set the PyTorch CUDA memory fraction to 1.0.
_limit_torch_cuda_mem_fraction: bool = True

# FIXME: PyTorch only uses the garbage_collection_threshold setting
# if a memory fraction is set, cf.
# https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
Expand Down Expand Up @@ -964,7 +964,7 @@ def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
# lead PyTorch to release all unused memory before hitting the set fraction. This
# still mitigates OOM, although at a higher performance impact, because it
# effectively resets the allocator cache.
if not pytorch_backend_config._limit_torch_cuda_mem_fraction:
if not _limit_torch_cuda_mem_fraction:
return
mem_reserved = torch.cuda.memory_reserved()
mem_free, mem_total = torch.cuda.mem_get_info()
Expand Down
139 changes: 0 additions & 139 deletions tensorrt_llm/_torch/pyexecutor/config.py

This file was deleted.

3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@
from ..utils import (get_model_extra_attrs,
set_per_request_piecewise_cuda_graph_flag,
set_torch_compiling, with_model_extra_attrs)
from .config import _construct_checkpoint_loader
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
from .guided_decoder import CapturableGuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
from .model_loader import ModelLoader
from .model_loader import ModelLoader, _construct_checkpoint_loader
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager, ResourceManagerType)
from .sampler import SampleStateTensors
Expand Down
Loading
Loading