Skip to content

Commit adb8fd8

Browse files
authored
Add a vectordb module (#2263)
* Added vectordb base and chromadb * Remove timer and unused functions * Added filter by distance * Added test utils * Fix format * Fix type hint of dict * Rename test * Add test chromadb * Fix test no chromadb * Add coverage * Don't skip test vectordb utils * Add types * Fix tests * Fix docs build error * Add types to base * Update base * Update utils * Update chromadb * Add get_docs_by_ids * Improve docstring * Add get all docs * Move chroma_results_to_query_results to utils * Improve type hints * Update logger * Update init, add embedding func * Improve docstring of vectordb, add two attributes * Improve test workflow
1 parent 5a96dc2 commit adb8fd8

File tree

8 files changed

+785
-5
lines changed

8 files changed

+785
-5
lines changed

.github/workflows/contrib-openai.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
5454
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
5555
run: |
56-
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py
56+
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat
5757
coverage xml
5858
- name: Upload coverage to Codecov
5959
uses: codecov/codecov-action@v3

.github/workflows/contrib-tests.yml

+1-4
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,10 @@ jobs:
5858
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
5959
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
6060
fi
61-
- name: Test RetrieveChat
62-
run: |
63-
pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
6461
- name: Coverage
6562
run: |
6663
pip install coverage>=5.3
67-
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
64+
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai
6865
coverage xml
6966
- name: Upload coverage to Codecov
7067
uses: codecov/codecov-action@v3

autogen/agentchat/contrib/vectordb/__init__.py

Whitespace-only changes.
+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable
2+
3+
Metadata = Union[Mapping[str, Any], None]
4+
Vector = Union[Sequence[float], Sequence[int]]
5+
ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does
6+
7+
8+
class Document(TypedDict):
9+
"""A Document is a record in the vector database.
10+
11+
id: ItemID | the unique identifier of the document.
12+
content: str | the text content of the chunk.
13+
metadata: Metadata, Optional | contains additional information about the document such as source, date, etc.
14+
embedding: Vector, Optional | the vector representation of the content.
15+
"""
16+
17+
id: ItemID
18+
content: str
19+
metadata: Optional[Metadata]
20+
embedding: Optional[Vector]
21+
22+
23+
"""QueryResults is the response from the vector database for a query/queries.
24+
A query is a list containing one string while queries is a list containing multiple strings.
25+
The response is a list of query results, each query result is a list of tuples containing the document and the distance.
26+
"""
27+
QueryResults = List[List[Tuple[Document, float]]]
28+
29+
30+
@runtime_checkable
31+
class VectorDB(Protocol):
32+
"""
33+
Abstract class for vector database. A vector database is responsible for storing and retrieving documents.
34+
35+
Attributes:
36+
active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
37+
type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
38+
39+
Methods:
40+
create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database.
41+
get_collection: Callable[[str], Any] | Get the collection from the vector database.
42+
delete_collection: Callable[[str], Any] | Delete the collection from the vector database.
43+
insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database.
44+
update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database.
45+
delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database.
46+
retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries.
47+
get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids.
48+
"""
49+
50+
active_collection: Any = None
51+
type: str = ""
52+
53+
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
54+
"""
55+
Create a collection in the vector database.
56+
Case 1. if the collection does not exist, create the collection.
57+
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
58+
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
59+
otherwise it raise a ValueError.
60+
61+
Args:
62+
collection_name: str | The name of the collection.
63+
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
64+
get_or_create: bool | Whether to get the collection if it exists. Default is True.
65+
66+
Returns:
67+
Any | The collection object.
68+
"""
69+
...
70+
71+
def get_collection(self, collection_name: str = None) -> Any:
72+
"""
73+
Get the collection from the vector database.
74+
75+
Args:
76+
collection_name: str | The name of the collection. Default is None. If None, return the
77+
current active collection.
78+
79+
Returns:
80+
Any | The collection object.
81+
"""
82+
...
83+
84+
def delete_collection(self, collection_name: str) -> Any:
85+
"""
86+
Delete the collection from the vector database.
87+
88+
Args:
89+
collection_name: str | The name of the collection.
90+
91+
Returns:
92+
Any
93+
"""
94+
...
95+
96+
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None:
97+
"""
98+
Insert documents into the collection of the vector database.
99+
100+
Args:
101+
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
102+
collection_name: str | The name of the collection. Default is None.
103+
upsert: bool | Whether to update the document if it exists. Default is False.
104+
kwargs: Dict | Additional keyword arguments.
105+
106+
Returns:
107+
None
108+
"""
109+
...
110+
111+
def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None:
112+
"""
113+
Update documents in the collection of the vector database.
114+
115+
Args:
116+
docs: List[Document] | A list of documents.
117+
collection_name: str | The name of the collection. Default is None.
118+
kwargs: Dict | Additional keyword arguments.
119+
120+
Returns:
121+
None
122+
"""
123+
...
124+
125+
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
126+
"""
127+
Delete documents from the collection of the vector database.
128+
129+
Args:
130+
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
131+
collection_name: str | The name of the collection. Default is None.
132+
kwargs: Dict | Additional keyword arguments.
133+
134+
Returns:
135+
None
136+
"""
137+
...
138+
139+
def retrieve_docs(
140+
self,
141+
queries: List[str],
142+
collection_name: str = None,
143+
n_results: int = 10,
144+
distance_threshold: float = -1,
145+
**kwargs,
146+
) -> QueryResults:
147+
"""
148+
Retrieve documents from the collection of the vector database based on the queries.
149+
150+
Args:
151+
queries: List[str] | A list of queries. Each query is a string.
152+
collection_name: str | The name of the collection. Default is None.
153+
n_results: int | The number of relevant documents to return. Default is 10.
154+
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
155+
returned. Don't filter with it if < 0. Default is -1.
156+
kwargs: Dict | Additional keyword arguments.
157+
158+
Returns:
159+
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
160+
the distance.
161+
"""
162+
...
163+
164+
def get_docs_by_ids(
165+
self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
166+
) -> List[Document]:
167+
"""
168+
Retrieve documents from the collection of the vector database based on the ids.
169+
170+
Args:
171+
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
172+
collection_name: str | The name of the collection. Default is None.
173+
include: List[str] | The fields to include. Default is None.
174+
If None, will include ["metadatas", "documents"], ids will always be included.
175+
kwargs: dict | Additional keyword arguments.
176+
177+
Returns:
178+
List[Document] | The results.
179+
"""
180+
...
181+
182+
183+
class VectorDBFactory:
184+
"""
185+
Factory class for creating vector databases.
186+
"""
187+
188+
PREDEFINED_VECTOR_DB = ["chroma"]
189+
190+
@staticmethod
191+
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
192+
"""
193+
Create a vector database.
194+
195+
Args:
196+
db_type: str | The type of the vector database.
197+
kwargs: Dict | The keyword arguments for initializing the vector database.
198+
199+
Returns:
200+
VectorDB | The vector database.
201+
"""
202+
if db_type.lower() in ["chroma", "chromadb"]:
203+
from .chromadb import ChromaVectorDB
204+
205+
return ChromaVectorDB(**kwargs)
206+
else:
207+
raise ValueError(
208+
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
209+
)

0 commit comments

Comments
 (0)