Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Mistral client class to support new Mistral v1.0.1 package #3356

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 87 additions & 41 deletions autogen/oai/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,32 @@

Resources:
- https://docs.mistral.ai/getting-started/quickstart/
"""

# Important notes when using the Mistral.AI API:
# The first system message can greatly affect whether the model returns a tool call, including text that references the ability to use functions will help.
# Changing the role on the first system message to 'user' improved the chances of the model recommending a tool call.
NOTE: Requires mistralai package version >= 1.0.1
"""

import inspect
import json
import os
import time
import warnings
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union

# Mistral libraries
# pip install mistralai
from mistralai.client import MistralClient
from mistralai.exceptions import MistralAPIException
from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ToolCall
from mistralai import (
AssistantMessage,
Function,
FunctionCall,
Mistral,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
from typing_extensions import Annotated

from autogen.oai.client_utils import should_hide_tools, validate_parameter

Expand All @@ -50,6 +54,7 @@ def __init__(self, **kwargs):
Args:
api_key (str): The API key for using Mistral.AI (or environment variable MISTRAL_API_KEY needs to be set)
"""

# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
Expand All @@ -59,7 +64,9 @@ def __init__(self, **kwargs):
self.api_key
), "Please specify the 'api_key' in your config list entry for Mistral or set the MISTRAL_API_KEY env variable."

def message_retrieval(self, response: ChatCompletionResponse) -> Union[List[str], List[ChatCompletionMessage]]:
self._client = Mistral(api_key=self.api_key)

def message_retrieval(self, response: ChatCompletion) -> Union[List[str], List[ChatCompletionMessage]]:
"""Retrieve the messages from the response."""

return [choice.message for choice in response.choices]
Expand All @@ -86,34 +93,52 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
)
mistral_params["random_seed"] = validate_parameter(params, "random_seed", int, True, None, False, None)

# TODO
if params.get("stream", False):
warnings.warn(
"Streaming is not currently supported, streaming will be disabled.",
UserWarning,
)

# 3. Convert messages to Mistral format
mistral_messages = []
tool_call_ids = {} # tool call ids to function name mapping
for message in params["messages"]:
if message["role"] == "assistant" and "tool_calls" in message and message["tool_calls"] is not None:
# Convert OAI ToolCall to Mistral ToolCall
openai_toolcalls = message["tool_calls"]
mistral_toolcalls = []
for toolcall in openai_toolcalls:
mistral_toolcall = ToolCall(id=toolcall["id"], function=toolcall["function"])
mistral_toolcalls.append(mistral_toolcall)
mistral_messages.append(
ChatMessage(role=message["role"], content=message["content"], tool_calls=mistral_toolcalls)
)
mistral_messages_tools = []
for toolcall in message["tool_calls"]:
mistral_messages_tools.append(
ToolCall(
id=toolcall["id"],
function=FunctionCall(
name=toolcall["function"]["name"],
arguments=json.loads(toolcall["function"]["arguments"]),
),
)
)

mistral_messages.append(AssistantMessage(content="", tool_calls=mistral_messages_tools))
marklysze marked this conversation as resolved.
Show resolved Hide resolved

# Map tool call id to the function name
for tool_call in message["tool_calls"]:
tool_call_ids[tool_call["id"]] = tool_call["function"]["name"]

elif message["role"] in ("system", "user", "assistant"):
# Note this ChatMessage can take a 'name' but it is rejected by the Mistral API if not role=tool, so, no, the 'name' field is not used.
mistral_messages.append(ChatMessage(role=message["role"], content=message["content"]))
elif message["role"] == "system":
if len(mistral_messages) > 0 and mistral_messages[-1].role == "assistant":
# System messages can't appear after an Assistant message, so use a UserMessage
mistral_messages.append(UserMessage(content=message["content"]))
else:
mistral_messages.append(SystemMessage(content=message["content"]))
elif message["role"] == "assistant":
mistral_messages.append(AssistantMessage(content=message["content"]))
elif message["role"] == "user":
mistral_messages.append(UserMessage(content=message["content"]))

elif message["role"] == "tool":
# Indicates the result of a tool call, the name is the function name called
mistral_messages.append(
ChatMessage(
role="tool",
ToolMessage(
name=tool_call_ids[message["tool_call_id"]],
content=message["content"],
tool_call_id=message["tool_call_id"],
Expand All @@ -122,30 +147,28 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
else:
warnings.warn(f"Unknown message role {message['role']}", UserWarning)

# If a 'system' message follows an 'assistant' message, change it to 'user'
# This can occur when using LLM summarisation
for i in range(1, len(mistral_messages)):
if mistral_messages[i - 1].role == "assistant" and mistral_messages[i].role == "system":
mistral_messages[i].role = "user"
# 4. Last message needs to be user or tool, if not, add a "please continue" message
if not isinstance(mistral_messages[-1], UserMessage) and not isinstance(mistral_messages[-1], ToolMessage):
mistral_messages.append(UserMessage(content="Please continue."))

mistral_params["messages"] = mistral_messages

# 4. Add tools to the call if we have them and aren't hiding them
# 5. Add tools to the call if we have them and aren't hiding them
if "tools" in params:
hide_tools = validate_parameter(
params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"]
)
if not should_hide_tools(params["messages"], params["tools"], hide_tools):
mistral_params["tools"] = params["tools"]
mistral_params["tools"] = tool_def_to_mistral(params["tools"])

return mistral_params

def create(self, params: Dict[str, Any]) -> ChatCompletion:
# 1. Parse parameters to Mistral.AI API's parameters
mistral_params = self.parse_params(params)

# 2. Call Mistral.AI API
client = MistralClient(api_key=self.api_key)
mistral_response = client.chat(**mistral_params)
mistral_response = self._client.chat.complete(**mistral_params)
# TODO: Handle streaming

# 3. Convert Mistral response to OAI compatible format
Expand Down Expand Up @@ -191,7 +214,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
return response_oai

@staticmethod
def get_usage(response: ChatCompletionResponse) -> Dict:
def get_usage(response: ChatCompletion) -> Dict:
return {
"prompt_tokens": response.usage.prompt_tokens if response.usage is not None else 0,
"completion_tokens": response.usage.completion_tokens if response.usage is not None else 0,
Expand All @@ -203,25 +226,48 @@ def get_usage(response: ChatCompletionResponse) -> Dict:
}


def tool_def_to_mistral(tool_definitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Converts AutoGen tool definition to a mistral tool format"""

mistral_tools = []

for autogen_tool in tool_definitions:
mistral_tool = {
"type": "function",
"function": Function(
name=autogen_tool["function"]["name"],
description=autogen_tool["function"]["description"],
parameters=autogen_tool["function"]["parameters"],
),
}

mistral_tools.append(mistral_tool)

return mistral_tools


def calculate_mistral_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
"""Calculate the cost of the mistral response."""

# Prices per 1 million tokens
# Prices per 1 thousand tokens
# https://mistral.ai/technology/
model_cost_map = {
"open-mistral-7b": {"input": 0.25, "output": 0.25},
"open-mixtral-8x7b": {"input": 0.7, "output": 0.7},
"open-mixtral-8x22b": {"input": 2.0, "output": 6.0},
"mistral-small-latest": {"input": 1.0, "output": 3.0},
"mistral-medium-latest": {"input": 2.7, "output": 8.1},
"mistral-large-latest": {"input": 4.0, "output": 12.0},
"open-mistral-7b": {"input": 0.00025, "output": 0.00025},
"open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007},
"open-mixtral-8x22b": {"input": 0.002, "output": 0.006},
"mistral-small-latest": {"input": 0.001, "output": 0.003},
"mistral-medium-latest": {"input": 0.00275, "output": 0.0081},
"mistral-large-latest": {"input": 0.0003, "output": 0.0003},
"mistral-large-2407": {"input": 0.0003, "output": 0.0003},
"open-mistral-nemo-2407": {"input": 0.0003, "output": 0.0003},
"codestral-2405": {"input": 0.001, "output": 0.003},
}

# Ensure we have the model they are using and return the total cost
if model_name in model_cost_map:
costs = model_cost_map[model_name]

return (input_tokens * costs["input"] / 1_000_000) + (output_tokens * costs["output"] / 1_000_000)
return (input_tokens * costs["input"] / 1000) + (output_tokens * costs["output"] / 1000)
else:
warnings.warn(f"Cost calculation is not implemented for model {model_name}, will return $0.", UserWarning)
return 0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
"long-context": ["llmlingua<0.3"],
"anthropic": ["anthropic>=0.23.1"],
"mistral": ["mistralai>=0.2.0"],
"mistral": ["mistralai>=1.0.1"],
"groq": ["groq>=0.9.0"],
"cohere": ["cohere>=5.5.8"],
}
Expand Down
28 changes: 18 additions & 10 deletions test/oai/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import pytest

