Skip to content

Commit cf94e72

Browse files
committed
feat: Qdrant support
1 parent 55cc542 commit cf94e72

File tree

3 files changed

+388
-0
lines changed

3 files changed

+388
-0
lines changed

autogen/agentchat/contrib/vectordb/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def create_vector_db(db_type: str, **kwargs) -> VectorDB:
207207
from .pgvectordb import PGVectorDB
208208

209209
return PGVectorDB(**kwargs)
210+
if db_type.lower() in ["qdrant", "qdrantdb"]:
211+
from .qdrant import QdrantVectorDB
212+
213+
return QdrantVectorDB(**kwargs)
210214
else:
211215
raise ValueError(
212216
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
import abc
2+
import logging
3+
import os
4+
from typing import Callable, List, Optional, Sequence, Tuple, Union
5+
6+
from .base import Document, ItemID, QueryResults, VectorDB
7+
from .utils import get_logger
8+
9+
try:
10+
from qdrant_client import QdrantClient, models
11+
except ImportError:
12+
raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
13+
14+
logger = get_logger(__name__)
15+
16+
Embeddings = Union[Sequence[float], Sequence[int]]
17+
18+
19+
class EmbeddingFunction(abc.ABC):
20+
@abc.abstractmethod
21+
def __call__(self, inputs: List[str]) -> List[Embeddings]:
22+
raise NotImplementedError
23+
24+
25+
class FastEmbedEmbeddingFunction(EmbeddingFunction):
26+
"""Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
27+
28+
def __init__(
29+
self,
30+
model_name: str = "BAAI/bge-small-en-v1.5",
31+
batch_size: int = 256,
32+
cache_dir: Optional[str] = None,
33+
threads: Optional[int] = None,
34+
parallel: Optional[int] = None,
35+
**kwargs,
36+
):
37+
"""Initialize fastembed.TextEmbedding.
38+
39+
Args:
40+
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
41+
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
42+
Defaults to 256.
43+
cache_dir (str, optional): The path to the model cache directory.\
44+
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
45+
threads (int, optional): The number of threads single onnxruntime session can use.
46+
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
47+
If `0`, use all available cores.\
48+
If `None`, don't use data-parallel processing, use default onnxruntime threading.\
49+
Defaults to None.
50+
**kwargs: Additional options to pass to fastembed.TextEmbedding
51+
Raises:
52+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-small-en-v1.5.
53+
"""
54+
try:
55+
from fastembed import TextEmbedding
56+
except ImportError as e:
57+
raise ValueError(
58+
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
59+
) from e
60+
self._batch_size = batch_size
61+
self._parallel = parallel
62+
self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
63+
64+
def __call__(self, inputs: List[str]) -> List[Embeddings]:
65+
embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
66+
67+
return [embedding.tolist() for embedding in embeddings]
68+
69+
70+
class QdrantVectorDB(VectorDB):
71+
"""
72+
A vector database implementation that uses Qdrant as the backend.
73+
"""
74+
75+
def __init__(
76+
self,
77+
*,
78+
client=None,
79+
embedding_function: EmbeddingFunction = None,
80+
content_payload_key: str = "_content",
81+
metadata_payload_key: str = "_metadata",
82+
collection_options: dict = {},
83+
**kwargs,
84+
) -> None:
85+
"""
86+
Initialize the vector database.
87+
88+
Args:
89+
client: qdrant_client.QdrantClient | An instance of QdrantClient.
90+
embedding_function: Callable | The embedding function used to generate the vector representation
91+
of the documents. Defaults to FastEmbedEmbeddingFunction.
92+
collection_options: dict | The options for creating the collection.
93+
kwargs: dict | Additional keyword arguments.
94+
"""
95+
self.client: QdrantClient = client
96+
self.embedding_function = FastEmbedEmbeddingFunction() or embedding_function
97+
self.collection_options = collection_options
98+
self.content_payload_key = content_payload_key
99+
self.metadata_payload_key = metadata_payload_key
100+
self.type = "qdrant"
101+
102+
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
103+
"""
104+
Create a collection in the vector database.
105+
106+
Args:
107+
collection_name: str | The name of the collection.
108+
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
109+
get_or_create: bool | Whether to get the collection if it exists. Default is True.
110+
111+
Returns:
112+
Any | The collection object.
113+
"""
114+
embeddings_size = len(self.embedding_function(["test"])[0])
115+
116+
if not self.client.collection_exists(collection_name) or overwrite:
117+
self.client.create_collection(
118+
collection_name,
119+
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
120+
**self.collection_options,
121+
)
122+
123+
def get_collection(self, collection_name: str = None):
124+
"""
125+
Get the collection from the vector database.
126+
127+
Args:
128+
collection_name: str | The name of the collection.
129+
130+
Returns:
131+
Any | The collection object.
132+
"""
133+
if collection_name is None:
134+
raise ValueError("The collection name is required.")
135+
136+
return self.client.get_collection(collection_name)
137+
138+
def delete_collection(self, collection_name: str) -> None:
139+
"""Delete the collection from the vector database.
140+
141+
Args:
142+
collection_name: str | The name of the collection.
143+
144+
Returns:
145+
Any
146+
"""
147+
return self.client.delete_collection(collection_name)
148+
149+
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
150+
"""
151+
Insert documents into the collection of the vector database.
152+
153+
Args:
154+
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
155+
collection_name: str | The name of the collection. Default is None.
156+
upsert: bool | Whether to update the document if it exists. Default is False.
157+
kwargs: Dict | Additional keyword arguments.
158+
159+
Returns:
160+
None
161+
"""
162+
if not docs:
163+
return
164+
if any(doc.get("content") is None for doc in docs):
165+
raise ValueError("The document content is required.")
166+
if any(doc.get("id") is None for doc in docs):
167+
raise ValueError("The document id is required.")
168+
169+
if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
170+
logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
171+
172+
self.client.upsert(collection_name, points=self._documents_to_points(docs))
173+
174+
def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
175+
if not docs:
176+
return
177+
if any(doc.get("id") is None for doc in docs):
178+
raise ValueError("The document id is required.")
179+
if any(doc.get("content") is None for doc in docs):
180+
raise ValueError("The document content is required.")
181+
if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
182+
return self.client.upsert(collection_name, points=self._documents_to_points(docs))
183+
184+
raise ValueError("Some IDs do not exist. Skipping update")
185+
186+
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
187+
"""
188+
Delete documents from the collection of the vector database.
189+
190+
Args:
191+
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
192+
collection_name: str | The name of the collection. Default is None.
193+
kwargs: Dict | Additional keyword arguments.
194+
195+
Returns:
196+
None
197+
"""
198+
self.client.delete(collection_name, ids)
199+
200+
def retrieve_docs(
201+
self,
202+
queries: List[str],
203+
collection_name: str = None,
204+
n_results: int = 10,
205+
distance_threshold: float = 0,
206+
**kwargs,
207+
) -> QueryResults:
208+
"""
209+
Retrieve documents from the collection of the vector database based on the queries.
210+
211+
Args:
212+
queries: List[str] | A list of queries. Each query is a string.
213+
collection_name: str | The name of the collection. Default is None.
214+
n_results: int | The number of relevant documents to return. Default is 10.
215+
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
216+
returned. Don't filter with it if < 0. Default is 0.
217+
kwargs: Dict | Additional keyword arguments.
218+
219+
Returns:
220+
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
221+
the distance.
222+
"""
223+
embeddings = self.embedding_function(queries)
224+
requests = [
225+
models.SearchRequest(
226+
vector=embedding,
227+
limit=n_results,
228+
score_threshold=distance_threshold,
229+
with_payload=True,
230+
with_vector=False,
231+
)
232+
for embedding in embeddings
233+
]
234+
235+
batch_results = self.client.search_batch(collection_name, requests)
236+
return [self._scored_points_to_documents(results) for results in batch_results]
237+
238+
def get_docs_by_ids(
239+
self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
240+
) -> List[Document]:
241+
"""
242+
Retrieve documents from the collection of the vector database based on the ids.
243+
244+
Args:
245+
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
246+
collection_name: str | The name of the collection. Default is None.
247+
include: List[str] | The fields to include. Default is True.
248+
If None, will include ["metadatas", "documents"], ids will always be included.
249+
kwargs: dict | Additional keyword arguments.
250+
251+
Returns:
252+
List[Document] | The results.
253+
"""
254+
results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
255+
return [self._point_to_document(result) for result in results]
256+
257+
def _point_to_document(self, point) -> Document:
258+
return {
259+
"id": point.id,
260+
"content": point.payload.get(self.content_payload_key, ""),
261+
"metadata": point.payload.get(self.metadata_payload_key, {}),
262+
"embedding": point.vector,
263+
}
264+
265+
def _points_to_documents(self, points) -> List[Document]:
266+
return [self._point_to_document(point) for point in points]
267+
268+
def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
269+
return self._point_to_document(scored_point), scored_point.score
270+
271+
def _documents_to_points(self, documents: List[Document]):
272+
contents = [document["content"] for document in documents]
273+
embeddings = self.embedding_function(contents)
274+
points = [
275+
models.PointStruct(
276+
id=documents[i]["id"],
277+
vector=embeddings[i],
278+
payload={
279+
self.content_payload_key: documents[i].get("content"),
280+
self.metadata_payload_key: documents[i].get("metadata"),
281+
},
282+
)
283+
for i in range(len(documents))
284+
]
285+
return points
286+
287+
def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
288+
return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
289+
290+
def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
291+
"""
292+
Validates all the IDs exist in the collection
293+
"""
294+
retrieved_ids = [
295+
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
296+
]
297+
298+
if missing_ids := set(ids) - set(retrieved_ids):
299+
logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
300+
return False
301+
302+
return True
303+
304+
def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
305+
"""
306+
Validate none of the IDs exist in the collection
307+
"""
308+
retrieved_ids = [
309+
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
310+
]
311+
312+
if existing_ids := set(ids) & set(retrieved_ids):
313+
logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
314+
return False
315+
316+
return True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import os
2+
import sys
3+
4+
import pytest
5+
6+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
7+
8+
try:
9+
import uuid
10+
11+
from qdrant_client import QdrantClient
12+
13+
from autogen.agentchat.contrib.vectordb.qdrant import QdrantVectorDB
14+
except ImportError:
15+
skip = True
16+
else:
17+
skip = False
18+
19+
20+
@pytest.mark.skipif(skip, reason="dependency is not installed")
21+
def test_qdrant():
22+
# test create collection
23+
client = QdrantClient(location=":memory:")
24+
db = QdrantVectorDB(client=client)
25+
collection_name = uuid.uuid4().hex
26+
db.create_collection(collection_name, overwrite=True, get_or_create=True)
27+
assert client.collection_exists(collection_name)
28+
29+
# test_delete_collection
30+
db.delete_collection(collection_name)
31+
assert not client.collection_exists(collection_name)
32+
33+
# test_get_collection
34+
db.create_collection(collection_name, overwrite=True, get_or_create=True)
35+
collection_info = db.get_collection(collection_name)
36+
# Assert default FastEmbed model dimensions
37+
assert collection_info.config.params.vectors.size == 384
38+
39+
# test_insert_docs
40+
docs = [{"content": "doc1", "id": 1}, {"content": "doc2", "id": 2}]
41+
db.insert_docs(docs, collection_name, upsert=False)
42+
res = db.get_docs_by_ids([1, 2], collection_name)
43+
assert res[0]["id"] == 1
44+
assert res[0]["content"] == "doc1"
45+
assert res[1]["id"] == 2
46+
assert res[1]["content"] == "doc2"
47+
48+
# test_update_docs and get_docs_by_ids
49+
docs = [{"content": "doc11", "id": 1}, {"content": "doc22", "id": 2}]
50+
db.update_docs(docs, collection_name)
51+
res = db.get_docs_by_ids([1, 2], collection_name)
52+
assert res[0]["id"] == 1
53+
assert res[0]["content"] == "doc11"
54+
assert res[1]["id"] == 2
55+
assert res[1]["content"] == "doc22"
56+
57+
# test_retrieve_docs
58+
queries = ["doc22", "doc11"]
59+
res = db.retrieve_docs(queries, collection_name)
60+
assert [[r[0]["id"] for r in rr] for rr in res] == [[2, 1], [1, 2]]
61+
62+
# test_delete_docs
63+
db.delete_docs([1], collection_name)
64+
assert db.client.count(collection_name).count == 1
65+
66+
67+
if __name__ == "__main__":
68+
test_qdrant()

0 commit comments

Comments
 (0)