1313from vllm .outputs import RequestOutput
1414from vllm .sampling_params import GuidedDecodingParams , SamplingParams
1515
16- GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" ]
16+ GUIDED_DECODING_BACKENDS_V1 = ["xgrammar" , "guidance" ]
1717MODELS_TO_TEST = [
1818 "Qwen/Qwen2.5-1.5B-Instruct" , "mistralai/Ministral-8B-Instruct-2410"
1919]
@@ -30,12 +30,13 @@ def test_guided_json_completion(
3030 model_name : str ,
3131):
3232 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
33- llm = LLM (model = model_name , max_model_len = 1024 )
34- sampling_params = SamplingParams (temperature = 1.0 ,
35- max_tokens = 1000 ,
36- guided_decoding = GuidedDecodingParams (
37- json = sample_json_schema ,
38- backend = guided_decoding_backend ))
33+ llm = LLM (model = model_name ,
34+ max_model_len = 1024 ,
35+ guided_decoding_backend = guided_decoding_backend )
36+ sampling_params = SamplingParams (
37+ temperature = 1.0 ,
38+ max_tokens = 1000 ,
39+ guided_decoding = GuidedDecodingParams (json = sample_json_schema ))
3940 outputs = llm .generate (prompts = [
4041 f"Give an example JSON for an employee profile "
4142 f"that fits this schema: { sample_json_schema } "
@@ -111,13 +112,14 @@ def test_guided_json_object(
111112 model_name : str ,
112113):
113114 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
114- llm = LLM (model = model_name , max_model_len = 1024 )
115- sampling_params = SamplingParams (temperature = 1.0 ,
116- max_tokens = 100 ,
117- n = 2 ,
118- guided_decoding = GuidedDecodingParams (
119- json_object = True ,
120- backend = guided_decoding_backend ))
115+ llm = LLM (model = model_name ,
116+ max_model_len = 1024 ,
117+ guided_decoding_backend = guided_decoding_backend )
118+ sampling_params = SamplingParams (
119+ temperature = 1.0 ,
120+ max_tokens = 100 ,
121+ n = 2 ,
122+ guided_decoding = GuidedDecodingParams (json_object = True ))
121123
122124 outputs = llm .generate (
123125 prompts = ("Generate a JSON object with curly braces for a person with "
@@ -137,12 +139,20 @@ def test_guided_json_object(
137139
138140 # Parse to verify it is valid JSON
139141 parsed_json = json .loads (generated_text )
140- assert isinstance (parsed_json , dict )
142+ allowed_types : tuple [type , ...] = (dict , )
143+ if guided_decoding_backend == "xgrammar" :
144+ # TODO - we are currently too permissive with xgrammar and
145+ # allow # any valid json (typically comes back as a list or
146+ # object). We can fix this by specifying a jsonschema of
147+ # {"type": "object"}, # but we need this fix in a release
148+ # first: https://github.com/mlc-ai/xgrammar/pull/264
149+ allowed_types = (dict , list )
150+ assert isinstance (parsed_json , allowed_types )
141151
142152
143153@pytest .mark .skip_global_cleanup
144154@pytest .mark .parametrize ("guided_decoding_backend" ,
145- GUIDED_DECODING_BACKENDS_V1 )
155+ GUIDED_DECODING_BACKENDS_V1 + [ "auto" ] )
146156@pytest .mark .parametrize ("model_name" , MODELS_TO_TEST )
147157def test_guided_json_unsupported_schema (
148158 monkeypatch : pytest .MonkeyPatch ,
@@ -151,21 +161,43 @@ def test_guided_json_unsupported_schema(
151161 model_name : str ,
152162):
153163 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
154- llm = LLM (model = model_name , max_model_len = 1024 )
155- sampling_params = SamplingParams (temperature = 1.0 ,
156- max_tokens = 1000 ,
157- guided_decoding = GuidedDecodingParams (
158- json = unsupported_json_schema ,
159- backend = guided_decoding_backend ))
160- with pytest .raises (ValueError ,
161- match = "The provided JSON schema contains features "
162- "not supported by xgrammar." ):
163- llm .generate (prompts = [
164- f"Give an example JSON for an employee profile "
165- f"that fits this schema: { unsupported_json_schema } "
166- ] * 2 ,
167- sampling_params = sampling_params ,
168- use_tqdm = True )
164+ llm = LLM (model = model_name ,
165+ max_model_len = 1024 ,
166+ guided_decoding_backend = guided_decoding_backend )
167+ sampling_params = SamplingParams (
168+ temperature = 1.0 ,
169+ max_tokens = 1000 ,
170+ guided_decoding = GuidedDecodingParams (json = unsupported_json_schema ))
171+ if guided_decoding_backend == "xgrammar" :
172+ with pytest .raises (ValueError ,
173+ match = "The provided JSON schema contains features "
174+ "not supported by xgrammar." ):
175+ llm .generate (prompts = [
176+ f"Give an example JSON for an employee profile "
177+ f"that fits this schema: { unsupported_json_schema } "
178+ ] * 2 ,
179+ sampling_params = sampling_params ,
180+ use_tqdm = True )
181+ else :
182+ # This should work for both "guidance" and "auto".
183+
184+ outputs = llm .generate (
185+ prompts = ("Give an example JSON object for a grade "
186+ "that fits this schema: "
187+ f"{ unsupported_json_schema } " ),
188+ sampling_params = sampling_params ,
189+ use_tqdm = True )
190+ assert outputs is not None
191+ for output in outputs :
192+ assert output is not None
193+ assert isinstance (output , RequestOutput )
194+ generated_text = output .outputs [0 ].text
195+ assert generated_text is not None
196+ print (generated_text )
197+
198+ # Parse to verify it is valid JSON
199+ parsed_json = json .loads (generated_text )
200+ assert isinstance (parsed_json , dict )
169201
170202
171203@pytest .mark .skip_global_cleanup
@@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
179211 model_name : str ,
180212):
181213 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
182- llm = LLM (model = model_name , max_model_len = 1024 )
183- sampling_params = SamplingParams (temperature = 0.8 ,
184- top_p = 0.95 ,
185- max_tokens = 1000 ,
186- guided_decoding = GuidedDecodingParams (
187- grammar = sample_sql_ebnf ,
188- backend = guided_decoding_backend ))
214+ llm = LLM (model = model_name ,
215+ max_model_len = 1024 ,
216+ guided_decoding_backend = guided_decoding_backend )
217+ sampling_params = SamplingParams (
218+ temperature = 0.8 ,
219+ top_p = 0.95 ,
220+ max_tokens = 1000 ,
221+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_ebnf ))
189222 outputs = llm .generate (
190223 prompts = ("Generate a sql statement that selects col_1 from "
191224 "table_1 where it is equal to 1" ),
@@ -222,13 +255,14 @@ def test_guided_grammar_lark(
222255 model_name : str ,
223256):
224257 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
225- llm = LLM (model = model_name , max_model_len = 1024 )
226- sampling_params = SamplingParams (temperature = 0.8 ,
227- top_p = 0.95 ,
228- max_tokens = 1000 ,
229- guided_decoding = GuidedDecodingParams (
230- grammar = sample_sql_lark ,
231- backend = guided_decoding_backend ))
258+ llm = LLM (model = model_name ,
259+ max_model_len = 1024 ,
260+ guided_decoding_backend = guided_decoding_backend )
261+ sampling_params = SamplingParams (
262+ temperature = 0.8 ,
263+ top_p = 0.95 ,
264+ max_tokens = 1000 ,
265+ guided_decoding = GuidedDecodingParams (grammar = sample_sql_lark ))
232266 outputs = llm .generate (
233267 prompts = ("Generate a sql statement that selects col_1 from "
234268 "table_1 where it is equal to 1" ),
@@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
269303 model_name : str ,
270304):
271305 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
272- llm = LLM (model = model_name , max_model_len = 1024 )
273- sampling_params = SamplingParams (temperature = 0.8 ,
274- top_p = 0.95 ,
275- max_tokens = 1000 ,
276- guided_decoding = GuidedDecodingParams (
277- grammar = "not a grammar" ,
278- backend = guided_decoding_backend ))
279- with pytest .raises (ValueError ,
280- match = "Failed to convert the grammar "
281- "from Lark to EBNF." ):
306+ llm = LLM (model = model_name ,
307+ max_model_len = 1024 ,
308+ guided_decoding_backend = guided_decoding_backend )
309+ sampling_params = SamplingParams (
310+ temperature = 0.8 ,
311+ top_p = 0.95 ,
312+ max_tokens = 1000 ,
313+ guided_decoding = GuidedDecodingParams (grammar = "not a grammar" ))
314+ with pytest .raises (ValueError , match = "Failed to convert the grammar " ):
282315 llm .generate (
283316 prompts = ("Generate a sql statement that selects col_1 from "
284317 "table_1 where it is equal to 1" ),
@@ -298,12 +331,13 @@ def test_guided_regex(
298331 model_name : str ,
299332):
300333 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
301- llm = LLM (model = model_name , max_model_len = 1024 )
302- sampling_params = SamplingParams (temperature = 0.8 ,
303- top_p = 0.95 ,
304- guided_decoding = GuidedDecodingParams (
305- regex = sample_regex ,
306- backend = guided_decoding_backend ))
334+ llm = LLM (model = model_name ,
335+ max_model_len = 1024 ,
336+ guided_decoding_backend = guided_decoding_backend )
337+ sampling_params = SamplingParams (
338+ temperature = 0.8 ,
339+ top_p = 0.95 ,
340+ guided_decoding = GuidedDecodingParams (regex = sample_regex ))
307341 outputs = llm .generate (
308342 prompts = [
309343 f"Give an example IPv4 address with this regex: { sample_regex } "
@@ -335,12 +369,13 @@ def test_guided_choice_completion(
335369 model_name : str ,
336370):
337371 monkeypatch .setenv ("VLLM_USE_V1" , "1" )
338- llm = LLM (model = model_name , max_model_len = 1024 )
339- sampling_params = SamplingParams (temperature = 0.8 ,
340- top_p = 0.95 ,
341- guided_decoding = GuidedDecodingParams (
342- choice = sample_guided_choice ,
343- backend = guided_decoding_backend ))
372+ llm = LLM (model = model_name ,
373+ max_model_len = 1024 ,
374+ guided_decoding_backend = guided_decoding_backend )
375+ sampling_params = SamplingParams (
376+ temperature = 0.8 ,
377+ top_p = 0.95 ,
378+ guided_decoding = GuidedDecodingParams (choice = sample_guided_choice ))
344379 outputs = llm .generate (
345380 prompts = "The best language for type-safe systems programming is " ,
346381 sampling_params = sampling_params ,
0 commit comments