Skip to content

Commit 746a8f6

Browse files
committed
修复 MistralTool 格式
1 parent 558b983 commit 746a8f6

File tree

4 files changed

+199
-11
lines changed

4 files changed

+199
-11
lines changed

src/llamafactory/data/formatter.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def apply(self, **kwargs) -> SLOTS:
147147

148148
elements = []
149149
for name, arguments in functions:
150-
elements.append(f""""{{"name":"{name}","arguments":{arguments}}}""")
150+
elements.append(f"""{{"name": "{name}", "arguments": {arguments}}}""")
151151
elements = ["[TOOL_CALLS] [" + ", ".join(elements) + "]"]
152152

153153
return elements
@@ -163,14 +163,14 @@ def apply(self, **kwargs) -> SLOTS:
163163
content = kwargs.pop("content")
164164
tool_results: List[Tuple[str, str]]
165165
try:
166-
tool_results = [json.dumps(result) for result in json.loads(content)]
166+
tool_results = json.loads(content)
167167
except json.JSONDecodeError:
168168
tool_results = []
169169

170170
elements = []
171-
for content in tool_results:
172-
elements.append(f"[TOOL_RESULTS] {{\"content\":{content}}}[/TOOL_RESULTS]")
173-
return ["".join(elements)]
171+
for result in tool_results:
172+
elements.append(f"[TOOL_RESULTS] {{\"content\": {result}}}[/TOOL_RESULTS]")
173+
return elements
174174

175175

176176
@dataclass

src/llamafactory/data/template.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
723723

724724
_register_template(
725725
name="mistral",
726-
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
726+
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
727727
format_assistant=StringFormatter(slots=[" {{content}}"]), # mistral add space here
728728
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
729729
format_function=MistralFunctionFormatter(slots=[], tool_format="mistral"),

src/llamafactory/data/tool_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
3939
)
4040

41-
MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools} [/AVAILABLE_TOOLS]"
41+
MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools}[/AVAILABLE_TOOLS]"
4242

4343
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
4444

@@ -176,7 +176,7 @@ def get_function_slots() -> SLOTS:
176176
@override
177177
@staticmethod
178178
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
179-
tools = [{"type": "function", "function": tool} for tool in tools]
179+
tools = json.dumps([{"type": "function", "function": tool} for tool in tools],ensure_ascii=False)
180180
return MISTRAL_TOOL_PROMPT.format(tools=tools)
181181

182182
@override

tests/data/test_template.py

+191-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import json
1515
import os
1616
from typing import TYPE_CHECKING, List, Sequence
1717

@@ -21,11 +21,9 @@
2121
from llamafactory.data import get_template_and_fix_tokenizer
2222
from llamafactory.hparams import DataArguments
2323

24-
2524
if TYPE_CHECKING:
2625
from transformers import PreTrainedTokenizer
2726

28-
2927
HF_TOKEN = os.environ.get("HF_TOKEN", None)
3028

3129
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
@@ -37,6 +35,81 @@
3735
{"role": "assistant", "content": "很高兴认识你!"},
3836
]
3937

38+
TOOL_MESSAGES = {
39+
"tools": [
40+
{
41+
"type": "function",
42+
"function": {
43+
"name": "get_news",
44+
"description": "获取最新新闻文章",
45+
"parameters": {
46+
"type": "object",
47+
"properties": {
48+
"category": {"type": "string", "description": "要检索的新闻文章类别"},
49+
"country": {"type": "string", "description": "获取新闻文章的国家"}
50+
},
51+
"required": ["category"]
52+
}
53+
}
54+
},
55+
{
56+
"type": "function",
57+
"function": {
58+
"name": "search_books",
59+
"description": "根据提供的标准搜索书籍",
60+
"parameters": {
61+
"type": "object",
62+
"properties": {
63+
"title": {"type": "string", "description": "这本书的标题"},
64+
"author": {"type": "string", "description": "这本书的作者"},
65+
"genre": {"type": "string", "description": "这本书的类型"}
66+
}
67+
}
68+
}
69+
}
70+
],
71+
"messages": [
72+
{
73+
"role": "user",
74+
"content": "你能帮我找到最新的美国体育新闻吗?"
75+
},
76+
{
77+
"role": "tool_calls",
78+
"content": [
79+
{
80+
"type": "function",
81+
"function": {"name": "get_news", "arguments": {"category": "运动", "country": "美国"}}
82+
}
83+
]
84+
},
85+
{
86+
"role": "tool",
87+
"content": json.dumps(
88+
{"title": "NBA总决赛:湖人队对阵热火队", "link": "NBA官方网站"},
89+
ensure_ascii=False
90+
),
91+
},
92+
{
93+
"role": "tool",
94+
"content": json.dumps(
95+
{"title": "NFL:爱国者队击败酋长队", "link": "https://www.nfl.com/新闻"},
96+
ensure_ascii=False
97+
),
98+
},
99+
{
100+
"role": "tool",
101+
"content": json.dumps(
102+
{"title": "MLB:道奇队赢得世界系列赛", "link": "https://www.mlb.com/新闻"},
103+
ensure_ascii=False
104+
)
105+
},
106+
{
107+
"role": "assistant",
108+
"content": "1. NBA总决赛:湖人队对阵热火队\n2. NFL:爱国者队击败酋长队\n3. MLB:道奇队赢得世界系列赛"
109+
}
110+
],
111+
}
112+
40113

