diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index b43298d9..ac06197a 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -27,6 +27,7 @@ class QdrantVecDBConfig(BaseVecDBConfig): host: str | None = Field(default=None, description="Host for Qdrant") port: int | None = Field(default=None, description="Port for Qdrant") path: str | None = Field(default=None, description="Path for Qdrant") + api_key: str | None = Field(default=None, description="Optional API key for Qdrant authentication") @model_validator(mode="after") def set_default_path(self): diff --git a/src/memos/vec_dbs/qdrant.py b/src/memos/vec_dbs/qdrant.py index a0ebf1d8..b8ca7c8d 100644 --- a/src/memos/vec_dbs/qdrant.py +++ b/src/memos/vec_dbs/qdrant.py @@ -33,9 +33,15 @@ def __init__(self, config: QdrantVecDBConfig): "(e.g., via Docker: https://qdrant.tech/documentation/quickstart/)." ) - self.client = QdrantClient( - host=self.config.host, port=self.config.port, path=self.config.path - ) + client_kwargs = { + "host": self.config.host, + "port": self.config.port, + "path": self.config.path, + } + if self.config.api_key: + client_kwargs["api_key"] = self.config.api_key + + self.client = QdrantClient(**client_kwargs) self.create_collection() def create_collection(self) -> None: diff --git a/tests/configs/test_vec_db.py b/tests/configs/test_vec_db.py index b41e775a..ec739460 100644 --- a/tests/configs/test_vec_db.py +++ b/tests/configs/test_vec_db.py @@ -40,7 +40,7 @@ def test_qdrant_vec_db_config(): required_fields=[ "collection_name", ], - optional_fields=["vector_dimension", "distance_metric", "host", "port", "path"], + optional_fields=["vector_dimension", "distance_metric", "host", "port", "path", "api_key"], ) check_config_instantiation_valid( diff --git a/tests/vec_dbs/test_qdrant.py b/tests/vec_dbs/test_qdrant.py index 828240ae..519298c8 100644 --- a/tests/vec_dbs/test_qdrant.py +++ b/tests/vec_dbs/test_qdrant.py @@ -113,3 +113,31 @@ def test_get_all(vec_db): results = vec_db.get_all() assert len(results) == 1 assert isinstance(results[0], VecDBItem) + + +def test_client_receives_api_key(): + from unittest.mock import patch + from memos import settings + from memos.configs.vec_db import VectorDBConfigFactory + from memos.vec_dbs.factory import VecDBFactory + + api_key = "your_secure_api_key_here_change_in_production" + with patch("qdrant_client.QdrantClient") as mock_client: + cfg = VectorDBConfigFactory.model_validate( + { + "backend": "qdrant", + "config": { + "collection_name": "test_collection", + "vector_dimension": 4, + "distance_metric": "cosine", + "path": str(settings.MEMOS_DIR / "qdrant"), + "api_key": api_key, + }, + } + ) + _ = VecDBFactory.from_config(cfg) + + # Assert that QdrantClient was called with api_key keyword argument + assert mock_client.called + kwargs = mock_client.call_args.kwargs + assert kwargs.get("api_key") == api_key