Skip to content

Commit c00d08d

Browse files
committed
feat(agents-api): added mmr to chat
1 parent 63894e4 commit c00d08d

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

Diff for: agents-api/agents_api/queries/chat/gather_messages.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from ..entries.get_history import get_history
1717
from ..sessions.get_session import get_session
1818
from ..utils import rewrap_exceptions
19+
from ..docs.mmr import maximal_marginal_relevance
20+
import numpy as np
1921

2022
T = TypeVar("T")
2123

@@ -133,6 +135,30 @@ async def gather_messages(
133135
connection_pool=connection_pool,
134136
)
135137

136-
# TODO: Add missing MMR implementation
138+
# Apply MMR if enabled
139+
if (
140+
# MMR is enabled
141+
recall_options.mmr_strength > 0
142+
# The number of doc references is greater than the limit
143+
and len(doc_references) > recall_options.limit
144+
# MMR is not applied to text search
145+
and recall_options.mode != "text"
146+
):
147+
# FIXME: This is a temporary fix to ensure that the MMR algorithm works.
148+
# We shouldn't be having references without embeddings.
149+
doc_references = [
150+
doc for doc in doc_references if doc.snippet.embedding is not None
151+
]
152+
153+
# Apply MMR
154+
indices = maximal_marginal_relevance(
155+
np.asarray(query_embedding),
156+
[doc.snippet.embedding for doc in doc_references],
157+
k=recall_options.limit,
158+
)
159+
# Apply MMR
160+
doc_references = [
161+
doc for i, doc in enumerate(doc_references) if i in set(indices)
162+
]
137163

138164
return past_messages, doc_references

0 commit comments

Comments
 (0)