Skip to content

Commit 2fa00dd

Browse files
maxdebayseriGmainC
andcommitted
Initialize the delta tool call fields explicitly
Previously the implementation relied on a combination of the default field values being and `model_dump_json(exclude_unset=True)`. This led to bugs such as the "type" field missing from the response. The behavior of the OpenAI API is to set the "id" and "type" fields only in the first Delta object. The purpose of this commit is to make all the places where these fields need to be set explicit. Signed-off-by: Max de Bayser <[email protected]> Co-authored-by: igmainc <[email protected]>
1 parent cfe4532 commit 2fa00dd

File tree

11 files changed

+50
-33
lines changed

11 files changed

+50
-33
lines changed

vllm/entrypoints/chat_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from vllm.multimodal.utils import MediaConnector
4141
from vllm.transformers_utils.processor import cached_get_processor
4242
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
43+
from vllm.utils import random_uuid
4344

4445
logger = init_logger(__name__)
4546

@@ -1257,3 +1258,6 @@ def apply_mistral_chat_template(
12571258
"An error occurred in `mistral_common` while applying chat "
12581259
"template")
12591260
raise ValueError from e
1261+
1262+
def random_tool_call_id() -> str:
1263+
return f"chatcmpl-tool-{random_uuid()}"

vllm/entrypoints/openai/protocol.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
ValidationInfo, field_validator, model_validator)
1515
from typing_extensions import TypeAlias
1616

17-
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
17+
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
18+
random_tool_call_id)
1819
from vllm.logger import init_logger
1920
from vllm.pooling_params import PoolingParams
2021
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
@@ -1290,7 +1291,7 @@ class FunctionCall(OpenAIBaseModel):
12901291

12911292

12921293
class ToolCall(OpenAIBaseModel):
1293-
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
1294+
id: str = Field(default_factory=random_tool_call_id)
12941295
type: Literal["function"] = "function"
12951296
function: FunctionCall
12961297

@@ -1302,8 +1303,8 @@ class DeltaFunctionCall(BaseModel):
13021303

13031304
# a tool call delta where everything is optional
13041305
class DeltaToolCall(OpenAIBaseModel):
1305-
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
1306-
type: Literal["function"] = "function"
1306+
id: Optional[str] = None
1307+
type: Optional[Literal["function"]] = None
13071308
index: int
13081309
function: Optional[DeltaFunctionCall] = None
13091310

vllm/entrypoints/openai/serving_chat.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from vllm.config import ModelConfig
1717
from vllm.engine.protocol import EngineClient
1818
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
19-
ConversationMessage)
19+
ConversationMessage,
20+
random_tool_call_id)
2021
from vllm.entrypoints.logger import RequestLogger
2122
from vllm.entrypoints.openai.protocol import (
2223
ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -365,9 +366,10 @@ def extract_tool_call_required_streaming(
365366

366367
function_name_returned = True
367368
delta_message = DeltaMessage(tool_calls=[
368-
DeltaToolCall(function=DeltaFunctionCall(
369-
name=current_tool_call["name"],
370-
arguments=arguments),
369+
DeltaToolCall(id=random_tool_call_id(),
370+
function=DeltaFunctionCall(
371+
name=current_tool_call["name"],
372+
arguments=arguments),
371373
index=len(obj) - 1,
372374
type="function")
373375
])
@@ -384,8 +386,7 @@ def extract_tool_call_required_streaming(
384386
# instead of name every time
385387
name=None,
386388
arguments=delta_text),
387-
index=len(obj) - 1,
388-
type="function")
389+
index=len(obj) - 1)
389390
])
390391
else:
391392
delta_message = None
@@ -427,7 +428,7 @@ async def chat_completion_stream_generator(
427428
self._should_stream_with_reasoning_parsing(request))
428429

429430
all_previous_token_ids: Optional[list[list[int]]]
430-
function_name_returned: Optional[list[bool]] = None
431+
function_name_returned = [False] * num_choices
431432

