Skip to content

Commit

Permalink
fix(agents-api): fix serilizing data.output in list_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Jan 28, 2025
1 parent 2262af8 commit 1f49c2f
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
5 changes: 4 additions & 1 deletion agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AND doc_own.owner_id = $4
"""


@rewrap_exceptions(common_db_exceptions("doc", ["list"]))
@wrap_in_class(
Doc,
Expand Down Expand Up @@ -117,7 +118,9 @@ async def list_docs(
d.created_at"""

# Add sorting and pagination
query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}"
query += (
f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}"
)
params.extend([limit, offset])

return query, params
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ...common.utils.datetime import utcnow
from ...common.utils.db_exceptions import common_db_exceptions
from ...metrics.counters import increase_counter
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
from ..utils import pg_query, rewrap_exceptions, serialize_model_data, wrap_in_class

# Query to create a transition
create_execution_transition_query = """
Expand Down Expand Up @@ -121,14 +121,7 @@ async def create_execution_transition(
data.execution_id = execution_id

# Dump to json
if isinstance(data.output, list):
data.output = [
item.model_dump(mode="json") if hasattr(item, "model_dump") else item
for item in data.output
]

elif hasattr(data.output, "model_dump"):
data.output = data.output.model_dump(mode="json")
data.output = serialize_model_data(data.output)

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})
Expand Down
19 changes: 19 additions & 0 deletions agents-api/agents_api/queries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,22 @@ def run_concurrently(
]

return [future.result() for future in concurrent.futures.as_completed(futures)]


def serialize_model_data(data: Any) -> Any:
"""
Recursively serialize Pydantic models and their nested structures.
Args:
data: Any data structure that might contain Pydantic models
Returns:
JSON-serializable data structure
"""
if hasattr(data, "model_dump"):
return data.model_dump(mode="json")
if isinstance(data, dict):
return {key: serialize_model_data(value) for key, value in data.items()}
if isinstance(data, list | tuple):
return [serialize_model_data(item) for item in data]
return data
2 changes: 1 addition & 1 deletion agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ async def test_doc_with_embedding(dsn=pg_dsn, developer=test_developer, doc=test
# ensuring it's up-to-date before executing queries.
# This can be achieved by executing a REINDEX command
await pool.execute("REINDEX DATABASE")

yield await get_doc(developer_id=developer.id, doc_id=doc.id, connection_pool=pool)


Expand Down
1 change: 1 addition & 0 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
assert any(d.id == doc_user_different_metadata.id for d in docs_list_metadata)
assert any(d.metadata == {"test": "test2"} for d in docs_list_metadata)


@test("query: list agent docs")
async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
pool = await create_db_pool(dsn=dsn)
Expand Down

0 comments on commit 1f49c2f

Please sign in to comment.