Skip to content

Commit

Permalink
Merge pull request #706 from Mirascope/support-mistral-vision-models
Browse files Browse the repository at this point in the history
feat: Support mistral vision models
  • Loading branch information
koxudaxi authored Nov 21, 2024
2 parents 37afeaf + d6c72c3 commit 80eef5b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 24 deletions.
39 changes: 36 additions & 3 deletions mirascope/core/mistral/_utils/_convert_message_params.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Utility for converting `BaseMessageParam` to `ChatMessage`."""

import base64

from mistralai.models import (
AssistantMessage,
ImageURL,
ImageURLChunk,
SystemMessage,
TextChunk,
ToolMessage,
UserMessage,
)
Expand Down Expand Up @@ -37,9 +42,37 @@ def convert_message_params(
elif isinstance(content := message_param.content, str):
converted_message_params.append(_make_message(**message_param.model_dump()))
else:
if len(content) != 1 or content[0].type != "text":
raise ValueError("Mistral currently only supports text parts.")
converted_content = []
for part in content:
if part.type == "text":
converted_content.append(TextChunk(text=part.text))

elif part.type == "image":
if part.media_type not in [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
]:
raise ValueError(
f"Unsupported image media type: {part.media_type}. Mistral"
" currently only supports JPEG, PNG, GIF, and WebP images."
)
data = base64.b64encode(part.image).decode("utf-8")
converted_content.append(
ImageURLChunk(
image_url=ImageURL(
url=f"data:{part.media_type};base64,{data}",
detail=part.detail if part.detail else "auto",
)
)
)
else:
raise ValueError(
"Mistral currently only supports text and image parts. "
f"Part provided: {part.type}"
)
converted_message_params.append(
_make_message(role=message_param.role, content=content[0].text)
_make_message(role=message_param.role, content=converted_content)
)
return converted_message_params
66 changes: 45 additions & 21 deletions tests/core/mistral/_utils/test_convert_message_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import pytest
from mistralai.models import (
AssistantMessage,
ImageURL,
ImageURLChunk,
SystemMessage,
TextChunk,
ToolMessage,
UserMessage,
)
Expand All @@ -30,12 +33,21 @@ def test_convert_message_params() -> None:
),
SystemMessage(content="Hello", role="system"),
ToolMessage(content="Hello", tool_call_id=Unset(), name=Unset(), role="tool"),
BaseMessageParam(
role="user",
content=[
TextPart(type="text", text="Hello"),
ImagePart(
type="image", media_type="image/jpeg", image=b"image", detail="auto"
),
],
),
]
converted_message_params = convert_message_params(message_params)
assert converted_message_params == [
UserMessage(content="Hello"),
UserMessage(role="user", content="Hello"),
UserMessage(role="user", content="Hello"),
UserMessage(content=[TextChunk(text="Hello", TYPE="text")], role="user"),
AssistantMessage(content="Hello"),
SystemMessage(content="Hello"),
ToolMessage(content="Hello", tool_call_id=Unset(), name=Unset(), role="tool"),
Expand All @@ -44,53 +56,65 @@ def test_convert_message_params() -> None:
),
SystemMessage(content="Hello", role="system"),
ToolMessage(content="Hello"),
UserMessage(
role="user",
content=[
TextChunk(text="Hello"),
ImageURLChunk(
image_url=ImageURL(
url="data:image/jpeg;base64,aW1hZ2U=", detail="auto"
)
),
],
),
]

with pytest.raises(
ValueError,
match="Mistral currently only supports text parts.",
match="Mistral currently only supports text and image parts. Part provided: audio",
):
convert_message_params(
[
BaseMessageParam(
role="user",
content=[
ImagePart(
type="image",
media_type="image/jpeg",
image=b"image",
detail="auto",
AudioPart(
type="audio",
media_type="audio/wav",
audio=b"audio",
)
],
)
),
]
)

with pytest.raises(
ValueError,
match="Mistral currently only supports text parts.",
match="Invalid role: invalid_role",
):
convert_message_params(
[
BaseMessageParam(
role="user",
content=[
AudioPart(
type="audio",
media_type="audio/wav",
audio=b"audio",
)
],
),
BaseMessageParam(role="invalid_role", content="Hello"),
]
)

with pytest.raises(
ValueError,
match="Invalid role: invalid_role",
match="Unsupported image media type: image/svg."
" Mistral currently only supports JPEG, PNG, GIF, and WebP images.",
):
convert_message_params(
[
BaseMessageParam(role="invalid_role", content="Hello"),
BaseMessageParam(
role="user",
content=[
ImagePart(
type="image",
media_type="image/svg",
image=b"image",
detail="auto",
)
],
)
]
)

0 comments on commit 80eef5b

Please sign in to comment.