try:
from mistralai.models.chat_completion import ChatMessage
from mistralai import (
AssistantMessage,
Function,
FunctionCall,
Mistral,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
marklysze marked this conversation as resolved.
Show resolved Hide resolved

from autogen.oai.mistral import MistralAIClient, calculate_mistral_cost

Expand Down Expand Up @@ -66,17 +75,16 @@ def test_cost_calculation(mock_response):
cost=None,
model="mistral-large-latest",
)
assert (
calculate_mistral_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model)
== 0.0001
), "Cost for this should be $0.0001"
assert calculate_mistral_cost(
response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model
) == (15 / 1000 * 0.0003), "Cost for this should be $0.0000045"


# Test text generation
@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
@patch("autogen.oai.mistral.MistralClient.chat")
@patch("autogen.oai.mistral.MistralAIClient.create")
def test_create_response(mock_chat, mistral_client):
# Mock MistralClient.chat response
# Mock `mistral_response = client.chat.complete(**mistral_params)`
mock_mistral_response = MagicMock()
mock_mistral_response.choices = [
MagicMock(finish_reason="stop", message=MagicMock(content="Example Mistral response", tool_calls=None))
Expand Down Expand Up @@ -108,9 +116,9 @@ def test_create_response(mock_chat, mistral_client):

# Test functions/tools
@pytest.mark.skipif(skip, reason="Mistral.AI dependency is not installed")
@patch("autogen.oai.mistral.MistralClient.chat")
@patch("autogen.oai.mistral.MistralAIClient.create")
def test_create_response_with_tool_call(mock_chat, mistral_client):
# Mock `mistral_response = client.chat(**mistral_params)`
# Mock `mistral_response = client.chat.complete(**mistral_params)`
mock_function = MagicMock(name="currency_calculator")
mock_function.name = "currency_calculator"
mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}'
Expand Down Expand Up @@ -159,7 +167,7 @@ def test_create_response_with_tool_call(mock_chat, mistral_client):
{"role": "assistant", "content": "World"},
]

# Call the create method
# Call the chat method
response = mistral_client.create(
{"messages": mistral_messages, "tools": converted_functions, "model": "mistral-medium-latest"}
)
Expand Down
Loading