Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
099ddb3
WIP added tool classes
br3no May 23, 2024
5934e22
added correct models. Tests still missing
br3no May 23, 2024
1248bc1
fix implementation and tests
br3no May 24, 2024
1b2b453
fix formatting
br3no May 24, 2024
07af0cc
fix test
br3no May 24, 2024
49b560c
Merge branch '5008-chat-logprobs' into 1869-tools-support-step-1
br3no May 24, 2024
755625f
named tool working
br3no May 24, 2024
193e6ec
fix formatting complaint
br3no May 24, 2024
46d5f27
correct output format and support streaming
br3no May 24, 2024
b59e1b3
fix ruff complaint
br3no May 24, 2024
f0dc5b8
fix mypy complaint
br3no May 24, 2024
80e66cf
reverting removal of
br3no May 28, 2024
3ca5fce
refactoring – move 'create_logprobs' for completion out of serving_en…
br3no May 28, 2024
06519c7
fix formatting
br3no May 28, 2024
3c5457a
adding changes after review from @DarkLight1337
br3no May 28, 2024
e388194
Merge branch 'main' into 5008-chat-logprobs
br3no May 28, 2024
c37d5a9
review iteration 2
br3no May 28, 2024
adcdc31
formatting – isort breaks it again..?
br3no May 28, 2024
91b4cfa
disable yapf in import to avoid conflict with isort
br3no May 29, 2024
825e0ad
Merge branch 'main' into 5008-chat-logprobs
br3no May 29, 2024
496fb25
fix formatting
br3no May 29, 2024
111548a
formatting
br3no May 29, 2024
2d59282
Merge branch 'main' into 1869-tools-support-step-1
br3no May 30, 2024
e7c7450
remove tool_choice 'required'
br3no May 30, 2024
b77e60a
add sad path test
br3no May 30, 2024
9f33687
add more sad path tests
br3no May 30, 2024
5f0c3ae
fix test
br3no May 31, 2024
15da872
fix test
br3no May 31, 2024
bdf0dcf
after review
br3no Jun 3, 2024
37130f7
adding docs for named function calling in tool use
br3no Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:prog: -m vllm.entrypoints.openai.api_server
```
```

## Tool calling in the chat completion API
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.

To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.

It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**

vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.

Please refer to the OpenAI API reference documentation for more information.
185 changes: 185 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,191 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token in top_logprobs)


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

# non-streaming

chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
})
message = chat_completion.choices[0].message
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)

messages.append({"role": "assistant", "content": json_string})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})

# streaming

stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
},
stream=True)

output = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
assert delta.content is None or len(delta.content) == 0
if delta.tool_calls:
output.append(delta.tool_calls[0].function.arguments)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="required")

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="auto")


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tool_choice={
"type": "function",
"function": {
"name":
"dummy_function_name"
}
})

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "nondefined_function_name"
}
})


@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down
57 changes: 55 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]


class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None


class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"


class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
Expand All @@ -121,6 +141,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None

# doc: begin-chat-completion-sampling-params
Expand Down Expand Up @@ -244,10 +267,27 @@ def check_guided_decoding_count(cls, data):
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data

@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
return data

@model_validator(mode="before")
Expand Down Expand Up @@ -505,9 +545,21 @@ class EmbeddingResponse(BaseModel):
usage: UsageInfo


class FunctionCall(OpenAIBaseModel):
name: str
arguments: str


class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall


class ChatMessage(OpenAIBaseModel):
role: str
content: str
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionLogProb(OpenAIBaseModel):
Expand All @@ -534,7 +586,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):

class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
Expand All @@ -544,6 +596,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
Expand All @@ -556,7 +609,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):

class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
Expand Down
Loading