Skip to content

Commit 4820a25

Browse files
authored
Merge branch 'main' into fix-article-markdown
2 parents 830daf1 + 9069eb9 commit 4820a25

File tree

6 files changed

+360
-345
lines changed

6 files changed

+360
-345
lines changed

autogen/agentchat/contrib/agent_optimizer.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import json
3-
from typing import Dict, List, Optional
3+
from typing import Dict, List, Literal, Optional, Union
44

55
import autogen
66
from autogen.code_utils import execute_code
@@ -172,16 +172,16 @@ class AgentOptimizer:
172172
def __init__(
173173
self,
174174
max_actions_per_step: int,
175-
config_file_or_env: Optional[str] = "OAI_CONFIG_LIST",
176-
config_file_location: Optional[str] = "",
175+
llm_config: dict,
177176
optimizer_model: Optional[str] = "gpt-4-1106-preview",
178177
):
179178
"""
180179
(These APIs are experimental and may change in the future.)
181180
Args:
182181
max_actions_per_step (int): the maximum number of actions that the optimizer can take in one step.
183-
config_file_or_env: path or environment of the OpenAI api configs.
184-
config_file_location: the location of the OpenAI config file.
182+
llm_config (dict): llm inference configuration.
183+
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create) for available options.
184+
When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `llm_config` or in each config of 'config_list' in `llm_config`.
185185
optimizer_model: the model used for the optimizer.
186186
"""
187187
self.max_actions_per_step = max_actions_per_step
@@ -199,14 +199,17 @@ def __init__(
199199
self._failure_functions_performance = []
200200
self._best_performance = -1
201201

202-
config_list = autogen.config_list_from_json(
203-
config_file_or_env,
204-
file_location=config_file_location,
205-
filter_dict={"model": [self.optimizer_model]},
202+
assert isinstance(llm_config, dict), "llm_config must be a dict"
203+
llm_config = copy.deepcopy(llm_config)
204+
self.llm_config = llm_config
205+
if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]:
206+
raise ValueError(
207+
"When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
208+
)
209+
self.llm_config["config_list"] = autogen.filter_config(
210+
llm_config["config_list"], {"model": [self.optimizer_model]}
206211
)
207-
if len(config_list) == 0:
208-
raise RuntimeError("No valid openai config found in the config file or environment variable.")
209-
self._client = autogen.OpenAIWrapper(config_list=config_list)
212+
self._client = autogen.OpenAIWrapper(**self.llm_config)
210213

211214
def record_one_conversation(self, conversation_history: List[Dict], is_satisfied: bool = None):
212215
"""
@@ -266,7 +269,7 @@ def step(self):
266269
actions_num=action_index,
267270
best_functions=best_functions,
268271
incumbent_functions=incumbent_functions,
269-
accumerated_experience=failure_experience_prompt,
272+
accumulated_experience=failure_experience_prompt,
270273
statistic_informations=statistic_prompt,
271274
)
272275
messages = [{"role": "user", "content": prompt}]

autogen/agentchat/contrib/capabilities/transforms.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ class MessageTokenLimiter:
8585
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
8686
and other types of content, only the text content is truncated.
8787
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
88-
exceeds this limit, the current message being processed as well as any remaining messages are discarded.
88+
exceeds this limit, the current message being processed get truncated to meet the total token count and any
89+
remaining messages get discarded.
8990
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
9091
original message order.
9192
"""
@@ -128,13 +129,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
128129
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)
129130

130131
for msg in reversed(temp_messages):
131-
msg["content"] = self._truncate_str_to_tokens(msg["content"])
132-
msg_tokens = _count_tokens(msg["content"])
132+
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
133133

134-
# If adding this message would exceed the token limit, discard it and all remaining messages
135-
if processed_messages_tokens + msg_tokens > self._max_tokens:
134+
# If adding this message would exceed the token limit, truncate the last message to meet the total token
135+
# limit and discard all remaining messages
136+
if expected_tokens_remained < 0:
137+
msg["content"] = self._truncate_str_to_tokens(
138+
msg["content"], self._max_tokens - processed_messages_tokens
139+
)
140+
processed_messages.insert(0, msg)
136141
break
137142

