Skip to content

Commit

Permalink
fix(agetns-api): fixed the docs metadata filter query + added test
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Jan 28, 2025
1 parent 1575f87 commit 9dfaade
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 16 deletions.
32 changes: 16 additions & 16 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,8 @@
WHERE d.developer_id = $1
AND doc_own.owner_type = $3
AND doc_own.owner_id = $4
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at
"""


@rewrap_exceptions(common_db_exceptions("doc", ["list"]))
@wrap_in_class(
Doc,
Expand Down Expand Up @@ -108,16 +97,27 @@ async def list_docs(
query = base_docs_query
params = [developer_id, include_without_embeddings, owner_type, owner_id]

# Add metadata filtering
# Add metadata filtering before GROUP BY
if metadata_filter:
for key, value in metadata_filter.items():
query += f" AND metadata->>'{key}' = ${len(params) + 1}"
query += f" AND d.metadata->>'{key}' = ${len(params) + 1}"
params.append(value)

# Add GROUP BY clause
query += """
GROUP BY
d.doc_id,
d.developer_id,
d.title,
d.modality,
d.embedding_model,
d.embedding_dimensions,
d.language,
d.metadata,
d.created_at"""

# Add sorting and pagination
query += (
f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}"
)
query += f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}"
params.extend([limit, offset])

return query, params
5 changes: 5 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ async def test_doc_with_embedding(dsn=pg_dsn, developer=test_developer, doc=test
f"[{', '.join([str(x) for x in embedding_with_confidence_1_neg])}]",
)

# Explicitly Refresh Indices: After inserting data, run a command to refresh the index,
# ensuring it's up-to-date before executing queries.
# This can be achieved by executing a REINDEX command
await pool.execute("REINDEX DATABASE")

yield await get_doc(developer_id=developer.id, doc_id=doc.id, connection_pool=pool)


Expand Down
52 changes: 52 additions & 0 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,32 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
)
assert len(docs_list) >= 1
assert any(d.id == doc_user.id for d in docs_list)
assert any(d.id == doc_user.id for d in docs_list)

# Create a doc with a different metadata
doc_user_different_metadata = await create_doc(
developer_id=developer.id,
data=CreateDocRequest(
title="User List Test 2",
content="Some user doc content 2",
metadata={"test": "test2"},
embed_instruction="Embed the document",
),
owner_type="user",
owner_id=user.id,
connection_pool=pool,
)

docs_list_metadata = await list_docs(
developer_id=developer.id,
owner_type="user",
owner_id=user.id,
connection_pool=pool,
metadata_filter={"test": "test2"},
)
assert len(docs_list_metadata) >= 1
assert any(d.id == doc_user_different_metadata.id for d in docs_list_metadata)
assert any(d.metadata == {"test": "test2"} for d in docs_list_metadata)

@test("query: list agent docs")
async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
Expand All @@ -145,9 +170,36 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
owner_id=agent.id,
connection_pool=pool,
)

assert len(docs_list) >= 1
assert any(d.id == doc_agent.id for d in docs_list)

# Create a doc with a different metadata
doc_agent_different_metadata = await create_doc(
developer_id=developer.id,
data=CreateDocRequest(
title="Agent List Test 2",
content="Some agent doc content 2",
metadata={"test": "test2"},
embed_instruction="Embed the document",
),
owner_type="agent",
owner_id=agent.id,
connection_pool=pool,
)

# List agent's docs
docs_list_metadata = await list_docs(
developer_id=developer.id,
owner_type="agent",
owner_id=agent.id,
connection_pool=pool,
metadata_filter={"test": "test2"},
)
assert len(docs_list_metadata) >= 1
assert any(d.id == doc_agent_different_metadata.id for d in docs_list_metadata)
assert any(d.metadata == {"test": "test2"} for d in docs_list_metadata)


@test("query: delete user doc")
async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
Expand Down

0 comments on commit 9dfaade

Please sign in to comment.