Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
emit_response_events,
)
from opentelemetry.instrumentation.cohere.span_utils import (
set_input_attributes,
set_response_attributes,
set_input_content_attributes,
set_response_content_attributes,
set_span_request_attributes,
set_span_response_attributes,
)
from opentelemetry.instrumentation.cohere.streaming import (
process_chat_v1_streaming_response,
aprocess_chat_v1_streaming_response,
process_chat_v2_streaming_response,
aprocess_chat_v2_streaming_response,
)
from opentelemetry.instrumentation.cohere.utils import dont_throw, should_emit_events
from opentelemetry.instrumentation.cohere.version import __version__
Expand All @@ -27,7 +34,7 @@
LLMRequestTypeValues,
SpanAttributes,
)
from opentelemetry.trace import SpanKind, Tracer, get_tracer
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer, get_tracer, use_span
from wrapt import wrap_function_wrapper

logger = logging.getLogger(__name__)
Expand All @@ -36,20 +43,121 @@

WRAPPED_METHODS = [
{
"module": "cohere.client",
"object": "Client",
"method": "generate",
"span_name": "cohere.completion",
},
{
"module": "cohere.client",
"object": "Client",
"method": "chat",
"span_name": "cohere.chat",
},
{
"module": "cohere.client",
"object": "Client",
"method": "chat_stream",
"span_name": "cohere.chat",
"stream_process_func": process_chat_v1_streaming_response,
},
{
"module": "cohere.client",
"object": "Client",
"method": "rerank",
"span_name": "cohere.rerank",
},
{
"module": "cohere.client",
"object": "Client",
"method": "embed",
"span_name": "cohere.embed",
},
{
"module": "cohere.client_v2",
"object": "ClientV2",
"method": "chat",
"span_name": "cohere.chat",
},
{
"module": "cohere.client_v2",
"object": "ClientV2",
"method": "chat_stream",
"span_name": "cohere.chat",
"stream_process_func": process_chat_v2_streaming_response,
},
{
"module": "cohere.client_v2",
"object": "ClientV2",
"method": "rerank",
"span_name": "cohere.rerank",
},
{
"module": "cohere.client_v2",
"object": "ClientV2",
"method": "embed",
"span_name": "cohere.embed",
},
# Async methods that return AsyncIterator must be wrapped with sync wrapper
{
"module": "cohere.client",
"object": "AsyncClient",
"method": "chat_stream",
"span_name": "cohere.chat",
"stream_process_func": aprocess_chat_v1_streaming_response,
},
{
"module": "cohere.client_v2",
"object": "AsyncClientV2",
"method": "chat_stream",
"span_name": "cohere.chat",
"stream_process_func": aprocess_chat_v2_streaming_response,
},
]

WRAPPED_AMETHODS = [
{
"module": "cohere.client",
"object": "AsyncClient",
"method": "generate",
"span_name": "cohere.completion",
},
{
"module": "cohere.client",
"object": "AsyncClient",
"method": "chat",
"span_name": "cohere.chat",
},
{
"module": "cohere.client",
"object": "AsyncClient",
"method": "rerank",
"span_name": "cohere.rerank",
},
{
"module": "cohere.client",
"object": "AsyncClient",
"method": "embed",
"span_name": "cohere.embed",
},
{
"module": "cohere.client_v2",
"object": "AsyncClientV2",
"method": "chat",
"span_name": "cohere.chat",
},
{
"module": "cohere.client_v2",
"object": "AsyncClientV2",
"method": "rerank",
"span_name": "cohere.rerank",
},
{
"module": "cohere.client_v2",
"object": "AsyncClientV2",
"method": "embed",
"span_name": "cohere.embed",
},
]


Expand All @@ -66,30 +174,30 @@ def wrapper(wrapped, instance, args, kwargs):


def _llm_request_type_by_method(method_name):
if method_name == "chat":
if method_name in ["chat", "chat_stream"]:
return LLMRequestTypeValues.CHAT
elif method_name == "generate":
elif method_name in ["generate", "generate_stream"]:
return LLMRequestTypeValues.COMPLETION
elif method_name == "rerank":
return LLMRequestTypeValues.RERANK
elif method_name == "embed":
return LLMRequestTypeValues.EMBEDDING
else:
return LLMRequestTypeValues.UNKNOWN


@dont_throw
def _handle_input(span, event_logger, llm_request_type, kwargs):
def _handle_input_content(span, event_logger, llm_request_type, kwargs):
set_input_content_attributes(span, llm_request_type, kwargs)
if should_emit_events():
emit_input_event(event_logger, llm_request_type, kwargs)
else:
set_input_attributes(span, llm_request_type, kwargs)


