Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import pydantic
import re
import threading
import time

from openai import AsyncStream, Stream
from wrapt import ObjectProxy

# Conditional imports for backward compatibility
try:
Expand Down Expand Up @@ -190,11 +192,20 @@ def set_data_attributes(traced_response: TracedData, span: Span):
span, SpanAttributes.LLM_USAGE_TOTAL_TOKENS, usage.total_tokens
)
if usage.input_tokens_details:
_set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS,
usage.input_tokens_details.cached_tokens,
)
# Check if the attribute exists (it may not in older semconv versions)
if hasattr(GenAIAttributes, 'GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS'):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need it?

_set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS,
usage.input_tokens_details.cached_tokens,
)
# Fallback to older attribute name if it exists
elif hasattr(GenAIAttributes, 'GEN_AI_USAGE_INPUT_TOKENS_CACHED'):
_set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
usage.input_tokens_details.cached_tokens,
)

# Usage - count of reasoning tokens
reasoning_tokens = None
Expand Down Expand Up @@ -433,7 +444,19 @@ def responses_get_or_create_wrapper(tracer: Tracer, wrapped, instance, args, kwa
try:
response = wrapped(*args, **kwargs)
if isinstance(response, Stream):
return response
span = tracer.start_span(
SPAN_NAME,
kind=SpanKind.CLIENT,
start_time=start_time,
)

return ResponseStream(
span=span,
response=response,
start_time=start_time,
request_kwargs=kwargs,
tracer=tracer,
)
except Exception as e:
response_id = kwargs.get("response_id")
existing_data = {}
Expand Down Expand Up @@ -563,7 +586,21 @@ async def async_responses_get_or_create_wrapper(
try:
response = await wrapped(*args, **kwargs)
if isinstance(response, (Stream, AsyncStream)):
return response
# Create a span for the streaming response
span = tracer.start_span(
SPAN_NAME,
kind=SpanKind.CLIENT,
start_time=start_time,
)

# Wrap the stream with ResponseStream to capture telemetry
return ResponseStream(
span=span,
response=response,
start_time=start_time,
request_kwargs=kwargs,
tracer=tracer,
)
except Exception as e:
response_id = kwargs.get("response_id")
existing_data = {}
Expand Down Expand Up @@ -728,4 +765,188 @@ async def async_responses_cancel_wrapper(
return response


# TODO: build streaming responses
class ResponseStream(ObjectProxy):
"""Proxy class for streaming responses to capture telemetry data"""

_span = None
_start_time = None
_request_kwargs = None
_tracer = None
_traced_data = None

def __init__(
self,
span,
response,
start_time=None,
request_kwargs=None,
tracer=None,
traced_data=None,
):
super().__init__(response)
self._span = span
self._start_time = start_time
self._request_kwargs = request_kwargs or {}
self._tracer = tracer
self._traced_data = traced_data or TracedData(
start_time=start_time,
response_id="",
input=process_input(self._request_kwargs.get("input", [])),
instructions=self._request_kwargs.get("instructions"),
tools=get_tools_from_kwargs(self._request_kwargs),
output_blocks={},
usage=None,
output_text="",
request_model=self._request_kwargs.get("model", ""),
response_model="",
request_reasoning_summary=self._request_kwargs.get("reasoning", {}).get(
"summary"
),
request_reasoning_effort=self._request_kwargs.get("reasoning", {}).get("effort"),
response_reasoning_effort=None,
)

self._complete_response_data = None
self._output_text = ""

self._cleanup_completed = False
self._cleanup_lock = threading.Lock()

def __del__(self):
"""Cleanup when object is garbage collected"""
if hasattr(self, "_cleanup_completed") and not self._cleanup_completed:
self._ensure_cleanup()

def __enter__(self):
"""Context manager entry"""
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
if exc_type is not None:
self._handle_exception(exc_val)
else:
self._process_complete_response()
return False

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def __iter__(self):
"""Synchronous iterator"""
return self

def __next__(self):
"""Synchronous iteration"""
try:
chunk = self.__wrapped__.__next__()
except StopIteration:
self._process_complete_response()
raise
except Exception as e:
self._handle_exception(e)
raise
else:
self._process_chunk(chunk)
return chunk

def __aiter__(self):
"""Async iterator"""
return self

async def __anext__(self):
"""Async iteration"""
try:
chunk = await self.__wrapped__.__anext__()
except StopAsyncIteration:
self._process_complete_response()
raise
except Exception as e:
self._handle_exception(e)
raise
else:
self._process_chunk(chunk)
return chunk

def _process_chunk(self, chunk):
"""Process a streaming chunk"""
if hasattr(chunk, "type"):
if chunk.type == "response.output_text.delta":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider accumulating streaming text into a list and then joining the list once (e.g. using ''.join(list)) rather than performing repeated string concatenation. This can improve performance when processing many small chunks.

if hasattr(chunk, "delta") and chunk.delta:
self._output_text += chunk.delta
elif chunk.type == "response.completed" and hasattr(chunk, "response"):
self._complete_response_data = chunk.response

if hasattr(chunk, "delta"):
if hasattr(chunk.delta, "text") and chunk.delta.text:
self._output_text += chunk.delta.text

if hasattr(chunk, "response") and chunk.response:
self._complete_response_data = chunk.response

@dont_throw
def _process_complete_response(self):
"""Process the complete response and emit span"""
with self._cleanup_lock:
if self._cleanup_completed:
return

try:
if self._complete_response_data:
parsed_response = parse_response(self._complete_response_data)

self._traced_data.response_id = parsed_response.id
self._traced_data.response_model = parsed_response.model
self._traced_data.output_text = self._output_text

if parsed_response.usage:
self._traced_data.usage = parsed_response.usage

if parsed_response.output:
self._traced_data.output_blocks = {
block.id: block for block in parsed_response.output
}

responses[parsed_response.id] = self._traced_data

set_data_attributes(self._traced_data, self._span)
self._span.set_status(StatusCode.OK)
self._span.end()
self._cleanup_completed = True

except Exception as e:
if self._span and self._span.is_recording():
self._span.set_attribute(ERROR_TYPE, e.__class__.__name__)
self._span.set_status(StatusCode.ERROR, str(e))
self._span.end()
self._cleanup_completed = True

@dont_throw
def _handle_exception(self, exception):
"""Handle exceptions during streaming"""
with self._cleanup_lock:
if self._cleanup_completed:
return

if self._span and self._span.is_recording():
self._span.set_attribute(ERROR_TYPE, exception.__class__.__name__)
self._span.record_exception(exception)
self._span.set_status(StatusCode.ERROR, str(exception))
self._span.end()

self._cleanup_completed = True

@dont_throw
def _ensure_cleanup(self):
"""Ensure cleanup happens even if stream is not fully consumed"""
with self._cleanup_lock:
if self._cleanup_completed:
return

try:
if self._span and self._span.is_recording():
set_data_attributes(self._traced_data, self._span)
self._span.set_status(StatusCode.OK)
self._span.end()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

self._cleanup_completed = True

except Exception:
self._cleanup_completed = True
Loading
Loading