Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 110 additions & 44 deletions python/packages/mem0/agent_framework_mem0/_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,28 @@

from __future__ import annotations

import asyncio
import logging
import sys
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias

from agent_framework import Message
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
from mem0 import AsyncMemory, AsyncMemoryClient

if sys.version_info >= (3, 11):
from typing import NotRequired, Self, TypedDict # pragma: no cover
from typing import Self # pragma: no cover
else:
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
from typing_extensions import Self # pragma: no cover

if TYPE_CHECKING:
from agent_framework._agents import SupportsAgentRun


class _MemorySearchResponse_v1_1(TypedDict):
results: list[dict[str, Any]]
relations: NotRequired[list[dict[str, Any]]]


_MemorySearchResponse_v2 = list[dict[str, Any]]
logger = logging.getLogger(__name__)
MemoryRecord: TypeAlias = dict[str, Any]
SearchResponse: TypeAlias = list[MemoryRecord] | MemoryRecord


class Mem0ContextProvider(ContextProvider):
Expand Down Expand Up @@ -106,28 +105,80 @@ async def before_run(
if not input_text.strip():
return

filters = self._build_filters()
# Query entity partitions independently to bypass strict logical AND limitations
# Mem0 OSS and Platform SDKs expose inconsistent search typings.
search_tasks: list[Awaitable[Any]] = []

# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
# AsyncMemoryClient (Platform) expects them in a filters dict
search_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
search_kwargs.update(filters)
else:
search_kwargs["filters"] = filters

search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
**search_kwargs,
)
# 1. Query User partition independently
if self.user_id:
user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id)
search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]

if isinstance(search_response, list):
memories = search_response
elif isinstance(search_response, dict) and "results" in search_response:
memories = search_response["results"]
else:
memories = [search_response]
# 2. Query Agent partition independently
if self.agent_id:
agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id)
search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]

# Fall back to an app-scoped search when only application_id is configured
if not search_tasks and self.application_id:
app_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
app_kwargs["app_id"] = self.application_id
else:
app_kwargs["filters"] = {"app_id": self.application_id}
search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
if not search_tasks:
return

line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True)
Comment thread
VedantSonani marked this conversation as resolved.

# Merge and deduplicate results
memories: list[MemoryRecord] = []
seen_memory_ids: set[str] = set()
failed_tasks_count: int = 0

for search_response in results:
if isinstance(search_response, asyncio.CancelledError):
raise search_response

if isinstance(search_response, BaseException):
failed_tasks_count += 1
logger.error(
"Mem0 partition search task failed: %s",
search_response,
exc_info=search_response,
)
Comment thread
VedantSonani marked this conversation as resolved.
Comment thread
VedantSonani marked this conversation as resolved.
continue
Comment thread
VedantSonani marked this conversation as resolved.

current_memories: list[MemoryRecord] = []
if isinstance(search_response, list):
current_memories = [mem for mem in search_response if isinstance(mem, dict)]
elif isinstance(search_response, dict):
results_field = search_response.get("results")
if isinstance(results_field, list):
current_memories = [
item for item in results_field if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType]
]
else:
current_memories = [search_response]

for mem in current_memories:
mem_id = mem.get("id")
if mem_id is not None and not isinstance(mem_id, str):
mem_id = str(mem_id)

if mem_id is not None and mem_id in seen_memory_ids:
continue

if mem_id is not None:
seen_memory_ids.add(mem_id)

memories.append(mem)
Comment thread
VedantSonani marked this conversation as resolved.

if failed_tasks_count == len(search_tasks):
logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.")

line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories)
if line_separated_memories:
context.extend_messages(
self.source_id,
Expand Down Expand Up @@ -159,12 +210,21 @@ def get_role_value(role: Any) -> str:
]

if messages:
await self.mem0_client.add( # type: ignore[misc]
messages=messages,
user_id=self.user_id,
agent_id=self.agent_id,
metadata={"application_id": self.application_id},
)
add_kwargs: dict[str, Any] = {
"messages": messages,
"user_id": self.user_id,
"agent_id": self.agent_id,
}

# Inject the application scope using the matching signature format for each SDK variant
if isinstance(self.mem0_client, AsyncMemory):
if self.application_id:
add_kwargs["app_id"] = self.application_id
else:
if self.application_id:
add_kwargs["filters"] = {"app_id": self.application_id}

await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg]

# -- Internal methods ------------------------------------------------------

Expand All @@ -173,15 +233,21 @@ def _validate_filters(self) -> None:
if not self.agent_id and not self.user_id and not self.application_id:
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")

def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters."""
filters: dict[str, Any] = {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.application_id:
filters["app_id"] = self.application_id
def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]:
"""Build search keyword arguments formatted for OSS vs Platform clients."""
filters: dict[str, Any] = {"query": input_text}

if isinstance(self.mem0_client, AsyncMemory):
# AsyncMemory (OSS) expects direct kwargs
filters[entity_key] = entity_value
if self.application_id:
filters["app_id"] = self.application_id
else:
# AsyncMemoryClient (Platform) expects a filters dict
filters["filters"] = {entity_key: entity_value}
if self.application_id:
filters["filters"]["app_id"] = self.application_id

return filters
Comment thread
VedantSonani marked this conversation as resolved.
Comment thread
VedantSonani marked this conversation as resolved.


Expand Down
Loading