Skip to content

feat(agents-api): added mmr to chat #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ def get_handler(system: SystemDef) -> Callable:
from ..queries.agents.update_agent import update_agent as update_agent_query
from ..queries.docs.delete_doc import delete_doc as delete_doc_query
from ..queries.docs.list_docs import list_docs as list_docs_query
from ..queries.entries.get_history import get_history as get_history_query
from ..queries.sessions.create_session import create_session as create_session_query
from ..queries.sessions.delete_session import delete_session as delete_session_query
from ..queries.sessions.get_session import get_session as get_session_query
from ..queries.sessions.list_sessions import list_sessions as list_sessions_query
from ..queries.sessions.update_session import update_session as update_session_query
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
pg_dsn = os.environ.get("PG_DSN")

for container in containers:
if not getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = await create_db_pool(pg_dsn)

# INIT S3 #
Expand All @@ -35,7 +35,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
s3_endpoint = os.environ.get("S3_ENDPOINT")

for container in containers:
if not getattr(container.state, "s3_client", None):
if hasattr(container, "state") and not getattr(container.state, "s3_client", None):
session = get_session()
container.state.s3_client = await session.create_client(
"s3",
Expand All @@ -49,13 +49,13 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
finally:
# CLOSE POSTGRES #
for container in containers:
if getattr(container.state, "postgres_pool", None):
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
await container.state.postgres_pool.close()
container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
if getattr(container.state, "s3_client", None):
if hasattr(container, "state") and getattr(container.state, "s3_client", None):
await container.state.s3_client.close()
container.state.s3_client = None

Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ async def aembedding(
embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data

# Truncate the embedding to the specified dimensions
embedding_list = [
return [
item["embedding"][:dimensions]
for item in embedding_list
if len(item["embedding"]) >= dimensions
]

return embedding_list
142 changes: 85 additions & 57 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypeVar
from uuid import UUID

import numpy as np
from beartype import beartype
from fastapi import HTTPException
from pydantic import ValidationError
Expand All @@ -10,6 +11,7 @@
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...common.utils.db_exceptions import common_db_exceptions, partialclass
from ..docs.mmr import maximal_marginal_relevance
from ..docs.search_docs_by_embedding import search_docs_by_embedding
from ..docs.search_docs_by_text import search_docs_by_text
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand Down Expand Up @@ -75,64 +77,90 @@ async def gather_messages(
)
recall_options = session.recall_options

# search the last `search_threshold` messages
search_messages = [
msg
for msg in (past_messages + new_raw_messages)[-(recall_options.num_search_messages) :]
if isinstance(msg["content"], str) and msg["role"] in ["user", "assistant"]
]

if len(search_messages) == 0:
return past_messages, []

# Search matching docs
embed_text = "\n\n".join([
f"{msg.get('name') or msg['role']}: {msg['content']}" for msg in search_messages
]).strip()

# Don't embed if search mode is text only
if recall_options.mode != "text":
[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)

# Truncate on the right to take only the first `search_query_chars` characters
query_text = search_messages[-1]["content"].strip()[: recall_options.max_query_length]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

# Search for doc references
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references: list[DocReference] = await search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
connection_pool=connection_pool,
# Ensure recall_options is not None and has the necessary attributes
if recall and recall_options:
# search the last `search_threshold` messages
search_messages = [
msg
for msg in (past_messages + new_raw_messages)[
-(recall_options.num_search_messages) :
]
if isinstance(msg["content"], str) and msg["role"] in ["user", "assistant"]
]

if len(search_messages) == 0:
return past_messages, []

# Search matching docs
embed_text = "\n\n".join([
f"{msg.get('name') or msg['role']}: {msg['content']}" for msg in search_messages
]).strip()

# Don't embed if search mode is text only
if recall_options.mode != "text":
[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)
case "hybrid":
doc_references: list[DocReference] = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
text_query=query_text,
embedding=query_embedding,
connection_pool=connection_pool,
)
case "text":
doc_references: list[DocReference] = await search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
connection_pool=connection_pool,

# Truncate on the right to take only the first `search_query_chars` characters
query_text = search_messages[-1]["content"].strip()[: recall_options.max_query_length]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

# Search for doc references
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references = await search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
connection_pool=connection_pool,
)
case "hybrid":
doc_references = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
text_query=query_text,
embedding=query_embedding,
connection_pool=connection_pool,
)
case "text":
doc_references = await search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
connection_pool=connection_pool,
)

