Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import jsonschema
import pytest
import regex as re
import torch
from pydantic import BaseModel

from tests.reasoning.utils import run_reasoning_extraction
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
Expand Down Expand Up @@ -721,3 +723,83 @@ def generate_with_backend(backend):
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated


@pytest.mark.parametrize("guided_decoding_backend",
["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_guided_requests(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")

# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool(not current_platform.is_tpu())

llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
)

guided_prompt = (
"Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")

non_guided_prompt = "The diameter of the Earth in kilometers is "

prompts = [guided_prompt, non_guided_prompt]
sampling_params = [
SamplingParams(
temperature=1.0,
max_tokens=400,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents
SamplingParams(
seed=42,
temperature=0,
top_p=1.0,
),
]

outputs = llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

# Free memory as soon as possible as failed assertions
# will short circuit and not free up memory
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

for index, output in enumerate(outputs):
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")

if index == 0:
# First prompt is guided, expect valid JSON
assert "\n" not in generated_text
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
else:
# Second prompt is not guided, expect valid output
# Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text

# non-guided requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)