Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions tensorrt_llm/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import dataclasses
import json
import math
import os
import shutil
import time
from dataclasses import dataclass, field
from functools import cache
from pathlib import Path
from typing import Dict, Optional, Union

Expand Down Expand Up @@ -557,53 +559,89 @@ def override_attri(attr_name, value):
else:
override_attri('paged_state', False)

@classmethod
@cache
def get_build_config_defaults(cls):
return {
field.name: field.default
for field in dataclasses.fields(cls)
if field.default is not dataclasses.MISSING
}

@classmethod
def from_dict(cls, config, plugin_config=None):
config = copy.deepcopy(
config
) # it just does not make sense to change the input arg `config`
max_input_len = config.pop('max_input_len')
max_seq_len = config.pop('max_seq_len')
max_batch_size = config.pop('max_batch_size')
max_beam_width = config.pop('max_beam_width')
max_num_tokens = config.pop('max_num_tokens')
opt_num_tokens = config.pop('opt_num_tokens')
opt_batch_size = config.pop('opt_batch_size', 8)

defaults = cls.get_build_config_defaults()
max_input_len = config.pop('max_input_len',
defaults.get('max_input_len'))
max_seq_len = config.pop('max_seq_len', defaults.get('max_seq_len'))
max_batch_size = config.pop('max_batch_size',
defaults.get('max_batch_size'))
max_beam_width = config.pop('max_beam_width',
defaults.get('max_beam_width'))
max_num_tokens = config.pop('max_num_tokens',
defaults.get('max_num_tokens'))
opt_num_tokens = config.pop('opt_num_tokens',
defaults.get('opt_num_tokens'))
opt_batch_size = config.pop('opt_batch_size',
defaults.get('opt_batch_size'))
max_prompt_embedding_table_size = config.pop(
'max_prompt_embedding_table_size', 0)

kv_cache_type = KVCacheType(
config.pop('kv_cache_type')) if 'plugin_config' in config else None
gather_context_logits = config.pop('gather_context_logits', False)
gather_generation_logits = config.pop('gather_generation_logits', False)
strongly_typed = config.pop('strongly_typed', True)
force_num_profiles = config.pop('force_num_profiles', None)
weight_sparsity = config.pop('weight_sparsity', False)
'max_prompt_embedding_table_size',
defaults.get('max_prompt_embedding_table_size'))

if "kv_cache_type" in config and config["kv_cache_type"] is not None:
kv_cache_type = KVCacheType(config.pop('kv_cache_type'))
else:
kv_cache_type = None
gather_context_logits = config.pop(
'gather_context_logits', defaults.get('gather_context_logits'))
gather_generation_logits = config.pop(
'gather_generation_logits',
defaults.get('gather_generation_logits'))
strongly_typed = config.pop('strongly_typed',
defaults.get('strongly_typed'))
force_num_profiles = config.pop('force_num_profiles',
defaults.get('force_num_profiles'))
weight_sparsity = config.pop('weight_sparsity',
defaults.get('weight_sparsity'))
profiling_verbosity = config.pop('profiling_verbosity',
'layer_names_only')
enable_debug_output = config.pop('enable_debug_output', False)
max_draft_len = config.pop('max_draft_len', 0)
speculative_decoding_mode = config.pop('speculative_decoding_mode',
SpeculativeDecodingMode.NONE)
use_refit = config.pop('use_refit', False)
input_timing_cache = config.pop('input_timing_cache', None)
output_timing_cache = config.pop('output_timing_cache', None)
defaults.get('profiling_verbosity'))
enable_debug_output = config.pop('enable_debug_output',
defaults.get('enable_debug_output'))
max_draft_len = config.pop('max_draft_len',
defaults.get('max_draft_len'))
speculative_decoding_mode = config.pop(
'speculative_decoding_mode',
defaults.get('speculative_decoding_mode'))
use_refit = config.pop('use_refit', defaults.get('use_refit'))
input_timing_cache = config.pop('input_timing_cache',
defaults.get('input_timing_cache'))
output_timing_cache = config.pop('output_timing_cache',
defaults.get('output_timing_cache'))
lora_config = LoraConfig.from_dict(config.get('lora_config', {}))
auto_parallel_config = AutoParallelConfig.from_dict(
config.get('auto_parallel_config', {}))
max_encoder_input_len = config.pop('max_encoder_input_len', 1024)
weight_streaming = config.pop('weight_streaming', False)
use_strip_plan = config.pop('use_strip_plan', False)
max_encoder_input_len = config.pop(
'max_encoder_input_len', defaults.get('max_encoder_input_len'))
weight_streaming = config.pop('weight_streaming',
defaults.get('weight_streaming'))
use_strip_plan = config.pop('use_strip_plan',
defaults.get('use_strip_plan'))

