Skip to content

Commit

Permalink
Bump mistralai to > 1.0.0 in preparation for latest models such as Pi…
Browse files Browse the repository at this point in the history
…xtral

Related to #521
  • Loading branch information
skylarbpayne committed Oct 7, 2024
1 parent 7c1f6e3 commit 861f53c
Show file tree
Hide file tree
Showing 27 changed files with 387 additions and 295 deletions.
5 changes: 3 additions & 2 deletions examples/learn/calls/basic_call/mistral/official_sdk_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mistralai.client import MistralClient
from mistralai import Mistral
import os

client = MistralClient()
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", ""))


def recommend_book(genre: str) -> str:
Expand Down
15 changes: 11 additions & 4 deletions examples/learn/calls/custom_client/mistral/base_message_param.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os

from mirascope.core import BaseMessageParam, mistral
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
def recommend_book(genre: str) -> list[BaseMessageParam]:
return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")]


@mistral.call("mistral-large-latest", client=MistralAsyncClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
async def recommend_book_async(genre: str) -> list[BaseMessageParam]:
return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")]
15 changes: 11 additions & 4 deletions examples/learn/calls/custom_client/mistral/messages.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os

from mirascope.core import Messages, mistral
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
def recommend_book(genre: str) -> Messages.Type:
return Messages.User(f"Recommend a {genre} book")


@mistral.call("mistral-large-latest", client=MistralAsyncClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
async def recommend_book_async(genre: str) -> Messages.Type:
return Messages.User(f"Recommend a {genre} book")
15 changes: 11 additions & 4 deletions examples/learn/calls/custom_client/mistral/shorthand.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os

from mirascope.core import mistral
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
def recommend_book(genre: str) -> str:
return f"Recommend a {genre} book"


@mistral.call("mistral-large-latest", client=MistralAsyncClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
async def recommend_book_async(genre: str) -> str:
return f"Recommend a {genre} book"
15 changes: 11 additions & 4 deletions examples/learn/calls/custom_client/mistral/string_template.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import os

from mirascope.core import mistral, prompt_template
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
@prompt_template("Recommend a {genre} book")
def recommend_book(genre: str): ...


@mistral.call("mistral-large-latest", client=MistralAsyncClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")),
)
@prompt_template("Recommend a {genre} book")
async def recommend_book_async(genre: str): ...
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ToolChoice
import os

from mistralai.client import Mistral
from mistralai.models import ToolChoice
from pydantic import BaseModel

client = MistralClient()
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", ""))


class Book(BaseModel):
Expand Down
11 changes: 9 additions & 2 deletions mirascope/core/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from typing import TypeAlias

from mistralai.models.chat_completion import ChatMessage
from mistralai.models import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)

from ..base import BaseMessageParam
from ._call import mistral_call
Expand All @@ -14,7 +19,9 @@
from .stream import MistralStream
from .tool import MistralTool

MistralMessageParam: TypeAlias = ChatMessage | BaseMessageParam
MistralMessageParam: TypeAlias = (
AssistantMessage | SystemMessage | ToolMessage | UserMessage | BaseMessageParam
)

__all__ = [
"call",
Expand Down
24 changes: 18 additions & 6 deletions mirascope/core/mistral/_utils/_convert_message_params.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
"""Utility for converting `BaseMessageParam` to `ChatMessage`."""

from mistralai.models.chat_completion import ChatMessage
from mistralai.models import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)

from ...base import BaseMessageParam


def convert_message_params(
message_params: list[BaseMessageParam | ChatMessage],
) -> list[ChatMessage]:
message_params: list[
BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage
],
) -> list[BaseMessageParam]:
converted_message_params = []
for message_param in message_params:
if isinstance(message_param, ChatMessage):
if not isinstance(
message_param,
BaseMessageParam,
):
converted_message_params.append(message_param)
elif isinstance(content := message_param.content, str):
converted_message_params.append(ChatMessage(**message_param.model_dump()))
converted_message_params.append(
BaseMessageParam(**message_param.model_dump())
)
else:
if len(content) != 1 or content[0].type != "text":
raise ValueError("Mistral currently only supports text parts.")
converted_message_params.append(
ChatMessage(role=message_param.role, content=content[0].text)
BaseMessageParam(role=message_param.role, content=content[0].text)
)
return converted_message_params
17 changes: 8 additions & 9 deletions mirascope/core/mistral/_utils/_handle_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@

from collections.abc import AsyncGenerator, Generator

from mistralai.models.chat_completion import (
ChatCompletionStreamResponse,
from mistralai.models import (
CompletionEvent,
FunctionCall,
ToolCall,
ToolType,
)

from ..call_response_chunk import MistralCallResponseChunk
from ..tool import MistralTool


def _handle_chunk(
chunk: ChatCompletionStreamResponse,
chunk: CompletionEvent,
current_tool_call: ToolCall,
current_tool_type: type[MistralTool] | None,
tool_types: list[type[MistralTool]] | None,
Expand All @@ -38,7 +37,7 @@ def _handle_chunk(
arguments="",
name=tool_call.function.name if tool_call.function.name else "",
),
type=ToolType.function,
type="function",
)
current_tool_type = None
for tool_type in tool_types:
Expand All @@ -64,12 +63,12 @@ def _handle_chunk(


def handle_stream(
stream: Generator[ChatCompletionStreamResponse, None, None],
stream: Generator[CompletionEvent, None, None],
tool_types: list[type[MistralTool]] | None,
) -> Generator[tuple[MistralCallResponseChunk, MistralTool | None], None, None]:
"""Iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ToolCall(
id="", function=FunctionCall(arguments="", name=""), type=ToolType.function
id="", function=FunctionCall(arguments="", name=""), type="function"
)
current_tool_type = None
for chunk in stream:
Expand All @@ -93,12 +92,12 @@ def handle_stream(


async def handle_stream_async(
stream: AsyncGenerator[ChatCompletionStreamResponse, None],
stream: AsyncGenerator[CompletionEvent, None],
tool_types: list[type[MistralTool]] | None,
) -> AsyncGenerator[tuple[MistralCallResponseChunk, MistralTool | None], None]:
"""Async iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ToolCall(
id="", function=FunctionCall(arguments="", name=""), type=ToolType.function
id="", function=FunctionCall(arguments="", name=""), type="function"
)
current_tool_type = None
async for chunk in stream:
Expand Down
61 changes: 32 additions & 29 deletions mirascope/core/mistral/_utils/_setup_call.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
"""This module contains the setup_call function for Mistral tools."""

import inspect
import os
from collections.abc import (
Awaitable,
Callable,
)
from typing import Any, cast, overload

from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.models.chat_completion import (
from mistralai import Mistral
from mistralai.models import (
AssistantMessage,
ChatCompletionResponse,
ChatCompletionStreamResponse,
ChatMessage,
CompletionEvent,
ResponseFormat,
ResponseFormats,
ToolChoice,
SystemMessage,
ToolChoiceEnum,
ToolMessage,
UserMessage,
)

from mirascope.core.base._utils._protocols import fn_is_async

from ...base import BaseMessageParam, BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn
from ..call_kwargs import MistralCallKwargs
Expand All @@ -31,7 +34,7 @@
def setup_call(
*,
model: str,
client: MistralClient | MistralAsyncClient | None,
client: Mistral | None,
fn: Callable[..., Awaitable[MistralDynamicConfig]],
fn_args: dict[str, Any],
dynamic_config: MistralDynamicConfig,
Expand All @@ -40,9 +43,9 @@ def setup_call(
call_params: MistralCallParams,
extract: bool,
) -> tuple[
AsyncCreateFn[ChatCompletionResponse, ChatCompletionStreamResponse],
AsyncCreateFn[ChatCompletionResponse, CompletionEvent],
str | None,
list[ChatMessage],
list[AssistantMessage | SystemMessage | ToolMessage | UserMessage],
list[type[MistralTool]] | None,
MistralCallKwargs,
]: ...
Expand All @@ -52,7 +55,7 @@ def setup_call(
def setup_call(
*,
model: str,
client: MistralClient | MistralAsyncClient | None,
client: Mistral | None,
fn: Callable[..., MistralDynamicConfig],
fn_args: dict[str, Any],
dynamic_config: MistralDynamicConfig,
Expand All @@ -61,9 +64,9 @@ def setup_call(
call_params: MistralCallParams,
extract: bool,
) -> tuple[
CreateFn[ChatCompletionResponse, ChatCompletionStreamResponse],
CreateFn[ChatCompletionResponse, CompletionEvent],
str | None,
list[ChatMessage],
list[AssistantMessage | SystemMessage | ToolMessage | UserMessage],
list[type[MistralTool]] | None,
MistralCallKwargs,
]: ...
Expand All @@ -72,7 +75,7 @@ def setup_call(
def setup_call(
*,
model: str,
client: MistralClient | MistralAsyncClient | None,
client: Mistral | None,
fn: Callable[..., MistralDynamicConfig | Awaitable[MistralDynamicConfig]],
fn_args: dict[str, Any],
dynamic_config: MistralDynamicConfig,
Expand All @@ -81,42 +84,42 @@ def setup_call(
call_params: MistralCallParams,
extract: bool,
) -> tuple[
CreateFn[ChatCompletionResponse, ChatCompletionStreamResponse]
| AsyncCreateFn[ChatCompletionResponse, ChatCompletionStreamResponse],
CreateFn[ChatCompletionResponse, CompletionEvent]
| AsyncCreateFn[ChatCompletionResponse, CompletionEvent],
str | None,
list[ChatMessage],
list[AssistantMessage | SystemMessage | ToolMessage | UserMessage],
list[type[MistralTool]] | None,
MistralCallKwargs,
]:
prompt_template, messages, tool_types, base_call_kwargs = _utils.setup_call(
fn, fn_args, dynamic_config, tools, MistralTool, call_params
)
call_kwargs = cast(MistralCallKwargs, base_call_kwargs)
messages = cast(list[BaseMessageParam | ChatMessage], messages)
messages = cast(
list[AssistantMessage | SystemMessage | ToolMessage | UserMessage], messages
)
messages = convert_message_params(messages)
if json_mode:
call_kwargs["response_format"] = ResponseFormat(
type=ResponseFormats("json_object")
)
call_kwargs["response_format"] = ResponseFormat(type="json_object")
json_mode_content = _utils.json_mode_content(
tool_types[0] if tool_types else None
)
if messages[-1].role == "user":
messages[-1].content += json_mode_content
else:
messages.append(ChatMessage(role="user", content=json_mode_content.strip()))
messages.append(UserMessage(content=json_mode_content.strip()))
call_kwargs.pop("tools", None)
elif extract:
assert tool_types, "At least one tool must be provided for extraction."
call_kwargs["tool_choice"] = cast(ToolChoice, ToolChoice.any)
call_kwargs["tool_choice"] = cast(ToolChoiceEnum, "any")
call_kwargs |= {"model": model, "messages": messages}

if client is None:
client = (
MistralAsyncClient() if inspect.iscoroutinefunction(fn) else MistralClient()
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", ""))
if fn_is_async(fn):
create_or_stream = get_async_create_fn(
client.chat.complete_async, client.chat.stream_async
)
if isinstance(client, MistralAsyncClient):
create_or_stream = get_async_create_fn(client.chat, client.chat_stream)
else:
create_or_stream = get_create_fn(client.chat, client.chat_stream)
create_or_stream = get_create_fn(client.chat.complete, client.chat.stream)
return create_or_stream, prompt_template, messages, tool_types, call_kwargs
Loading

0 comments on commit 861f53c

Please sign in to comment.