Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,36 @@
from opentelemetry.trace.status import Status, StatusCode
from wrapt import wrap_function_wrapper

from mistralai.models.chat_completion import (
from mistralai.models import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
ChatCompletionChoice,
AssistantMessage,
UserMessage,
SystemMessage,
UsageInfo,
EmbeddingResponse,
)
from mistralai.models.common import UsageInfo
from mistralai.models.embeddings import EmbeddingResponse

logger = logging.getLogger(__name__)

_instruments = ("mistralai >= 0.2.0, < 1",)
_instruments = ("mistralai >= 1.0.0",)

WRAPPED_METHODS = [
{
"method": "chat",
"method": "complete",
"module": "chat",
"span_name": "mistralai.chat",
"streaming": False,
},
{
"method": "chat_stream",
"method": "stream",
"module": "chat",
"span_name": "mistralai.chat",
"streaming": True,
},
{
"method": "embeddings",
"method": "create",
"module": "embeddings",
"span_name": "mistralai.embeddings",
"streaming": False,
},
Expand Down Expand Up @@ -92,7 +97,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
message.role,
)
else:
input = kwargs.get("input")
input = kwargs.get("input") or kwargs.get("inputs")

if isinstance(input, str):
_set_span_attribute(
Expand All @@ -101,7 +106,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
_set_span_attribute(
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", input
)
else:
elif input:
for index, prompt in enumerate(input):
_set_span_attribute(
span,
Expand Down Expand Up @@ -205,20 +210,22 @@ def _accumulate_streaming_response(span, event_logger, llm_request_type, respons
for res in response:
yield res

if res.model:
accumulated_response.model = res.model
if res.usage:
accumulated_response.usage = res.usage
# Handle new CompletionEvent structure with .data attribute
chunk_data = res.data if hasattr(res, 'data') else res
if chunk_data.model:
accumulated_response.model = chunk_data.model
if chunk_data.usage:
accumulated_response.usage = chunk_data.usage
# Id is the same for all chunks, so it's safe to overwrite it every time
if res.id:
accumulated_response.id = res.id
if chunk_data.id:
accumulated_response.id = chunk_data.id

for idx, choice in enumerate(res.choices):
for idx, choice in enumerate(chunk_data.choices):
if len(accumulated_response.choices) <= idx:
accumulated_response.choices.append(
ChatCompletionResponseChoice(
ChatCompletionChoice(
index=idx,
message=ChatMessage(role="assistant", content=""),
message=AssistantMessage(role="assistant", content=""),
finish_reason=None,
)
)
Expand Down Expand Up @@ -247,20 +254,22 @@ async def _aaccumulate_streaming_response(
async for res in response:
yield res

if res.model:
accumulated_response.model = res.model
if res.usage:
accumulated_response.usage = res.usage
# Handle new CompletionEvent structure with .data attribute
chunk_data = res.data if hasattr(res, 'data') else res
if chunk_data.model:
accumulated_response.model = chunk_data.model
if chunk_data.usage:
accumulated_response.usage = chunk_data.usage
# Id is the same for all chunks, so it's safe to overwrite it every time
if res.id:
accumulated_response.id = res.id
if chunk_data.id:
accumulated_response.id = chunk_data.id

for idx, choice in enumerate(res.choices):
for idx, choice in enumerate(chunk_data.choices):
if len(accumulated_response.choices) <= idx:
accumulated_response.choices.append(
ChatCompletionResponseChoice(
ChatCompletionChoice(
index=idx,
message=ChatMessage(role="assistant", content=""),
message=AssistantMessage(role="assistant", content=""),
finish_reason=None,
)
)
Expand All @@ -287,9 +296,9 @@ def wrapper(wrapped, instance, args, kwargs):


def _llm_request_type_by_method(method_name):
if method_name == "chat" or method_name == "chat_stream":
if method_name == "complete" or method_name == "stream":
return LLMRequestTypeValues.CHAT
elif method_name == "embeddings":
elif method_name == "create":
return LLMRequestTypeValues.EMBEDDING
else:
return LLMRequestTypeValues.UNKNOWN
Expand All @@ -301,7 +310,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
if method_wrapped == "mistralai.chat":
messages = args[0] if len(args) > 0 else kwargs.get("messages", [])
for message in messages:
if isinstance(message, ChatMessage):
if isinstance(message, (UserMessage, AssistantMessage, SystemMessage)):
role = message.role
content = message.content
elif isinstance(message, dict):
Expand All @@ -313,7 +322,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):

# Handle embedding events
elif method_wrapped == "mistralai.embeddings":
embedding_input = args[0] if len(args) > 0 else kwargs.get("input", [])
embedding_input = args[0] if len(args) > 0 else (kwargs.get("input") or kwargs.get("inputs", []))
if isinstance(embedding_input, str):
emit_event(MessageEvent(content=embedding_input, role="user"), event_logger)
elif isinstance(embedding_input, list):
Expand Down Expand Up @@ -452,7 +461,7 @@ async def _awrap(
_handle_input(span, event_logger, args, kwargs, to_wrap)

if to_wrap.get("streaming"):
response = wrapped(*args, **kwargs)
response = await wrapped(*args, **kwargs)
else:
response = await wrapped(*args, **kwargs)

Expand Down Expand Up @@ -495,21 +504,23 @@ def _instrument(self, **kwargs):

for wrapped_method in WRAPPED_METHODS:
wrap_method = wrapped_method.get("method")
module_name = wrapped_method.get("module")
# Wrap sync methods on the class
wrap_function_wrapper(
"mistralai.client",
f"MistralClient.{wrap_method}",
f"mistralai.{module_name}",
f"{module_name.capitalize()}.{wrap_method}",
_wrap(tracer, event_logger, wrapped_method),
)
# Wrap async methods on the class
wrap_function_wrapper(
"mistralai.async_client",
f"MistralAsyncClient.{wrap_method}",
f"mistralai.{module_name}",
f"{module_name.capitalize()}.{wrap_method}_async",
_awrap(tracer, event_logger, wrapped_method),
)

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
unwrap("mistralai.client.MistralClient", wrapped_method.get("method"))
unwrap(
"mistralai.async_client.MistralAsyncClient",
wrapped_method.get("method"),
)
wrap_method = wrapped_method.get("method")
module_name = wrapped_method.get("module")
unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", wrap_method)
unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", f"{wrap_method}_async")
Loading