@dont_throw
def _handle_response(span, event_logger, llm_request_type, response):
def _handle_response_content(span, event_logger, llm_request_type, response):
set_response_content_attributes(span, llm_request_type, response)
if should_emit_events():
emit_response_events(event_logger, llm_request_type, response)
else:
set_response_attributes(span, llm_request_type, response)


@_with_tracer_wrapper
Expand All @@ -108,6 +216,55 @@ def _wrap(
):
return wrapped(*args, **kwargs)

name = to_wrap.get("span_name")
llm_request_type = _llm_request_type_by_method(to_wrap.get("method"))
span = tracer.start_span(
name,
kind=SpanKind.CLIENT,
attributes={
SpanAttributes.LLM_SYSTEM: "Cohere",
SpanAttributes.LLM_REQUEST_TYPE: llm_request_type.value,
},
)

with use_span(span, end_on_exit=False):
set_span_request_attributes(span, kwargs)
_handle_input_content(span, event_logger, llm_request_type, kwargs)

try:
response = wrapped(*args, **kwargs)
except Exception as e:
if span.is_recording():
span.set_status(Status(StatusCode.ERROR, str(e)))
span.record_exception(e)
span.end()
raise

if to_wrap.get("stream_process_func"):
return to_wrap.get("stream_process_func")(span, event_logger, llm_request_type, response)
Comment on lines +243 to +244

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify stream processor functions handle all edge cases

The streaming delegation correctly passes control to specialized stream processors. Ensure these processors handle early cancellation, connection drops, and other streaming edge cases.


🏁 Script executed:

#!/bin/bash
# Check if stream processors handle early cancellation and exceptions
ast-grep --pattern 'def $_(span, $_, $_, response):
  $$$
  try:
    $$$
  finally:
    span.end()
  $$$'

# Check test coverage for streaming edge cases
rg -A5 -B5 "test.*stream.*cancel|test.*stream.*error|test.*stream.*interrupt" --type=py

Length of output: 0


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Show file around lines 200-320 ==="
nl -ba packages/opentelemetry-instrumentation-cohere/opentelemetry/instrumentation/cohere/__init__.py | sed -n '1,400p' || true

echo "=== Search for 'stream_process_func' occurrences ==="
rg -n "stream_process_func" -S || true

echo "=== Search for 'stream' references in the package ==="
rg -n --hidden --no-ignore -S "stream" packages/opentelemetry-instrumentation-cohere || true

echo "=== Search for function defs containing 'stream' ==="
rg -n --type py "def\s+.*stream" -S || true

echo "=== Search tests for streaming/cancel/error keywords ==="
rg -n --type py "stream.*cancel|stream.*error|stream.*interrupt|cancel.*stream|interrupt.*stream|stream.*exception" -S || true

Length of output: 58215


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== streaming.py (cohere) ==="
awk '{printf("%5d: %s\n", NR, $0)}' packages/opentelemetry-instrumentation-cohere/opentelemetry/instrumentation/cohere/streaming.py || true

echo "=== __init__.py (cohere) ==="
awk '{printf("%5d: %s\n", NR, $0)}' packages/opentelemetry-instrumentation-cohere/opentelemetry/instrumentation/cohere/__init__.py || true

echo "=== Relevant test snippets (tests/test_chat.py head) ==="
awk '{printf("%5d: %s\n", NR, $0)}' packages/opentelemetry-instrumentation-cohere/tests/test_chat.py | sed -n '1,260p' || true

echo "=== Search for try/finally/exception handling in streaming.py ==="
rg -n --hidden --no-ignore -S "try:|except |finally|span.end\(|stream-end|is_finished" packages/opentelemetry-instrumentation-cohere/opentelemetry/instrumentation/cohere/streaming.py || true

echo "=== Search for tests covering streaming exceptions/cancellation in cohere tests ==="
rg -n --hidden --no-ignore -S "exception|cancel|not_consumed|partial_consumption|interrupt|disconnect|connection|broken|stream.*exception" packages/opentelemetry-instrumentation-cohere/tests || true

Length of output: 39167


Guarantee span.end() and robust cancellation/error handling for async v2 streaming

  • aprocess_chat_v2_streaming_response is missing a try/finally and only calls span.end() on the success path — file packages/opentelemetry-instrumentation-cohere/opentelemetry/instrumentation/cohere/streaming.py (async def aprocess_chat_v2_streaming_response, ~lines 111–144). Wrap the async iteration in try/finally to always end the span and record exceptions consistently (match the other processors that use finally).
  • Add unit tests that simulate early cancellation/connection drops and exceptions during iteration in packages/opentelemetry-instrumentation-cohere/tests/test_chat.py (no cohere tests currently cover these edge cases).


