Skip to content

feat(agents-api): added mmr to chat #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 4, 2025
28 changes: 27 additions & 1 deletion agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from ..entries.get_history import get_history
from ..sessions.get_session import get_session
from ..utils import rewrap_exceptions
from ..docs.mmr import maximal_marginal_relevance
import numpy as np

T = TypeVar("T")

Expand Down Expand Up @@ -133,6 +135,30 @@ async def gather_messages(
connection_pool=connection_pool,
)

# TODO: Add missing MMR implementation
# Apply MMR if enabled
if (
# MMR is enabled
recall_options.mmr_strength > 0
# The number of doc references is greater than the limit
and len(doc_references) > recall_options.limit
# MMR is not applied to text search
and recall_options.mode != "text"
):
# FIXME: This is a temporary fix to ensure that the MMR algorithm works.
# We shouldn't be having references without embeddings.
doc_references = [
doc for doc in doc_references if doc.snippet.embedding is not None
]

# Apply MMR
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in doc_references],
k=recall_options.limit,
)
# Apply MMR
doc_references = [
doc for i, doc in enumerate(doc_references) if i in set(indices)
]

return past_messages, doc_references
Loading