Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated type hints for Agent Capabilities #2408

Merged
merged 2 commits into from
Apr 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions autogen/agentchat/contrib/capabilities/teachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()

def process_last_received_message(self, text):
def process_last_received_message(self, text: Union[Dict, str]):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
Expand All @@ -103,7 +103,7 @@ def process_last_received_message(self, text):
# Return the (possibly) expanded message text.
return expanded_text

def _consider_memo_storage(self, comment):
def _consider_memo_storage(self, comment: Union[Dict, str]):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False

Expand Down Expand Up @@ -161,7 +161,7 @@ def _consider_memo_storage(self, comment):
# Yes. Save them to disk.
self.memo_store._save_memos()

def _consider_memo_retrieval(self, comment):
def _consider_memo_retrieval(self, comment: Union[Dict, str]):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""

# First, use the comment directly as the lookup key.
Expand Down Expand Up @@ -195,7 +195,7 @@ def _consider_memo_retrieval(self, comment):
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)

def _retrieve_relevant_memos(self, input_text):
def _retrieve_relevant_memos(self, input_text: str) -> list:
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
Expand All @@ -213,7 +213,7 @@ def _retrieve_relevant_memos(self, input_text):
memo_list = [memo[1] for memo in memo_list]
return memo_list

def _concatenate_memo_texts(self, memo_list):
def _concatenate_memo_texts(self, memo_list: list) -> str:
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
Expand All @@ -225,7 +225,7 @@ def _concatenate_memo_texts(self, memo_list):
memo_texts = memo_texts + "\n" + info
return memo_texts

def _analyze(self, text_to_analyze, analysis_instructions):
def _analyze(self, text_to_analyze: Union[Dict, str], analysis_instructions: Union[Dict, str]):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
Expand All @@ -246,10 +246,16 @@ class MemoStore:
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
"""

def __init__(self, verbosity, reset, path_to_db_dir):
def __init__(
self,
verbosity: Optional[int] = 0,
reset: Optional[bool] = False,
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Args:
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
- reset (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
Expand Down Expand Up @@ -304,7 +310,7 @@ def reset_db(self):
self.uid_text_dict = {}
self._save_memos()

def add_input_output_pair(self, input_text, output_text):
def add_input_output_pair(self, input_text: str, output_text: str):
"""Adds an input-output pair to the vector DB."""
self.last_memo_id += 1
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
Expand All @@ -321,7 +327,7 @@ def add_input_output_pair(self, input_text, output_text):
if self.verbosity >= 3:
self.list_memos()

def get_nearest_memo(self, query_text):
def get_nearest_memo(self, query_text: str):
"""Retrieves the nearest memo to the given query text."""
results = self.vec_db.query(query_texts=[query_text], n_results=1)
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
Expand All @@ -338,7 +344,7 @@ def get_nearest_memo(self, query_text):
)
return input_text, output_text, distance

def get_related_memos(self, query_text, n_results, threshold):
def get_related_memos(self, query_text: str, n_results: int, threshold: Union[int, float]):
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
if n_results > len(self.uid_text_dict):
n_results = len(self.uid_text_dict)
Expand Down
Loading