Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tool role in BaseChatTemplate as tool response in messages #2979

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ChatTemplateConfig:
eoh (str | None): end of the user prompt
assistant (str | None): begin of the assistant prompt
eoa (str | None): end of the assistant prompt
tool (str | None): begin of the tool prompt
eotool (str | None): end of the tool prompt
capability: ('completion' | 'infilling' | 'chat' | 'python') = None
""" # noqa: E501

Expand All @@ -57,6 +59,8 @@ class ChatTemplateConfig:
eoh: Optional[str] = None
assistant: Optional[str] = None
eoa: Optional[str] = None
tool: Optional[str] = None
eotool: Optional[str] = None
separator: Optional[str] = None
capability: Optional[Literal['completion', 'infilling', 'chat',
'python']] = None
Expand Down Expand Up @@ -173,6 +177,8 @@ def __init__(self,
assistant='',
eoa='',
separator='',
tool='',
eotool='',
**kwargs):
super().__init__(**kwargs)
self.system = system
Expand All @@ -183,6 +189,8 @@ def __init__(self,
self.separator = separator
self.eosys = eosys
self.assistant = assistant
self.tool = tool
self.eotool = eotool

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
Expand Down Expand Up @@ -223,10 +231,12 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
return self.get_prompt(messages, sequence_start)
box_map = dict(user=self.user,
assistant=self.assistant,
system=self.system)
system=self.system,
tool=self.tool)
eox_map = dict(user=self.eoh,
assistant=self.eoa + self.separator,
system=self.eosys)
system=self.eosys,
tool=self.eotool)
ret = ''
if self.meta_instruction is not None and sequence_start:
if len(messages) and messages[0]['role'] != 'system':
Expand Down Expand Up @@ -819,7 +829,7 @@ class Llama3_1(Llama3):

def __init__(
self,
tools="""# Tool Instructions
tool="""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search

Expand All @@ -828,7 +838,7 @@ def __init__(
You have access to the following functions:

""", # noqa
eotools="""
eotool="""

If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
Expand Down Expand Up @@ -858,8 +868,8 @@ def __init__(
**kwargs)
self.ipython = ipython
self.eoi = eoi
self.tools = tools
self.eotools = eotools
self.tool = tool
self.eotool = eotool
self.knowledge = knowledge

def messages2prompt(self,
Expand Down Expand Up @@ -899,15 +909,15 @@ def messages2prompt(self,
if tools is None:
ret += f'{self.system}{self.knowledge}{self.meta_instruction}{self.eosys}'
else:
ret += f'{self.system}{self.knowledge}{self.tools}{tool_prompt}{self.eotools}{self.meta_instruction}{self.eosys}'
ret += f'{self.system}{self.knowledge}{self.tool}{tool_prompt}{self.eotool}{self.meta_instruction}{self.eosys}'
for message in messages:
role = message['role']
content = get_text(message['content'])
if role == 'assistant' and ('<|python_tag|>' in content
or '</function>' in content):
ret += f'{box_map[role]}{content}<|eom_id|>'
elif role == 'system' and tools is not None:
ret += f'{box_map[role]}{self.tools}{tool_prompt}{self.eotools}{content}{eox_map[role]}'
ret += f'{box_map[role]}{self.tool}{tool_prompt}{self.eotool}{content}{eox_map[role]}'
else:
ret += f'{box_map[role]}{content}{eox_map[role]}'
if sequence_start and not isinstance(messages, str):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
for call_info in call_info_list
]
except Exception as e:
logger.error(f'Exception: {e}')
logger.error(f'Failed to parse {text}. Exception: {e}.')
return create_error_response(
HTTPStatus.BAD_REQUEST,
'Failed to parse fc related info to json format!')
Expand Down
Loading