Skip to content

Commit bafc8a0

Browse files
sunxichensunxichen
and
sunxichen
authored
fix: tool call message role according to credentials (langgenius#5625)
Co-authored-by: sunxichen <[email protected]>
1 parent 92c56fd commit bafc8a0

File tree

1 file changed

+6
-6
lines changed
  • api/core/model_runtime/model_providers/openai_api_compatible/llm

1 file changed

+6
-6
lines changed

api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr
8888
:param tools: tools for tool calling
8989
:return:
9090
"""
91-
return self._num_tokens_from_messages(model, prompt_messages, tools)
91+
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
9292

9393
def validate_credentials(self, model: str, credentials: dict) -> None:
9494
"""
@@ -305,7 +305,7 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM
305305

306306
if completion_type is LLMMode.CHAT:
307307
endpoint_url = urljoin(endpoint_url, 'chat/completions')
308-
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
308+
data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
309309
elif completion_type is LLMMode.COMPLETION:
310310
endpoint_url = urljoin(endpoint_url, 'completions')
311311
data['prompt'] = prompt_messages[0].content
@@ -582,7 +582,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: req
582582

583583
return result
584584

585-
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
585+
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: dict = None) -> dict:
586586
"""
587587
Convert PromptMessage to dict for OpenAI API format
588588
"""
@@ -636,7 +636,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
636636
# "tool_call_id": message.tool_call_id
637637
# }
638638
message_dict = {
639-
"role": "function",
639+
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
640640
"content": message.content,
641641
"name": message.tool_call_id
642642
}
@@ -675,7 +675,7 @@ def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessag
675675
return num_tokens
676676

677677
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
678-
tools: Optional[list[PromptMessageTool]] = None) -> int:
678+
tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int:
679679
"""
680680
Approximate num tokens with GPT2 tokenizer.
681681
"""
@@ -684,7 +684,7 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
684684
tokens_per_name = 1
685685

686686
num_tokens = 0
687-
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
687+
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
688688
for message in messages_dict:
689689
num_tokens += tokens_per_message
690690
for key, value in message.items():

0 commit comments

Comments
 (0)