Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…lient
  • Loading branch information
BeibinLi committed Mar 30, 2024
2 parents 4d93ce8 + 7a685b5 commit 3b7640f
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 46 deletions.
69 changes: 44 additions & 25 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2262,32 +2262,55 @@ def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Un
"""
if message is None:
message = self.get_human_input(">")

return self._handle_carryover(message, kwargs)

def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
if not kwargs.get("carryover"):
return message

if isinstance(message, str):
return self._process_carryover(message, kwargs)
elif isinstance(message, list):
return message # TODO: this multimodal issue is handled in PR #2124
return self._process_multimodal_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy()
# TODO: Do we need to do the following?
# if message.get("content") is None:
# message["content"] = self.get_human_input(">")
message["content"] = self._process_carryover(message.get("content", ""), kwargs)
return message
if isinstance(message.get("content"), str):
# Makes sure the original message is not mutated
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
elif isinstance(message.get("content"), list):
# Makes sure the original message is not mutated
message = message.copy()
message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
else:
raise InvalidCarryOverType("Carryover should be a string or a list of strings.")

def _process_carryover(self, message: str, kwargs: dict) -> str:
carryover = kwargs.get("carryover")
if carryover:
# if carryover is string
if isinstance(carryover, str):
message += "\nContext: \n" + carryover
elif isinstance(carryover, list):
message += "\nContext: \n" + ("\n").join([t for t in carryover])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
return message

def _process_carryover(self, content: str, kwargs: dict) -> str:
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content

# if carryover is string
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
return content

def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
"""Prepends the context to a multimodal message."""
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content

return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content

async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
"""Generate the initial message for the agent.
If message is None, input() will be called to get the initial message.
Expand All @@ -2300,12 +2323,8 @@ async def a_generate_init_message(self, message: Union[Dict, str, None], **kwarg
"""
if message is None:
message = await self.a_get_human_input(">")
if isinstance(message, str):
return self._process_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
return message

return self._handle_carryover(message, kwargs)

def register_function(self, function_map: Dict[str, Union[Callable, None]]):
"""Register functions to the agent.
Expand Down
6 changes: 4 additions & 2 deletions autogen/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def get_default() -> "IOStream":
"""
iostream = IOStream._default_io_stream.get()
if iostream is None:
logger.warning("No default IOStream has been set, defaulting to IOConsole.")
return IOStream.get_global_default()
logger.info("No default IOStream has been set, defaulting to IOConsole.")
iostream = IOStream.get_global_default()
# Set the default IOStream of the current context (thread/cooroutine)
IOStream.set_default(iostream)
return iostream

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,11 @@ def get_config(
"""
config = {"api_key": api_key}
if base_url:
config["base_url"] = base_url
config["base_url"] = os.getenv(base_url, default=base_url)
if api_type:
config["api_type"] = api_type
config["api_type"] = os.getenv(api_type, default=api_type)
if api_version:
config["api_version"] = api_version
config["api_version"] = os.getenv(api_version, default=api_version)
return config


Expand Down
62 changes: 47 additions & 15 deletions samples/tools/autogenbench/autogenbench/run_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,14 @@ def run_scenario_in_docker(work_dir, env, timeout=TASK_TIMEOUT, docker_image=Non

# Create and run the container
container = client.containers.run(
image, command=["sh", "run.sh"], working_dir="/workspace", environment=env, detach=True, volumes=volumes
image,
command=["sh", "run.sh"],
working_dir="/workspace",
environment=env,
detach=True,
remove=True,
auto_remove=True,
volumes=volumes,
)

# Read the logs in a streaming fashion. Keep an eye on the time to make sure we don't need to stop.
Expand All @@ -494,30 +501,55 @@ def run_scenario_in_docker(work_dir, env, timeout=TASK_TIMEOUT, docker_image=Non
logs = container.logs(stream=True)
log_file = open(os.path.join(work_dir, "console_log.txt"), "wt", encoding="utf-8")
stopping = False
exiting = False

for chunk in logs: # When streaming it should return a generator
# Stream the data to the log file and the console
chunk = chunk.decode("utf-8")
log_file.write(chunk)
log_file.flush()
sys.stdout.reconfigure(encoding="utf-8")
sys.stdout.write(chunk)
sys.stdout.flush()

# Check if we need to terminate
if not stopping and time.time() - start_time >= docker_timeout:
while True:
try:
chunk = next(logs) # Manually step the iterator so it is captures with the try-catch

# Stream the data to the log file and the console
chunk = chunk.decode("utf-8")
log_file.write(chunk)
log_file.flush()
sys.stdout.reconfigure(encoding="utf-8")
sys.stdout.write(chunk)
sys.stdout.flush()

# Check if we need to terminate
if not stopping and time.time() - start_time >= docker_timeout:
container.stop()

# Don't exit the loop right away, as there are things we may still want to read from the logs
# but remember how we got here.
stopping = True
except KeyboardInterrupt:
log_file.write("\nKeyboard interrupt (Ctrl-C). Attempting to exit gracefully.\n")
log_file.flush()
sys.stdout.write("\nKeyboard interrupt (Ctrl-C). Attempting to exit gracefully.\n")
sys.stdout.flush()

# Start the exit process, and give it a minute, but keep iterating
container.stop()
exiting = True
docker_timeout = time.time() - start_time + 60
except StopIteration:
break

# Don't exit the loop right away, as there are things we may still want to read from the logs
# but remember how we got here.
stopping = True
# Clean up the container
try:
container.remove()
except docker.errors.APIError:
pass

if stopping: # By this line we've exited the loop, and the container has actually stopped.
log_file.write("\nDocker timed out.\n")
log_file.flush()
sys.stdout.write("\nDocker timed out.\n")
sys.stdout.flush()

if exiting: # User hit ctrl-C
sys.exit(1)


def build_default_docker_image(docker_client, image_tag):
for segment in docker_client.api.build(
Expand Down
2 changes: 1 addition & 1 deletion samples/tools/autogenbench/autogenbench/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.2"
__version__ = "0.0.3"
48 changes: 48 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,54 @@ def test_messages_with_carryover():
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)

# Test multimodal messages
mm_content = [
{"type": "text", "text": "hello"},
{"type": "text", "text": "goodbye"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
]
mm_message = {"content": mm_content}
context = dict(
message=mm_message,
carryover="Testing carryover.",
)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4

context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4

context = dict(message=mm_message, carryover=3)
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)

# Test without carryover
print(mm_message)
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 3

# Test without text in multimodal message
mm_content = [
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
]
mm_message = {"content": mm_content}
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 1

generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.")
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2


if __name__ == "__main__":
# test_trigger()
Expand Down

0 comments on commit 3b7640f

Please sign in to comment.