Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

+mdb atlas vectordb [clean_final] #3000

Merged
merged 51 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
28a009b
+mdb atlas
Jun 22, 2024
383168e
Update test/agentchat/contrib/vectordb/test_mongodb.py
ranfysvalle02 Jun 22, 2024
f531568
update test_mongodb.py; we dont need to do the assert .collection_nam…
Jun 22, 2024
a1c385c
Try fix mongodb service
thinkall Jun 22, 2024
bb1a183
Try fix mongodb service
thinkall Jun 22, 2024
1e30425
Update username and password
thinkall Jun 22, 2024
1d6dbaf
Merge branch 'main' into main
thinkall Jun 22, 2024
e0f3c59
Update autogen/agentchat/contrib/vectordb/mongodb.py
thinkall Jun 22, 2024
d6a1162
closer --- but im not super thrilled about the solution...
Jun 23, 2024
334cc25
Merge branch 'main' into main
thinkall Jun 24, 2024
de48057
PYTHON-4506 Expanded tests and simplified vector search pipelines
caseyclements Jun 24, 2024
f5e5fdf
Merge branch 'main' into pull/3000
ranfysvalle02 Jun 24, 2024
6245c30
Merge pull request #1 from caseyclements/pull/3000
ranfysvalle02 Jun 24, 2024
a367426
Update mongodb.py
ranfysvalle02 Jun 24, 2024
ffa3e38
Update mongodb.py - Casey
ranfysvalle02 Jun 24, 2024
d741ef6
Merge branch 'main' into main
ranfysvalle02 Jun 27, 2024
d2fbd02
Merge branch 'main' into main
ranfysvalle02 Jun 30, 2024
3646d1e
search_index_magic
Jun 30, 2024
3e0ac8e
Fix format
thinkall Jun 30, 2024
95e2f79
Fix tests
thinkall Jun 30, 2024
64a157c
hacking trying to figure this out
Jul 1, 2024
17d02d1
Merge branch 'main' of https://github.com/ranfysvalle02/autogen
Jul 1, 2024
6cfb689
Merge branch 'main' into main
thinkall Jul 2, 2024
66e46e8
Merge branch 'main' into main
thinkall Jul 2, 2024
7405463
Streamline checks for indexes in construction and restructure tests
Jibola Jul 18, 2024
7d778fe
Add tests for score_threshold, embedding inclusion, and multiple quer…
Jibola Jul 18, 2024
0fcf320
Merge branch 'main' into main
Jibola Jul 19, 2024
0921c53
refactored create_collection to meet base object requirements
Jibola Jul 19, 2024
01f96c7
lint
Jibola Jul 19, 2024
311259e
change the localhost port to 27017
Jibola Jul 19, 2024
cf97466
add test to check that no embedding is there unless explicitly provided
Jibola Jul 19, 2024
6df51df
Merge branch 'main' into main
ranfysvalle02 Jul 20, 2024
e003d1f
Merge branch 'main' into main
thinkall Jul 21, 2024
8491d5a
Update logger
thinkall Jul 21, 2024
1b41e18
Add test get docs with ids=None
thinkall Jul 21, 2024
14776e4
Rename and update notebook
thinkall Jul 21, 2024
de12cd1
have index management include waiting behaviors
Jibola Jul 23, 2024
5e00b2d
Adds further optional waits or users and tests. Cleans up upsert.
caseyclements Jul 23, 2024
f3a2a0c
Merge branch 'microsoft:main' into main
cozypet Jul 24, 2024
347fd0e
ensure the embedding size for multiple embedding inputs is equal to d…
Jibola Jul 24, 2024
5790e48
fix up tests and add configuration to ensure documents and indexes ar…
Jibola Jul 25, 2024
cdc6b6d
Merge branch 'main' into main
Jibola Jul 25, 2024
8804087
fix import failure
Jibola Jul 25, 2024
1f41bbd
Merge branch 'main' of https://github.com/ranfysvalle02/autogen into …
Jibola Jul 25, 2024
ead65ca
adjust typing for 3.9
Jibola Jul 25, 2024
892b81a
fix up the notebook output
Jibola Jul 25, 2024
2cca0c0
changed language to communicate time taken on first init_chat call
Jibola Jul 25, 2024
2f1bb68
Merge branch 'main' into main
Jibola Jul 25, 2024
7a44641
replace environment variable usage
Jibola Jul 25, 2024
bb9d57a
Merge branch 'main' of https://github.com/ranfysvalle02/autogen into …
Jibola Jul 25, 2024
2c788bd
Merge branch 'main' into main
thinkall Jul 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ jobs:
--health-retries 5
ports:
- 5432:5432
mongodb:
image: mongodb/mongodb-atlas-local:latest
restart: unless-stopped
ports:
- "27017:27017"
environment:
MONGODB_INITDB_ROOT_USERNAME: mongodb_user
MONGODB_INITDB_ROOT_PASSWORD: mongodb_password
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -102,6 +110,9 @@ jobs:
- name: Install pgvector when on linux
run: |
pip install -e .[retrievechat-pgvector]
- name: Install mongodb when on linux
run: |
pip install -e .[retrievechat-mongodb]
- name: Install unstructured when python-version is 3.9 and on linux
if: matrix.python-version == '3.9'
run: |
Expand Down
6 changes: 5 additions & 1 deletion autogen/agentchat/contrib/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""

PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb"]

@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
Expand All @@ -207,6 +207,10 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
from .pgvectordb import PGVectorDB

return PGVectorDB(**kwargs)
if db_type.lower() in ["mdb", "mongodb", "atlas"]:
from .mongodb import MongoDBAtlasVectorDB

return MongoDBAtlasVectorDB(**kwargs)
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
Expand Down
292 changes: 292 additions & 0 deletions autogen/agentchat/contrib/vectordb/mongodb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
from typing import Callable, List, Literal

import numpy as np
from pymongo import MongoClient, errors
from pymongo.operations import SearchIndexModel
from sentence_transformers import SentenceTransformer

from .base import Document, ItemID, QueryResults, VectorDB
from .utils import get_logger

logger = get_logger(__name__)


class MongoDBAtlasVectorDB(VectorDB):
"""
A Collection object for MongoDB.
"""

def __init__(
self,
connection_string: str = "",
database_name: str = "vector_db",
embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
):
"""
Initialize the vector database.

Args:
connection_string: str | The MongoDB connection string to connect to. Default is ''.
database_name: str | The name of the database. Default is 'vector_db'.
embedding_function: The embedding function used to generate the vector representation.
"""
if embedding_function:
self.embedding_function = embedding_function
try:
self.client = MongoClient(connection_string)
Jibola marked this conversation as resolved.
Show resolved Hide resolved
self.client.admin.command("ping")
except errors.ServerSelectionTimeoutError as err:
raise ConnectionError("Could not connect to MongoDB server") from err

self.db = self.client[database_name]
self.active_collection = None
# This will get the model dimension size by computing the embeddings dimensions
sentences = [
"The weather is lovely today in paradise.",
]
embeddings = self.embedding_function(sentences)
self.dimensions = len(embeddings[0])

def list_collections(self):
"""
List the collections in the vector database.

Returns:
List[str] | The list of collections.
"""
try:
return self.db.list_collection_names()
except Exception as err:
raise err

def create_collection(
self,
collection_name: str,
overwrite: bool = False,
get_or_create: bool = True,
index_name: str = "default_index",
similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
):
"""
Create a collection in the vector database and create a vector search index in the collection.

Args:
collection_name: str | The name of the collection.
index_name: str | The name of the index.
similarity: str | The similarity metric for the vector search index.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True
"""
# if overwrite is False and get_or_create is False, raise a ValueError
if not overwrite and not get_or_create:
raise ValueError("If overwrite is False, get_or_create must be True.")
# If overwrite is True and the collection already exists, drop the existing collection
collection_names = self.db.list_collection_names()
if overwrite and collection_name in collection_names:
self.db.drop_collection(collection_name)
thinkall marked this conversation as resolved.
Show resolved Hide resolved
# If get_or_create is True and the collection already exists, return the existing collection
if get_or_create and collection_name in collection_names:
return self.db[collection_name]
# If get_or_create is False and the collection already exists, raise a ValueError
if not get_or_create and collection_name in collection_names:
raise ValueError(f"Collection {collection_name} already exists.")

