Skip to content

Commit

Permalink
Merge branch 'main' into release/3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 7, 2025
2 parents c3e1da4 + 1701e17 commit adb6f8f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
5 changes: 5 additions & 0 deletions swift/llm/template/template_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,10 @@ def messages_join_observation(messages: Messages, tools_prompt='react_en') -> No
assert isinstance(pre_content, str)
pre_message['content'] = pre_content + content # assistant
messages.pop(i) # remove tool
elif (pre_role == 'assistant' and role == 'assistant' and isinstance(pre_content, str)
and isinstance(content, str)):
# Consecutive messages from the assistant role need to be merged to prevent errors.
pre_message['content'] = pre_content + content
messages.pop(i)
else:
i += 1
40 changes: 24 additions & 16 deletions swift/plugin/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def format_react_en(tool_names, tool_descs):
Begin!
"""
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
return REACT_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))


Expand All @@ -49,7 +49,7 @@ def format_react_zh(tool_names, tool_descs):
开始!
"""
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
return REACT_ZH_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))


Expand All @@ -59,7 +59,7 @@ def format_glm4(tool_names, tool_descs):
# 可用工具
{tool_list}"""
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
tool_list = ''
for name, tool in zip(tool_names, tool_descs):
tool_list += f'## {name}\n\n{tool}\n\n'
Expand Down Expand Up @@ -92,7 +92,7 @@ def format_toolbench(tool_names, tool_descs):
use function Finish->give_up_and_restart.
2.Do not use origin tool names, use only subfunctions' names.
Specifically, you have access to the following APIs: {tool_list}"""
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
return TOOLBENCH_PROMPT.format(tool_list='\n\n'.join(tool_descs))


Expand All @@ -107,10 +107,20 @@ def format_qwen(tool_names, tool_descs):
{tool_list}
## 你可以在回复中插入以下命令以调用这些工具:
{format_list}
'''
## 你可以在回复中插入以下命令以并行调用N个工具:
✿FUNCTION✿: 工具1的名称,必须是[{tool_names}]之一
✿ARGS✿: 工具1的输入
✿FUNCTION✿: 工具2的名称
✿ARGS✿: 工具2的输入
...
✿FUNCTION✿: 工具N的名称
✿ARGS✿: 工具N的输入
✿RESULT✿: 工具1的结果
✿RESULT✿: 工具2的结果
...
✿RESULT✿: 工具N的结果
✿RETURN✿: 根据工具结果进行回复'''
# 定义星期映射
weekdays = {0: '星期一', 1: '星期二', 2: '星期三', 3: '星期四', 4: '星期五', 5: '星期六', 6: '星期日'}
now = dt.datetime.now()
Expand All @@ -122,15 +132,13 @@ def format_qwen(tool_names, tool_descs):
PROMPT = PROMPT.replace('{date}', formatted_date)
tool_list = ''
for name, tool in zip(tool_names, tool_descs):
tool_list += f'### {name} \n{name}: {tool["description"]} 输入参数: {json.dumps(tool["parameters"])}\n'
desc = tool.get('description', '')
parameters = json.dumps(params, ensure_ascii=False) if (params := tool.get('parameters')) else ''
tool_list += f'### {name}\n\n{name}: {desc} 输入参数: {parameters} 此工具的输入应为JSON对象。'

PROMPT = PROMPT.replace('{tool_list}', tool_list)

format_list = ''
for i, _ in enumerate(tool_names):
format_list += f'✿FUNCTION✿:工具{i+1}的名称\n✿ARGS✿:工具{i + 1}的输入\n✿RESULT✿:工具{i + 1}的结果\n'
PROMPT = PROMPT.replace('{format_list}', format_list)
return PROMPT
PROMPT = PROMPT.replace('{tool_names}', ','.join(tool_names))
return PROMPT.rstrip()


def format_custom(tool_names, tool_descs):
Expand All @@ -140,7 +148,7 @@ def format_custom(tool_names, tool_descs):
{tool_list}'''
tool_list = ''
tool_descs = [json.dumps(t) if not isinstance(t, str) else t for t in tool_descs]
tool_descs = [json.dumps(t, ensure_ascii=False) if not isinstance(t, str) else t for t in tool_descs]
for name, tool in zip(tool_names, tool_descs):
tool_list += f'## {name}\n\n{tool}\n\n'
return PROMPT.format(tool_list=tool_list)
Expand Down
46 changes: 46 additions & 0 deletions tests/llm/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,52 @@ def test_template(self):
response2 = _infer_model(pt_engine)
assert response == response2

def test_tool_message_join(self):
from copy import deepcopy

from swift.llm.template.template_inputs import StdTemplateInputs
from swift.plugin.tools import get_tools_keyword

messages = [
# first round
{
'role': 'user',
'content': 'testing_user_message'
},
{
'role': 'assistant',
'content': ''
},
{
'role': 'tool',
'content': ''
},
# second round
{
'role': 'assistant',
'content': ''
},
{
'role': 'tool',
'content': ''
},
]

# testing two template type.
for tool_prompt in ('react_en', 'qwen'):
tool_prompt = 'react_en'
test_messages = deepcopy(messages)
obs_word = get_tools_keyword(tool_prompt).get('observation')
test_messages[1]['content'] = f'{obs_word}'
test_messages[2]['content'] = 'first_round_result\n'
test_messages[3]['content'] = f'{obs_word}'
test_messages[4]['content'] = 'second_round_result\n'
StdTemplateInputs.messages_join_observation(test_messages, tools_prompt=tool_prompt)

# multi-round tool calling should be joined that only one assistant message left.
assert len(test_messages) == 2, f'Tool prompot {tool_prompt} join failed, {messages}'
assert test_messages[1]['content'] == f"""{obs_word}first_round_result\n{obs_word}second_round_result\n"""


if __name__ == '__main__':
unittest.main()

0 comments on commit adb6f8f

Please sign in to comment.