Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cd3b3bb
first pass for JSON and regex
felixzhu555 Feb 8, 2024
e589bd0
tiny refactor
felixzhu555 Feb 8, 2024
5f55e6a
Added support for guided decoding in `api_server` by integrating _out…
br3no Feb 8, 2024
3a051cf
refactor/combine breno's PR with mine
felixzhu555 Feb 9, 2024
54217ba
fix type check
felixzhu555 Feb 9, 2024
c9c6f4f
fix try-except
felixzhu555 Feb 9, 2024
b82dedb
fix import bug
felixzhu555 Feb 9, 2024
ba92cb2
add outlines v0.0.27 requirement
felixzhu555 Feb 9, 2024
da2f5b8
fix dummy_llm
felixzhu555 Feb 10, 2024
b090c18
start adding tests
felixzhu555 Feb 11, 2024
9093c5e
add more tests
felixzhu555 Feb 11, 2024
1efd64d
fix pytest fixtures scope
felixzhu555 Feb 13, 2024
736ca31
remove guided decoding from vllm api server
felixzhu555 Feb 13, 2024
4d1b049
refactor + add guided_choice
felixzhu555 Feb 14, 2024
a46684e
add caching for logit processors
felixzhu555 Feb 14, 2024
058fce6
use re.escape
felixzhu555 Feb 14, 2024
dc601c7
revert logits processor 2 vs 3 arg fix
felixzhu555 Feb 14, 2024
09d2a9c
copy on cache hit
felixzhu555 Feb 14, 2024
d774cf6
add separate thread for creating logits processor
felixzhu555 Feb 15, 2024
782b1da
add simple cache test
felixzhu555 Feb 15, 2024
df3c774
use asyncio
felixzhu555 Feb 16, 2024
c74f6bb
add grammar support
felixzhu555 Feb 18, 2024
cf8494d
remove grammar
felixzhu555 Feb 21, 2024
c30fed0
copy outlines' logits processors code
felixzhu555 Feb 21, 2024
33dc082
breno PR comments
felixzhu555 Feb 22, 2024
8699039
resolve PR comments
felixzhu555 Feb 22, 2024
6bf277e
add unit test for logits processors
felixzhu555 Feb 24, 2024
b45942e
PR comments
felixzhu555 Feb 27, 2024
3169b63
Merge branch 'main' of github.com:vllm-project/vllm into add_structur…
felixzhu555 Feb 27, 2024
c176dab
fix
felixzhu555 Feb 27, 2024
bfbbce3
format with yapf and ruff
felixzhu555 Feb 28, 2024
9655b6e
fix global pool
simon-mo Feb 28, 2024
0c3d475
Apply suggestions from code review
simon-mo Feb 28, 2024
0793fc7
Merge branch 'main' of github.com:vllm-project/vllm into add_structur…
simon-mo Feb 29, 2024
ce9c07a
some minor fix
simon-mo Feb 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
75 changes: 75 additions & 0 deletions tests/entrypoints/test_guided_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.

from transformers import AutoTokenizer
import torch

from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor,
JSONLogitsProcessor)

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"


def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)

regex_LP.init_state()
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

json_LP.init_state()
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
237 changes: 237 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,64 @@
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests

# imports for guided decoding tests
import json
import jsonschema
import re

from vllm.transformers_utils.tokenizer import get_tokenizer

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"

TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]

pytestmark = pytest.mark.asyncio


Expand Down Expand Up @@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
seed=42,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
Expand Down Expand Up @@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text


async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
for i in range(3):
assert completion.choices[i].text is not None
output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)


async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "Give an example JSON for an employee profile that " + \
f"fits this schema: {TEST_SCHEMA}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)

messages.append({"role": "assistant", "content": message.content})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
for i in range(3):
assert completion.choices[i].text is not None
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None


async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example IP address with this regex: {TEST_REGEX}"
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None

messages.append({"role": "assistant", "content": ip1})
messages.append({"role": "user", "content": "Give me a different one"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2


async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))

assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in TEST_CHOICE


async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE

messages.append({"role": "assistant", "content": choice1})
messages.append({
"role": "user",
"content": "I disagree, pick another one"
})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2


async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42))

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
with pytest.raises(openai.BadRequestError):
_ = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
extra_body=dict(guided_regex={
1: "Python",
2: "C++"
}))

with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))


if __name__ == "__main__":
pytest.main([__file__])
3 changes: 3 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer

def start_background_loop(self) -> None:
"""Start the background loop."""
if self.is_running:
Expand Down
Loading