Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to customized vectordb and embedding functions #161

Merged
merged 16 commits into from
Oct 10, 2023
Merged
36 changes: 30 additions & 6 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
- client (Optional, chromadb.Client): the chromadb client.
If key not provided, a default client `chromadb.Client()` will be used.
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
or the url to a single file. If key not provided, a default path `./docs` will be used.
or the url to a single file. Default is None, which works only if the collection is already created.
- collection_name (Optional, str): the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- model (Optional, str): the model to use for the retrieve chat.
Expand All @@ -122,12 +122,15 @@ def __init__(
If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models
can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a
fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.
- embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None,
SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or
other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
- customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
This is the same as that used in chromadb. Default is False.
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
Expand All @@ -143,21 +146,24 @@ def __init__(
self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._task = self._retrieve_config.get("task", "default")
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", "./docs")
self._docs_path = self._retrieve_config.get("docs_path", None)
thinkall marked this conversation as resolved.
Show resolved Hide resolved
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
self._model = self._retrieve_config.get("model", "gpt-4")
self._max_tokens = self.get_max_tokens(self._model)
self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.4))
self._chunk_mode = self._retrieve_config.get("chunk_mode", "multi_lines")
self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True)
self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2")
self._embedding_function = self._retrieve_config.get("embedding_function", None)
thinkall marked this conversation as resolved.
Show resolved Hide resolved
self.customized_prompt = self._retrieve_config.get("customized_prompt", None)
self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper()
self.update_context = self._retrieve_config.get("update_context", True)
self._get_or_create = self._retrieve_config.get("get_or_create", False)
self._get_or_create = (
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
)
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = False # the collection is not created
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
self._doc_idx = -1 # the index of the current used doc
self._results = {} # the results of the current query
Expand Down Expand Up @@ -185,7 +191,7 @@ def _reset(self, intermediate=False):
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc

def _get_context(self, results):
def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]):
doc_contents = ""
current_tokens = 0
_doc_idx = self._doc_idx
Expand Down Expand Up @@ -293,6 +299,22 @@ def _generate_retrieve_user_reply(
return False, None

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
"""Retrieve docs based on the given problem and assign the results to the class property `_results`.
In case you want to customize the retrieval process, such as using a different vector db whose APIs are not
compatible with chromadb or filter results with metadata, you can override this function. Just keep the current
parameters and add your own parameters with default values, and keep the results in below type.

Type of the results: Dict[str, List[List[Any]]]
ids: List[string]
documents: List[List[string]]
metadatas: Optional[List[List[string]]]
distances: Optional[List[List[float]]]
thinkall marked this conversation as resolved.
Show resolved Hide resolved

Args:
problem (str): the problem to be solved.
n_results (int): the number of results to be retrieved.
search_string (str): only docs containing this string will be retrieved.
"""
if not self._collection or self._get_or_create:
print("Trying to create collection.")
create_vector_db_from_dir(
Expand All @@ -304,6 +326,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
must_break_at_empty_line=self._must_break_at_empty_line,
embedding_model=self._embedding_model,
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
)
self._collection = True
self._get_or_create = False
Expand All @@ -315,6 +338,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
client=self._client,
collection_name=self._collection_name,
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
self._results = results
print("doc_ids: ", results["ids"])
Expand Down
62 changes: 57 additions & 5 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tiktoken
import chromadb
from chromadb.api import API
from chromadb.api.types import QueryResult
import chromadb.utils.embedding_functions as ef
import logging
import pypdf
Expand Down Expand Up @@ -263,12 +264,36 @@ def create_vector_db_from_dir(
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
):
"""Create a vector db from all the files in a given directory."""
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
you prepared your own vector db.

Args:
dir_path (str): the path to the directory, file or url.
max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000.
client (Optional, API): the chromadb client. Default is None.
db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db".
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection
will be recreated if it already exists.
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
embedding_function is not None.
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.
"""
if client is None:
client = chromadb.PersistentClient(path=db_path)
try:
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
embedding_function = (
ef.SentenceTransformerEmbeddingFunction(embedding_model)
if embedding_function is None
else embedding_function
)
collection = client.create_collection(
collection_name,
get_or_create=get_or_create,
Expand Down Expand Up @@ -300,14 +325,41 @@ def query_vector_db(
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
) -> Dict[str, List[str]]:
"""Query a vector db."""
embedding_function: Callable = None,
) -> QueryResult:
"""Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db
and query function.

Args:
query_texts (List[str]): the query texts.
n_results (Optional, int): the number of results to return. Default is 10.
client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used.
db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db".
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
search_string (Optional, str): the search string. Default is "".
embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if
embedding_function is not None.
embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with
the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding
functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`.

Returns:
QueryResult: the query result. The format is:
class QueryResult(TypedDict):
ids: List[IDs]
embeddings: Optional[List[List[Embedding]]]
documents: Optional[List[List[Document]]]
metadatas: Optional[List[List[Metadata]]]
distances: Optional[List[List[float]]]
"""
if client is None:
client = chromadb.PersistentClient(path=db_path)
# the collection's embedding function is always the default one, but we want to use the one we used to create the
# collection. So we compute the embeddings ourselves and pass it to the query function.
collection = client.get_collection(collection_name)
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
embedding_function = (
ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function
)
query_embeddings = embedding_function(query_texts)
# Query/search n most similar results. You can also .get by id
results = collection.query(
Expand Down
Loading