Comment on lines +243 to +245

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Streaming span lifecycle can leak on early-cancel; verify v2 error branch ends the span.
If the caller stops consuming the stream early, process_chat_* may never run its trailer and span may not end. Also, the v2 streaming processor sets ERROR but doesn’t end the span in the error path.

  • In streaming.py, wrap the yield loop in try/finally and end the span in finally when final_response is present or on cancel.
  • Also end the span in the error branch for v2 processors.

Example (streaming.py) pattern:

def process_chat_v2_streaming_response(span, event_logger, llm_request_type, response):
    final_response = {...}
    try:
        for item in response:
            span.add_event(name=f"{SpanAttributes.LLM_CONTENT_COMPLETION_CHUNK}")
            _accumulate_stream_item(..., final_response)
            yield item
    except Exception as e:
        if span.is_recording():
            span.set_status(Status(StatusCode.ERROR, str(e)))
            span.record_exception(e)
        raise
    finally:
        # set attributes only if we have something to set; still end the span
        if final_response:
            set_span_response_attributes(span, final_response)
            if should_emit_events():
                emit_response_events(event_logger, llm_request_type, final_response)
            elif should_send_prompts():
                _set_span_chat_response(span, final_response)
        if span.status is None or span.status.status_code == StatusCode.UNSET:
            span.set_status(Status(StatusCode.OK))
        span.end()

set_span_response_attributes(span, response)
_handle_response_content(span, event_logger, llm_request_type, response)
span.end()
return response


@_with_tracer_wrapper
async def _awrap(
tracer: Tracer,
event_logger: Union[EventLogger, None],
to_wrap,
wrapped,
instance,
args,
kwargs,
):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value(
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
):
return await wrapped(*args, **kwargs)

name = to_wrap.get("span_name")
llm_request_type = _llm_request_type_by_method(to_wrap.get("method"))
with tracer.start_as_current_span(
Expand All @@ -119,12 +276,19 @@ def _wrap(
},
) as span:
set_span_request_attributes(span, kwargs)
_handle_input(span, event_logger, llm_request_type, kwargs)
_handle_input_content(span, event_logger, llm_request_type, kwargs)

response = wrapped(*args, **kwargs)
try:
response = await wrapped(*args, **kwargs)
except Exception as e:
if span.is_recording():
span.set_status(Status(StatusCode.ERROR, str(e)))
span.record_exception(e)
span.end()
raise

if response:
_handle_response(span, event_logger, llm_request_type, response)
set_span_response_attributes(span, response)
_handle_response_content(span, event_logger, llm_request_type, response)

return response

Expand All @@ -151,18 +315,51 @@ def _instrument(self, **kwargs):
__name__, __version__, event_logger_provider=event_logger_provider
)
for wrapped_method in WRAPPED_METHODS:
wrap_module = wrapped_method.get("module")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
wrap_function_wrapper(
"cohere.client",
f"{wrap_object}.{wrap_method}",
_wrap(tracer, event_logger, wrapped_method),
)
try:
wrap_function_wrapper(
wrap_module,
f"{wrap_object}.{wrap_method}",
_wrap(tracer, event_logger, wrapped_method),
)
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to instrument {wrap_module}.{wrap_object}.{wrap_method}")

for wrapped_method in WRAPPED_AMETHODS:
wrap_module = wrapped_method.get("module")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
try:
wrap_function_wrapper(
wrap_module,
f"{wrap_object}.{wrap_method}",
_awrap(tracer, event_logger, wrapped_method),
)
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to instrument {wrap_module}.{wrap_object}.{wrap_method}")

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
wrap_module = wrapped_method.get("module")
wrap_object = wrapped_method.get("object")
unwrap(
f"cohere.client.{wrap_object}",
wrapped_method.get("method"),
)
wrap_method = wrapped_method.get("method")
try:
unwrap(
f"{wrap_module}.{wrap_object}",
wrap_method,
)
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to uninstrument {wrap_module}.{wrap_object}.{wrap_method}")
for wrapped_method in WRAPPED_AMETHODS:
wrap_module = wrapped_method.get("module")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
try:
unwrap(
f"{wrap_module}.{wrap_object}",
wrap_method,
)
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to uninstrument {wrap_module}.{wrap_object}.{wrap_method}")
Loading