diff --git a/autogen/oai/mistral.py b/autogen/oai/mistral.py index 8017e3536324..10d0f926ffbf 100644 --- a/autogen/oai/mistral.py +++ b/autogen/oai/mistral.py @@ -15,28 +15,32 @@ Resources: - https://docs.mistral.ai/getting-started/quickstart/ -""" -# Important notes when using the Mistral.AI API: -# The first system message can greatly affect whether the model returns a tool call, including text that references the ability to use functions will help. -# Changing the role on the first system message to 'user' improved the chances of the model recommending a tool call. +NOTE: Requires mistralai package version >= 1.0.1 +""" import inspect import json import os import time import warnings -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union # Mistral libraries # pip install mistralai -from mistralai.client import MistralClient -from mistralai.exceptions import MistralAPIException -from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ToolCall +from mistralai import ( + AssistantMessage, + Function, + FunctionCall, + Mistral, + SystemMessage, + ToolCall, + ToolMessage, + UserMessage, +) from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage -from typing_extensions import Annotated from autogen.oai.client_utils import should_hide_tools, validate_parameter @@ -50,6 +54,7 @@ def __init__(self, **kwargs): Args: api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set) """ + # Ensure we have the api_key upon instantiation self.api_key = kwargs.get("api_key", None) if not self.api_key: @@ -59,7 +64,9 @@ def __init__(self, **kwargs): self.api_key ), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable." - def message_retrieval(self, response: ChatCompletionResponse) -> Union[List[str], List[ChatCompletionMessage]]: + self._client = Mistral(api_key=self.api_key) + + def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]: """Retrieve the messages from the response.""" return [choice.message for choice in response.choices] @@ -86,34 +93,52 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: ) mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None) + # TODO + if params.get("stream", False): + warnings.warn( + "Streaming is not currently supported, streaming will be disabled.", + UserWarning, + ) + # 3. Convert messages to Mistral format mistral_messages = [] tool_call_ids = {} # tool call ids to function name mapping for message in params["messages"]: if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None: # Convert OAI ToolCall to Mistral ToolCall - openai_toolcalls = message["tool_calls"] - mistral_toolcalls = [] - for toolcall in openai_toolcalls: - mistral_toolcall = ToolCall(id=toolcall["id"], function=toolcall["function"]) - mistral_toolcalls.append(mistral_toolcall) - mistral_messages.append( - ChatMessage(role=message["role"], content=message["content"], tool_calls=mistral_toolcalls) - ) + mistral_messages_tools = [] + for toolcall in message["tool_calls"]: + mistral_messages_tools.append( + ToolCall( + id=toolcall["id"], + function=FunctionCall( + name=toolcall["function"]["name"], + arguments=json.loads(toolcall["function"]["arguments"]), + ), + ) + ) + + mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools)) # Map tool call id to the function name for tool_call in message["tool_calls"]: tool_call_ids[tool_call["id"]] = tool_call["function"]["name"] - elif message["role"] in ("system", "user", "assistant"): - # Note this ChatMessage can take a 'name' but it is rejected by the Mistral API if not role=tool, so, no, the 'name' field is not used. - mistral_messages.append(ChatMessage(role=message["role"], content=message["content"])) + elif message["role"] == "system": + if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant": + # System messages can't appear after an Assistant message, so use a UserMessage + mistral_messages.append(UserMessage(content=message["content"])) + else: + mistral_messages.append(SystemMessage(content=message["content"])) + elif message["role"] == "assistant": + mistral_messages.append(AssistantMessage(content=message["content"])) + elif message["role"] == "user": + mistral_messages.append(UserMessage(content=message["content"])) elif message["role"] == "tool": # Indicates the result of a tool call, the name is the function name called mistral_messages.append( - ChatMessage( - role="tool", + ToolMessage( name=tool_call_ids[message["tool_call_id"]], content=message["content"], tool_call_id=message["tool_call_id"], @@ -122,21 +147,20 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: else: warnings.warn(f"Unknown message role {message['role']}", UserWarning) - # If a 'system' message follows an 'assistant' message, change it to 'user' - # This can occur when using LLM summarisation - for i in range(1, len(mistral_messages)): - if mistral_messages[i - 1].role == "assistant" and mistral_messages[i].role == "system": - mistral_messages[i].role = "user" + # 4. Last message needs to be user or tool, if not, add a "please continue" message + if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage): + mistral_messages.append(UserMessage(content="Please continue.")) mistral_params["messages"] = mistral_messages - # 4. Add tools to the call if we have them and aren't hiding them + # 5. Add tools to the call if we have them and aren't hiding them if "tools" in params: hide_tools = validate_parameter( params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"] ) if not should_hide_tools(params["messages"], params["tools"], hide_tools): - mistral_params["tools"] = params["tools"] + mistral_params["tools"] = tool_def_to_mistral(params["tools"]) + return mistral_params def create(self, params: Dict[str, Any]) -> ChatCompletion: @@ -144,8 +168,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: mistral_params = self.parse_params(params) # 2. Call Mistral.AI API - client = MistralClient(api_key=self.api_key) - mistral_response = client.chat(**mistral_params) + mistral_response = self._client.chat.complete(**mistral_params) # TODO: Handle streaming # 3. Convert Mistral response to OAI compatible format @@ -191,7 +214,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: return response_oai @staticmethod - def get_usage(response: ChatCompletionResponse) -> Dict: + def get_usage(response: ChatCompletion) -> Dict: return { "prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0, "completion_tokens": response.usage.completion_tokens if response.usage is not None else 0, @@ -203,25 +226,48 @@ def get_usage(response: ChatCompletionResponse) -> Dict: } +def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Converts AutoGen tool definition to a mistral tool format""" + + mistral_tools = [] + + for autogen_tool in tool_definitions: + mistral_tool = { + "type": "function", + "function": Function( + name=autogen_tool["function"]["name"], + description=autogen_tool["function"]["description"], + parameters=autogen_tool["function"]["parameters"], + ), + } + + mistral_tools.append(mistral_tool) + + return mistral_tools + + def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float: """Calculate the cost of the mistral response.""" - # Prices per 1 million tokens + # Prices per 1 thousand tokens # https://mistral.ai/technology/ model_cost_map = { - "open-mistral-7b": {"input": 0.25, "output": 0.25}, - "open-mixtral-8x7b": {"input": 0.7, "output": 0.7}, - "open-mixtral-8x22b": {"input": 2.0, "output": 6.0}, - "mistral-small-latest": {"input": 1.0, "output": 3.0}, - "mistral-medium-latest": {"input": 2.7, "output": 8.1}, - "mistral-large-latest": {"input": 4.0, "output": 12.0}, + "open-mistral-7b": {"input": 0.00025, "output": 0.00025}, + "open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007}, + "open-mixtral-8x22b": {"input": 0.002, "output": 0.006}, + "mistral-small-latest": {"input": 0.001, "output": 0.003}, + "mistral-medium-latest": {"input": 0.00275, "output": 0.0081}, + "mistral-large-latest": {"input": 0.0003, "output": 0.0003}, + "mistral-large-2407": {"input": 0.0003, "output": 0.0003}, + "open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003}, + "codestral-2405": {"input": 0.001, "output": 0.003}, } # Ensure we have the model they are using and return the total cost if model_name in model_cost_map: costs = model_cost_map[model_name] - return (input_tokens * costs["input"] / 1_000_000) + (output_tokens * costs["output"] / 1_000_000) + return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000) else: warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning) return 0 diff --git a/setup.py b/setup.py index 13a88be5f0a4..b94227420dec 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ "types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor, "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], - "mistral": ["mistralai>=0.2.0"], + "mistral": ["mistralai>=1.0.1"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], } diff --git a/test/oai/test_mistral.py b/test/oai/test_mistral.py index 5236f71d7b7d..f89c3d304d90 100644 --- a/test/oai/test_mistral.py +++ b/test/oai/test_mistral.py @@ -3,7 +3,16 @@ import pytest try: - from mistralai.models.chat_completion import ChatMessage + from mistralai import ( + AssistantMessage, + Function, + FunctionCall, + Mistral, + SystemMessage, + ToolCall, + ToolMessage, + UserMessage, + ) from autogen.oai.mistral import MistralAIClient, calculate_mistral_cost @@ -66,17 +75,16 @@ def test_cost_calculation(mock_response): cost=None, model="mistral-large-latest", ) - assert ( - calculate_mistral_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model) - == 0.0001 - ), "Cost for this should be $0.0001" + assert calculate_mistral_cost( + response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model + ) == (15 / 1000 * 0.0003), "Cost for this should be $0.0000045" # Test text generation @pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed") -@patch("autogen.oai.mistral.MistralClient.chat") +@patch("autogen.oai.mistral.MistralAIClient.create") def test_create_response(mock_chat, mistral_client): - # Mock MistralClient.chat response + # Mock `mistral_response = client.chat.complete(**mistral_params)` mock_mistral_response = MagicMock() mock_mistral_response.choices = [ MagicMock(finish_reason="stop", message=MagicMock(content="Example Mistral response", tool_calls=None)) @@ -108,9 +116,9 @@ def test_create_response(mock_chat, mistral_client): # Test functions/tools @pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed") -@patch("autogen.oai.mistral.MistralClient.chat") +@patch("autogen.oai.mistral.MistralAIClient.create") def test_create_response_with_tool_call(mock_chat, mistral_client): - # Mock `mistral_response = client.chat(**mistral_params)` + # Mock `mistral_response = client.chat.complete(**mistral_params)` mock_function = MagicMock(name="currency_calculator") mock_function.name = "currency_calculator" mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}' @@ -159,7 +167,7 @@ def test_create_response_with_tool_call(mock_chat, mistral_client): {"role": "assistant", "content": "World"}, ] - # Call the create method + # Call the chat method response = mistral_client.create( {"messages": mistral_messages, "tools": converted_functions, "model": "mistral-medium-latest"} )