From 3b978473769e611bd88ac68529c86ef257220ebf Mon Sep 17 00:00:00 2001 From: Wael Karkoub Date: Mon, 21 Oct 2024 23:54:33 -0500 Subject: [PATCH 1/2] adds tool check --- autogen/agentchat/contrib/capabilities/transforms.py | 8 +++++++- autogen/agentchat/contrib/capabilities/transforms_util.py | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index d9ad365b91b3..18dc4dcebdbe 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -102,6 +102,9 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: if remaining_count == 0: break + if not transforms_util.is_tool_call_valid(truncated_messages): + truncated_messages.pop() + return truncated_messages def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: @@ -229,6 +232,9 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: processed_messages_tokens += msg_tokens processed_messages.insert(0, msg) + if not transforms_util.is_tool_call_valid(processed_messages): + processed_messages.pop() + return processed_messages def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]: @@ -242,7 +248,7 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: if post_transform_messages_tokens < pre_transform_messages_tokens: logs_str = ( f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " - f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" + f"Num_ber of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" ) return logs_str, True return "No tokens were truncated.", False diff --git a/autogen/agentchat/contrib/capabilities/transforms_util.py b/autogen/agentchat/contrib/capabilities/transforms_util.py index 8678dec654c4..f6ed5f732ee8 100644 --- a/autogen/agentchat/contrib/capabilities/transforms_util.py +++ b/autogen/agentchat/contrib/capabilities/transforms_util.py @@ -112,3 +112,7 @@ def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict return True return len(filter_config([message], filter_dict, exclude)) > 0 + + +def is_tool_call_valid(messages: List[Dict[str, Any]]) -> bool: + return messages[0].get("role") != "tool" From 295aad89f391f6358418fb692e283e17613a5a53 Mon Sep 17 00:00:00 2001 From: Wael Karkoub Date: Tue, 22 Oct 2024 00:02:57 -0500 Subject: [PATCH 2/2] oops --- autogen/agentchat/contrib/capabilities/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index 18dc4dcebdbe..973a732ab7b8 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -248,7 +248,7 @@ def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: if post_transform_messages_tokens < pre_transform_messages_tokens: logs_str = ( f"Truncated {pre_transform_messages_tokens - post_transform_messages_tokens} tokens. " - f"Num_ber of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" + f"Number of tokens reduced from {pre_transform_messages_tokens} to {post_transform_messages_tokens}" ) return logs_str, True return "No tokens were truncated.", False