Skip to content

Commit

Permalink
Merge pull request #1016 from julep-ai/x/misc-fixes
Browse files Browse the repository at this point in the history
X/misc fixes
  • Loading branch information
whiterabbit1983 authored Jan 5, 2025
2 parents c567cdc + be6b292 commit 577a808
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 25 deletions.
23 changes: 15 additions & 8 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@ class ObjectWithState(Protocol):
state: State


pool = None


# TODO: This currently doesn't use env.py, we should move to using them
@asynccontextmanager
async def lifespan(*containers: FastAPI | ObjectWithState):
# INIT POSTGRES #
pg_dsn = os.environ.get("PG_DSN")

global pool
if not pool:
pool = await create_db_pool(pg_dsn)

for container in containers:
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = await create_db_pool(pg_dsn)
container.state.postgres_pool = pool

# INIT S3 #
s3_access_key = os.environ.get("S3_ACCESS_KEY")
Expand All @@ -50,13 +57,13 @@ async def lifespan(*containers: FastAPI | ObjectWithState):
try:
yield
finally:
# CLOSE POSTGRES #
for container in containers:
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
pool = getattr(container.state, "postgres_pool", None)
if pool:
await pool.close()
container.state.postgres_pool = None
# # CLOSE POSTGRES #
# for container in containers:
# if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
# pool = getattr(container.state, "postgres_pool", None)
# if pool:
# await pool.close()
# container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def gather_messages(
doc_references = await search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
embedding=query_embedding,
connection_pool=connection_pool,
)
case "hybrid":
Expand Down
14 changes: 7 additions & 7 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
search_docs_by_embedding_query = """
SELECT * FROM search_by_vector(
$1, -- developer_id
$2::vector(1024), -- query_embedding
$2::vector(1024), -- embedding
$3::text[], -- owner_types
$4::uuid[], -- owner_ids
$5, -- k
Expand All @@ -33,7 +33,7 @@
async def search_docs_by_embedding(
*,
developer_id: UUID,
query_embedding: list[float],
embedding: list[float],
k: int = 10,
owners: list[tuple[Literal["user", "agent"], UUID]],
confidence: float = 0.5,
Expand All @@ -44,7 +44,7 @@ async def search_docs_by_embedding(
Parameters:
developer_id (UUID): The ID of the developer.
query_embedding (List[float]): The vector to query.
embedding (List[float]): The vector to query.
k (int): The number of results to return.
owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples.
confidence (float): The confidence threshold for the search.
Expand All @@ -56,11 +56,11 @@ async def search_docs_by_embedding(
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")

if not query_embedding:
if not embedding:
raise HTTPException(status_code=400, detail="Empty embedding provided")

# Convert query_embedding to a string
query_embedding_str = f"[{', '.join(map(str, query_embedding))}]"
# Convert embedding to a string
embedding_str = f"[{', '.join(map(str, embedding))}]"

# Extract owner types and IDs
owner_types: list[str] = [owner[0] for owner in owners]
Expand All @@ -70,7 +70,7 @@ async def search_docs_by_embedding(
search_docs_by_embedding_query,
[
developer_id,
query_embedding_str,
embedding_str,
owner_types,
owner_ids,
k,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def wrapper(
pool = (
connection_pool
if connection_pool is not None
else cast(asyncpg.Pool, app.state.postgres_pool)
else cast(asyncpg.Pool, getattr(app.state, "postgres_pool", None))
)

try:
Expand Down
12 changes: 6 additions & 6 deletions agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def get_search_fn_and_params(
}

case VectorDocSearchRequest(
vector=query_embedding,
vector=embedding,
limit=k,
confidence=confidence,
metadata_filter=metadata_filter,
):
search_fn = search_docs_by_embedding
params = {
"query_embedding": query_embedding,
"embedding": embedding,
"k": k * 3 if search_params.mmr_strength > 0 else k,
"confidence": confidence,
"metadata_filter": metadata_filter,
}

case HybridDocSearchRequest(
text=query,
vector=query_embedding,
vector=embedding,
lang=lang,
limit=k,
confidence=confidence,
Expand All @@ -66,7 +66,7 @@ def get_search_fn_and_params(
search_fn = search_docs_hybrid
params = {
"text_query": query,
"embedding": query_embedding,
"embedding": embedding,
"k": k * 3 if search_params.mmr_strength > 0 else k,
"confidence": confidence,
"alpha": alpha,
Expand Down Expand Up @@ -111,7 +111,7 @@ async def search_user_docs(
and len(docs) > search_params.limit
):
indices = maximal_marginal_relevance(
np.asarray(params["query_embedding"]),
np.asarray(params["embedding"]),
[doc.snippet.embedding for doc in docs],
k=search_params.limit,
)
Expand Down Expand Up @@ -160,7 +160,7 @@ async def search_agent_docs(
and len(docs) > search_params.limit
):
indices = maximal_marginal_relevance(
np.asarray(params["query_embedding"]),
np.asarray(params["embedding"]),
[doc.snippet.embedding for doc in docs],
k=search_params.limit,
)
Expand Down
2 changes: 1 addition & 1 deletion memory-store/migrations/000004_agents.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CREATE TABLE IF NOT EXISTS agents (
),
about TEXT CONSTRAINT ct_agents_about_length CHECK (
about IS NULL
OR length(about) <= 1000
OR length(about) <= 5000
),
instructions TEXT[] DEFAULT ARRAY[]::TEXT[],
model TEXT NOT NULL,
Expand Down
2 changes: 1 addition & 1 deletion memory-store/migrations/000015_entries.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS entries (
token_count INTEGER DEFAULT NULL,
tokenizer TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
timestamp DOUBLE PRECISION NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at),
CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content))
);
Expand Down

0 comments on commit 577a808

Please sign in to comment.