@@ -122,8 +122,8 @@ def __init__(
122
122
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
123
123
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
124
124
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
125
- - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
126
- This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
125
+ - get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb .
126
+ Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None.
127
127
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
128
128
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
129
129
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
@@ -178,9 +178,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
178
178
self .customized_prompt = self ._retrieve_config .get ("customized_prompt" , None )
179
179
self .customized_answer_prefix = self ._retrieve_config .get ("customized_answer_prefix" , "" ).upper ()
180
180
self .update_context = self ._retrieve_config .get ("update_context" , True )
181
- self ._get_or_create = (
182
- self ._retrieve_config .get ("get_or_create" , False ) if self ._docs_path is not None else False
183
- )
181
+ self ._get_or_create = self ._retrieve_config .get ("get_or_create" , False ) if self ._docs_path is not None else True
184
182
self .custom_token_count_function = self ._retrieve_config .get ("custom_token_count_function" , count_token )
185
183
self .custom_text_split_function = self ._retrieve_config .get ("custom_text_split_function" , None )
186
184
self ._context_max_tokens = self ._max_tokens * 0.8
@@ -360,7 +358,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
360
358
n_results (int): the number of results to be retrieved.
361
359
search_string (str): only docs containing this string will be retrieved.
362
360
"""
363
- if not self ._collection or self ._get_or_create :
361
+ if not self ._collection or not self ._get_or_create :
364
362
print ("Trying to create collection." )
365
363
self ._client = create_vector_db_from_dir (
366
364
dir_path = self ._docs_path ,
@@ -375,7 +373,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
375
373
custom_text_split_function = self .custom_text_split_function ,
376
374
)
377
375
self ._collection = True
378
- self ._get_or_create = False
376
+ self ._get_or_create = True
379
377
380
378
results = query_vector_db (
381
379
query_texts = [problem ],
0 commit comments