|
| 1 | +from typing import Callable, Dict, List, Optional |
| 2 | + |
| 3 | +from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent |
| 4 | +from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks |
| 5 | +import logging |
| 6 | + |
| 7 | +logger = logging.getLogger(__name__) |
| 8 | + |
| 9 | +try: |
| 10 | + from qdrant_client import QdrantClient, models |
| 11 | + from qdrant_client.fastembed_common import QueryResponse |
| 12 | + import fastembed |
| 13 | +except ImportError as e: |
| 14 | + logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'") |
| 15 | + raise e |
| 16 | + |
| 17 | + |
| 18 | +class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent): |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + name="RetrieveChatAgent", |
| 22 | + human_input_mode: str | None = "ALWAYS", |
| 23 | + is_termination_msg: Callable[[Dict], bool] | None = None, |
| 24 | + retrieve_config: Dict | None = None, |
| 25 | + **kwargs, |
| 26 | + ): |
| 27 | + """ |
| 28 | + Args: |
| 29 | + name (str): name of the agent. |
| 30 | + human_input_mode (str): whether to ask for human inputs every time a message is received. |
| 31 | + Possible values are "ALWAYS", "TERMINATE", "NEVER". |
| 32 | + (1) When "ALWAYS", the agent prompts for human input every time a message is received. |
| 33 | + Under this mode, the conversation stops when the human input is "exit", |
| 34 | + or when is_termination_msg is True and there is no human input. |
| 35 | + (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or |
| 36 | + the number of auto reply reaches the max_consecutive_auto_reply. |
| 37 | + (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops |
| 38 | + when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. |
| 39 | + is_termination_msg (function): a function that takes a message in the form of a dictionary |
| 40 | + and returns a boolean value indicating if this received message is a termination message. |
| 41 | + The dict can contain the following keys: "content", "role", "name", "function_call". |
| 42 | + retrieve_config (dict or None): config for the retrieve agent. |
| 43 | + To use default config, set to None. Otherwise, set to a dictionary with the following keys: |
| 44 | + - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System |
| 45 | + prompt will be different for different tasks. The default value is `default`, which supports both code and qa. |
| 46 | + - client (Optional, qdrant_client.QdrantClient(":memory:")): A QdrantClient instance. If not provided, an in-memory instance will be assigned. Not recommended for production. |
| 47 | + will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. |
| 48 | + - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, |
| 49 | + or the url to a single file. Default is None, which works only if the collection is already created. |
| 50 | + - collection_name (Optional, str): the name of the collection. |
| 51 | + If key not provided, a default name `autogen-docs` will be used. |
| 52 | + - model (Optional, str): the model to use for the retrieve chat. |
| 53 | + If key not provided, a default model `gpt-4` will be used. |
| 54 | + - chunk_token_size (Optional, int): the chunk token size for the retrieve chat. |
| 55 | + If key not provided, a default size `max_tokens * 0.4` will be used. |
| 56 | + - context_max_tokens (Optional, int): the context max token size for the retrieve chat. |
| 57 | + If key not provided, a default size `max_tokens * 0.8` will be used. |
| 58 | + - chunk_mode (Optional, str): the chunk mode for the retrieve chat. Possible values are |
| 59 | + "multi_lines" and "one_line". If key not provided, a default mode `multi_lines` will be used. |
| 60 | + - must_break_at_empty_line (Optional, bool): chunk will only break at empty line if True. Default is True. |
| 61 | + If chunk_mode is "one_line", this parameter will be ignored. |
| 62 | + - embedding_model (Optional, str): the embedding model to use for the retrieve chat. |
| 63 | + If key not provided, a default model `BAAI/bge-small-en-v1.5` will be used. All available models |
| 64 | + can be found at `https://qdrant.github.io/fastembed/examples/Supported_Models/`. |
| 65 | + - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. |
| 66 | + - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". |
| 67 | + If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. |
| 68 | + - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. |
| 69 | + - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. |
| 70 | + The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name). |
| 71 | + Default is None, tiktoken will be used and may not be accurate for non-OpenAI models. |
| 72 | + - custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings. |
| 73 | + Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`. |
| 74 | + - parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores. |
| 75 | + - on_disk (Optional, bool): Whether to store the collection on disk. Default is False. |
| 76 | + - quantization_config: Quantization configuration. If None, quantization will be disabled. |
| 77 | + - hnsw_config: HNSW configuration. If None, default configuration will be used. |
| 78 | + You can find more info about the hnsw configuration options at https://qdrant.tech/documentation/concepts/indexing/#vector-index. |
| 79 | + API Reference: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection |
| 80 | + - payload_indexing: Whether to create a payload index for the document field. Default is False. |
| 81 | + You can find more info about the payload indexing options at https://qdrant.tech/documentation/concepts/indexing/#payload-index |
| 82 | + API Reference: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_field_index |
| 83 | + **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). |
| 84 | +
|
| 85 | + """ |
| 86 | + super().__init__(name, human_input_mode, is_termination_msg, retrieve_config, **kwargs) |
| 87 | + self._client = self._retrieve_config.get("client", QdrantClient(":memory:")) |
| 88 | + self._embedding_model = self._retrieve_config.get("embedding_model", "BAAI/bge-small-en-v1.5") |
| 89 | + # Uses all available CPU cores to encode data when set to 0 |
| 90 | + self._parallel = self._retrieve_config.get("parallel", 0) |
| 91 | + self._on_disk = self._retrieve_config.get("on_disk", False) |
| 92 | + self._quantization_config = self._retrieve_config.get("quantization_config", None) |
| 93 | + self._hnsw_config = self._retrieve_config.get("hnsw_config", None) |
| 94 | + self._payload_indexing = self._retrieve_config.get("payload_indexing", False) |
| 95 | + |
| 96 | + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): |
| 97 | + """ |
| 98 | + Args: |
| 99 | + problem (str): the problem to be solved. |
| 100 | + n_results (int): the number of results to be retrieved. |
| 101 | + search_string (str): only docs containing this string will be retrieved. |
| 102 | + """ |
| 103 | + if not self._collection: |
| 104 | + print("Trying to create collection.") |
| 105 | + create_qdrant_from_dir( |
| 106 | + dir_path=self._docs_path, |
| 107 | + max_tokens=self._chunk_token_size, |
| 108 | + client=self._client, |
| 109 | + collection_name=self._collection_name, |
| 110 | + chunk_mode=self._chunk_mode, |
| 111 | + must_break_at_empty_line=self._must_break_at_empty_line, |
| 112 | + embedding_model=self._embedding_model, |
| 113 | + custom_text_split_function=self.custom_text_split_function, |
| 114 | + parallel=self._parallel, |
| 115 | + on_disk=self._on_disk, |
| 116 | + quantization_config=self._quantization_config, |
| 117 | + hnsw_config=self._hnsw_config, |
| 118 | + payload_indexing=self._payload_indexing, |
| 119 | + ) |
| 120 | + self._collection = True |
| 121 | + |
| 122 | + results = query_qdrant( |
| 123 | + query_texts=problem, |
| 124 | + n_results=n_results, |
| 125 | + search_string=search_string, |
| 126 | + client=self._client, |
| 127 | + collection_name=self._collection_name, |
| 128 | + embedding_model=self._embedding_model, |
| 129 | + ) |
| 130 | + self._results = results |
| 131 | + |
| 132 | + |
| 133 | +def create_qdrant_from_dir( |
| 134 | + dir_path: str, |
| 135 | + max_tokens: int = 4000, |
| 136 | + client: QdrantClient = None, |
| 137 | + collection_name: str = "all-my-documents", |
| 138 | + chunk_mode: str = "multi_lines", |
| 139 | + must_break_at_empty_line: bool = True, |
| 140 | + embedding_model: str = "BAAI/bge-small-en-v1.5", |
| 141 | + custom_text_split_function: Callable = None, |
| 142 | + parallel: int = 0, |
| 143 | + on_disk: bool = False, |
| 144 | + quantization_config: Optional[models.QuantizationConfig] = None, |
| 145 | + hnsw_config: Optional[models.HnswConfigDiff] = None, |
| 146 | + payload_indexing: bool = False, |
| 147 | + qdrant_client_options: Optional[Dict] = {}, |
| 148 | +): |
| 149 | + """Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a url to |
| 150 | + a single file. |
| 151 | +
|
| 152 | + Args: |
| 153 | + dir_path (str): the path to the directory, file or url. |
| 154 | + max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000. |
| 155 | + client (Optional, QdrantClient): the QdrantClient instance. Default is None. |
| 156 | + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". |
| 157 | + chunk_mode (Optional, str): the chunk mode. Default is "multi_lines". |
| 158 | + must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True. |
| 159 | + embedding_model (Optional, str): the embedding model to use. Default is "BAAI/bge-small-en-v1.5". The list of all the available models can be at https://qdrant.github.io/fastembed/examples/Supported_Models/. |
| 160 | + parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores |
| 161 | + on_disk (Optional, bool): Whether to store the collection on disk. Default is False. |
| 162 | + quantization_config: Quantization configuration. If None, quantization will be disabled. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection |
| 163 | + hnsw_config: HNSW configuration. If None, default configuration will be used. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection |
| 164 | + payload_indexing: Whether to create a payload index for the document field. Default is False. |
| 165 | + qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client. Reference: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58. |
| 166 | + """ |
| 167 | + if client is None: |
| 168 | + client = QdrantClient(**qdrant_client_options) |
| 169 | + client.set_model(embedding_model) |
| 170 | + |
| 171 | + if custom_text_split_function is not None: |
| 172 | + chunks = split_files_to_chunks( |
| 173 | + get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function |
| 174 | + ) |
| 175 | + else: |
| 176 | + chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) |
| 177 | + logger.info(f"Found {len(chunks)} chunks.") |
| 178 | + |
| 179 | + # Check if collection by same name exists, if not, create it with custom options |
| 180 | + try: |
| 181 | + client.get_collection(collection_name=collection_name) |
| 182 | + except Exception: |
| 183 | + client.create_collection( |
| 184 | + collection_name=collection_name, |
| 185 | + vectors_config=client.get_fastembed_vector_params( |
| 186 | + on_disk=on_disk, quantization_config=quantization_config, hnsw_config=hnsw_config |
| 187 | + ), |
| 188 | + ) |
| 189 | + client.get_collection(collection_name=collection_name) |
| 190 | + |
| 191 | + # Upsert in batch of 100 or less if the total number of chunks is less than 100 |
| 192 | + for i in range(0, len(chunks), min(100, len(chunks))): |
| 193 | + end_idx = i + min(100, len(chunks) - i) |
| 194 | + client.add(collection_name, documents=chunks[i:end_idx], ids=[j for j in range(i, end_idx)], parallel=parallel) |
| 195 | + |
| 196 | + # Create a payload index for the document field |
| 197 | + # Enables highly efficient payload filtering. Reference: https://qdrant.tech/documentation/concepts/indexing/#indexing |
| 198 | + # Creating an index requires additional computational resources and memory. |
| 199 | + # If filtering performance is critical, we can consider creating an index. |
| 200 | + if payload_indexing: |
| 201 | + client.create_payload_index( |
| 202 | + collection_name=collection_name, |
| 203 | + field_name="document", |
| 204 | + field_schema=models.TextIndexParams( |
| 205 | + type="text", |
| 206 | + tokenizer=models.TokenizerType.WORD, |
| 207 | + min_token_len=2, |
| 208 | + max_token_len=15, |
| 209 | + ), |
| 210 | + ) |
| 211 | + |
| 212 | + |
| 213 | +def query_qdrant( |
| 214 | + query_texts: List[str], |
| 215 | + n_results: int = 10, |
| 216 | + client: QdrantClient = None, |
| 217 | + collection_name: str = "all-my-documents", |
| 218 | + search_string: str = "", |
| 219 | + embedding_model: str = "BAAI/bge-small-en-v1.5", |
| 220 | + qdrant_client_options: Optional[Dict] = {}, |
| 221 | +) -> List[List[QueryResponse]]: |
| 222 | + """Perform a similarity search with filters on a Qdrant collection |
| 223 | +
|
| 224 | + Args: |
| 225 | + query_texts (List[str]): the query texts. |
| 226 | + n_results (Optional, int): the number of results to return. Default is 10. |
| 227 | + client (Optional, API): the QdrantClient instance. A default in-memory client will be instantiated if None. |
| 228 | + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". |
| 229 | + search_string (Optional, str): the search string. Default is "". |
| 230 | + embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if embedding_function is not None. |
| 231 | + qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client. Reference: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58. |
| 232 | +
|
| 233 | + Returns: |
| 234 | + List[List[QueryResponse]]: the query result. The format is: |
| 235 | + class QueryResponse(BaseModel, extra="forbid"): # type: ignore |
| 236 | + id: Union[str, int] |
| 237 | + embedding: Optional[List[float]] |
| 238 | + metadata: Dict[str, Any] |
| 239 | + document: str |
| 240 | + score: float |
| 241 | + """ |
| 242 | + if client is None: |
| 243 | + client = QdrantClient(**qdrant_client_options) |
| 244 | + client.set_model(embedding_model) |
| 245 | + |
| 246 | + results = client.query_batch( |
| 247 | + collection_name, |
| 248 | + query_texts, |
| 249 | + limit=n_results, |
| 250 | + query_filter=models.Filter( |
| 251 | + must=[ |
| 252 | + models.FieldCondition( |
| 253 | + key="document", |
| 254 | + match=models.MatchText(text=search_string), |
| 255 | + ) |
| 256 | + ] |
| 257 | + ) |
| 258 | + if search_string |
| 259 | + else None, |
| 260 | + ) |
| 261 | + |
| 262 | + data = { |
| 263 | + "ids": [[result.id for result in sublist] for sublist in results], |
| 264 | + "documents": [[result.document for result in sublist] for sublist in results], |
| 265 | + } |
| 266 | + return data |
0 commit comments