Skip to content

Commit

Permalink
fix(agents-api,memory-store): Fix docs tests and some migrations
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Jan 3, 2025
1 parent 38b4acb commit 900cca8
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 80 deletions.
7 changes: 7 additions & 0 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from uuid import UUID

from beartype import beartype
Expand Down Expand Up @@ -46,6 +47,12 @@ def transform_get_doc(d: dict) -> dict:
content = d["content"][0] if len(d["content"]) == 1 else d["content"]

embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]

if isinstance(embeddings, str):
embeddings = json.loads(embeddings)
elif isinstance(embeddings, list) and all(isinstance(e, str) for e in embeddings):
embeddings = [json.loads(e) for e in embeddings]

if embeddings and all((e is None) for e in embeddings):
embeddings = None

Expand Down
6 changes: 6 additions & 0 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It constructs and executes SQL queries to fetch document details based on various filters.
"""

import json
from typing import Any, Literal
from uuid import UUID

Expand Down Expand Up @@ -55,6 +56,11 @@ def transform_list_docs(d: dict) -> dict:

embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]

if isinstance(embeddings, str):
embeddings = json.loads(embeddings)
elif isinstance(embeddings, list) and all(isinstance(e, str) for e in embeddings):
embeddings = [json.loads(e) for e in embeddings]

if embeddings and all((e is None) for e in embeddings):
embeddings = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
execution_id,
transition_id,
type,
step_definition,
step_label,
current_step,
next_step,
Expand All @@ -38,8 +37,7 @@
$6,
$7,
$8,
$9,
$10
$9
)
RETURNING *;
"""
Expand Down Expand Up @@ -152,7 +150,7 @@ async def create_execution_transition(
execution_id,
transition_id,
data.type,
{},
# TODO: Add step_label
None,
transition_data["current"],
transition_data["next"],
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/queries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def _check_error(error):

setattr(new_error, "__cause__", error)

print(error)
raise new_error from error

def decorator(
Expand Down
16 changes: 12 additions & 4 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from fastapi import Depends
from langcodes import Language

from ...autogen.openapi_model import (
DocReference,
Expand All @@ -26,12 +27,16 @@ def get_search_fn_and_params(
search_fn, params = None, None

match search_params:
case TextOnlyDocSearchRequest(text=query, limit=k, metadata_filter=metadata_filter):
case TextOnlyDocSearchRequest(
text=query, limit=k, lang=lang, metadata_filter=metadata_filter
):
search_language = Language.get(lang).describe()["language"].lower()
search_fn = search_docs_by_text
params = {
"query": query,
"k": k,
"metadata_filter": metadata_filter,
"search_language": search_language,
}

case VectorDocSearchRequest(
Expand All @@ -51,19 +56,22 @@ def get_search_fn_and_params(
case HybridDocSearchRequest(
text=query,
vector=query_embedding,
lang=lang,
limit=k,
confidence=confidence,
alpha=alpha,
metadata_filter=metadata_filter,
):
search_language = Language.get(lang).describe()["language"].lower()
search_fn = search_docs_hybrid
params = {
"query": query,
"query_embedding": query_embedding,
"text_query": query,
"embedding": query_embedding,
"k": k * 3 if search_params.mmr_strength > 0 else k,
"embed_search_options": {"confidence": confidence},
"confidence": confidence,
"alpha": alpha,
"metadata_filter": metadata_filter,
"search_language": search_language,
}

return search_fn, params
Expand Down
1 change: 1 addition & 0 deletions agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ dependencies = [
"uuid7>=0.1.0",
"asyncpg>=0.30.0",
"unique-namer>=1.6.1",
"langcodes>=3.5.0",
]

[dependency-groups]
Expand Down
19 changes: 19 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent):

yield await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool)

# TODO: Delete the doc
# await delete_doc(
# developer_id=developer.id,
# doc_id=resp.id,
Expand All @@ -160,6 +161,23 @@ async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent):
# )


@fixture(scope="test")
async def test_doc_with_embedding(dsn=pg_dsn, developer=test_developer, doc=test_doc):
pool = await create_db_pool(dsn=dsn)
await pool.execute(
"""
INSERT INTO docs_embeddings_store (developer_id, doc_id, index, chunk_seq, chunk, embedding)
VALUES ($1, $2, 0, 0, $3, $4)
""",
developer.id,
doc.id,
doc.content[0] if isinstance(doc.content, list) else doc.content,
f"[{', '.join([str(x) for x in [1.0] * 1024])}]",
)

yield await get_doc(developer_id=developer.id, doc_id=doc.id, connection_pool=pool)


@fixture(scope="test")
async def test_user_doc(dsn=pg_dsn, developer=test_developer, user=test_user):
pool = await create_db_pool(dsn=dsn)
Expand All @@ -183,6 +201,7 @@ async def test_user_doc(dsn=pg_dsn, developer=test_developer, user=test_user):

yield await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool)

# TODO: Delete the doc
# await delete_doc(
# developer_id=developer.id,
# doc_id=resp.id,
Expand Down
61 changes: 19 additions & 42 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding
from agents_api.queries.docs.search_docs_by_text import search_docs_by_text
from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid
from ward import skip, test
from ward import test

from .fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user
from .fixtures import (
pg_dsn,
test_agent,
test_developer,
test_doc,
test_doc_with_embedding,
test_user,
)

EMBEDDING_SIZE: int = 1024

Expand Down Expand Up @@ -250,32 +257,12 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
assert result[0].metadata == {"test": "test"}, "Metadata should match"


@skip("embedding search: test container not vectorizing")
@test("query: search docs by embedding")
async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
async def _(
dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding
):
pool = await create_db_pool(dsn=dsn)

# Create a test document
doc = await create_doc(
developer_id=developer.id,
owner_type="agent",
owner_id=agent.id,
data=CreateDocRequest(
title="Hello",
content="The world is a funny little thing",
metadata={"test": "test"},
embed_instruction="Embed the document",
),
connection_pool=pool,
)

# Create a test document
doc = await get_doc(
developer_id=developer.id,
doc_id=doc.id,
connection_pool=pool,
)

assert doc.embeddings is not None

# Get query embedding by averaging the embeddings (list of floats)
Expand All @@ -295,31 +282,21 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
assert result[0].metadata is not None


@skip("embedding search: test container not vectorizing")
@test("query: search docs by hybrid")
async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
async def _(
dsn=pg_dsn, agent=test_agent, developer=test_developer, doc=test_doc_with_embedding
):
pool = await create_db_pool(dsn=dsn)

# Create a test document
await create_doc(
developer_id=developer.id,
owner_type="agent",
owner_id=agent.id,
data=CreateDocRequest(
title="Hello",
content="The world is a funny little thing",
metadata={"test": "test"},
embed_instruction="Embed the document",
),
connection_pool=pool,
)
# Get query embedding by averaging the embeddings (list of floats)
query_embedding = [sum(k) / len(k) for k in zip(*doc.embeddings)]

# Search using the correct parameter types
result = await search_docs_hybrid(
developer_id=developer.id,
owners=[("agent", agent.id)],
text_query="funny thing",
embedding=[1.0] * 1024,
text_query=doc.content[0] if isinstance(doc.content, list) else doc.content,
embedding=query_embedding,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
Expand Down
12 changes: 3 additions & 9 deletions agents-api/tests/test_docs_routes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio

from ward import skip, test
from ward import test

from .fixtures import (
make_request,
patch_embed_acompletion,
test_agent,
test_doc,
test_doc_with_embedding,
test_user,
test_user_doc,
)
Expand Down Expand Up @@ -174,7 +173,6 @@ def _(make_request=make_request, agent=test_agent):

@test("route: search agent docs")
async def _(make_request=make_request, agent=test_agent, doc=test_doc):
await asyncio.sleep(1)
search_params = {
"text": doc.content[0],
"limit": 1,
Expand All @@ -196,7 +194,6 @@ async def _(make_request=make_request, agent=test_agent, doc=test_doc):

@test("route: search user docs")
async def _(make_request=make_request, user=test_user, doc=test_user_doc):
await asyncio.sleep(1)
search_params = {
"text": doc.content[0],
"limit": 1,
Expand All @@ -217,11 +214,8 @@ async def _(make_request=make_request, user=test_user, doc=test_user_doc):
assert len(docs) >= 1


@skip("embedding search: test container not vectorizing")
@test("route: search agent docs hybrid with mmr")
async def _(make_request=make_request, agent=test_agent, doc=test_doc):
await asyncio.sleep(1)

async def _(make_request=make_request, agent=test_agent, doc=test_doc_with_embedding):
EMBEDDING_SIZE = 1024
search_params = {
"text": doc.content[0],
Expand Down
2 changes: 2 additions & 0 deletions agents-api/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions memory-store/migrations/000012_transitions.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ CREATE TABLE IF NOT EXISTS transitions (
execution_id UUID NOT NULL,
transition_id UUID NOT NULL,
type transition_type NOT NULL,
step_definition JSONB NOT NULL,
step_label TEXT DEFAULT NULL,
current_step transition_cursor NOT NULL,
next_step transition_cursor DEFAULT NULL,
output JSONB,
task_token TEXT DEFAULT NULL,
metadata JSONB DEFAULT '{}'::JSONB,
CONSTRAINT pk_transitions PRIMARY KEY (created_at, execution_id, transition_id),
CONSTRAINT ct_step_definition_is_object CHECK (jsonb_typeof(step_definition) = 'object'),
CONSTRAINT ct_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object')
);

Expand Down
37 changes: 22 additions & 15 deletions memory-store/migrations/000013_executions_continuous_view.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,32 @@ WITH
timescaledb.materialized_only = FALSE
) AS
SELECT
time_bucket ('1 day', created_at) AS bucket,
execution_id,
last(transition_id, created_at) AS transition_id,
time_bucket ('1 day', t.created_at) AS bucket,
t.execution_id,
last (t.transition_id, t.created_at) AS transition_id,
count(*) AS total_transitions,
state_agg(created_at, to_text(type)) AS state,
max(created_at) AS created_at,
last(type, created_at) AS type,
last(step_definition, created_at) AS step_definition,
last(step_label, created_at) AS step_label,
last(current_step, created_at) AS current_step,
last(next_step, created_at) AS next_step,
last(output, created_at) AS output,
last(task_token, created_at) AS task_token,
last(metadata, created_at) AS metadata
state_agg (t.created_at, to_text (t.type)) AS state,
max(t.created_at) AS created_at,
last (t.type, t.created_at) AS type,
last (t.step_label, t.created_at) AS step_label,
last (t.current_step, t.created_at) AS current_step,
last (t.next_step, t.created_at) AS next_step,
last (t.output, t.created_at) AS output,
last (t.task_token, t.created_at) AS task_token,
last (t.metadata, t.created_at) AS metadata,
last (e.task_id, e.created_at) AS task_id,
last (e.task_version, e.created_at) AS task_version,
last (w.step_definition, e.created_at) AS step_definition
FROM
transitions
transitions t
JOIN executions e ON t.execution_id = e.execution_id
JOIN workflows w ON e.task_id = w.task_id
AND e.task_version = w.version
AND w.step_idx = (t.current_step).step_index
AND w.name = (t.current_step).workflow_name
GROUP BY
bucket,
execution_id
t.execution_id
WITH
no data;

Expand Down
Loading

0 comments on commit 900cca8

Please sign in to comment.