|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| - |
| 14 | +import json |
15 | 15 | import os
|
16 | 16 | from typing import TYPE_CHECKING, List, Sequence
|
17 | 17 |
|
|
21 | 21 | from llamafactory.data import get_template_and_fix_tokenizer
|
22 | 22 | from llamafactory.hparams import DataArguments
|
23 | 23 |
|
24 |
| - |
25 | 24 | if TYPE_CHECKING:
|
26 | 25 | from transformers import PreTrainedTokenizer
|
27 | 26 |
|
28 |
| - |
29 | 27 | HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
30 | 28 |
|
31 | 29 | TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
|
37 | 35 | {"role": "assistant", "content": "很高兴认识你!"},
|
38 | 36 | ]
|
39 | 37 |
|
| 38 | +TOOL_MESSAGES = { |
| 39 | + "tools": [ |
| 40 | + { |
| 41 | + "type": "function", |
| 42 | + "function": { |
| 43 | + "name": "get_news", |
| 44 | + "description": "获取最新新闻文章", |
| 45 | + "parameters": { |
| 46 | + "type": "object", |
| 47 | + "properties": { |
| 48 | + "category": {"type": "string", "description": "要检索的新闻文章类别"}, |
| 49 | + "country": {"type": "string", "description": "获取新闻文章的国家"} |
| 50 | + }, |
| 51 | + "required": ["category"] |
| 52 | + } |
| 53 | + } |
| 54 | + }, |
| 55 | + { |
| 56 | + "type": "function", |
| 57 | + "function": { |
| 58 | + "name": "search_books", |
| 59 | + "description": "根据提供的标准搜索书籍", |
| 60 | + "parameters": { |
| 61 | + "type": "object", |
| 62 | + "properties": { |
| 63 | + "title": {"type": "string", "description": "这本书的标题"}, |
| 64 | + "author": {"type": "string", "description": "这本书的作者"}, |
| 65 | + "genre": {"type": "string", "description": "这本书的类型"} |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + } |
| 70 | + ], |
| 71 | + "messages": [ |
| 72 | + { |
| 73 | + "role": "user", |
| 74 | + "content": "你能帮我找到最新的美国体育新闻吗?" |
| 75 | + }, |
| 76 | + { |
| 77 | + "role": "tool_calls", |
| 78 | + "content": [ |
| 79 | + { |
| 80 | + "type": "function", |
| 81 | + "function": {"name": "get_news", "arguments": {"category": "运动", "country": "美国"}} |
| 82 | + } |
| 83 | + ] |
| 84 | + }, |
| 85 | + { |
| 86 | + "role": "tool", |
| 87 | + "content": json.dumps( |
| 88 | + {"title": "NBA总决赛:湖人队对阵热火队", "link": "NBA官方网站"}, |
| 89 | + ensure_ascii=False |
| 90 | + ), |
| 91 | + }, |
| 92 | + { |
| 93 | + "role": "tool", |
| 94 | + "content": json.dumps( |
| 95 | + {"title": "NFL:爱国者队击败酋长队", "link": "https://www.nfl.com/新闻"}, |
| 96 | + ensure_ascii=False |
| 97 | + ), |
| 98 | + }, |
| 99 | + { |
| 100 | + "role": "tool", |
| 101 | + "content": json.dumps( |
| 102 | + {"title": "MLB:道奇队赢得世界系列赛", "link": "https://www.mlb.com/新闻"}, |
| 103 | + ensure_ascii=False |
| 104 | + ) |
| 105 | + }, |
| 106 | + { |
| 107 | + "role": "assistant", |
| 108 | + "content": "1. NBA总决赛:湖人队对阵热火队\n2. NFL:爱国者队击败酋长队\n3. MLB:道奇队赢得世界系列赛" |
| 109 | + } |
| 110 | + ], |
| 111 | +} |
| 112 | + |
40 | 113 |
|
41 | 114 | def _check_tokenization(
|
42 | 115 | tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
|
@@ -168,3 +241,118 @@ def test_yi_template():
|
168 | 241 | )
|
169 | 242 | answer_str = "很高兴认识你!<|im_end|>"
|
170 | 243 | _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str)
|
| 244 | + |
| 245 | + |
| 246 | +@pytest.mark.xfail(reason="The fast tokenizer of mistral model is corrupted.") |
| 247 | +def test_mistral_template(): |
| 248 | + TEMPLATE = r""" |
| 249 | +{%- if not tools is defined %} |
| 250 | + {%- set tools = none %} |
| 251 | +{%- endif %} |
| 252 | +{%- set user_messages = messages | selectattr("role", "equalto", "user") | list %} |
| 253 | +
|
| 254 | +{%- for message in lmessages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} |
| 255 | + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} |
| 256 | + {{- raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} |
| 257 | + {%- endif %} |
| 258 | +{%- endfor %} |
| 259 | +
|
| 260 | +{{- bos_token }} |
| 261 | +{%- for message in messages %} |
| 262 | + {%- if message["role"] == "user" %} |
| 263 | + {%- if tools is not none and (message == user_messages[-1]) %} |
| 264 | + {{- "[AVAILABLE_TOOLS] [" }} |
| 265 | + {%- for tool in tools %} |
| 266 | + {%- set tool = tool.function %} |
| 267 | + {{- '{"type": "function", "function": {' }} |
| 268 | + {%- for key, val in tool.items() if key != "return" %} |
| 269 | + {%- if val is string %} |
| 270 | + {{- '"' + key + '": "' + val + '"' }} |
| 271 | + {%- else %} |
| 272 | + {{- '"' + key + '": ' + val|tojson }} |
| 273 | + {%- endif %} |
| 274 | + {%- if not loop.last %} |
| 275 | + {{- ", " }} |
| 276 | + {%- endif %} |
| 277 | + {%- endfor %} |
| 278 | + {{- "}}" }} |
| 279 | + {%- if not loop.last %} |
| 280 | + {{- ", " }} |
| 281 | + {%- else %} |
| 282 | + {{- "]" }} |
| 283 | + {%- endif %} |
| 284 | + {%- endfor %} |
| 285 | + {{- "[/AVAILABLE_TOOLS]" }} |
| 286 | + {%- endif %} |
| 287 | + {{- "[INST] " + message["content"] + "[/INST]" }} |
| 288 | + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} |
| 289 | + {%- if message.tool_calls is defined %} |
| 290 | + {%- set tool_calls = message.tool_calls %} |
| 291 | + {%- else %} |
| 292 | + {%- set tool_calls = message.content %} |
| 293 | + {%- endif %} |
| 294 | + {{- "[TOOL_CALLS] [" }} |
| 295 | + {%- for tool_call in tool_calls %} |
| 296 | + {%- set out = tool_call.function|tojson %} |
| 297 | + {{- out }} |
| 298 | + {%- if not loop.last %} |
| 299 | + {{- ", " }} |
| 300 | + {%- else %} |
| 301 | + {{- "]" }} |
| 302 | + {%- endif %} |
| 303 | + {%- endfor %} |
| 304 | + {%- elif message["role"] == "assistant" %} |
| 305 | + {{- " " + message["content"] }} |
| 306 | + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} |
| 307 | + {%- if message.content is defined and message.content.content is defined %} |
| 308 | + {%- set content = message.content.content %} |
| 309 | + {%- else %} |
| 310 | + {%- set content = message.content %} |
| 311 | + {%- endif %} |
| 312 | + {{- '[TOOL_RESULTS] {"content": ' + content|string + "}[/TOOL_RESULTS]" }} |
| 313 | + {%- else %} |
| 314 | + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} |
| 315 | + {%- endif %} |
| 316 | +{%- endfor %} |
| 317 | +""" |
| 318 | + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("/home/share/models/Mistral-7B-v0.3") |
| 319 | + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="mistral")) |
| 320 | + |
| 321 | + content_str = tokenizer.apply_chat_template( |
| 322 | + conversation=TOOL_MESSAGES['messages'], |
| 323 | + tools=TOOL_MESSAGES['tools'], |
| 324 | + chat_template=TEMPLATE, |
| 325 | + tokenize=False |
| 326 | + ) |
| 327 | + content_ids = tokenizer.apply_chat_template( |
| 328 | + conversation=TOOL_MESSAGES['messages'], |
| 329 | + tools=TOOL_MESSAGES['tools'], |
| 330 | + chat_template=TEMPLATE, |
| 331 | + tokenize=True |
| 332 | + ) |
| 333 | + encoded_pairs = template.encode_multiturn( |
| 334 | + tokenizer, |
| 335 | + [ |
| 336 | + TOOL_MESSAGES['messages'][0], |
| 337 | + { |
| 338 | + "role": "function", |
| 339 | + "content": json.dumps([function['function'] for function in TOOL_MESSAGES['messages'][1]['content']]) |
| 340 | + }, |
| 341 | + { |
| 342 | + "role": "observation", |
| 343 | + "content": json.dumps([item['content'] for item in TOOL_MESSAGES['messages'][2:-1]]) |
| 344 | + }, |
| 345 | + TOOL_MESSAGES['messages'][-1], |
| 346 | + ], |
| 347 | + tools=json.dumps([tool['function'] for tool in TOOL_MESSAGES['tools']]) |
| 348 | + ) |
| 349 | + |
| 350 | + final_ids = [] |
| 351 | + for prompt, response in encoded_pairs: |
| 352 | + final_ids.extend(prompt) |
| 353 | + final_ids.extend(response) |
| 354 | + |
| 355 | + final_str = tokenizer.decode(final_ids) |
| 356 | + |
| 357 | + assert content_str == final_str |
| 358 | + assert content_ids == final_ids |
0 commit comments