Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
if not resp:
return
span.set_metric("openai.response.count", len(resp.data or []))
if hasattr(resp, "data"):
span.set_metric("openai.response.count", len(resp.data or []))
return resp


Expand Down
125 changes: 101 additions & 24 deletions ddtrace/contrib/internal/openai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,40 +286,117 @@ def patched_endpoint(openai, pin, func, instance, args, kwargs):
return patched_endpoint(openai)


class _TracedAsyncPaginator:
"""Wrapper for AsyncPaginator objects to enable tracing for both await and async for usage."""

def __init__(self, paginator, pin, integration, patch_hook, instance, args, kwargs):
self._paginator = paginator
self._pin = pin
self._integration = integration
self._patch_hook = patch_hook
self._instance = instance
self._args = args
self._kwargs = kwargs

def __aiter__(self):
async def _traced_aiter():
g = _traced_endpoint(
self._patch_hook, self._integration, self._instance, self._pin, self._args, self._kwargs
)
g.send(None)
err = None
completed = False
try:
iterator = self._paginator.__aiter__()
# Fetch first item to trigger trace completion before iteration starts.
# This ensures the span is recorded even if iteration stops early.
first_item = await iterator.__anext__()
try:
g.send((None, None))
completed = True
except StopIteration:
completed = True
yield first_item
async for item in iterator:
yield item
except StopAsyncIteration:
pass
except BaseException as e:
err = e
raise
finally:
if not completed:
try:
g.send((None, err))
except StopIteration:
pass

return _traced_aiter()

def __await__(self):
async def _trace_and_await():
g = _traced_endpoint(
self._patch_hook, self._integration, self._instance, self._pin, self._args, self._kwargs
)
g.send(None)
resp, err = None, None
try:
resp = await self._paginator
except BaseException as e:
err = e
raise
finally:
try:
g.send((resp, err))
except StopIteration as e:
if err is None:
return e.value
return resp

return _trace_and_await().__await__()


def _patched_endpoint_async(openai, patch_hook):
# Same as _patched_endpoint but async
@with_traced_module
async def patched_endpoint(openai, pin, func, instance, args, kwargs):
def patched_endpoint(openai, pin, func, instance, args, kwargs):
if (
patch_hook is _endpoint_hooks._ChatCompletionWithRawResponseHook
or patch_hook is _endpoint_hooks._CompletionWithRawResponseHook
):
kwargs[OPENAI_WITH_RAW_RESPONSE_ARG] = True
return await func(*args, **kwargs)
return func(*args, **kwargs)
if kwargs.pop(OPENAI_WITH_RAW_RESPONSE_ARG, False) and kwargs.get("stream", False):
return await func(*args, **kwargs)
return func(*args, **kwargs)

integration = openai._datadog_integration
g = _traced_endpoint(patch_hook, integration, instance, pin, args, kwargs)
g.send(None)
resp, err = None, None
override_return = None
try:
resp = await func(*args, **kwargs)
except BaseException as e:
err = e
raise
finally:
result = func(*args, **kwargs)
# Detect AsyncPaginator objects (have both __aiter__ and __await__).
# These must be returned directly (not awaited) to preserve iteration behavior.
if hasattr(result, "__aiter__") and hasattr(result, "__await__"):
return _TracedAsyncPaginator(result, pin, openai._datadog_integration, patch_hook, instance, args, kwargs)

async def async_wrapper():
integration = openai._datadog_integration
g = _traced_endpoint(patch_hook, integration, instance, pin, args, kwargs)
g.send(None)
resp, err = None, None
override_return = None
try:
g.send((resp, err))
except StopIteration as e:
if err is None:
# This return takes priority over `return resp`
override_return = e.value

if override_return is not None:
return override_return
return resp
resp = await result
except BaseException as e:
err = e
raise
finally:
try:
g.send((resp, err))
except StopIteration as e:
if err is None:
override_return = e.value

if override_return is not None:
return override_return
return resp

return async_wrapper()

return patched_endpoint(openai)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
openai: This fix resolves an issue where using async iteration with paginated methods (e.g., ``async for model in client.models.list()``) caused a ``TypeError: 'async for' requires an object with __aiter__ method, got coroutine``. See `issue #14574 <https://github.com/DataDog/dd-trace-py/issues/14574>`_.
32 changes: 32 additions & 0 deletions tests/contrib/openai/test_openai_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def test_model_list(api_key_in_env, request_api_key, openai, openai_vcr, snapsho
client.models.list()


@pytest.mark.parametrize("api_key_in_env", [True, False])
def test_model_list_pagination(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
with snapshot_context(
token="tests.contrib.openai.test_openai.test_model_list_pagination",
ignores=["meta.http.useragent", "meta.openai.api_type", "meta.openai.api_base", "meta.openai.request.user"],
):
with openai_vcr.use_cassette("model_list.yaml"):
client = openai.OpenAI(api_key=request_api_key)
count = 0
for model in client.models.list():
count += 1
if count >= 2:
break
assert count >= 2


@pytest.mark.parametrize("api_key_in_env", [True, False])
async def test_model_alist(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
with snapshot_context(
Expand All @@ -46,6 +62,22 @@ async def test_model_alist(api_key_in_env, request_api_key, openai, openai_vcr,
await client.models.list()


@pytest.mark.parametrize("api_key_in_env", [True, False])
async def test_model_alist_pagination(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
with snapshot_context(
token="tests.contrib.openai.test_openai.test_model_alist_pagination",
ignores=["meta.http.useragent", "meta.openai.api_type", "meta.openai.api_base", "meta.openai.request.user"],
):
with openai_vcr.use_cassette("model_alist.yaml"):
client = openai.AsyncOpenAI(api_key=request_api_key)
count = 0
async for model in client.models.list():
count += 1
if count >= 2:
break
assert count >= 2


@pytest.mark.parametrize("api_key_in_env", [True, False])
def test_model_retrieve(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
with snapshot_context(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[[
{
"name": "openai.request",
"service": "tests.contrib.openai",
"resource": "listModels",
"trace_id": 0,
"span_id": 1,
"parent_id": 0,
"type": "",
"error": 0,
"meta": {
"_dd.p.dm": "-0",
"_dd.p.tid": "68f0b1d700000000",
"component": "openai",
"language": "python",
"openai.request.endpoint": "/v1/models",
"openai.request.method": "GET",
"openai.request.provider": "OpenAI",
"runtime-id": "1e2a3154601a494f8f219a4327b659c2"
},
"metrics": {
"_dd.measured": 1,
"_dd.top_level": 1,
"_dd.tracer_kr": 1.0,
"_sampling_priority_v1": 1,
"process_id": 573
},
"duration": 1683125,
"start": 1760604631675824507
}]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[[
{
"name": "openai.request",
"service": "tests.contrib.openai",
"resource": "listModels",
"trace_id": 0,
"span_id": 1,
"parent_id": 0,
"type": "",
"error": 0,
"meta": {
"_dd.p.dm": "-0",
"_dd.p.tid": "68f0b1d500000000",
"component": "openai",
"language": "python",
"openai.request.endpoint": "/v1/models",
"openai.request.method": "GET",
"openai.request.provider": "OpenAI",
"runtime-id": "1e2a3154601a494f8f219a4327b659c2"
},
"metrics": {
"_dd.measured": 1,
"_dd.top_level": 1,
"_dd.tracer_kr": 1.0,
"_sampling_priority_v1": 1,
"openai.response.count": 112,
"process_id": 573
},
"duration": 13777416,
"start": 1760604629974266007
}]]
Loading