41114
def _check_tokenization(
42115
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
@@ -168,3 +241,118 @@ def test_yi_template():
168241
)
169242
answer_str = "很高兴认识你!<|im_end|>"
170243
_check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
244+
245+
246+
@pytest.mark.xfail(reason="The fast tokenizer of mistral model is corrupted.")
247+
def test_mistral_template():
248+
TEMPLATE = r"""
249+
{%- if not tools is defined %}
250+
{%- set tools = none %}
251+
{%- endif %}
252+
{%- set user_messages = messages | selectattr("role", "equalto", "user") | list %}
253+
254+
{%- for message in lmessages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
255+
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
256+
{{- raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
257+
{%- endif %}
258+
{%- endfor %}
259+
260+
{{- bos_token }}
261+
{%- for message in messages %}
262+
{%- if message["role"] == "user" %}
263+
{%- if tools is not none and (message == user_messages[-1]) %}
264+
{{- "[AVAILABLE_TOOLS] [" }}
265+
{%- for tool in tools %}
266+
{%- set tool = tool.function %}
267+
{{- '{"type": "function", "function": {' }}
268+
{%- for key, val in tool.items() if key != "return" %}
269+
{%- if val is string %}
270+
{{- '"' + key + '": "' + val + '"' }}
271+
{%- else %}
272+
{{- '"' + key + '": ' + val|tojson }}
273+
{%- endif %}
274+
{%- if not loop.last %}
275+
{{- ", " }}
276+
{%- endif %}
277+
{%- endfor %}
278+
{{- "}}" }}
279+
{%- if not loop.last %}
280+
{{- ", " }}
281+
{%- else %}
282+
{{- "]" }}
283+
{%- endif %}
284+
{%- endfor %}
285+
{{- "[/AVAILABLE_TOOLS]" }}
286+
{%- endif %}
287+
{{- "[INST] " + message["content"] + "[/INST]" }}
288+
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
289+
{%- if message.tool_calls is defined %}
290+
{%- set tool_calls = message.tool_calls %}
291+
{%- else %}
292+
{%- set tool_calls = message.content %}
293+
{%- endif %}
294+
{{- "[TOOL_CALLS] [" }}
295+
{%- for tool_call in tool_calls %}
296+
{%- set out = tool_call.function|tojson %}
297+
{{- out }}
298+
{%- if not loop.last %}
299+
{{- ", " }}
300+
{%- else %}
301+
{{- "]" }}
302+
{%- endif %}
303+
{%- endfor %}
304+
{%- elif message["role"] == "assistant" %}
305+
{{- " " + message["content"] }}
306+
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
307+
{%- if message.content is defined and message.content.content is defined %}
308+
{%- set content = message.content.content %}
309+
{%- else %}
310+
{%- set content = message.content %}
311+
{%- endif %}
312+
{{- '[TOOL_RESULTS] {"content": ' + content|string + "}[/TOOL_RESULTS]" }}
313+
{%- else %}
314+
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
315+
{%- endif %}
316+
{%- endfor %}
317+
"""
318+
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("/home/share/models/Mistral-7B-v0.3")
319+
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="mistral"))
320+
321+
content_str = tokenizer.apply_chat_template(
322+
conversation=TOOL_MESSAGES['messages'],
323+
tools=TOOL_MESSAGES['tools'],
324+
chat_template=TEMPLATE,
325+
tokenize=False
326+
)
327+
content_ids = tokenizer.apply_chat_template(
328+
conversation=TOOL_MESSAGES['messages'],
329+
tools=TOOL_MESSAGES['tools'],
330+
chat_template=TEMPLATE,
331+
tokenize=True
332+
)
333+
encoded_pairs = template.encode_multiturn(
334+
tokenizer,
335+
[
336+
TOOL_MESSAGES['messages'][0],
337+
{
338+
"role": "function",
339+
"content": json.dumps([function['function'] for function in TOOL_MESSAGES['messages'][1]['content']])
340+
},
341+
{
342+
"role": "observation",
343+
"content": json.dumps([item['content'] for item in TOOL_MESSAGES['messages'][2:-1]])
344+
},
345+
TOOL_MESSAGES['messages'][-1],
346+
],
347+
tools=json.dumps([tool['function'] for tool in TOOL_MESSAGES['tools']])
348+
)
349+
350+
final_ids = []
351+
for prompt, response in encoded_pairs:
352+
final_ids.extend(prompt)
353+
final_ids.extend(response)
354+
355+
final_str = tokenizer.decode(final_ids)
356+
357+
assert content_str == final_str
358+
assert content_ids == final_ids

0 commit comments

Comments
 (0)