# Apply MMR if enabled
if (
recall_options.mmr_strength > 0
and len(doc_references) > recall_options.limit
and recall_options.mode != "text"
and len([doc for doc in doc_references if doc.snippet.embedding is not None]) >= 2
):
# FIXME: This is a temporary fix to ensure that the MMR algorithm works.
# We shouldn't be having references without embeddings.
doc_references = [
doc for doc in doc_references if doc.snippet.embedding is not None
]

# Apply MMR
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in doc_references],
k=recall_options.limit,
)
doc_references = [doc for i, doc in enumerate(doc_references) if i in set(indices)]

# TODO: Add missing MMR implementation
return past_messages, doc_references

return past_messages, doc_references
# If recall is False or recall_options is None, return past messages with no doc references
return past_messages, []
6 changes: 0 additions & 6 deletions agents-api/scripts/agents_api.py

This file was deleted.

2 changes: 1 addition & 1 deletion agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def _(
connection_pool=pool,
)

(embed, _) = mocks
(_embed, _) = mocks

chat_context = await prepare_chat_context(
developer_id=developer_id,
Expand Down
14 changes: 11 additions & 3 deletions cookbooks/01-website-crawler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid

import yaml
from julep import Client

Expand All @@ -10,7 +11,8 @@
# Creating Julep Client with the API Key
api_key = os.getenv("JULEP_API_KEY")
if not api_key:
raise ValueError("JULEP_API_KEY not found in environment variables")
msg = "JULEP_API_KEY not found in environment variables"
raise ValueError(msg)

client = Client(api_key=api_key, environment="dev")

Expand All @@ -26,6 +28,11 @@
model="gpt-4o",
)

spider_api_key = os.getenv("SPIDER_API_KEY")
if not spider_api_key:
msg = "SPIDER_API_KEY not found in environment variables"
raise ValueError(msg)

# Defining a Task
task_def = yaml.safe_load(f"""
name: Crawling Task
Expand Down Expand Up @@ -63,7 +70,7 @@
page['content'] for page in _['result']
)
)

# Prompt step to create a summary of the results
- prompt: |
You are {{{{agent.about}}}}
Expand All @@ -90,6 +97,7 @@

# Waiting for the execution to complete
import time

time.sleep(5)

# Getting the execution details
Expand All @@ -104,4 +112,4 @@

# Stream the steps of the defined task
print("Streaming execution transitions:")
print(client.executions.transitions.stream(execution_id=execution.id))
print(client.executions.transitions.stream(execution_id=execution.id))
8 changes: 5 additions & 3 deletions cookbooks/02-sarcastic-news-headline-generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid

import yaml
from julep import Client

Expand All @@ -10,7 +11,8 @@
# Create Julep Client with the API Key
api_key = os.getenv("JULEP_API_KEY")
if not api_key:
raise ValueError("JULEP_API_KEY not found in environment variables")
msg = "JULEP_API_KEY not found in environment variables"
raise ValueError(msg)

client = Client(api_key=api_key, environment="dev")

Expand Down Expand Up @@ -76,7 +78,8 @@
)

# Waiting for the execution to complete
import time
import time

time.sleep(5)

# Getting the execution details
Expand All @@ -92,4 +95,3 @@
# Stream the steps of the defined task
print("Streaming execution transitions:")
print(client.executions.transitions.stream(execution_id=execution.id))

12 changes: 7 additions & 5 deletions cookbooks/03-trip-planning-assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import uuid

import yaml
from julep import Client
import os

openweathermap_api_key = os.getenv("OPENWEATHERMAP_API_KEY")
brave_api_key = os.getenv("BRAVE_API_KEY")
Expand Down Expand Up @@ -139,17 +140,18 @@

# Wait for the execution to complete
import time

time.sleep(200)

# Getting the execution details
# Get execution details
execution = client.executions.get(execution.id)
# Print the output
print(execution.output)
print("-"*50)
print("-" * 50)

if 'final_plan' in execution.output:
print(execution.output['final_plan'])
if "final_plan" in execution.output:
print(execution.output["final_plan"])

# Lists all the task steps that have been executed up to this point in time
transitions = client.executions.transitions.list(execution_id=execution.id).items
Expand All @@ -158,4 +160,4 @@
for transition in reversed(transitions):
print("Transition type: ", transition.type)
print("Transition output: ", transition.output)
print("-"*50)
print("-" * 50)
Loading
Loading