@@ -40,27 +40,100 @@ def test_LookaheadDecodingConfig():
4040 assert pybind_config .max_verification_set_size == 4
4141
4242
43- def test_update_llm_args_with_extra_dict_with_speculative_config ():
44- yaml_content = """
43+ class TestYaml :
44+
45+ def _yaml_to_dict (self , yaml_content : str ) -> dict :
46+ with tempfile .NamedTemporaryFile (delete = False ) as f :
47+ f .write (yaml_content .encode ('utf-8' ))
48+ f .flush ()
49+ f .seek (0 )
50+ dict_content = yaml .safe_load (f )
51+ return dict_content
52+
53+ def test_update_llm_args_with_extra_dict_with_speculative_config (self ):
54+ yaml_content = """
4555speculative_config:
46- decoding_type: Lookahead
47- max_window_size: 4
48- max_ngram_size: 3
49- verification_set_size: 4
56+ decoding_type: Lookahead
57+ max_window_size: 4
58+ max_ngram_size: 3
59+ verification_set_size: 4
60+ """
61+ dict_content = self ._yaml_to_dict (yaml_content )
62+
63+ llm_args = TrtLlmArgs (model = llama_model_path )
64+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .to_dict (),
65+ dict_content )
66+ llm_args = TrtLlmArgs (** llm_args_dict )
67+ assert llm_args .speculative_config .max_window_size == 4
68+ assert llm_args .speculative_config .max_ngram_size == 3
69+ assert llm_args .speculative_config .max_verification_set_size == 4
70+
71+ def test_llm_args_with_invalid_yaml (self ):
72+ yaml_content = """
73+ pytorch_backend_config: # this is deprecated
74+ max_num_tokens: 1
75+ max_seq_len: 1
5076"""
51- with tempfile .NamedTemporaryFile (delete = False ) as f :
52- f .write (yaml_content .encode ('utf-8' ))
53- f .flush ()
54- f .seek (0 )
55- dict_content = yaml .safe_load (f )
56-
57- llm_args = TrtLlmArgs (model = llama_model_path )
58- llm_args_dict = update_llm_args_with_extra_dict (llm_args .to_dict (),
59- dict_content )
60- llm_args = TrtLlmArgs (** llm_args_dict )
61- assert llm_args .speculative_config .max_window_size == 4
62- assert llm_args .speculative_config .max_ngram_size == 3
63- assert llm_args .speculative_config .max_verification_set_size == 4
77+ dict_content = self ._yaml_to_dict (yaml_content )
78+
79+ llm_args = TrtLlmArgs (model = llama_model_path )
80+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .to_dict (),
81+ dict_content )
82+ with pytest .raises (ValueError ):
83+ llm_args = TrtLlmArgs (** llm_args_dict )
84+
85+ def test_llm_args_with_build_config (self ):
86+ # build_config isn't a Pydantic
87+ yaml_content = """
88+ build_config:
89+ max_beam_width: 4
90+ max_batch_size: 8
91+ max_num_tokens: 256
92+ """
93+ dict_content = self ._yaml_to_dict (yaml_content )
94+
95+ llm_args = TrtLlmArgs (model = llama_model_path )
96+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .to_dict (),
97+ dict_content )
98+ llm_args = TrtLlmArgs (** llm_args_dict )
99+ assert llm_args .build_config .max_beam_width == 4
100+ assert llm_args .build_config .max_batch_size == 8
101+ assert llm_args .build_config .max_num_tokens == 256
102+
103+ def test_llm_args_with_kvcache_config (self ):
104+ yaml_content = """
105+ kv_cache_config:
106+ enable_block_reuse: True
107+ max_tokens: 1024
108+ max_attention_window: [1024, 1024, 1024]
109+ """
110+ dict_content = self ._yaml_to_dict (yaml_content )
111+
112+ llm_args = TrtLlmArgs (model = llama_model_path )
113+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .to_dict (),
114+ dict_content )
115+ llm_args = TrtLlmArgs (** llm_args_dict )
116+ assert llm_args .kv_cache_config .enable_block_reuse == True
117+ assert llm_args .kv_cache_config .max_tokens == 1024
118+ assert llm_args .kv_cache_config .max_attention_window == [
119+ 1024 , 1024 , 1024
120+ ]
121+
122+ def test_llm_args_with_pydantic_options (self ):
123+ yaml_content = """
124+ max_batch_size: 16
125+ max_num_tokens: 256
126+ max_seq_len: 128
127+ """
128+ dict_content = self ._yaml_to_dict (yaml_content )
129+
130+ llm_args = TrtLlmArgs (model = llama_model_path )
131+ llm_args_dict = update_llm_args_with_extra_dict (llm_args .to_dict (),
132+ dict_content )
133+ llm_args = TrtLlmArgs (** llm_args_dict )
134+ assert llm_args .max_batch_size == 16
135+ assert llm_args .max_num_tokens == 256
136+ assert llm_args .max_seq_len == 128
64137
65138
66139def check_defaults (py_config_cls , pybind_config_cls ):
0 commit comments