if plugin_config is None:
plugin_config = PluginConfig()
if "plugin_config" in config.keys():
plugin_config.update_from_dict(config["plugin_config"])

dry_run = config.pop('dry_run', False)
visualize_network = config.pop('visualize_network', None)
monitor_memory = config.pop('monitor_memory', False)
use_mrope = config.pop('use_mrope', False)
dry_run = config.pop('dry_run', defaults.get('dry_run'))
visualize_network = config.pop('visualize_network',
defaults.get('visualize_network'))
monitor_memory = config.pop('monitor_memory',
defaults.get('monitor_memory'))
use_mrope = config.pop('use_mrope', defaults.get('use_mrope'))

return cls(
max_input_len=max_input_len,
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,7 +2014,8 @@ def update_llm_args_with_extra_dict(
}
for field_name, field_type in field_mapping.items():
if field_name in llm_args_dict:
if field_name == "speculative_config":
# Some fields need to be converted manually.
if field_name in ["speculative_config", "build_config"]:
llm_args_dict[field_name] = field_type.from_dict(
llm_args_dict[field_name])
else:
Expand Down
60 changes: 60 additions & 0 deletions tests/unittest/llmapi/test_llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
import yaml

import tensorrt_llm.bindings.executor as tle
from tensorrt_llm import AutoParallelConfig
from tensorrt_llm._torch.llm import LLM as TorchLLM
from tensorrt_llm.builder import LoraConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
SchedulerConfig)
from tensorrt_llm.llmapi.llm import LLM
from tensorrt_llm.llmapi.llm_args import *
from tensorrt_llm.llmapi.utils import print_traceback_on_error
from tensorrt_llm.plugin import PluginConfig

from .test_llm import llama_model_path

Expand Down Expand Up @@ -179,6 +184,61 @@ def test_PeftCacheConfig_declaration():
assert pybind_config.lora_prefetch_dir == "."


def test_update_llm_args_with_extra_dict_with_nested_dict():
llm_api_args_dict = {
"model":
"dummy-model",
"build_config":
None, # Will override later.
"extended_runtime_perf_knob_config":
ExtendedRuntimePerfKnobConfig(multi_block_mode=True),
"kv_cache_config":
KvCacheConfig(enable_block_reuse=False),
"peft_cache_config":
PeftCacheConfig(num_host_module_layer=0),
"scheduler_config":
SchedulerConfig(capacity_scheduler_policy=CapacitySchedulerPolicy.
GUARANTEED_NO_EVICT)
}
plugin_config_dict = {
"_dtype": 'float16',
"nccl_plugin": None,
}
plugin_config = PluginConfig.from_dict(plugin_config_dict)
build_config = BuildConfig(max_input_len=1024,
lora_config=LoraConfig(lora_ckpt_source='hf'),
auto_parallel_config=AutoParallelConfig(
world_size=1,
same_buffer_io={},
debug_outputs=[]),
plugin_config=plugin_config)
extra_llm_args_dict = {
"build_config": build_config.to_dict(),
}

llm_api_args_dict = update_llm_args_with_extra_dict(llm_api_args_dict,
extra_llm_args_dict,
"build_config")
initialized_llm_args = TrtLlmArgs(**llm_api_args_dict)

def check_nested_dict_equality(dict1, dict2, path=""):
if not isinstance(dict1, dict) or not isinstance(dict2, dict):
if dict1 != dict2:
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
return True
if dict1.keys() != dict2.keys():
raise ValueError(f"Different keys at {path}:")
for key in dict1:
new_path = f"{path}.{key}" if path else key
if not check_nested_dict_equality(dict1[key], dict2[key], new_path):
raise ValueError(f"Mismatch at {path}: {dict1} != {dict2}")
return True

build_config_dict1 = build_config.to_dict()
build_config_dict2 = initialized_llm_args.build_config.to_dict()
check_nested_dict_equality(build_config_dict1, build_config_dict2)


class TestTorchLlmArgsCudaGraphSettings:

def test_cuda_graph_batch_sizes_case_0(self):
Expand Down