Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -1258,3 +1259,6 @@ def apply_mistral_chat_template(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError from e

def random_tool_call_id() -> str:
return f"chatcmpl-tool-{random_uuid()}"
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from typing_extensions import TypeAlias

from vllm import envs
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
random_tool_call_id)
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
Expand Down Expand Up @@ -1314,7 +1315,7 @@ class FunctionCall(OpenAIBaseModel):


class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
id: str = Field(default_factory=random_tool_call_id)
type: Literal["function"] = "function"
function: FunctionCall

Expand All @@ -1326,8 +1327,8 @@ class DeltaFunctionCall(BaseModel):

# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
id: Optional[str] = None
type: Optional[Literal["function"]] = None
index: int
function: Optional[DeltaFunctionCall] = None

Expand Down
39 changes: 25 additions & 14 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage)
ConversationMessage,
random_tool_call_id)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
Expand Down Expand Up @@ -363,9 +364,10 @@ def extract_tool_call_required_streaming(

function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
DeltaToolCall(id=random_tool_call_id(),
function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
Expand All @@ -382,8 +384,7 @@ def extract_tool_call_required_streaming(
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
index=len(obj) - 1)
])
else:
delta_message = None
Expand Down Expand Up @@ -422,7 +423,7 @@ async def chat_completion_stream_generator(
and self._should_stream_with_auto_tool_parsing(request))

all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
function_name_returned = [False] * num_choices

# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
Expand All @@ -435,7 +436,6 @@ async def chat_completion_stream_generator(
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
Expand Down Expand Up @@ -623,16 +623,27 @@ async def chat_completion_stream_generator(
delta_text = previous_text + delta_text
current_text = ""

if function_name_returned[i]:
delta_tool_call = DeltaToolCall(
function=DeltaFunctionCall(
arguments=delta_text),
index=i)
else:
delta_tool_call = DeltaToolCall(
id=random_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
function_name_returned[i] = True

delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
delta_tool_call,
])

elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
Expand Down Expand Up @@ -835,7 +846,7 @@ async def chat_completion_stream_generator(
total_tokens=num_prompt_tokens + completion_tokens,
)

data = chunk.model_dump_json(exclude_unset=True)
data = chunk.model_dump_json(exclude_none=True)
yield f"data: {data}\n\n"

# once the final token is handled, if stream_options.include_usage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -22,7 +23,6 @@
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -200,7 +200,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -20,7 +21,6 @@
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -182,7 +182,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -17,7 +18,6 @@
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -259,7 +259,7 @@ def extract_tool_calls_streaming(
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -18,7 +19,6 @@
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -106,7 +106,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -19,7 +20,6 @@
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -220,7 +220,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -21,7 +22,6 @@
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -208,7 +208,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -73,7 +73,7 @@ def extract_tool_calls(

tool_calls: list[ToolCall] = [
ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
Expand All @@ -288,5 +289,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,

arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id="", index=index, function=DeltaFunctionCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None