# Create a new collection
collection = self.db.create_collection(collection_name)
# Create a vector search index in the collection
search_index_model = SearchIndexModel(
definition={
"fields": [
{"type": "vector", "numDimensions": self.dimensions, "path": "embedding", "similarity": similarity},
]
},
name=index_name,
type="vectorSearch",
)
# Create the search index
try:
collection.create_search_index(model=search_index_model)
return collection
except Exception as e:
logger.error(f"Error creating search index: {e}")
raise e

def get_collection(self, collection_name: str = None):
"""
Get the collection from the vector database.

Args:
collection_name: str | The name of the collection. Default is None. If None, return the
current active collection.

Returns:
Collection | The collection object.
"""
if collection_name is None:
if self.active_collection is None:
raise ValueError("No collection is specified.")
else:
logger.debug(
f"No collection is specified. Using current active collection {self.active_collection.name}."
)
else:
if collection_name not in self.list_collections():
raise ValueError(f"Collection {collection_name} does not exist.")
if self.active_collection is None:
self.active_collection = self.db[collection_name]
return self.active_collection

def delete_collection(self, collection_name: str):
"""
Delete the collection from the vector database.

Args:
collection_name: str | The name of the collection.
"""
return self.db[collection_name].drop()

def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False):
"""
Insert documents into the collection of the vector database.

Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
"""
if not docs:
return
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")
collection = self.get_collection(collection_name)
for doc in docs:
if "embedding" not in doc:
doc["embedding"] = np.array(self.embedding_function([str(doc["content"])])).tolist()[0]
if upsert:
for doc in docs:
return collection.replace_one({"id": doc["id"]}, doc, upsert=True)
else:
return collection.insert_many(docs)

def update_docs(self, docs: List[Document], collection_name: str = None):
"""
Update documents in the collection of the vector database.

Args:
docs: List[Document] | A list of documents.
collection_name: str | The name of the collection. Default is None.
"""
return self.insert_docs(docs, collection_name, upsert=True)

def delete_docs(self, ids: List[ItemID], collection_name: str = None):
"""
Delete documents from the collection of the vector database.

Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: str | The name of the collection. Default is None.
"""
collection = self.get_collection(collection_name)
return collection.delete_many({"id": {"$in": ids}})

def get_docs_by_ids(self, ids: List[ItemID] = None, collection_name: str = None):
"""
Retrieve documents from the collection of the vector database based on the ids.

Args:
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
"""
results = []
if ids is None:
collection = self.get_collection(collection_name)
results = list(collection.find({}, {"embedding": 0}))
else:
for id in ids:
id = str(id)
collection = self.get_collection(collection_name)
results = list(collection.find({"id": {"$in": ids}}, {"embedding": 0}))
return results

def retrieve_docs(
self,
queries: List[str],
collection_name: str = None,
n_results: int = 10,
distance_threshold: float = -1,
index_name: str = "default",
**kwargs,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.

Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: str | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict | Additional keyword arguments.

Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
results = []
for query_text in queries:
query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
# Find documents with similar vectors using the specified index
search_collection = self.get_collection(collection_name)
pipeline = [
{
"$vectorSearch": {
"index": index_name,
"limit": n_results,
"numCandidates": n_results,
"queryVector": query_vector,
"path": "embedding",
}
},
{"$project": {"score": {"$meta": "vectorSearchScore"}}},
]
if distance_threshold >= 0.00:
similarity_threshold = 1 - distance_threshold
pipeline.append({"$match": {"score": {"gte": similarity_threshold}}})

# do a lookup on the same collection
pipeline.append(
{
"$lookup": {
"from": collection_name,
"localField": "_id",
"foreignField": "_id",
"as": "full_document_array",
}
}
)
pipeline.append(
{
"$addFields": {
"full_document": {
"$arrayElemAt": [
{
"$map": {
"input": "$full_document_array",
"as": "doc",
"in": {"id": "$$doc.id", "content": "$$doc.content"},
}
},
0,
]
}
}
}
)
pipeline.append({"$project": {"full_document_array": 0, "embedding": 0}})
tmp_results = []
for doc in search_collection.aggregate(pipeline):
tmp_results.append((doc["full_document"], 1 - doc["score"]))
results.append(tmp_results)
return results
Loading
Loading