Skip to content

Commit a673b7c

Browse files
committed
init
Signed-off-by: Superjomn <[email protected]>
1 parent 34212e2 commit a673b7c

File tree

2 files changed

+92
-25
lines changed

2 files changed

+92
-25
lines changed

tensorrt_llm/llmapi/llm_args.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,15 +1937,9 @@ def update_llm_args_with_extra_dict(
19371937
"quant_config": QuantConfig,
19381938
"calib_config": CalibConfig,
19391939
"build_config": BuildConfig,
1940-
"kv_cache_config": KvCacheConfig,
19411940
"decoding_config": DecodingConfig,
19421941
"enable_build_cache": BuildCacheConfig,
1943-
"peft_cache_config": PeftCacheConfig,
1944-
"scheduler_config": SchedulerConfig,
19451942
"speculative_config": DecodingBaseConfig,
1946-
"batching_type": BatchingType,
1947-
"extended_runtime_perf_knob_config": ExtendedRuntimePerfKnobConfig,
1948-
"cache_transceiver_config": CacheTransceiverConfig,
19491943
"lora_config": LoraConfig,
19501944
}
19511945
for field_name, field_type in field_mapping.items():

tests/unittest/llmapi/test_llm_args.py

Lines changed: 92 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,100 @@ def test_LookaheadDecodingConfig():
4040
assert pybind_config.max_verification_set_size == 4
4141

4242

43-
def test_update_llm_args_with_extra_dict_with_speculative_config():
44-
yaml_content = """
43+
class TestYaml:
44+
45+
def _yaml_to_dict(self, yaml_content: str) -> dict:
46+
with tempfile.NamedTemporaryFile(delete=False) as f:
47+
f.write(yaml_content.encode('utf-8'))
48+
f.flush()
49+
f.seek(0)
50+
dict_content = yaml.safe_load(f)
51+
return dict_content
52+
53+
def test_update_llm_args_with_extra_dict_with_speculative_config(self):
54+
yaml_content = """
4555
speculative_config:
46-
decoding_type: Lookahead
47-
max_window_size: 4
48-
max_ngram_size: 3
49-
verification_set_size: 4
56+
decoding_type: Lookahead
57+
max_window_size: 4
58+
max_ngram_size: 3
59+
verification_set_size: 4
60+
"""
61+
dict_content = self._yaml_to_dict(yaml_content)
62+
63+
llm_args = TrtLlmArgs(model=llama_model_path)
64+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
65+
dict_content)
66+
llm_args = TrtLlmArgs(**llm_args_dict)
67+
assert llm_args.speculative_config.max_window_size == 4
68+
assert llm_args.speculative_config.max_ngram_size == 3
69+
assert llm_args.speculative_config.max_verification_set_size == 4
70+
71+
def test_llm_args_with_invalid_yaml(self):
72+
yaml_content = """
73+
pytorch_backend_config: # this is deprecated
74+
max_num_tokens: 1
75+
max_seq_len: 1
5076
"""
51-
with tempfile.NamedTemporaryFile(delete=False) as f:
52-
f.write(yaml_content.encode('utf-8'))
53-
f.flush()
54-
f.seek(0)
55-
dict_content = yaml.safe_load(f)
56-
57-
llm_args = TrtLlmArgs(model=llama_model_path)
58-
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
59-
dict_content)
60-
llm_args = TrtLlmArgs(**llm_args_dict)
61-
assert llm_args.speculative_config.max_window_size == 4
62-
assert llm_args.speculative_config.max_ngram_size == 3
63-
assert llm_args.speculative_config.max_verification_set_size == 4
77+
dict_content = self._yaml_to_dict(yaml_content)
78+
79+
llm_args = TrtLlmArgs(model=llama_model_path)
80+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
81+
dict_content)
82+
with pytest.raises(ValueError):
83+
llm_args = TrtLlmArgs(**llm_args_dict)
84+
85+
def test_llm_args_with_build_config(self):
86+
# build_config isn't a Pydantic
87+
yaml_content = """
88+
build_config:
89+
max_beam_width: 4
90+
max_batch_size: 8
91+
max_num_tokens: 256
92+
"""
93+
dict_content = self._yaml_to_dict(yaml_content)
94+
95+
llm_args = TrtLlmArgs(model=llama_model_path)
96+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
97+
dict_content)
98+
llm_args = TrtLlmArgs(**llm_args_dict)
99+
assert llm_args.build_config.max_beam_width == 4
100+
assert llm_args.build_config.max_batch_size == 8
101+
assert llm_args.build_config.max_num_tokens == 256
102+
103+
def test_llm_args_with_kvcache_config(self):
104+
yaml_content = """
105+
kv_cache_config:
106+
enable_block_reuse: True
107+
max_tokens: 1024
108+
max_attention_window: [1024, 1024, 1024]
109+
"""
110+
dict_content = self._yaml_to_dict(yaml_content)
111+
112+
llm_args = TrtLlmArgs(model=llama_model_path)
113+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
114+
dict_content)
115+
llm_args = TrtLlmArgs(**llm_args_dict)
116+
assert llm_args.kv_cache_config.enable_block_reuse == True
117+
assert llm_args.kv_cache_config.max_tokens == 1024
118+
assert llm_args.kv_cache_config.max_attention_window == [
119+
1024, 1024, 1024
120+
]
121+
122+
def test_llm_args_with_pydantic_options(self):
123+
yaml_content = """
124+
max_batch_size: 16
125+
max_num_tokens: 256
126+
max_seq_len: 128
127+
"""
128+
dict_content = self._yaml_to_dict(yaml_content)
129+
130+
llm_args = TrtLlmArgs(model=llama_model_path)
131+
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
132+
dict_content)
133+
llm_args = TrtLlmArgs(**llm_args_dict)
134+
assert llm_args.max_batch_size == 16
135+
assert llm_args.max_num_tokens == 256
136+
assert llm_args.max_seq_len == 128
64137

65138

66139
def check_defaults(py_config_cls, pybind_config_cls):

0 commit comments

Comments
 (0)