Skip to content

Commit

Permalink
Merge branch 'main' into fix/groupchat-model-registration
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo-Frattaroli authored May 23, 2024
2 parents e83eef1 + 4ebfb82 commit e8e8214
Show file tree
Hide file tree
Showing 17 changed files with 434 additions and 155 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr poppler-utils
pip install unstructured[all-docs]==0.13.0
pip install --no-cache-dir unstructured[all-docs]==0.13.0
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
Expand Down
4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from autogen import ConversableAgent, token_count_utils

warn(
"Context handling with TransformChatHistory is deprecated. "
"Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.",
"Context handling with TransformChatHistory is deprecated and will be removed in `0.2.30`. "
"Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
DeprecationWarning,
stacklevel=2,
)
Expand Down
42 changes: 39 additions & 3 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from autogen.oai.openai_utils import filter_config

from .text_compressors import LLMLingua, TextCompressor

Expand Down Expand Up @@ -130,6 +131,8 @@ def __init__(
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
Expand All @@ -140,11 +143,17 @@ def __init__(
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
Expand All @@ -169,10 +178,15 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

for msg in reversed(temp_messages):
# Some messages may not have content.
if not isinstance(msg.get("content"), (str, list)):
if not _is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue

if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += _count_tokens(msg["content"])
continue

expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message

# If adding this message would exceed the token limit, truncate the last message to meet the total token
Expand Down Expand Up @@ -282,6 +296,8 @@ def __init__(
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
Expand All @@ -293,6 +309,10 @@ def __init__(
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""

if text_compressor is None:
Expand All @@ -303,6 +323,8 @@ def __init__(
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
self._cache = cache

# Optimizing savings calculations to optimize log generation
Expand Down Expand Up @@ -334,7 +356,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
if not _is_content_right_type(message.get("content")):
continue

if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
continue

if _is_content_text_empty(message["content"]):
Expand Down Expand Up @@ -397,7 +422,7 @@ def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
Expand Down Expand Up @@ -427,10 +452,21 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
return token_count


def _is_content_right_type(content: Any) -> bool:
return isinstance(content, (str, list))


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False


def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
if not filter_dict:
return True

return len(filter_config([message], filter_dict, exclude)) > 0
4 changes: 2 additions & 2 deletions autogen/agentchat/contrib/compressible_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
logger = logging.getLogger(__name__)

warn(
"Context handling with CompressibleAgent is deprecated. "
"Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/reference/agentchat/contrib/capabilities/transform_messages",
"Context handling with CompressibleAgent is deprecated and will be removed in `0.2.30`. "
"Please use `TransformMessages`, documentation can be found at https://microsoft.github.io/autogen/docs/topics/handling_long_contexts/intro_to_transform_messages",
DeprecationWarning,
stacklevel=2,
)
Expand Down
4 changes: 4 additions & 0 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant
from autogen.runtime_logging import log_new_agent, logging_enabled

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +66,8 @@ def __init__(
super().__init__(
name=name, system_message=instructions, human_input_mode="NEVER", llm_config=openai_client_cfg, **kwargs
)
if logging_enabled():
log_new_agent(self, locals())

# GPTAssistantAgent's azure_deployment param may cause NotFoundError (404) in client.beta.assistants.list()
# See: https://github.com/microsoft/autogen/pull/1721
Expand Down Expand Up @@ -169,6 +172,7 @@ def __init__(
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
logger.warning("overwrite_tools is False. Using existing tools from assistant API.")

self.update_system_message(self._openai_assistant.instructions)
# lazily create threads
self._openai_threads = {}
self._unread_index = defaultdict(int)
Expand Down
9 changes: 8 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2406,6 +2406,8 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]):
self._assert_valid_name(name)
if func is None and name not in self._function_map.keys():
warnings.warn(f"The function {name} to remove doesn't exist", name)
if name in self._function_map:
warnings.warn(f"Function '{name}' is being overridden.", UserWarning)
self._function_map.update(function_map)
self._function_map = {k: v for k, v in self._function_map.items() if v is not None}

Expand Down Expand Up @@ -2442,6 +2444,9 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)

self._assert_valid_name(func_sig["name"])
if "functions" in self.llm_config.keys():
if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]):
warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning)

self.llm_config["functions"] = [
func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"]
] + [func_sig]
Expand Down Expand Up @@ -2481,7 +2486,9 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
)
self._assert_valid_name(tool_sig["function"]["name"])
if "tools" in self.llm_config.keys():
if "tools" in self.llm_config:
if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]):
warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning)
self.llm_config["tools"] = [
tool
for tool in self.llm_config["tools"]
Expand Down
38 changes: 30 additions & 8 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class GroupChat:
Then select the next role from {agentlist} to play. Only return the role."
- select_speaker_prompt_template: customize the select speaker prompt (used in "auto" speaker selection), which appears last in the message context and generally includes the list of agents and guidance for the LLM to select the next agent. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
"Read the above conversation. Then select the next role from {agentlist} to play. Only return the role."
To ignore this prompt being used, set this to None. If set to None, ensure your instructions for selecting a speaker are in the select_speaker_message_template string.
- select_speaker_auto_multiple_template: customize the follow-up prompt used when selecting a speaker fails with a response that contains multiple agent names. This prompt guides the LLM to return just one agent name. Applies only to "auto" speaker selection method. If the string contains "{agentlist}" it will be replaced with a comma-separated list of agent names in square brackets. The default value is:
"You provided more than one name in your text, please return just the name of the next speaker. To determine the speaker use these prioritised rules:
1. If the context refers to themselves as a speaker e.g. "As the..." , choose that speaker's name
Expand Down Expand Up @@ -227,8 +228,8 @@ def __post_init__(self):
if self.select_speaker_message_template is None or len(self.select_speaker_message_template) == 0:
raise ValueError("select_speaker_message_template cannot be empty or None.")

if self.select_speaker_prompt_template is None or len(self.select_speaker_prompt_template) == 0:
raise ValueError("select_speaker_prompt_template cannot be empty or None.")
if self.select_speaker_prompt_template is not None and len(self.select_speaker_prompt_template) == 0:
self.select_speaker_prompt_template = None

if self.role_for_select_speaker_messages is None or len(self.role_for_select_speaker_messages) == 0:
raise ValueError("role_for_select_speaker_messages cannot be empty or None.")
Expand Down Expand Up @@ -332,7 +333,13 @@ def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
return return_msg

def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the floating system prompt selecting the next speaker. This is always the *last* message in the context."""
"""Return the floating system prompt selecting the next speaker.
This is always the *last* message in the context.
Will return None if the select_speaker_prompt_template is None."""

if self.select_speaker_prompt_template is None:
return None

if agents is None:
agents = self.agents

Expand Down Expand Up @@ -683,14 +690,20 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
agents, max_attempts, messages, validate_speaker_name
)

# Create the starting message
if self.select_speaker_prompt_template is not None:
start_message = {
"content": self.select_speaker_prompt(agents),
"override_role": self.role_for_select_speaker_messages,
}
else:
start_message = messages[-1]

# Run the speaker selection chat
result = checking_agent.initiate_chat(
speaker_selection_agent,
cache=None, # don't use caching for the speaker selection chat
message={
"content": self.select_speaker_prompt(agents),
"override_role": self.role_for_select_speaker_messages,
},
message=start_message,
max_turns=2
* max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
clear_history=False,
Expand Down Expand Up @@ -754,11 +767,20 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
agents, max_attempts, messages, validate_speaker_name
)

# Create the starting message
if self.select_speaker_prompt_template is not None:
start_message = {
"content": self.select_speaker_prompt(agents),
"override_role": self.role_for_select_speaker_messages,
}
else:
start_message = messages[-1]

# Run the speaker selection chat
result = await checking_agent.a_initiate_chat(
speaker_selection_agent,
cache=None, # don't use caching for the speaker selection chat
message=self.select_speaker_prompt(agents),
message=start_message,
max_turns=2
* max(1, max_attempts), # Limiting the chat to the number of attempts, including the initial one
clear_history=False,
Expand Down
44 changes: 36 additions & 8 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@
llm_config={
"config_list": [{
"api_type": "google",
"model": "models/gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY")
"model": "gemini-pro",
"api_key": os.environ.get("GOOGLE_API_KEY"),
"safety_settings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}
],
"top_p":0.5,
"max_tokens": 2048,
"temperature": 1.0,
"top_k": 5
}
]}
Expand Down Expand Up @@ -47,6 +57,17 @@ class GeminiClient:
of AutoGen.
"""

# Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
PARAMS_MAPPING = {
"max_tokens": "max_output_tokens",
# "n": "candidate_count", # Gemini supports only `n=1`
"stop_sequences": "stop_sequences",
"temperature": "temperature",
"top_p": "top_p",
"top_k": "top_k",
"max_output_tokens": "max_output_tokens",
}

def __init__(self, **kwargs):
self.api_key = kwargs.get("api_key", None)
if not self.api_key:
Expand Down Expand Up @@ -93,12 +114,15 @@ def create(self, params: Dict) -> ChatCompletion:
messages = params.get("messages", [])
stream = params.get("stream", False)
n_response = params.get("n", 1)
params.get("temperature", 0.5)
params.get("top_p", 1.0)
params.get("max_tokens", 4096)

generation_config = {
gemini_term: params[autogen_term]
for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
if autogen_term in params
}
safety_settings = params.get("safety_settings", {})

if stream:
# warn user that streaming is not supported
warnings.warn(
"Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
UserWarning,
Expand All @@ -112,7 +136,9 @@ def create(self, params: Dict) -> ChatCompletion:
gemini_messages = oai_messages_to_gemini_messages(messages)

# we use chat model by default
model = genai.GenerativeModel(model_name)
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
Expand Down Expand Up @@ -142,7 +168,9 @@ def create(self, params: Dict) -> ChatCompletion:
elif model_name == "gemini-pro-vision":
# B. handle the vision model
# Gemini's vision model does not support chat history yet
model = genai.GenerativeModel(model_name)
model = genai.GenerativeModel(
model_name, generation_config=generation_config, safety_settings=safety_settings
)
genai.configure(api_key=self.api_key)
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1])
Expand Down
Loading

0 comments on commit e8e8214

Please sign in to comment.