Skip to content

Commit

Permalink
Fix generate_init_message for Multimodal Messages (microsoft#2124)
Browse files Browse the repository at this point in the history
* multimodal carryover

* adds mm carryover tests

* more tests + cleanup code

* check content instead

* beibin suggestion

* cleanup

* fix async

* use deepcopy

* handle carryover method

* remove content copy

* sonichi suggestions

---------

Co-authored-by: Beibin Li <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Mar 30, 2024
1 parent 0a6fed5 commit 4bedecf
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 24 deletions.
68 changes: 44 additions & 24 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,30 +2259,54 @@ def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Un
"""
if message is None:
message = self.get_human_input(">")

return self._handle_carryover(message, kwargs)

def _handle_carryover(self, message: Union[str, Dict], kwargs: dict) -> Union[str, Dict]:
if not kwargs.get("carryover"):
return message

if isinstance(message, str):
return self._process_carryover(message, kwargs)

elif isinstance(message, dict):
message = message.copy()
# TODO: Do we need to do the following?
# if message.get("content") is None:
# message["content"] = self.get_human_input(">")
message["content"] = self._process_carryover(message.get("content", ""), kwargs)
return message
if isinstance(message.get("content"), str):
# Makes sure the original message is not mutated
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
elif isinstance(message.get("content"), list):
# Makes sure the original message is not mutated
message = message.copy()
message["content"] = self._process_multimodal_carryover(message["content"], kwargs)
else:
raise InvalidCarryOverType("Carryover should be a string or a list of strings.")

def _process_carryover(self, message: str, kwargs: dict) -> str:
carryover = kwargs.get("carryover")
if carryover:
# if carryover is string
if isinstance(carryover, str):
message += "\nContext: \n" + carryover
elif isinstance(carryover, list):
message += "\nContext: \n" + ("\n").join([t for t in carryover])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
return message

def _process_carryover(self, content: str, kwargs: dict) -> str:
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content

# if carryover is string
if isinstance(kwargs["carryover"], str):
content += "\nContext: \n" + kwargs["carryover"]
elif isinstance(kwargs["carryover"], list):
content += "\nContext: \n" + ("\n").join([t for t in kwargs["carryover"]])
else:
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)
return content

def _process_multimodal_carryover(self, content: List[Dict], kwargs: dict) -> List[Dict]:
"""Prepends the context to a multimodal message."""
# Makes sure there's a carryover
if not kwargs.get("carryover"):
return content

return [{"type": "text", "text": self._process_carryover("", kwargs)}] + content

async def a_generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]:
"""Generate the initial message for the agent.
If message is None, input() will be called to get the initial message.
Expand All @@ -2295,12 +2319,8 @@ async def a_generate_init_message(self, message: Union[Dict, str, None], **kwarg
"""
if message is None:
message = await self.a_get_human_input(">")
if isinstance(message, str):
return self._process_carryover(message, kwargs)
elif isinstance(message, dict):
message = message.copy()
message["content"] = self._process_carryover(message["content"], kwargs)
return message

return self._handle_carryover(message, kwargs)

def register_function(self, function_map: Dict[str, Union[Callable, None]]):
"""Register functions to the agent.
Expand Down
48 changes: 48 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,54 @@ def test_messages_with_carryover():
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)

# Test multimodal messages
mm_content = [
{"type": "text", "text": "hello"},
{"type": "text", "text": "goodbye"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.png"},
},
]
mm_message = {"content": mm_content}
context = dict(
message=mm_message,
carryover="Testing carryover.",
)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4

context = dict(message=mm_message, carryover=["Testing carryover.", "This should pass"])
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 4

context = dict(message=mm_message, carryover=3)
with pytest.raises(InvalidCarryOverType):
agent1.generate_init_message(**context)

# Test without carryover
print(mm_message)
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 3

# Test without text in multimodal message
mm_content = [
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
]
mm_message = {"content": mm_content}
context = dict(message=mm_message)
generated_message = agent1.generate_init_message(**context)
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 1

generated_message = agent1.generate_init_message(**context, carryover="Testing carryover.")
assert isinstance(generated_message, dict)
assert len(generated_message["content"]) == 2


if __name__ == "__main__":
# test_trigger()
Expand Down

0 comments on commit 4bedecf

Please sign in to comment.