diff --git a/pyproject.toml b/pyproject.toml index f2fe8cae..0ca4b794 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ mem0_memory = [ # Need to be optional as a fix for https://github.com/strands-agents/docs/issues/19 "mem0ai>=0.1.99,<1.0.0", "opensearch-py>=2.8.0,<3.0.0", + "psycopg2-binary", ] local_chromium_browser = ["nest-asyncio>=1.5.0,<2.0.0", "playwright>=1.42.0,<2.0.0"] agent_core_browser = [ diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index 08f559fc..65c9b36e 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -177,6 +177,30 @@ class Mem0ServiceClient: }, } + def _get_postgresql_config(self) -> Dict: + """Get PostgreSQL configuration based on the current provider.""" + # Start with the default embedder and llm config + config = { + "embedder": self.DEFAULT_CONFIG["embedder"].copy(), + "llm": self.DEFAULT_CONFIG["llm"].copy(), + } + + # Add PostgreSQL vector store configuration + config["vector_store"] = { + "provider": "pgvector", + "config": { + "host": os.environ.get("POSTGRESQL_HOST"), + "port": int(os.environ.get("POSTGRESQL_PORT", 5432)), + "user": os.environ.get("POSTGRESQL_USER"), + "password": os.environ.get("POSTGRESQL_PASSWORD"), + "dbname": os.environ.get("DB_NAME", "postgres"), + "collection_name": os.environ.get("DB_COLLECTION_NAME", "mem0_memories"), + "embedding_model_dims": 1024, + }, + } + + return config + def __init__(self, config: Optional[Dict] = None): """Initialize the Mem0 service client. @@ -208,6 +232,10 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any: logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)") config = self._configure_neptune_analytics_backend(config) + if os.environ.get("POSTGRESQL_HOST"): + logger.info("Using PostgreSQL backend (Mem0Memory with PostgreSQL)") + return self._initialize_postgresql_client(config) + if os.environ.get("OPENSEARCH_HOST"): logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)") return self._initialize_opensearch_client(config) @@ -231,6 +259,37 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) -> } return config + def _initialize_postgresql_client(self, config: Optional[Dict] = None) -> Mem0Memory: + """Initialize a Mem0 client with PostgreSQL backend. + + Args: + config: Optional configuration dictionary to override defaults. + + Returns: + An initialized Mem0Memory instance configured for PostgreSQL. + + Raises: + ValueError: If required PostgreSQL environment variables are missing. + """ + # Validate required environment variables + required_vars = ["POSTGRESQL_HOST", "POSTGRESQL_USER", "POSTGRESQL_PASSWORD"] + missing_vars = [var for var in required_vars if not os.environ.get(var)] + if missing_vars: + raise ValueError(f"Missing required PostgreSQL environment variables: {', '.join(missing_vars)}") + + # Get PostgreSQL configuration + pg_config = self._get_postgresql_config() + + # Validate OpenAI API key if using OpenAI + provider = os.environ.get("MEM0_LLM_PROVIDER", "aws_bedrock") + if provider == "openai" and not os.environ.get("OPENAI_API_KEY"): + raise ValueError("OPENAI_API_KEY environment variable is required when using OpenAI provider") + + # Merge with user-provided config if any + merged_config = self._merge_configs(pg_config, config) + + return Mem0Memory.from_config(config_dict=merged_config) + def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory: """Initialize a Mem0 client with OpenSearch backend. @@ -296,12 +355,24 @@ def _merge_config(self, config: Optional[Dict] = None) -> Dict: Returns: A merged configuration dictionary. """ - merged_config = self.DEFAULT_CONFIG.copy() - if not config: + return self._merge_configs(self.DEFAULT_CONFIG, config) + + def _merge_configs(self, base_config: Dict, override_config: Optional[Dict] = None) -> Dict: + """Merge two configuration dictionaries. + + Args: + base_config: Base configuration dictionary + override_config: Optional configuration to merge into base + + Returns: + A merged configuration dictionary. + """ + merged_config = base_config.copy() + if not override_config: return merged_config - # Deep merge the configs - for key, value in config.items(): + # Merge the configs + for key, value in override_config.items(): if key in merged_config and isinstance(value, dict) and isinstance(merged_config[key], dict): merged_config[key].update(value) else: diff --git a/tests/test_mem0.py b/tests/test_mem0.py index be836788..279fe88f 100644 --- a/tests/test_mem0.py +++ b/tests/test_mem0.py @@ -11,6 +11,7 @@ import pytest from strands import Agent from strands.types.tools import ToolUse + from strands_tools import mem0_memory from strands_tools.mem0_memory import Mem0ServiceClient @@ -523,3 +524,279 @@ def test_faiss_client(mock_mem0_memory, mock_tool): # Assertions assert result["status"] == "success" assert "Test memory content" in str(result["content"][0]["text"]) + + +@patch.dict( + os.environ, + { + "POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com", + "POSTGRESQL_USER": "test_user", + "POSTGRESQL_PASSWORD": "test_password", + "DB_NAME": "test_db", + "MEM0_LLM_PROVIDER": "openai", + "MEM0_LLM_MODEL": "gpt-4", + "MEM0_EMBEDDER_PROVIDER": "openai", + "MEM0_EMBEDDER_MODEL": "text-embedding-3-large", + "OPENAI_API_KEY": "test-api-key", + }, +) +@patch("strands_tools.mem0_memory.Mem0ServiceClient") +def test_postgresql_store_memory(mock_mem0_client, mock_mem0_service_client, mock_tool): + """Test PostgreSQL store memory functionality.""" + # Setup mocks + mock_mem0_client.return_value = mock_mem0_service_client + + # Configure the mock_tool + mock_tool.get.side_effect = lambda key, default=None: { + "toolUseId": "test-id", + "input": { + "action": "store", + "content": "Test memory content", + "user_id": "test_user", + "metadata": {"category": "test"}, + }, + }.get(key, default) + + # Mock data + store_response = [ + { + "event": "store", + "memory": "Test memory content", + "id": "mem123", + "created_at": "2024-03-20T10:00:00Z", + } + ] + + # Configure mocks + mock_mem0_service_client.store_memory.return_value = store_response + + # Call the memory function + result = mem0_memory.mem0_memory(tool=mock_tool) + + # Assertions + assert result["status"] == "success" + assert result["content"][0]["text"] == json.dumps(store_response, indent=2) + + +@patch.dict( + os.environ, + { + "POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com", + "POSTGRESQL_USER": "test_user", + "POSTGRESQL_PASSWORD": "test_password", + "DB_NAME": "test_db", + }, +) +@patch("strands_tools.mem0_memory.Mem0ServiceClient") +def test_postgresql_get_memory(mock_mem0_client, mock_mem0_service_client, mock_tool): + """Test PostgreSQL get memory functionality.""" + # Setup mocks + mock_mem0_client.return_value = mock_mem0_service_client + + # Configure the mock_tool + mock_tool.get.side_effect = lambda key, default=None: { + "toolUseId": "test-id", + "input": {"action": "get", "memory_id": "mem123"}, + }.get(key, default) + + # Mock data + get_response = { + "id": "mem123", + "memory": "Test memory content", + "created_at": "2024-03-20T10:00:00Z", + "user_id": "test_user", + "metadata": {"category": "test"}, + } + + # Configure mocks + mock_mem0_service_client.get_memory.return_value = get_response + + # Call the memory function + result = mem0_memory.mem0_memory(tool=mock_tool) + + # Assertions + assert result["status"] == "success" + assert isinstance(result["content"], list) + assert len(result["content"]) > 0 + assert "text" in result["content"][0] + memory = json.loads(result["content"][0]["text"]) + assert memory["id"] == "mem123" + assert memory["memory"] == "Test memory content" + assert memory["user_id"] == "test_user" + assert memory["metadata"] == {"category": "test"} + + +@patch.dict( + os.environ, + { + "POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com", + "POSTGRESQL_USER": "test_user", + "POSTGRESQL_PASSWORD": "test_password", + "DB_NAME": "test_db", + }, +) +@patch("strands_tools.mem0_memory.Mem0ServiceClient") +def test_postgresql_list_memories(mock_mem0_client, mock_mem0_service_client, mock_tool): + """Test PostgreSQL list memories functionality.""" + # Setup mocks + mock_mem0_client.return_value = mock_mem0_service_client + + # Configure the mock_tool + mock_tool.get.side_effect = lambda key, default=None: { + "toolUseId": "test-id", + "input": {"action": "list", "user_id": "test_user"}, + }.get(key, default) + + # Mock data for list_memories response + list_response = { + "results": [ + { + "id": "mem123", + "memory": "Test memory content", + "created_at": "2024-03-20T10:00:00Z", + "user_id": "test_user", + "metadata": {"category": "test"}, + } + ] + } + + # Configure mocks + mock_mem0_service_client.list_memories.return_value = list_response + + # Call the memory function + result = mem0_memory.mem0_memory(tool=mock_tool) + + # Assertions + assert result["status"] == "success" + assert isinstance(result["content"], list) + assert len(result["content"]) > 0 + assert "text" in result["content"][0] + # Parse the JSON string in text + memories = json.loads(result["content"][0]["text"]) + assert isinstance(memories, list) + assert len(memories) > 0 + assert "id" in memories[0] + assert memories[0]["id"] == "mem123" + + +@patch.dict( + os.environ, + { + "POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com", + "POSTGRESQL_USER": "test_user", + "POSTGRESQL_PASSWORD": "test_password", + "DB_NAME": "test_db", + }, +) +@patch("strands_tools.mem0_memory.Mem0ServiceClient") +def test_postgresql_retrieve_memories(mock_mem0_client, mock_mem0_service_client, mock_tool): + """Test PostgreSQL retrieve memories functionality.""" + # Setup mocks + mock_mem0_client.return_value = mock_mem0_service_client + + # Configure the mock_tool + mock_tool.get.side_effect = lambda key, default=None: { + "toolUseId": "test-id", + "input": {"action": "retrieve", "query": "test query", "user_id": "test_user"}, + }.get(key, default) + + # Mock data for search_memories response + retrieve_response = { + "results": [ + { + "id": "mem123", + "memory": "Test memory content", + "score": 0.85, + "created_at": "2024-03-20T10:00:00Z", + "user_id": "test_user", + "metadata": {"category": "test"}, + } + ] + } + + # Configure mocks + mock_mem0_service_client.search_memories.return_value = retrieve_response + + # Call the memory function + result = mem0_memory.mem0_memory(tool=mock_tool) + + # Assertions + assert result["status"] == "success" + assert isinstance(result["content"], list) + assert len(result["content"]) > 0 + assert "text" in result["content"][0] + # Parse the JSON string in text + memories = json.loads(result["content"][0]["text"]) + assert isinstance(memories, list) + assert len(memories) > 0 + assert "id" in memories[0] + assert memories[0]["id"] == "mem123" + + +@patch.dict( + os.environ, + { + "POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com", + "POSTGRESQL_USER": "test_user", + "POSTGRESQL_PASSWORD": "test_password", + "DB_NAME": "test_db", + "BYPASS_TOOL_CONSENT": "true", + }, +) +@patch("strands_tools.mem0_memory.Mem0ServiceClient") +def test_postgresql_delete_memory(mock_mem0_client, mock_mem0_service_client, mock_tool): + """Test PostgreSQL delete memory functionality with BYPASS_TOOL_CONSENT mode enabled.""" + # Setup mocks + mock_mem0_client.return_value = mock_mem0_service_client + + # Configure the mock_tool + mock_tool.get.side_effect = lambda key, default=None: { + "toolUseId": "test-id", + "input": {"action": "delete", "memory_id": "mem123"}, + }.get(key, default) + + # Configure mocks + mock_mem0_service_client.delete_memory.return_value = {"status": "success"} + + # Call the memory function + result = mem0_memory.mem0_memory(tool=mock_tool) + + # Assertions + assert result["status"] == "success" + assert "Memory mem123 deleted successfully" in str(result["content"][0]["text"]) + + # Verify correct functions were called + mock_mem0_service_client.delete_memory.assert_called_once() + call_args = mock_mem0_service_client.delete_memory.call_args[0] + assert call_args[0] == "mem123" + + +@patch.dict( + os.environ, + { + "POSTGRESQL_HOST": "test-cluster.cluster-abc123.us-west-2.rds.amazonaws.com", + "POSTGRESQL_USER": "test_user", + # Missing POSTGRESQL_PASSWORD + "MEM0_LLM_PROVIDER": "openai", + "OPENAI_API_KEY": "test-api-key", + }, +) +def test_postgresql_missing_required_vars(mock_tool): + """Test PostgreSQL client with missing required environment variables.""" + # Configure the mock_tool + mock_tool.get.side_effect = lambda key, default=None: { + "toolUseId": "test-id", + "input": { + "action": "store", + "content": "Test memory content", + "user_id": "test_user", + }, + }.get(key, default) + + # Call the memory function + result = mem0_memory.mem0_memory(tool=mock_tool) + + # Assertions + assert result["status"] == "error" + assert "Missing required PostgreSQL environment variables" in str(result["content"][0]["text"]) + assert "POSTGRESQL_PASSWORD" in str(result["content"][0]["text"])