8
8
from vllm .entrypoints .openai .protocol import (DeltaFunctionCall , DeltaMessage ,
9
9
DeltaToolCall ,
10
10
ExtractedToolCallInformation ,
11
- FunctionCall ,
12
- InitialDeltaToolCall , ToolCall )
11
+ FunctionCall , ToolCall )
13
12
from vllm .entrypoints .openai .tool_parsers .abstract_tool_parser import (
14
13
ToolParser )
15
14
from vllm .entrypoints .openai .tool_parsers .utils import (
16
15
extract_intermediate_diff )
17
16
from vllm .logger import init_logger
18
17
from vllm .transformers_utils .tokenizer import AnyTokenizer , MistralTokenizer
18
+ from vllm .utils import random_uuid
19
19
20
20
logger = init_logger (__name__ )
21
21
@@ -25,7 +25,7 @@ class MistralToolParser(ToolParser):
25
25
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
26
26
examples/tool_chat_template_mistral.jinja template.
27
27
28
- Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set
28
+ Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
29
29
"""
30
30
31
31
def __init__ (self , tokenizer : AnyTokenizer ):
@@ -42,7 +42,6 @@ def __init__(self, tokenizer: AnyTokenizer):
42
42
self .prev_tool_call_arr : List [Dict ] = []
43
43
self .current_tool_id : int = - 1
44
44
self .current_tool_name_sent : bool = False
45
- self .current_tool_initial_sent : bool = False
46
45
self .streamed_args_for_tool : List [str ] = [
47
46
] # map what has been streamed for each tool so far to a list
48
47
self .bot_token = "[TOOL_CALLS]"
@@ -91,7 +90,6 @@ def extract_tool_calls(self,
91
90
92
91
except Exception as e :
93
92
logger .error ("Error in extracting tool call from response: %s" , e )
94
- print ("ERROR" , e )
95
93
# return information to just treat the tool call as regular JSON
96
94
return ExtractedToolCallInformation (tools_called = False ,
97
95
tool_calls = [],
@@ -109,7 +107,7 @@ def extract_tool_calls_streaming(
109
107
110
108
# if the tool call token is not in the tokens generated so far, append
111
109
# output to contents since it's not a tool
112
- if self .bot_token_id not in current_token_ids :
110
+ if self .bot_token not in current_text :
113
111
return DeltaMessage (content = delta_text )
114
112
115
113
# if the tool call token ID IS in the tokens generated so far, that
@@ -134,7 +132,7 @@ def extract_tool_calls_streaming(
134
132
# replace BOT token with empty string, and convert single quotes
135
133
# to double to allow parsing as JSON since mistral uses single
136
134
# quotes instead of double for tool calls
137
- parsable_arr = current_text .split (self .bot_token )[1 ]
135
+ parsable_arr = current_text .split (self .bot_token )[- 1 ]
138
136
139
137
# tool calls are generated in an array, so do partial JSON
140
138
# parsing on the entire array
@@ -186,31 +184,22 @@ def extract_tool_calls_streaming(
186
184
# re-set stuff pertaining to progress in the current tool
187
185
self .current_tool_id = len (tool_call_arr ) - 1
188
186
self .current_tool_name_sent = False
189
- self .current_tool_initial_sent = False
190
187
self .streamed_args_for_tool .append ("" )
191
188
logger .debug ("starting on new tool %d" , self .current_tool_id )
192
189
return delta
193
190
194
191
# case: update an existing tool - this is handled below
195
192
196
- # if the current tool initial data incl. the id, type=function
197
- # and idx not sent, send that
198
- if not self .current_tool_initial_sent :
199
- self .current_tool_initial_sent = True
200
- delta = DeltaMessage (tool_calls = [
201
- InitialDeltaToolCall (
202
- index = self .current_tool_id ).model_dump (
203
- exclude_none = True )
204
- ])
205
-
206
193
# if the current tool name hasn't been sent, send if available
207
194
# - otherwise send nothing
208
- elif not self .current_tool_name_sent :
195
+ if not self .current_tool_name_sent :
209
196
function_name = current_tool_call .get ("name" )
210
197
if function_name :
211
198
212
199
delta = DeltaMessage (tool_calls = [
213
200
DeltaToolCall (index = self .current_tool_id ,
201
+ type = "function" ,
202
+ id = f"chatcmpl-tool-{ random_uuid ()} " ,
214
203
function = DeltaFunctionCall (
215
204
name = function_name ).model_dump (
216
205
exclude_none = True ))
0 commit comments