-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Native OpenSearch support for Mem0 (#2211)
- Loading branch information
1 parent
6e781f6
commit f4c0f98
Showing
9 changed files
with
446 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
[OpenSearch](https://opensearch.org/) is an open-source, enterprise-grade search and observability suite that brings order to unstructured data at scale. OpenSearch supports k-NN (k-Nearest Neighbors) and allows you to store and retrieve high-dimensional vector embeddings efficiently. | ||
|
||
### Installation | ||
|
||
OpenSearch support requires additional dependencies. Install them with: | ||
|
||
```bash | ||
pip install opensearch>=2.8.0 | ||
``` | ||
|
||
### Usage | ||
|
||
```python | ||
import os | ||
from mem0 import Memory | ||
|
||
os.environ["OPENAI_API_KEY"] = "sk-xx" | ||
|
||
config = { | ||
"vector_store": { | ||
"provider": "opensearch", | ||
"config": { | ||
"collection_name": "mem0", | ||
"host": "localhost", | ||
"port": 9200, | ||
"embedding_model_dims": 1536 | ||
} | ||
} | ||
} | ||
|
||
m = Memory.from_config(config) | ||
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) | ||
``` | ||
|
||
### Config | ||
|
||
Let's see the available parameters for the `opensearch` config: | ||
|
||
| Parameter | Description | Default Value | | ||
| ---------------------- | -------------------------------------------------- | ------------- | | ||
| `collection_name` | The name of the index to store the vectors | `mem0` | | ||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | | ||
| `host` | The host where the OpenSearch server is running | `localhost` | | ||
| `port` | The port where the OpenSearch server is running | `9200` | | ||
| `api_key` | API key for authentication | `None` | | ||
| `user` | Username for basic authentication | `None` | | ||
| `password` | Password for basic authentication | `None` | | ||
| `verify_certs` | Whether to verify SSL certificates | `False` | | ||
| `auto_create_index` | Whether to automatically create the index | `True` | | ||
| `use_ssl` | Whether to use SSL for connection | `False` | | ||
|
||
### Features | ||
|
||
- Fast and Efficient Vector Search | ||
- Can be deployed on-premises, in containers, or on cloud platforms like AWS OpenSearch Service. | ||
- Multiple Authentication and Security Methods (Basic Authentication, API Keys, LDAP, SAML, and OpenID Connect) | ||
- Automatic index creation with optimized mappings for vector search | ||
- Memory Optimization through Disk-Based Vector Search and Quantization | ||
- Real-Time Analytics and Observability |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Any, Dict, Optional | ||
|
||
from pydantic import BaseModel, Field, model_validator | ||
|
||
|
||
class OpenSearchConfig(BaseModel): | ||
collection_name: str = Field("mem0", description="Name of the index") | ||
host: str = Field("localhost", description="OpenSearch host") | ||
port: int = Field(9200, description="OpenSearch port") | ||
user: Optional[str] = Field(None, description="Username for authentication") | ||
password: Optional[str] = Field(None, description="Password for authentication") | ||
api_key: Optional[str] = Field(None, description="API key for authentication (if applicable)") | ||
embedding_model_dims: int = Field(1536, description="Dimension of the embedding vector") | ||
verify_certs: bool = Field(False, description="Verify SSL certificates (default False for OpenSearch)") | ||
use_ssl: bool = Field(False, description="Use SSL for connection (default False for OpenSearch)") | ||
auto_create_index: bool = Field(True, description="Automatically create index during initialization") | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
# Check if host is provided | ||
if not values.get("host"): | ||
raise ValueError("Host must be provided for OpenSearch") | ||
|
||
# Authentication: Either API key or user/password must be provided | ||
if not any([values.get("api_key"), (values.get("user") and values.get("password"))]): | ||
raise ValueError("Either api_key or user/password must be provided for OpenSearch authentication") | ||
|
||
return values | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
allowed_fields = set(cls.model_fields.keys()) | ||
input_fields = set(values.keys()) | ||
extra_fields = input_fields - allowed_fields | ||
if extra_fields: | ||
raise ValueError( | ||
f"Extra fields not allowed: {', '.join(extra_fields)}. " | ||
f"Allowed fields: {', '.join(allowed_fields)}" | ||
) | ||
return values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import logging | ||
from typing import Any, Dict, List, Optional | ||
|
||
try: | ||
from opensearchpy import OpenSearch | ||
from opensearchpy.helpers import bulk | ||
except ImportError: | ||
raise ImportError("OpenSearch requires extra dependencies. Install with `pip install opensearch-py`") from None | ||
|
||
from pydantic import BaseModel | ||
|
||
from mem0.configs.vector_stores.opensearch import OpenSearchConfig | ||
from mem0.vector_stores.base import VectorStoreBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OutputData(BaseModel): | ||
id: str | ||
score: float | ||
payload: Dict | ||
|
||
|
||
class OpenSearchDB(VectorStoreBase): | ||
def __init__(self, **kwargs): | ||
config = OpenSearchConfig(**kwargs) | ||
|
||
# Initialize OpenSearch client | ||
self.client = OpenSearch( | ||
hosts=[{"host": config.host, "port": config.port or 9200}], | ||
http_auth=(config.user, config.password) if (config.user and config.password) else None, | ||
use_ssl=config.use_ssl, | ||
verify_certs=config.verify_certs, | ||
) | ||
|
||
self.collection_name = config.collection_name | ||
self.vector_dim = config.embedding_model_dims | ||
|
||
# Create index only if auto_create_index is True | ||
if config.auto_create_index: | ||
self.create_index() | ||
|
||
def create_index(self) -> None: | ||
"""Create OpenSearch index with proper mappings if it doesn't exist.""" | ||
index_settings = { | ||
# ToDo change replicas to 1 | ||
"settings": { | ||
"index": {"number_of_replicas": 1, "number_of_shards": 5, "refresh_interval": "1s", "knn": True} | ||
}, | ||
"mappings": { | ||
"properties": { | ||
"text": {"type": "text"}, | ||
"vector": { | ||
"type": "knn_vector", | ||
"dimension": self.vector_dim | ||
}, | ||
"metadata": {"type": "object", "properties": {"user_id": {"type": "keyword"}}}, | ||
} | ||
}, | ||
} | ||
|
||
if not self.client.indices.exists(index=self.collection_name): | ||
self.client.indices.create(index=self.collection_name, body=index_settings) | ||
logger.info(f"Created index {self.collection_name}") | ||
else: | ||
logger.info(f"Index {self.collection_name} already exists") | ||
|
||
def create_col(self, name: str, vector_size: int) -> None: | ||
"""Create a new collection (index in OpenSearch).""" | ||
index_settings = { | ||
"mappings": { | ||
"properties": { | ||
"vector": { | ||
"type": "knn_vector", | ||
"dimension": vector_size, | ||
"method": { "engine": "lucene", "name": "hnsw", "space_type": "cosinesimil"}, | ||
}, | ||
"payload": {"type": "object"}, | ||
"id": {"type": "keyword"}, | ||
} | ||
} | ||
} | ||
|
||
if not self.client.indices.exists(index=name): | ||
self.client.indices.create(index=name, body=index_settings) | ||
logger.info(f"Created index {name}") | ||
|
||
def insert( | ||
self, vectors: List[List[float]], payloads: Optional[List[Dict]] = None, ids: Optional[List[str]] = None | ||
) -> List[OutputData]: | ||
"""Insert vectors into the index.""" | ||
if not ids: | ||
ids = [str(i) for i in range(len(vectors))] | ||
|
||
if payloads is None: | ||
payloads = [{} for _ in range(len(vectors))] | ||
|
||
actions = [] | ||
for i, (vec, id_) in enumerate(zip(vectors, ids)): | ||
action = { | ||
"_index": self.collection_name, | ||
"_id": id_, | ||
"_source": { | ||
"vector": vec, | ||
"metadata": payloads[i], # Store metadata in the metadata field | ||
}, | ||
} | ||
actions.append(action) | ||
|
||
bulk(self.client, actions) | ||
|
||
results = [] | ||
for i, id_ in enumerate(ids): | ||
results.append(OutputData(id=id_, score=1.0, payload=payloads[i])) | ||
return results | ||
|
||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]: | ||
"""Search for similar vectors using OpenSearch k-NN search with pre-filtering.""" | ||
search_query = { | ||
"size": limit, | ||
"query": { | ||
"knn": { | ||
"vector": { | ||
"vector": query, | ||
"k": limit, | ||
} | ||
} | ||
} | ||
} | ||
|
||
if filters: | ||
filter_conditions = [{"term": {f"metadata.{key}": value}} for key, value in filters.items()] | ||
search_query["query"]["knn"]["vector"]["filter"] = { "bool": {"filter": filter_conditions} } | ||
|
||
response = self.client.search(index=self.collection_name, body=search_query) | ||
|
||
results = [ | ||
OutputData(id=hit["_id"], score=hit["_score"], payload=hit["_source"].get("metadata", {})) | ||
for hit in response["hits"]["hits"] | ||
] | ||
return results | ||
|
||
def delete(self, vector_id: str) -> None: | ||
"""Delete a vector by ID.""" | ||
self.client.delete(index=self.collection_name, id=vector_id) | ||
|
||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[Dict] = None) -> None: | ||
"""Update a vector and its payload.""" | ||
doc = {} | ||
if vector is not None: | ||
doc["vector"] = vector | ||
if payload is not None: | ||
doc["metadata"] = payload | ||
|
||
self.client.update(index=self.collection_name, id=vector_id, body={"doc": doc}) | ||
|
||
def get(self, vector_id: str) -> Optional[OutputData]: | ||
"""Retrieve a vector by ID.""" | ||
try: | ||
response = self.client.get(index=self.collection_name, id=vector_id) | ||
return OutputData(id=response["_id"], score=1.0, payload=response["_source"].get("metadata", {})) | ||
except Exception as e: | ||
logger.error(f"Error retrieving vector {vector_id}: {e}") | ||
return None | ||
|
||
def list_cols(self) -> List[str]: | ||
"""List all collections (indices).""" | ||
return list(self.client.indices.get_alias().keys()) | ||
|
||
def delete_col(self) -> None: | ||
"""Delete a collection (index).""" | ||
self.client.indices.delete(index=self.collection_name) | ||
|
||
def col_info(self, name: str) -> Any: | ||
"""Get information about a collection (index).""" | ||
return self.client.indices.get(index=name) | ||
|
||
def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]: | ||
"""List all memories.""" | ||
query = {"query": {"match_all": {}}} | ||
|
||
if filters: | ||
query["query"] = {"bool": {"must": [{"term": {f"metadata.{key}": value}} for key, value in filters.items()]}} | ||
|
||
if limit: | ||
query["size"] = limit | ||
|
||
response = self.client.search(index=self.collection_name, body=query) | ||
return [[OutputData(id=hit["_id"], score=1.0, payload=hit["_source"].get("metadata", {})) for hit in response["hits"]["hits"]]] |
Oops, something went wrong.