143+
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
144+
msg_tokens = _count_tokens(msg["content"])
145+
138146
# prepend the message to the list to preserve order
139147
processed_messages_tokens += msg_tokens
140148
processed_messages.insert(0, msg)
@@ -149,30 +157,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
149157

150158
return processed_messages
151159

152-
def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]:
160+
def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]:
153161
if isinstance(contents, str):
154-
return self._truncate_tokens(contents)
162+
return self._truncate_tokens(contents, n_tokens)
155163
elif isinstance(contents, list):
156-
return self._truncate_multimodal_text(contents)
164+
return self._truncate_multimodal_text(contents, n_tokens)
157165
else:
158166
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")
159167

160-
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
168+
def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]:
161169
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
162170
tmp_contents = []
163171
for content in contents:
164172
if content["type"] == "text":
165-
truncated_text = self._truncate_tokens(content["text"])
173+
truncated_text = self._truncate_tokens(content["text"], n_tokens)
166174
tmp_contents.append({"type": "text", "text": truncated_text})
167175
else:
168176
tmp_contents.append(content)
169177
return tmp_contents
170178

171-
def _truncate_tokens(self, text: str) -> str:
179+
def _truncate_tokens(self, text: str, n_tokens: int) -> str:
172180
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer
173181

174182
encoded_tokens = encoding.encode(text)
175-
truncated_tokens = encoded_tokens[: self._max_tokens_per_message]
183+
truncated_tokens = encoded_tokens[:n_tokens]
176184
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
177185

178186
return truncated_text

notebook/agentchat_agentoptimizer.ipynb

+14-9
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"source": [
4242
"import copy\n",
4343
"import json\n",
44+
"import os\n",
4445
"from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n",
4546
"\n",
4647
"from openai import BadRequestError\n",
@@ -299,16 +300,22 @@
299300
"metadata": {},
300301
"outputs": [],
301302
"source": [
302-
"config_list = config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
303+
"llm_config = {\n",
304+
" \"config_list\": [\n",
305+
" {\n",
306+
" \"model\": \"gpt-4-1106-preview\",\n",
307+
" \"api_type\": \"azure\",\n",
308+
" \"api_key\": os.environ[\"AZURE_OPENAI_API_KEY\"],\n",
309+
" \"base_url\": \"https://ENDPOINT.openai.azure.com/\",\n",
310+
" \"api_version\": \"2023-07-01-preview\",\n",
311+
" }\n",
312+
" ]\n",
313+
"}\n",
303314
"\n",
304315
"assistant = autogen.AssistantAgent(\n",
305316
" name=\"assistant\",\n",
306317
" system_message=\"You are a helpful assistant.\",\n",
307-
" llm_config={\n",
308-
" \"timeout\": 600,\n",
309-
" \"seed\": 42,\n",
310-
" \"config_list\": config_list,\n",
311-
" },\n",
318+
" llm_config=llm_config,\n",
312319
")\n",
313320
"user_proxy = MathUserProxyAgent(\n",
314321
" name=\"mathproxyagent\",\n",
@@ -361,9 +368,7 @@
361368
"source": [
362369
"EPOCH = 10\n",
363370
"optimizer_model = \"gpt-4-1106-preview\"\n",
364-
"optimizer = AgentOptimizer(\n",
365-
" max_actions_per_step=3, config_file_or_env=\"OAI_CONFIG_LIST\", optimizer_model=optimizer_model\n",
366-
")\n",
371+
"optimizer = AgentOptimizer(max_actions_per_step=3, llm_config=llm_config, optimizer_model=optimizer_model)\n",
367372
"for i in range(EPOCH):\n",
368373
" for index, query in enumerate(train_data):\n",
369374
" is_correct = user_proxy.initiate_chat(assistant, answer=query[\"answer\"], problem=query[\"question\"])\n",

0 commit comments

Comments
 (0)