432433
# Only one of these will be used, thus previous_texts and
433434
# all_previous_token_ids will not be used twice in the same iteration.
@@ -440,7 +441,6 @@ async def chat_completion_stream_generator(
440441
reasoning_end_arr = [False] * num_choices
441442
elif request.tool_choice == "required":
442443
previous_texts = [""] * num_choices
443-
function_name_returned = [False] * num_choices
444444
all_previous_token_ids = None
445445
else:
446446
previous_texts, all_previous_token_ids = None, None
@@ -634,16 +634,27 @@ async def chat_completion_stream_generator(
634634
delta_text = previous_text + delta_text
635635
current_text = ""
636636

637+
if function_name_returned[i]:
638+
delta_tool_call = DeltaToolCall(
639+
function=DeltaFunctionCall(
640+
arguments=delta_text),
641+
index=i)
642+
else:
643+
delta_tool_call = DeltaToolCall(
644+
id=random_tool_call_id(),
645+
type="function",
646+
function=DeltaFunctionCall(
647+
name=tool_choice_function_name,
648+
arguments=delta_text),
649+
index=i)
650+
function_name_returned[i] = True
651+
637652
delta_message = DeltaMessage(tool_calls=[
638-
DeltaToolCall(function=DeltaFunctionCall(
639-
name=tool_choice_function_name,
640-
arguments=delta_text),
641-
index=i)
653+
delta_tool_call,
642654
])
643655

644656
elif request.tool_choice == "required":
645657
assert previous_texts is not None
646-
assert function_name_returned is not None
647658
previous_text = previous_texts[i]
648659
current_text = previous_text + delta_text
649660
fn_name_returned = function_name_returned[i]
@@ -847,7 +858,7 @@ async def chat_completion_stream_generator(
847858
total_tokens=num_prompt_tokens + completion_tokens,
848859
)
849860

850-
data = chunk.model_dump_json(exclude_unset=True)
861+
data = chunk.model_dump_json()
851862
yield f"data: {data}\n\n"
852863

853864
# once the final token is handled, if stream_options.include_usage

vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import partial_json_parser
1010
from partial_json_parser.core.options import Allow
1111

12+
from vllm.entrypoints.chat_utils import random_tool_call_id
1213
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1314
DeltaFunctionCall, DeltaMessage,
1415
DeltaToolCall,
@@ -22,7 +23,6 @@
2223
partial_json_loads)
2324
from vllm.logger import init_logger
2425
from vllm.transformers_utils.tokenizer import AnyTokenizer
25-
from vllm.utils import random_uuid
2626

2727
logger = init_logger(__name__)
2828

@@ -200,7 +200,7 @@ def extract_tool_calls_streaming(
200200
delta = DeltaMessage(tool_calls=[
201201
DeltaToolCall(index=self.current_tool_id,
202202
type="function",
203-
id=f"chatcmpl-tool-{random_uuid()}",
203+
id=random_tool_call_id(),
204204
function=DeltaFunctionCall(
205205
name=function_name).model_dump(
206206
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import partial_json_parser
88
from partial_json_parser.core.options import Allow
99

10+
from vllm.entrypoints.chat_utils import random_tool_call_id
1011
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1112
DeltaFunctionCall, DeltaMessage,
1213
DeltaToolCall,
@@ -20,7 +21,6 @@
2021
partial_json_loads)
2122
from vllm.logger import init_logger
2223
from vllm.transformers_utils.tokenizer import AnyTokenizer
23-
from vllm.utils import random_uuid
2424

2525
logger = init_logger(__name__)
2626

@@ -182,7 +182,7 @@ def extract_tool_calls_streaming(
182182
delta = DeltaMessage(tool_calls=[
183183
DeltaToolCall(index=self.current_tool_id,
184184
type="function",
185-
id=f"chatcmpl-tool-{random_uuid()}",
185+
id=random_tool_call_id(),
186186
function=DeltaFunctionCall(
187187
name=function_name).model_dump(
188188
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import partial_json_parser
99
from partial_json_parser.core.options import Allow
1010

11+
from vllm.entrypoints.chat_utils import random_tool_call_id
1112
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1213
DeltaFunctionCall, DeltaMessage,
1314
DeltaToolCall,
@@ -17,7 +18,6 @@
1718
ToolParser, ToolParserManager)
1819
from vllm.logger import init_logger
1920
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
20-
from vllm.utils import random_uuid
2121

2222
logger = init_logger(__name__)
2323

@@ -259,7 +259,7 @@ def extract_tool_calls_streaming(
259259
return DeltaMessage(tool_calls=[
260260
DeltaToolCall(index=self.current_tool_id,
261261
type="function",
262-
id=f"chatcmpl-tool-{random_uuid()}",
262+
id=random_tool_call_id(),
263263
function=DeltaFunctionCall(
264264
name=function_name).model_dump(
265265
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import partial_json_parser
88
from partial_json_parser.core.options import Allow
99

10+
from vllm.entrypoints.chat_utils import random_tool_call_id
1011
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1112
DeltaFunctionCall, DeltaMessage,
1213
DeltaToolCall,
@@ -18,7 +19,6 @@
1819
extract_intermediate_diff)
1920
from vllm.logger import init_logger
2021
from vllm.transformers_utils.tokenizer import AnyTokenizer
21-
from vllm.utils import random_uuid
2222

2323
logger = init_logger(__name__)
2424

@@ -106,7 +106,7 @@ def extract_tool_calls_streaming(
106106
delta = DeltaMessage(tool_calls=[
107107
DeltaToolCall(index=self.current_tool_id,
108108
type="function",
109-
id=f"chatcmpl-tool-{random_uuid()}",
109+
id=random_tool_call_id(),
110110
function=DeltaFunctionCall(
111111
name=function_name).model_dump(
112112
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import partial_json_parser
99
from partial_json_parser.core.options import Allow
1010

11+
from vllm.entrypoints.chat_utils import random_tool_call_id
1112
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1213
DeltaFunctionCall, DeltaMessage,
1314
DeltaToolCall,
@@ -19,7 +20,6 @@
1920
from vllm.logger import init_logger
2021
from vllm.transformers_utils.tokenizer import AnyTokenizer
2122
from vllm.transformers_utils.tokenizers import MistralTokenizer
22-
from vllm.utils import random_uuid
2323

2424
logger = init_logger(__name__)
2525

@@ -220,7 +220,7 @@ def extract_tool_calls_streaming(
220220
delta = DeltaMessage(tool_calls=[
221221
DeltaToolCall(index=self.current_tool_id,
222222
type="function",
223-
id=f"chatcmpl-tool-{random_uuid()}",
223+
id=random_tool_call_id(),
224224
function=DeltaFunctionCall(
225225
name=function_name).model_dump(
226226
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from partial_json_parser.core.options import Allow
1111
from transformers import PreTrainedTokenizerBase
1212

13+
from vllm.entrypoints.chat_utils import random_tool_call_id
1314
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1415
DeltaFunctionCall, DeltaMessage,
1516
DeltaToolCall,
@@ -21,7 +22,6 @@
2122
is_complete_json,
2223
partial_json_loads)
2324
from vllm.logger import init_logger
24-
from vllm.utils import random_uuid
2525

2626
logger = init_logger(__name__)
2727

@@ -208,7 +208,7 @@ def extract_tool_calls_streaming(
208208
delta = DeltaMessage(tool_calls=[
209209
DeltaToolCall(index=self.current_tool_id,
210210
type="function",
211-
id=f"chatcmpl-tool-{random_uuid()}",
211+
id=random_tool_call_id(),
212212
function=DeltaFunctionCall(
213213
name=function_name).model_dump(
214214
exclude_none=True))

vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
from transformers import PreTrainedTokenizerBase
99

10+
from vllm.entrypoints.chat_utils import random_tool_call_id
1011
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
1112
DeltaMessage,
1213
ExtractedToolCallInformation,
1314
FunctionCall, ToolCall)
1415
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
1516
ToolParser, ToolParserManager)
1617
from vllm.logger import init_logger
17-
from vllm.utils import random_uuid
1818

1919
logger = init_logger(__name__)
2020

@@ -73,7 +73,7 @@ def extract_tool_calls(
7373

7474
tool_calls: list[ToolCall] = [
7575
ToolCall(
76-
id=f"chatcmpl-tool-{random_uuid()}",
76+
id=random_tool_call_id(),
7777
type="function",
7878
function=FunctionCall(
7979
name=raw_function_call["name"],

0 commit comments

Comments
 (0)