Skip to content

Commit 87cd10c

Browse files
Anush008thinkall
authored andcommitted
feat: Qdrant support for the VectorDB interface (#3035)
* feat: Qdrant support * chore: pre-defined vector db * Fix issues --------- Co-authored-by: Li Jiang <[email protected]>
1 parent c628e46 commit 87cd10c

File tree

4 files changed

+401
-2
lines changed

4 files changed

+401
-2
lines changed

autogen/agentchat/contrib/retrieve_user_proxy_agent.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import hashlib
22
import os
33
import re
4+
import uuid
45
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
56

67
from IPython import get_ipython
@@ -365,7 +366,11 @@ def _init_db(self):
365366
else:
366367
all_docs_ids = set()
367368

368-
chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
369+
chunk_ids = (
370+
[hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
371+
if not self._vector_db.type == "qdrant"
372+
else [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
373+
)
369374
chunk_ids_set = set(chunk_ids)
370375
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
371376
docs = [

autogen/agentchat/contrib/vectordb/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class VectorDBFactory:
185185
Factory class for creating vector databases.
186186
"""
187187

188-
PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
188+
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"]
189189

190190
@staticmethod
191191
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@@ -207,6 +207,10 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
207207
from .pgvectordb import PGVectorDB
208208

209209
return PGVectorDB(**kwargs)
210+
if db_type.lower() in ["qdrant", "qdrantdb"]:
211+
from .qdrant import QdrantVectorDB
212+
213+
return QdrantVectorDB(**kwargs)
210214
else:
211215
raise ValueError(
212216
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
import abc
2+
import logging
3+
import os
4+
from typing import Callable, List, Optional, Sequence, Tuple, Union
5+
6+
from .base import Document, ItemID, QueryResults, VectorDB
7+
from .utils import get_logger
8+
9+
try:
10+
from qdrant_client import QdrantClient, models
11+
except ImportError:
12+
raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
13+
14+
logger = get_logger(__name__)
15+
16+
Embeddings = Union[Sequence[float], Sequence[int]]
17+
18+
19+
class EmbeddingFunction(abc.ABC):
20+
@abc.abstractmethod
21+
def __call__(self, inputs: List[str]) -> List[Embeddings]:
22+
raise NotImplementedError
23+
24+
25+
class FastEmbedEmbeddingFunction(EmbeddingFunction):
26+
"""Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
27+
28+
def __init__(
29+
self,
30+
model_name: str = "BAAI/bge-small-en-v1.5",
31+
batch_size: int = 256,
32+
cache_dir: Optional[str] = None,
33+
threads: Optional[int] = None,
34+
parallel: Optional[int] = None,
35+
**kwargs,
36+
):
37+
"""Initialize fastembed.TextEmbedding.
38+
39+
Args:
40+
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
41+
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
42+
Defaults to 256.
43+
cache_dir (str, optional): The path to the model cache directory.\
44+
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
45+
threads (int, optional): The number of threads single onnxruntime session can use.
46+
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
47+
If `0`, use all available cores.\
48+
If `None`, don't use data-parallel processing, use default onnxruntime threading.\
49+
Defaults to None.
50+
**kwargs: Additional options to pass to fastembed.TextEmbedding
51+
Raises:
52+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-small-en-v1.5.
53+
"""
54+
try:
55+
from fastembed import TextEmbedding
56+
except ImportError as e:
57+
raise ValueError(
58+
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
59+
) from e
60+
self._batch_size = batch_size
61+
self._parallel = parallel
62+
self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
63+
64+
def __call__(self, inputs: List[str]) -> List[Embeddings]:
65+
embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
66+
67+
return [embedding.tolist() for embedding in embeddings]
68+
69+
70+
class QdrantVectorDB(VectorDB):
71+
"""
72+
A vector database implementation that uses Qdrant as the backend.
73+
"""
74+
75+
def __init__(
76+
self,
77+
*,
78+
client=None,
79+
embedding_function: EmbeddingFunction = None,
80+
content_payload_key: str = "_content",
81+
metadata_payload_key: str = "_metadata",
82+
collection_options: dict = {},
83+
**kwargs,
84+
) -> None:
85+
"""
86+
Initialize the vector database.
87+
88+
Args:
89+
client: qdrant_client.QdrantClient | An instance of QdrantClient.
90+
embedding_function: Callable | The embedding function used to generate the vector representation
91+
of the documents. Defaults to FastEmbedEmbeddingFunction.
92+
collection_options: dict | The options for creating the collection.
93+
kwargs: dict | Additional keyword arguments.
94+
"""
95+
self.client: QdrantClient = client if client is not None else QdrantClient(location=":memory:")
96+
self.embedding_function = FastEmbedEmbeddingFunction() or embedding_function
97+
self.collection_options = collection_options
98+
self.content_payload_key = content_payload_key
99+
self.metadata_payload_key = metadata_payload_key
100+
self.type = "qdrant"
101+
102+
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
103+
"""
104+
Create a collection in the vector database.
105+
106+
Args:
107+
collection_name: str | The name of the collection.
108+
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
109+
get_or_create: bool | Whether to get the collection if it exists. Default is True.
110+
111+
Returns:
112+
Any | The collection object.
113+
"""
114+
embeddings_size = len(self.embedding_function(["test"])[0])
115+
116+
if self.client.collection_exists(collection_name) and overwrite:
117+
self.client.delete_collection(collection_name)
118+
119+
if not self.client.collection_exists(collection_name):
120+
self.client.create_collection(
121+
collection_name,
122+
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
123+
**self.collection_options,
124+
)
125+
126+
def get_collection(self, collection_name: str = None):
127+
"""
128+
Get the collection from the vector database.
129+
130+
Args:
131+
collection_name: str | The name of the collection.
132+
133+
Returns:
134+
Any | The collection object.
135+
"""
136+
if collection_name is None:
137+
raise ValueError("The collection name is required.")
138+
139+
return self.client.get_collection(collection_name)
140+
141+
def delete_collection(self, collection_name: str) -> None:
142+
"""Delete the collection from the vector database.
143+
144+
Args:
145+
collection_name: str | The name of the collection.
146+
147+
Returns:
148+
Any
149+
"""
150+
return self.client.delete_collection(collection_name)
151+
152+
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
153+
"""
154+
Insert documents into the collection of the vector database.
155+
156+
Args:
157+
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
158+
collection_name: str | The name of the collection. Default is None.
159+
upsert: bool | Whether to update the document if it exists. Default is False.
160+
kwargs: Dict | Additional keyword arguments.
161+
162+
Returns:
163+
None
164+
"""
165+
if not docs:
166+
return
167+
if any(doc.get("content") is None for doc in docs):
168+
raise ValueError("The document content is required.")
169+
if any(doc.get("id") is None for doc in docs):
170+
raise ValueError("The document id is required.")
171+
172+
if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
173+
logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
174+
175+
self.client.upsert(collection_name, points=self._documents_to_points(docs))
176+
177+
def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
178+
if not docs:
179+
return
180+
if any(doc.get("id") is None for doc in docs):
181+
raise ValueError("The document id is required.")
182+
if any(doc.get("content") is None for doc in docs):
183+
raise ValueError("The document content is required.")
184+
if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
185+
return self.client.upsert(collection_name, points=self._documents_to_points(docs))
186+
187+
raise ValueError("Some IDs do not exist. Skipping update")
188+
189+
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
190+
"""
191+
Delete documents from the collection of the vector database.
192+
193+
Args:
194+
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
195+
collection_name: str | The name of the collection. Default is None.
196+
kwargs: Dict | Additional keyword arguments.
197+
198+
Returns:
199+
None
200+
"""
201+
self.client.delete(collection_name, ids)
202+
203+
def retrieve_docs(
204+
self,
205+
queries: List[str],
206+
collection_name: str = None,
207+
n_results: int = 10,
208+
distance_threshold: float = 0,
209+
**kwargs,
210+
) -> QueryResults:
211+
"""
212+
Retrieve documents from the collection of the vector database based on the queries.
213+
214+
Args:
215+
queries: List[str] | A list of queries. Each query is a string.
216+
collection_name: str | The name of the collection. Default is None.
217+
n_results: int | The number of relevant documents to return. Default is 10.
218+
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
219+
returned. Don't filter with it if < 0. Default is 0.
220+
kwargs: Dict | Additional keyword arguments.
221+
222+
Returns:
223+
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
224+
the distance.
225+
"""
226+
embeddings = self.embedding_function(queries)
227+
requests = [
228+
models.SearchRequest(
229+
vector=embedding,
230+
limit=n_results,
231+
score_threshold=distance_threshold,
232+
with_payload=True,
233+
with_vector=False,
234+
)
235+
for embedding in embeddings
236+
]
237+
238+
batch_results = self.client.search_batch(collection_name, requests)
239+
return [self._scored_points_to_documents(results) for results in batch_results]
240+
241+
def get_docs_by_ids(
242+
self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
243+
) -> List[Document]:
244+
"""
245+
Retrieve documents from the collection of the vector database based on the ids.
246+
247+
Args:
248+
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
249+
collection_name: str | The name of the collection. Default is None.
250+
include: List[str] | The fields to include. Default is True.
251+
If None, will include ["metadatas", "documents"], ids will always be included.
252+
kwargs: dict | Additional keyword arguments.
253+
254+
Returns:
255+
List[Document] | The results.
256+
"""
257+
if ids is None:
258+
results = self.client.scroll(collection_name=collection_name, with_payload=include, with_vectors=True)[0]
259+
else:
260+
results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
261+
return [self._point_to_document(result) for result in results]
262+
263+
def _point_to_document(self, point) -> Document:
264+
return {
265+
"id": point.id,
266+
"content": point.payload.get(self.content_payload_key, ""),
267+
"metadata": point.payload.get(self.metadata_payload_key, {}),
268+
"embedding": point.vector,
269+
}
270+
271+
def _points_to_documents(self, points) -> List[Document]:
272+
return [self._point_to_document(point) for point in points]
273+
274+
def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
275+
return self._point_to_document(scored_point), scored_point.score
276+
277+
def _documents_to_points(self, documents: List[Document]):
278+
contents = [document["content"] for document in documents]
279+
embeddings = self.embedding_function(contents)
280+
points = [
281+
models.PointStruct(
282+
id=documents[i]["id"],
283+
vector=embeddings[i],
284+
payload={
285+
self.content_payload_key: documents[i].get("content"),
286+
self.metadata_payload_key: documents[i].get("metadata"),
287+
},
288+
)
289+
for i in range(len(documents))
290+
]
291+
return points
292+
293+
def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
294+
return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
295+
296+
def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
297+
"""
298+
Validates all the IDs exist in the collection
299+
"""
300+
retrieved_ids = [
301+
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
302+
]
303+
304+
if missing_ids := set(ids) - set(retrieved_ids):
305+
logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
306+
return False
307+
308+
return True
309+
310+
def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
311+
"""
312+
Validate none of the IDs exist in the collection
313+
"""
314+
retrieved_ids = [
315+
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
316+
]
317+
318+
if existing_ids := set(ids) & set(retrieved_ids):
319+
logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
320+
return False
321+
322+
return True

0 commit comments

Comments
 (0)