Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,36 @@
from opentelemetry.trace.status import Status, StatusCode
from wrapt import wrap_function_wrapper

from mistralai.models.chat_completion import (
from mistralai.models import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
ChatCompletionChoice,
AssistantMessage,
UserMessage,
SystemMessage,
UsageInfo,
EmbeddingResponse,
)
from mistralai.models.common import UsageInfo
from mistralai.models.embeddings import EmbeddingResponse

logger = logging.getLogger(__name__)

_instruments = ("mistralai >= 0.2.0, < 1",)
_instruments = ("mistralai >= 1.0.0",)

WRAPPED_METHODS = [
{
"method": "chat",
"method": "complete",
"module": "chat",
"span_name": "mistralai.chat",
"streaming": False,
},
{
"method": "chat_stream",
"method": "stream",
"module": "chat",
"span_name": "mistralai.chat",
"streaming": True,
},
{
"method": "embeddings",
"method": "create",
"module": "embeddings",
"span_name": "mistralai.embeddings",
"streaming": False,
},
Expand Down Expand Up @@ -92,7 +97,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
message.role,
)
else:
input = kwargs.get("input")
input = kwargs.get("input") or kwargs.get("inputs")

if isinstance(input, str):
_set_span_attribute(
Expand All @@ -101,7 +106,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
_set_span_attribute(
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", input
)
else:
elif input:
for index, prompt in enumerate(input):
_set_span_attribute(
span,
Expand Down Expand Up @@ -205,20 +210,22 @@ def _accumulate_streaming_response(span, event_logger, llm_request_type, respons
for res in response:
yield res

if res.model:
accumulated_response.model = res.model
if res.usage:
accumulated_response.usage = res.usage
# Handle new CompletionEvent structure with .data attribute
chunk_data = res.data if hasattr(res, 'data') else res
if chunk_data.model:
accumulated_response.model = chunk_data.model
if chunk_data.usage:
accumulated_response.usage = chunk_data.usage
# Id is the same for all chunks, so it's safe to overwrite it every time
if res.id:
accumulated_response.id = res.id
if chunk_data.id:
accumulated_response.id = chunk_data.id

for idx, choice in enumerate(res.choices):
for idx, choice in enumerate(chunk_data.choices):
if len(accumulated_response.choices) <= idx:
accumulated_response.choices.append(
ChatCompletionResponseChoice(
ChatCompletionChoice(
index=idx,
message=ChatMessage(role="assistant", content=""),
message=AssistantMessage(role="assistant", content=""),
finish_reason=None,
)
)
Expand Down Expand Up @@ -247,20 +254,22 @@ async def _aaccumulate_streaming_response(
async for res in response:
yield res

if res.model:
accumulated_response.model = res.model
if res.usage:
accumulated_response.usage = res.usage
# Handle new CompletionEvent structure with .data attribute
chunk_data = res.data if hasattr(res, 'data') else res
if chunk_data.model:
accumulated_response.model = chunk_data.model
if chunk_data.usage:
accumulated_response.usage = chunk_data.usage
# Id is the same for all chunks, so it's safe to overwrite it every time
if res.id:
accumulated_response.id = res.id
if chunk_data.id:
accumulated_response.id = chunk_data.id

for idx, choice in enumerate(res.choices):
for idx, choice in enumerate(chunk_data.choices):
if len(accumulated_response.choices) <= idx:
accumulated_response.choices.append(
ChatCompletionResponseChoice(
ChatCompletionChoice(
index=idx,
message=ChatMessage(role="assistant", content=""),
message=AssistantMessage(role="assistant", content=""),
finish_reason=None,
)
)
Expand All @@ -287,9 +296,9 @@ def wrapper(wrapped, instance, args, kwargs):


def _llm_request_type_by_method(method_name):
if method_name == "chat" or method_name == "chat_stream":
if method_name == "complete" or method_name == "stream":
return LLMRequestTypeValues.CHAT
elif method_name == "embeddings":
elif method_name == "create":
return LLMRequestTypeValues.EMBEDDING
else:
return LLMRequestTypeValues.UNKNOWN
Expand All @@ -301,7 +310,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
if method_wrapped == "mistralai.chat":
messages = args[0] if len(args) > 0 else kwargs.get("messages", [])
for message in messages:
if isinstance(message, ChatMessage):
if isinstance(message, (UserMessage, AssistantMessage, SystemMessage)):
role = message.role
content = message.content
elif isinstance(message, dict):
Expand All @@ -313,7 +322,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):

# Handle embedding events
elif method_wrapped == "mistralai.embeddings":
embedding_input = args[0] if len(args) > 0 else kwargs.get("input", [])
embedding_input = args[0] if len(args) > 0 else (kwargs.get("input") or kwargs.get("inputs", []))
if isinstance(embedding_input, str):
emit_event(MessageEvent(content=embedding_input, role="user"), event_logger)
elif isinstance(embedding_input, list):
Expand Down Expand Up @@ -452,7 +461,7 @@ async def _awrap(
_handle_input(span, event_logger, args, kwargs, to_wrap)

if to_wrap.get("streaming"):
response = wrapped(*args, **kwargs)
response = await wrapped(*args, **kwargs)
else:
response = await wrapped(*args, **kwargs)

Expand Down Expand Up @@ -495,21 +504,23 @@ def _instrument(self, **kwargs):

for wrapped_method in WRAPPED_METHODS:
wrap_method = wrapped_method.get("method")
module_name = wrapped_method.get("module")
# Wrap sync methods on the class
wrap_function_wrapper(
"mistralai.client",
f"MistralClient.{wrap_method}",
f"mistralai.{module_name}",
f"{module_name.capitalize()}.{wrap_method}",
_wrap(tracer, event_logger, wrapped_method),
)
# Wrap async methods on the class
wrap_function_wrapper(
"mistralai.async_client",
f"MistralAsyncClient.{wrap_method}",
f"mistralai.{module_name}",
f"{module_name.capitalize()}.{wrap_method}_async",
_awrap(tracer, event_logger, wrapped_method),
)

Comment on lines 507 to 520

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Unwrap target mismatch prevents proper uninstrumentation

You wrap with wrap_function_wrapper("mistralai.{module}", "{Class}.{method}", ...), but unwrap uses "mistralai.{module}.{Class}", "method". This won’t remove the wrapper, risking double-wrapping.

Apply:

-            unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", wrap_method)
-            unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", f"{wrap_method}_async")
+            unwrap(f"mistralai.{module_name}", f"{module_name.capitalize()}.{wrap_method}")
+            unwrap(f"mistralai.{module_name}", f"{module_name.capitalize()}.{wrap_method}_async")

Also applies to: 524-531

🤖 Prompt for AI Agents
In
packages/opentelemetry-instrumentation-mistralai/opentelemetry/instrumentation/mistralai/__init__.py
around lines 507-523 (and similarly 524-531), the wrap targets use
wrap_function_wrapper("mistralai.{module}", "{Class}.{method}", ...) but
unwrapping expects ("mistralai.{module}.{Class}", "method"), so change the wrap
calls to use the same module/attribute split as unwrap: call
wrap_function_wrapper(f"mistralai.{module_name}.{module_name.capitalize()}",
f"{wrap_method}", _wrap(...)) and for async use f"{wrap_method}_async" as the
attribute, ensuring the module string includes the Class and the attribute is
only the method name so wrapping and unwrapping targets match.

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
unwrap("mistralai.client.MistralClient", wrapped_method.get("method"))
unwrap(
"mistralai.async_client.MistralAsyncClient",
wrapped_method.get("method"),
)
wrap_method = wrapped_method.get("method")
module_name = wrapped_method.get("module")
unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", wrap_method)
unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", f"{wrap_method}_async")
Loading