Skip to content

Commit

Permalink
Add a warning message if docs_path not explicitly set (#814)
Browse files Browse the repository at this point in the history
* Add a warning message if docs_path not explicitly set

* update

* Add how to suppress warning message

* Fix tests errors

* Fix tests errors

* Fix tests errors
  • Loading branch information
thinkall authored Nov 30, 2023
1 parent f654946 commit ae7066b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
7 changes: 7 additions & 0 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code
from autogen import logger

from typing import Callable, Dict, Optional, Union, List, Tuple, Any
from IPython import get_ipython
Expand Down Expand Up @@ -171,6 +172,12 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", None)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
logger.warning(
"docs_path is not provided in retrieve_config. "
f"Will raise ValueError if the collection `{self._collection_name}` doesn't exist. "
"Set docs_path to None to suppress this warning."
)
self._model = self._retrieve_config.get("model", "gpt-4")
self._max_tokens = self.get_max_tokens(self._model)
self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.4))
Expand Down
31 changes: 30 additions & 1 deletion test/agentchat/contrib/test_retrievechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,34 @@ def test_retrievechat():
print(conversations)


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_retrieve_config(caplog):
# test warning message when no docs_path is provided
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=2,
retrieve_config={
"chunk_token_size": 2000,
"get_or_create": True,
},
)

# Capture the printed content
captured_logs = caplog.records[0]
print(captured_logs)

# Assert on the printed content
assert (
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
in captured_logs.message
)
assert captured_logs.levelname == "WARNING"


if __name__ == "__main__":
test_retrievechat()
# test_retrievechat()
test_retrieve_config()

0 comments on commit ae7066b

Please sign in to comment.