diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 5b41abf7e8e..e2dc543ac42 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import dataclasses import json import math import os import shutil import time from dataclasses import dataclass, field +from functools import cache from pathlib import Path from typing import Dict, Optional, Union @@ -557,53 +559,89 @@ def override_attri(attr_name, value): else: override_attri('paged_state', False) + @classmethod + @cache + def get_build_config_defaults(cls): + return { + field.name: field.default + for field in dataclasses.fields(cls) + if field.default is not dataclasses.MISSING + } + @classmethod def from_dict(cls, config, plugin_config=None): config = copy.deepcopy( config ) # it just does not make sense to change the input arg `config` - max_input_len = config.pop('max_input_len') - max_seq_len = config.pop('max_seq_len') - max_batch_size = config.pop('max_batch_size') - max_beam_width = config.pop('max_beam_width') - max_num_tokens = config.pop('max_num_tokens') - opt_num_tokens = config.pop('opt_num_tokens') - opt_batch_size = config.pop('opt_batch_size', 8) + + defaults = cls.get_build_config_defaults() + max_input_len = config.pop('max_input_len', + defaults.get('max_input_len')) + max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len')) + max_batch_size = config.pop('max_batch_size', + defaults.get('max_batch_size')) + max_beam_width = config.pop('max_beam_width', + defaults.get('max_beam_width')) + max_num_tokens = config.pop('max_num_tokens', + defaults.get('max_num_tokens')) + opt_num_tokens = config.pop('opt_num_tokens', + defaults.get('opt_num_tokens')) + opt_batch_size = config.pop('opt_batch_size', + defaults.get('opt_batch_size')) max_prompt_embedding_table_size = config.pop( - 'max_prompt_embedding_table_size', 0) - - kv_cache_type = KVCacheType( - config.pop('kv_cache_type')) if 'plugin_config' in config else None - gather_context_logits = config.pop('gather_context_logits', False) - gather_generation_logits = config.pop('gather_generation_logits', False) - strongly_typed = config.pop('strongly_typed', True) - force_num_profiles = config.pop('force_num_profiles', None) - weight_sparsity = config.pop('weight_sparsity', False) + 'max_prompt_embedding_table_size', + defaults.get('max_prompt_embedding_table_size')) + + if "kv_cache_type" in config and config["kv_cache_type"] is not None: + kv_cache_type = KVCacheType(config.pop('kv_cache_type')) + else: + kv_cache_type = None + gather_context_logits = config.pop( + 'gather_context_logits', defaults.get('gather_context_logits')) + gather_generation_logits = config.pop( + 'gather_generation_logits', + defaults.get('gather_generation_logits')) + strongly_typed = config.pop('strongly_typed', + defaults.get('strongly_typed')) + force_num_profiles = config.pop('force_num_profiles', + defaults.get('force_num_profiles')) + weight_sparsity = config.pop('weight_sparsity', + defaults.get('weight_sparsity')) profiling_verbosity = config.pop('profiling_verbosity', - 'layer_names_only') - enable_debug_output = config.pop('enable_debug_output', False) - max_draft_len = config.pop('max_draft_len', 0) - speculative_decoding_mode = config.pop('speculative_decoding_mode', - SpeculativeDecodingMode.NONE) - use_refit = config.pop('use_refit', False) - input_timing_cache = config.pop('input_timing_cache', None) - output_timing_cache = config.pop('output_timing_cache', None) + defaults.get('profiling_verbosity')) + enable_debug_output = config.pop('enable_debug_output', + defaults.get('enable_debug_output')) + max_draft_len = config.pop('max_draft_len', + defaults.get('max_draft_len')) + speculative_decoding_mode = config.pop( + 'speculative_decoding_mode', + defaults.get('speculative_decoding_mode')) + use_refit = config.pop('use_refit', defaults.get('use_refit')) + input_timing_cache = config.pop('input_timing_cache', + defaults.get('input_timing_cache')) + output_timing_cache = config.pop('output_timing_cache', + defaults.get('output_timing_cache')) lora_config = LoraConfig.from_dict(config.get('lora_config', {})) auto_parallel_config = AutoParallelConfig.from_dict( config.get('auto_parallel_config', {})) - max_encoder_input_len = config.pop('max_encoder_input_len', 1024) - weight_streaming = config.pop('weight_streaming', False) - use_strip_plan = config.pop('use_strip_plan', False) + max_encoder_input_len = config.pop( + 'max_encoder_input_len', defaults.get('max_encoder_input_len')) + weight_streaming = config.pop('weight_streaming', + defaults.get('weight_streaming')) + use_strip_plan = config.pop('use_strip_plan', + defaults.get('use_strip_plan')) if plugin_config is None: plugin_config = PluginConfig() if "plugin_config" in config.keys(): plugin_config.update_from_dict(config["plugin_config"]) - dry_run = config.pop('dry_run', False) - visualize_network = config.pop('visualize_network', None) - monitor_memory = config.pop('monitor_memory', False) - use_mrope = config.pop('use_mrope', False) + dry_run = config.pop('dry_run', defaults.get('dry_run')) + visualize_network = config.pop('visualize_network', + defaults.get('visualize_network')) + monitor_memory = config.pop('monitor_memory', + defaults.get('monitor_memory')) + use_mrope = config.pop('use_mrope', defaults.get('use_mrope')) return cls( max_input_len=max_input_len, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index de4aa4365ae..5ae0b085ed1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2014,7 +2014,8 @@ def update_llm_args_with_extra_dict( } for field_name, field_type in field_mapping.items(): if field_name in llm_args_dict: - if field_name == "speculative_config": + # Some fields need to be converted manually. + if field_name in ["speculative_config", "build_config"]: llm_args_dict[field_name] = field_type.from_dict( llm_args_dict[field_name]) else: diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 83e3b73809c..cd68bbc5ca8 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -5,10 +5,15 @@ import yaml import tensorrt_llm.bindings.executor as tle +from tensorrt_llm import AutoParallelConfig from tensorrt_llm._torch.llm import LLM as TorchLLM +from tensorrt_llm.builder import LoraConfig +from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, + SchedulerConfig) from tensorrt_llm.llmapi.llm import LLM from tensorrt_llm.llmapi.llm_args import * from tensorrt_llm.llmapi.utils import print_traceback_on_error +from tensorrt_llm.plugin import PluginConfig from .test_llm import llama_model_path @@ -179,6 +184,61 @@ def test_PeftCacheConfig_declaration(): assert pybind_config.lora_prefetch_dir == "." +def test_update_llm_args_with_extra_dict_with_nested_dict(): + llm_api_args_dict = { + "model": + "dummy-model", + "build_config": + None, # Will override later. + "extended_runtime_perf_knob_config": + ExtendedRuntimePerfKnobConfig(multi_block_mode=True), + "kv_cache_config": + KvCacheConfig(enable_block_reuse=False), + "peft_cache_config": + PeftCacheConfig(num_host_module_layer=0), + "scheduler_config": + SchedulerConfig(capacity_scheduler_policy=CapacitySchedulerPolicy. + GUARANTEED_NO_EVICT) + } + plugin_config_dict = { + "_dtype": 'float16', + "nccl_plugin": None, + } + plugin_config = PluginConfig.from_dict(plugin_config_dict) + build_config = BuildConfig(max_input_len=1024, + lora_config=LoraConfig(lora_ckpt_source='hf'), + auto_parallel_config=AutoParallelConfig( + world_size=1, + same_buffer_io={}, + debug_outputs=[]), + plugin_config=plugin_config) + extra_llm_args_dict = { + "build_config": build_config.to_dict(), + } + + llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict, + extra_llm_args_dict, + "build_config") + initialized_llm_args = TrtLlmArgs(**llm_api_args_dict) + + def check_nested_dict_equality(dict1, dict2, path=""): + if not isinstance(dict1, dict) or not isinstance(dict2, dict): + if dict1 != dict2: + raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}") + return True + if dict1.keys() != dict2.keys(): + raise ValueError(f"Different keys at {path}:") + for key in dict1: + new_path = f"{path}.{key}" if path else key + if not check_nested_dict_equality(dict1[key], dict2[key], new_path): + raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}") + return True + + build_config_dict1 = build_config.to_dict() + build_config_dict2 = initialized_llm_args.build_config.to_dict() + check_nested_dict_equality(build_config_dict1, build_config_dict2) + + class TestTorchLlmArgsCudaGraphSettings: def test_cuda_graph_batch_sizes_case_0(self):