Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parallel function calls with tool_choice #1503

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 91 additions & 59 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,49 @@ def _convert_completion_to_chat_function(
],
stream: bool,
):
def _completion_text_to_tool_calls(
tool_name: str,
completion_text: str,
completion_id: str,
stream: bool,
) -> Union[
llama_types.ChatCompletionMessageToolCalls, List[llama_types.ChatCompletionMessageToolCallChunk]
]:
try:
function_calls = json.loads(completion_text)
assert isinstance(function_calls, list)
except Exception as e:
function_calls = []

i = 0
tool_calls = []
for function_call in function_calls:
function_name = function_call.get("name")
function_arguments = function_call.get("arguments")
if function_name == tool_name and function_arguments:
tool_id = f'call__{i}_{tool_name}_{completion_id}'
tool_call = {
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(function_arguments, ensure_ascii=False),
},
}
if stream:
tool_call["index"] = i
typed_call: llama_types.ChatCompletionMessageToolCallChunk = tool_call
else:
typed_call: llama_types.ChatCompletionMessageToolCall = tool_call
tool_calls.append(typed_call)
i += 1

return tool_calls

if not stream:
completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore
assert "usage" in completion
tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"]
tool_calls: llama_types.ChatCompletionMessageToolCalls = _completion_text_to_tool_calls(tool_name, completion["choices"][0]["text"], completion["id"], stream) # type: ignore
# TODO: Fix for legacy function calls
chat_completion: llama_types.CreateChatCompletionResponse = {
"id": "chat" + completion["id"],
Expand All @@ -366,24 +405,12 @@ def _convert_completion_to_chat_function(
"index": 0,
"message": {
"role": "assistant",
"content": None,
"function_call": {
"name": tool_name,
"arguments": completion["choices"][0]["text"],
},
"tool_calls": [
{
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": completion["choices"][0]["text"],
},
}
],
"content": None if tool_calls else completion["choices"][0]["text"],
"function_call": tool_calls[0]["function"] if tool_calls else None,
"tool_calls": tool_calls or None,
},
"logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls",
"finish_reason": "tool_calls" if tool_calls else completion["choices"][0]["finish_reason"],
}
],
"usage": completion["usage"],
Expand All @@ -400,13 +427,15 @@ def _stream_response_to_function_stream(
id_ = None
created = None
model = None
tool_id = None
finish = None
tools_called = ""
for chunk in chunks:
tools_called += chunk["choices"][0]["text"]
finish = chunk["choices"][0]["finish_reason"]
if first:
id_ = "chat" + chunk["id"]
created = chunk["created"]
model = chunk["model"]
tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"]
yield {
"id": id_,
"object": "chat.completion.chunk",
Expand Down Expand Up @@ -438,31 +467,15 @@ def _stream_response_to_function_stream(
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
"function_call": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
"tool_calls": [
{
"index": 0,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": chunk["choices"][0][
"text"
],
},
}
],
"content": chunk["choices"][0]["text"],
"function_call": None,
"tool_calls": None,
},
}
],
}
first = False
continue
assert tool_id is not None
yield {
"id": "chat" + chunk["id"],
"object": "chat.completion.chunk",
Expand All @@ -475,28 +488,16 @@ def _stream_response_to_function_stream(
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
"function_call": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
"tool_calls": [
{
"index": 0,
"id": tool_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": chunk["choices"][0]["text"],
},
}
],
"content": chunk["choices"][0]["text"],
"function_call": None,
"tool_calls": None,
},
}
],
}

if id_ is not None and created is not None and model is not None:
tool_calls: List[llama_types.ChatCompletionMessageToolCallChunk] = _completion_text_to_tool_calls(tool_name, tools_called, id_, stream) # type: ignore
yield {
"id": id_,
"object": "chat.completion.chunk",
Expand All @@ -505,13 +506,13 @@ def _stream_response_to_function_stream(
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"finish_reason": "tool_calls" if tool_calls else finish,
"logprobs": None,
"delta": {
"role": None,
"content": None,
"function_call": None,
"tool_calls": None,
"function_call": tool_calls[0]["function"] if tool_calls else None,
"tool_calls": tool_calls or None,
},
}
],
Expand Down Expand Up @@ -621,7 +622,22 @@ def chat_completion_handler(
tool = next((t for t in tools if t["function"]["name"] == name), None)
if tool is None:
raise ValueError(f"Tool choice '{name}' not found in tools.")
schema = tool["function"]["parameters"]
schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"arguments": tool["function"]["parameters"]
},
"required": [
"name",
"arguments"
]
}
}
try:
# create grammar from json schema
grammar = llama_grammar.LlamaGrammar.from_json_schema(
Expand Down Expand Up @@ -3486,9 +3502,25 @@ def chatml_function_calling(
add_generation_prompt=True,
)
prompt += f"functions.{tool_name}:\n"
schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"arguments": tool["function"]["parameters"]
},
"required": [
"name",
"arguments"
]
}
}
try:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
json.dumps(schema), verbose=llama.verbose
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
Expand Down