Skip to content

Commit a5f00b1

Browse files
authored
Merge pull request #1000 from julep-ai/x/misc-session-fixes
Fix(agents-api): Miscellaneous fixes related to sessions & entries
2 parents 81841e4 + 7f76bc3 commit a5f00b1

14 files changed

+80
-64
lines changed

Diff for: agents-api/agents_api/queries/chat/gather_messages.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ async def gather_messages(
120120
doc_references: list[DocReference] = await search_docs_hybrid(
121121
developer_id=developer.id,
122122
owners=owners,
123-
query=query_text,
124-
query_embedding=query_embedding,
123+
text_query=query_text,
124+
embedding=query_embedding,
125125
connection_pool=connection_pool,
126126
)
127127
case "text":

Diff for: agents-api/agents_api/queries/docs/list_docs.py

+9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
It constructs and executes SQL queries to fetch document details based on various filters.
44
"""
55

6+
import ast
67
from typing import Any, Literal
78
from uuid import UUID
89

@@ -57,6 +58,14 @@ def transform_list_docs(d: dict) -> dict:
5758
content = d["content"][0] if len(d["content"]) == 1 else d["content"]
5859

5960
embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]
61+
62+
# try:
63+
# # Embeddings are retreived as a string, so we need to evaluate it
64+
# embeddings = ast.literal_eval(embeddings)
65+
# except Exception as e:
66+
# msg = f"Error evaluating embeddings: {e}"
67+
# raise ValueError(msg)
68+
6069
if embeddings and all((e is None) for e in embeddings):
6170
embeddings = None
6271

Diff for: agents-api/agents_api/queries/docs/search_docs_by_embedding.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...autogen.openapi_model import DocReference
88
from ...common.utils.db_exceptions import common_db_exceptions
99
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
10+
from .utils import transform_to_doc_reference
1011

1112
# Raw query for vector search
1213
search_docs_by_embedding_query = """
@@ -25,14 +26,7 @@
2526
@rewrap_exceptions(common_db_exceptions("doc", ["search"]))
2627
@wrap_in_class(
2728
DocReference,
28-
transform=lambda d: {
29-
"owner": {
30-
"id": d["owner_id"],
31-
"role": d["owner_type"],
32-
},
33-
"metadata": d.get("metadata", {}),
34-
**d,
35-
},
29+
transform=transform_to_doc_reference,
3630
)
3731
@pg_query
3832
@beartype

Diff for: agents-api/agents_api/queries/docs/search_docs_by_text.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...autogen.openapi_model import DocReference
88
from ...common.utils.db_exceptions import common_db_exceptions
99
from ..utils import pg_query, rewrap_exceptions, wrap_in_class
10+
from .utils import transform_to_doc_reference
1011

1112
# Raw query for text search
1213
search_docs_text_query = """
@@ -25,14 +26,7 @@
2526
@rewrap_exceptions(common_db_exceptions("doc", ["search"]))
2627
@wrap_in_class(
2728
DocReference,
28-
transform=lambda d: {
29-
"owner": {
30-
"id": d["owner_id"],
31-
"role": d["owner_type"],
32-
},
33-
"metadata": d.get("metadata", {}),
34-
**d,
35-
},
29+
transform=transform_to_doc_reference,
3630
)
3731
@pg_query
3832
@beartype

Diff for: agents-api/agents_api/queries/docs/search_docs_hybrid.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
rewrap_exceptions,
1212
wrap_in_class,
1313
)
14+
from .utils import transform_to_doc_reference
1415

1516
# Raw query for hybrid search
1617
search_docs_hybrid_query = """
@@ -32,14 +33,7 @@
3233
@rewrap_exceptions(common_db_exceptions("doc", ["search"]))
3334
@wrap_in_class(
3435
DocReference,
35-
transform=lambda d: {
36-
"owner": {
37-
"id": d["owner_id"],
38-
"role": d["owner_type"],
39-
},
40-
"metadata": d.get("metadata", {}),
41-
**d,
42-
},
36+
transform=transform_to_doc_reference,
4337
)
4438
@pg_query
4539
@beartype

Diff for: agents-api/agents_api/queries/docs/utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import ast
2+
3+
4+
def transform_to_doc_reference(d: dict) -> dict:
5+
id = d.pop("doc_id")
6+
content = d.pop("content")
7+
index = d.pop("index")
8+
9+
embedding = d.pop("embedding")
10+
11+
try:
12+
# Embeddings are retreived as a string, so we need to evaluate it
13+
embedding = ast.literal_eval(embedding)
14+
except Exception as e:
15+
msg = f"Error evaluating embeddings: {e}"
16+
raise ValueError(msg)
17+
18+
owner = {
19+
"id": d.pop("owner_id"),
20+
"role": d.pop("owner_type"),
21+
}
22+
snippet = {
23+
"content": content,
24+
"index": index,
25+
"embedding": embedding,
26+
}
27+
metadata = d.pop("metadata")
28+
29+
return {
30+
"id": id,
31+
"owner": owner,
32+
"snippet": snippet,
33+
"metadata": metadata,
34+
**d,
35+
}

Diff for: agents-api/agents_api/queries/entries/create_entries.py

-15
Original file line numberDiff line numberDiff line change
@@ -89,21 +89,6 @@ async def create_entries(
8989
# Convert the data to a list of dictionaries
9090
data_dicts = [item.model_dump(mode="json") for item in data]
9191

92-
# Prepare the parameters for the query
93-
# $1
94-
# $2
95-
# $3
96-
# $4
97-
# $5
98-
# $6
99-
# $7
100-
# $8
101-
# $9
102-
# $10
103-
# $11
104-
# $12
105-
# $13
106-
# $14
10792
params = [
10893
[
10994
session_id, # $1

Diff for: agents-api/agents_api/queries/entries/get_history.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@
4444
"""
4545

4646

47-
@rewrap_exceptions(common_db_exceptions("history", ["get"]))
48-
@wrap_in_class(
49-
History,
50-
one=True,
51-
transform=lambda d: {
52-
"entries": json.loads(d.get("entries") or "[]"),
47+
def _transform(d):
48+
return {
49+
"entries": [
50+
{
51+
**entry,
52+
}
53+
for entry in json.loads(d.get("entries") or "[]")
54+
],
5355
"relations": [
5456
{
5557
"head": r["head"],
@@ -60,7 +62,14 @@
6062
],
6163
"session_id": d.get("session_id"),
6264
"created_at": utcnow(),
63-
},
65+
}
66+
67+
68+
@rewrap_exceptions(common_db_exceptions("history", ["get"]))
69+
@wrap_in_class(
70+
History,
71+
one=True,
72+
transform=_transform,
6473
)
6574
@pg_query
6675
@beartype

Diff for: agents-api/agents_api/queries/sessions/create_or_update_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def create_or_update_session(
119119
data.token_budget, # $7
120120
data.context_overflow, # $8
121121
data.forward_tool_calls, # $9
122-
data.recall_options or {}, # $10
122+
data.recall_options.model_dump() if data.recall_options else {}, # $10
123123
]
124124

125125
# Prepare lookup parameters

Diff for: agents-api/agents_api/queries/sessions/patch_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def patch_session(
6565
data.token_budget, # $7
6666
data.context_overflow, # $8
6767
data.forward_tool_calls, # $9
68-
data.recall_options or {}, # $10
68+
data.recall_options.model_dump() if data.recall_options else {}, # $10
6969
]
7070

7171
return [(session_query, session_params)]

Diff for: agents-api/agents_api/queries/sessions/update_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def update_session(
6464
data.token_budget, # $7
6565
data.context_overflow, # $8
6666
data.forward_tool_calls, # $9
67-
data.recall_options or {}, # $10
67+
data.recall_options.model_dump() if data.recall_options else {}, # $10
6868
]
6969

7070
return [

Diff for: agents-api/agents_api/routers/docs/create_doc.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Annotated
22
from uuid import UUID
33

4-
from fastapi import BackgroundTasks, Depends
4+
from fastapi import Depends
55
from starlette.status import HTTP_201_CREATED
66

77
from ...autogen.openapi_model import CreateDocRequest, Doc, ResourceCreatedResponse
@@ -15,7 +15,6 @@ async def create_user_doc(
1515
user_id: UUID,
1616
data: CreateDocRequest,
1717
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
18-
background_tasks: BackgroundTasks,
1918
) -> ResourceCreatedResponse:
2019
"""
2120
Creates a new document for a user.
@@ -24,7 +23,6 @@ async def create_user_doc(
2423
user_id (UUID): The unique identifier of the user associated with the document.
2524
data (CreateDocRequest): The data to create the document with.
2625
x_developer_id (UUID): The unique identifier of the developer associated with the document.
27-
background_tasks (BackgroundTasks): The background tasks to run.
2826
2927
Returns:
3028
ResourceCreatedResponse: The created document.
@@ -45,7 +43,6 @@ async def create_agent_doc(
4543
agent_id: UUID,
4644
data: CreateDocRequest,
4745
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
48-
background_tasks: BackgroundTasks,
4946
) -> ResourceCreatedResponse:
5047
doc: Doc = await create_doc_query(
5148
developer_id=x_developer_id,

Diff for: agents-api/agents_api/routers/docs/search_docs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .router import router
2121

2222

23-
async def get_search_fn_and_params(
23+
def get_search_fn_and_params(
2424
search_params,
2525
) -> tuple[Any, dict[str, float | int | str | dict[str, float] | list[float]] | None]:
2626
search_fn, params = None, None
@@ -58,10 +58,10 @@ async def get_search_fn_and_params(
5858
):
5959
search_fn = search_docs_hybrid
6060
params = {
61-
"text_query": query,
62-
"embedding": query_embedding,
61+
"query": query,
62+
"query_embedding": query_embedding,
6363
"k": k * 3 if search_params.mmr_strength > 0 else k,
64-
"confidence": confidence,
64+
"embed_search_options": {"confidence": confidence},
6565
"alpha": alpha,
6666
"metadata_filter": metadata_filter,
6767
}
@@ -88,7 +88,7 @@ async def search_user_docs(
8888
"""
8989

9090
# MMR here
91-
search_fn, params = await get_search_fn_and_params(search_params)
91+
search_fn, params = get_search_fn_and_params(search_params)
9292

9393
start = time.time()
9494
docs: list[DocReference] = await search_fn(
@@ -137,7 +137,7 @@ async def search_agent_docs(
137137
DocSearchResponse: The search results.
138138
"""
139139

140-
search_fn, params = await get_search_fn_and_params(search_params)
140+
search_fn, params = get_search_fn_and_params(search_params)
141141

142142
start = time.time()
143143
docs: list[DocReference] = await search_fn(

Diff for: memory-store/migrations/000015_entries.up.sql

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@ CREATE TABLE IF NOT EXISTS entries (
3737
name TEXT,
3838
content JSONB[] NOT NULL,
3939
tool_call_id TEXT DEFAULT NULL,
40-
tool_calls JSONB[] NOT NULL DEFAULT '{}'::JSONB[],
40+
tool_calls JSONB[] DEFAULT NULL,
4141
model TEXT NOT NULL,
4242
token_count INTEGER DEFAULT NULL,
4343
tokenizer TEXT NOT NULL,
4444
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
4545
timestamp DOUBLE PRECISION NOT NULL,
4646
CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at),
47-
CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content)),
48-
CONSTRAINT ct_tool_calls_is_array_of_objects CHECK (all_jsonb_elements_are_objects (tool_calls))
47+
CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content))
4948
);
5049

5150
-- Convert to hypertable if not already

0 commit comments

Comments
 (0)