From afdfb4ea5834fd566239bb54f784a442a95c445c Mon Sep 17 00:00:00 2001 From: Enhao Zhang Date: Tue, 24 Sep 2024 16:54:22 -0700 Subject: [PATCH] TeamOne cancellation token support and making logger a member variable (#622) * Added cancellation token support for team_one; made logger a member variable of each agent. formatting fix error fix error formatting * No need to create a new cancellation token --- .../src/team_one/agents/base_agent.py | 18 +++++- .../src/team_one/agents/base_orchestrator.py | 17 +++-- .../src/team_one/agents/base_worker.py | 6 +- .../team-one/src/team_one/agents/coder.py | 4 +- .../agents/file_surfer/file_surfer.py | 2 +- .../multimodal_web_surfer.py | 47 +++++++++----- .../src/team_one/agents/orchestrator.py | 63 +++++++++++-------- 7 files changed, 100 insertions(+), 57 deletions(-) diff --git a/python/packages/team-one/src/team_one/agents/base_agent.py b/python/packages/team-one/src/team_one/agents/base_agent.py index b8df8e0e8829..837104989436 100644 --- a/python/packages/team-one/src/team_one/agents/base_agent.py +++ b/python/packages/team-one/src/team_one/agents/base_agent.py @@ -15,8 +15,6 @@ TeamOneMessages, ) -logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent") - class TeamOneBaseAgent(RoutedAgent): """An agent that optionally ensures messages are handled non-concurrently in the order they arrive.""" @@ -29,6 +27,7 @@ def __init__( super().__init__(description) self._handle_messages_concurrently = handle_messages_concurrently self._enabled = True + self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.agent") if not self._handle_messages_concurrently: # TODO: make it possible to stop @@ -40,6 +39,7 @@ async def _process(self) -> None: message, ctx, future = await self._message_queue.get() if ctx.cancellation_token.is_cancelled(): # TODO: Do we need to resolve the future here? + future.cancel() continue try: @@ -54,6 +54,8 @@ async def _process(self) -> None: else: raise ValueError("Unknown message type.") future.set_result(None) + except asyncio.CancelledError: + future.cancel() except Exception as e: future.set_exception(e) @@ -92,9 +94,19 @@ async def _handle_request_reply(self, message: RequestReplyMessage, ctx: Message async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None: """Handle a deactivate message.""" self._enabled = False - logger.info( + self.logger.info( AgentEvent( f"{self.metadata['type']} (deactivated)", "", ) ) + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + """Drop the message, with a log.""" + # self.logger.info( + # AgentEvent( + # f"{self.metadata['type']} (unhandled message)", + # f"Unhandled message type: {type(message)}", + # ) + # ) + pass diff --git a/python/packages/team-one/src/team_one/agents/base_orchestrator.py b/python/packages/team-one/src/team_one/agents/base_orchestrator.py index a146613ccd24..4f1eead9148e 100644 --- a/python/packages/team-one/src/team_one/agents/base_orchestrator.py +++ b/python/packages/team-one/src/team_one/agents/base_orchestrator.py @@ -10,8 +10,6 @@ from ..utils import message_content_to_str from .base_agent import TeamOneBaseAgent -logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator") - class BaseOrchestrator(TeamOneBaseAgent): def __init__( @@ -28,6 +26,7 @@ def __init__( self._max_time = max_time self._num_rounds = 0 self._start_time: float = -1.0 + self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.orchestrator") async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None: """Handle an incoming message.""" @@ -42,11 +41,11 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext content = message_content_to_str(message.content.content) - logger.info(OrchestrationEvent(source, content)) + self.logger.info(OrchestrationEvent(source, content)) # Termination conditions if self._num_rounds >= self._max_rounds: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (termination condition)", f"Max rounds ({self._max_rounds}) reached.", @@ -55,7 +54,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext return if time.time() - self._start_time >= self._max_time: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (termination condition)", f"Max time ({self._max_time}s) reached.", @@ -64,7 +63,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext return if message.request_halt: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (termination condition)", f"{source} requested halt.", @@ -74,7 +73,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext next_agent = await self._select_next_agent(message.content) if next_agent is None: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (termination condition)", "No agent selected.", @@ -84,7 +83,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext request_reply_message = RequestReplyMessage() # emit an event - logger.info( + self.logger.info( OrchestrationEvent( source=f"{self.metadata['type']} (thought)", message=f"Next speaker {(await next_agent.metadata)['type']}" "", @@ -92,7 +91,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext ) self._num_rounds += 1 # Call before sending the message - await self.send_message(request_reply_message, next_agent.id) + await self.send_message(request_reply_message, next_agent.id, cancellation_token=ctx.cancellation_token) async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: raise NotImplementedError() diff --git a/python/packages/team-one/src/team_one/agents/base_worker.py b/python/packages/team-one/src/team_one/agents/base_worker.py index 6da6de975690..5a6516d8daf7 100644 --- a/python/packages/team-one/src/team_one/agents/base_worker.py +++ b/python/packages/team-one/src/team_one/agents/base_worker.py @@ -46,7 +46,11 @@ async def _handle_request_reply(self, message: RequestReplyMessage, ctx: Message user_message = UserMessage(content=response, source=self.metadata["type"]) topic_id = TopicId("default", self.id.key) - await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id) + await self.publish_message( + BroadcastMessage(content=user_message, request_halt=request_halt), + topic_id=topic_id, + cancellation_token=ctx.cancellation_token, + ) async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]: """Returns (request_halt, response_message)""" diff --git a/python/packages/team-one/src/team_one/agents/coder.py b/python/packages/team-one/src/team_one/agents/coder.py index d5badcae4fdb..ea83ae0a9155 100644 --- a/python/packages/team-one/src/team_one/agents/coder.py +++ b/python/packages/team-one/src/team_one/agents/coder.py @@ -49,7 +49,9 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[ """Respond to a reply request.""" # Make an inference to the model. - response = await self._model_client.create(self._system_messages + self._chat_history) + response = await self._model_client.create( + self._system_messages + self._chat_history, cancellation_token=cancellation_token + ) assert isinstance(response.content, str) return "TERMINATE" in response.content, response.content diff --git a/python/packages/team-one/src/team_one/agents/file_surfer/file_surfer.py b/python/packages/team-one/src/team_one/agents/file_surfer/file_surfer.py index 6cd9f959196f..af798af5c5c6 100644 --- a/python/packages/team-one/src/team_one/agents/file_surfer/file_surfer.py +++ b/python/packages/team-one/src/team_one/agents/file_surfer/file_surfer.py @@ -90,7 +90,7 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[ ) create_result = await self._model_client.create( - messages=history + [context_message, task_message], tools=self._tools + messages=history + [context_message, task_message], tools=self._tools, cancellation_token=cancellation_token ) response = create_result.content diff --git a/python/packages/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py b/python/packages/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py index 3032770c1858..17b7ebe98a42 100644 --- a/python/packages/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py +++ b/python/packages/team-one/src/team_one/agents/multimodal_web_surfer/multimodal_web_surfer.py @@ -7,7 +7,7 @@ import pathlib import re import traceback -from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callable, Dict, List, Literal, Tuple +from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast, Optional # Any, Callable, Dict, List, Literal, Tuple from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse import aiofiles @@ -67,8 +67,6 @@ SCREENSHOT_TOKENS = 1105 -logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer") - # Sentinels class DEFAULT_CHANNEL(metaclass=SentinelMeta): @@ -96,6 +94,7 @@ def __init__( self._page: Page | None = None self._last_download: Download | None = None self._prior_metadata_hash: str | None = None + self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.MultimodalWebSurfer") # Read page_script self._page_script: str = "" @@ -196,7 +195,7 @@ async def _set_debug_dir(self, debug_dir: str | None) -> None: """.strip(), ) await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png")) - logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n") + self.logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n") async def _reset(self, cancellation_token: CancellationToken) -> None: assert self._page is not None @@ -205,7 +204,7 @@ async def _reset(self, cancellation_token: CancellationToken) -> None: await self._visit_page(self.start_page) if self.debug_dir: await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png")) - logger.info( + self.logger.info( WebSurferEvent( source=self.metadata["type"], url=self._page.url, @@ -250,13 +249,18 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[ return False, f"Web surfing error:\n\n{traceback.format_exc()}" async def _execute_tool( - self, message: List[FunctionCall], rects: Dict[str, InteractiveRegion], tool_names: str, use_ocr: bool = True + self, + message: List[FunctionCall], + rects: Dict[str, InteractiveRegion], + tool_names: str, + use_ocr: bool = True, + cancellation_token: Optional[CancellationToken] = None, ) -> Tuple[bool, UserContent]: name = message[0].name args = json.loads(message[0].arguments) action_description = "" assert self._page is not None - logger.info( + self.logger.info( WebSurferEvent( source=self.metadata["type"], url=self._page.url, @@ -340,11 +344,11 @@ async def _execute_tool( elif name == "answer_question": question = str(args.get("question")) # Do Q&A on the DOM. No need to take further action. Browser state does not change. - return False, await self._summarize_page(question=question) + return False, await self._summarize_page(question=question, cancellation_token=cancellation_token) elif name == "summarize_page": # Summarize the DOM. No need to take further action. Browser state does not change. - return False, await self._summarize_page() + return False, await self._summarize_page(cancellation_token=cancellation_token) elif name == "sleep": action_description = "I am waiting a short period of time before taking further action." @@ -394,7 +398,9 @@ async def _execute_tool( async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file: await file.write(new_screenshot) - ocr_text = await self._get_ocr_text(new_screenshot) if use_ocr is True else "" + ocr_text = ( + await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) if use_ocr is True else "" + ) # Return the complete observation message_content = "" # message.content or "" @@ -518,7 +524,7 @@ async def __generate_reply(self, cancellation_token: CancellationToken) -> Tuple UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"]) ) response = await self._model_client.create( - history, tools=tools, extra_create_args={"tool_choice": "auto"} + history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token ) # , "parallel_tool_calls": False}) message = response.content @@ -529,7 +535,7 @@ async def __generate_reply(self, cancellation_token: CancellationToken) -> Tuple return False, message elif isinstance(message, list): # Take an action - return await self._execute_tool(message, rects, tool_names) + return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token) else: # Not sure what happened here raise AssertionError(f"Unknown response format '{message}'") @@ -668,7 +674,7 @@ async def _click_id(self, identifier: str) -> None: assert isinstance(new_page, Page) await self._on_new_page(new_page) - logger.info( + self.logger.info( WebSurferEvent( source=self.metadata["type"], url=self._page.url, @@ -716,7 +722,12 @@ async def _scroll_id(self, identifier: str, direction: str) -> None: """ ) - async def _summarize_page(self, question: str | None = None, token_limit: int = 100000) -> str: + async def _summarize_page( + self, + question: str | None = None, + token_limit: int = 100000, + cancellation_token: Optional[CancellationToken] = None, + ) -> str: assert self._page is not None page_markdown: str = await self._get_page_markdown() @@ -780,12 +791,14 @@ async def _summarize_page(self, question: str | None = None, token_limit: int = ) # Generate the response - response = await self._model_client.create(messages) + response = await self._model_client.create(messages, cancellation_token=cancellation_token) scaled_screenshot.close() assert isinstance(response.content, str) return response.content - async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) -> str: + async def _get_ocr_text( + self, image: bytes | io.BufferedIOBase | Image.Image, cancellation_token: Optional[CancellationToken] = None + ) -> str: scaled_screenshot = None if isinstance(image, Image.Image): scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT)) @@ -810,7 +823,7 @@ async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) -> source=self.metadata["type"], ) ) - response = await self._model_client.create(messages) + response = await self._model_client.create(messages, cancellation_token=cancellation_token) scaled_screenshot.close() assert isinstance(response.content, str) return response.content diff --git a/python/packages/team-one/src/team_one/agents/orchestrator.py b/python/packages/team-one/src/team_one/agents/orchestrator.py index 16b2643483a6..2d4c70c9789d 100644 --- a/python/packages/team-one/src/team_one/agents/orchestrator.py +++ b/python/packages/team-one/src/team_one/agents/orchestrator.py @@ -1,7 +1,7 @@ import json from typing import Any, Dict, List, Optional -from autogen_core.base import AgentProxy, MessageContext, TopicId +from autogen_core.base import AgentProxy, MessageContext, TopicId, CancellationToken from autogen_core.components import default_subscription from autogen_core.components.models import ( AssistantMessage, @@ -12,7 +12,7 @@ ) from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage -from .base_orchestrator import BaseOrchestrator, logger +from .base_orchestrator import BaseOrchestrator from .orchestrator_prompts import ( ORCHESTRATOR_CLOSED_BOOK_PROMPT, ORCHESTRATOR_LEDGER_PROMPT, @@ -128,7 +128,7 @@ def _get_message_str(self, message: LLMMessage) -> str: assert len(result) > 0 return result - async def _initialize_task(self, task: str) -> None: + async def _initialize_task(self, task: str, cancellation_token: Optional[CancellationToken] = None) -> None: self._task = task self._team_description = await self._get_team_description() @@ -140,7 +140,9 @@ async def _initialize_task(self, task: str) -> None: planning_conversation.append( UserMessage(content=self._get_closed_book_prompt(self._task), source=self.metadata["type"]) ) - response = await self._model_client.create(self._system_messages + planning_conversation) + response = await self._model_client.create( + self._system_messages + planning_conversation, cancellation_token=cancellation_token + ) assert isinstance(response.content, str) self._facts = response.content @@ -151,14 +153,16 @@ async def _initialize_task(self, task: str) -> None: planning_conversation.append( UserMessage(content=self._get_plan_prompt(self._team_description), source=self.metadata["type"]) ) - response = await self._model_client.create(self._system_messages + planning_conversation) + response = await self._model_client.create( + self._system_messages + planning_conversation, cancellation_token=cancellation_token + ) assert isinstance(response.content, str) self._plan = response.content # At this point, the planning conversation is dropped. - async def _update_facts_and_plan(self) -> None: + async def _update_facts_and_plan(self, cancellation_token: Optional[CancellationToken] = None) -> None: # Shallow-copy the conversation planning_conversation = [m for m in self._chat_history] @@ -166,7 +170,9 @@ async def _update_facts_and_plan(self) -> None: planning_conversation.append( UserMessage(content=self._get_update_facts_prompt(self._task, self._facts), source=self.metadata["type"]) ) - response = await self._model_client.create(self._system_messages + planning_conversation) + response = await self._model_client.create( + self._system_messages + planning_conversation, cancellation_token=cancellation_token + ) assert isinstance(response.content, str) self._facts = response.content @@ -176,14 +182,16 @@ async def _update_facts_and_plan(self) -> None: planning_conversation.append( UserMessage(content=self._get_update_plan_prompt(self._team_description), source=self.metadata["type"]) ) - response = await self._model_client.create(self._system_messages + planning_conversation) + response = await self._model_client.create( + self._system_messages + planning_conversation, cancellation_token=cancellation_token + ) assert isinstance(response.content, str) self._plan = response.content # At this point, the planning conversation is dropped. - async def update_ledger(self) -> Dict[str, Any]: + async def update_ledger(self, cancellation_token: Optional[CancellationToken] = None) -> Dict[str, Any]: max_json_retries = 10 team_description = await self._get_team_description() @@ -197,6 +205,7 @@ async def update_ledger(self) -> Dict[str, Any]: ledger_response = await self._model_client.create( self._system_messages + self._chat_history + ledger_user_messages, json_output=True, + cancellation_token=cancellation_token, ) ledger_str = ledger_response.content @@ -230,7 +239,7 @@ async def update_ledger(self) -> Dict[str, Any]: continue return ledger_dict except json.JSONDecodeError as e: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (error)", f"Failed to parse ledger information: {ledger_str}", @@ -244,10 +253,12 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext self._chat_history.append(message.content) await super()._handle_broadcast(message, ctx) - async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: + async def _select_next_agent( + self, message: LLMMessage, cancellation_token: Optional[CancellationToken] = None + ) -> Optional[AgentProxy]: # Check if the task is still unset, in which case this message contains the task string if len(self._task) == 0: - await self._initialize_task(self._get_message_str(message)) + await self._initialize_task(self._get_message_str(message), cancellation_token) # At this point the task, plan and facts shouls all be set assert len(self._task) > 0 @@ -263,9 +274,10 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: await self.publish_message( BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])), topic_id=topic_id, + cancellation_token=cancellation_token, ) - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (thought)", f"Initial plan:\n{synthesized_prompt}", @@ -279,11 +291,11 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: self._chat_history.append(synthesized_message) # Answer from this synthesized message - return await self._select_next_agent(synthesized_message) + return await self._select_next_agent(synthesized_message, cancellation_token) # Orchestrate the next step - ledger_dict = await self.update_ledger() - logger.info( + ledger_dict = await self.update_ledger(cancellation_token) + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (thought)", f"Updated Ledger:\n{json.dumps(ledger_dict, indent=2)}", @@ -292,7 +304,7 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: # Task is complete if ledger_dict["is_request_satisfied"]["answer"] is True: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (thought)", "Request satisfied.", @@ -312,7 +324,7 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: # We exceeded our replan counter if self._replan_counter > self._max_replans: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (thought)", "Replan counter exceeded... Terminating.", @@ -321,7 +333,7 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: return None # Let's create a new plan else: - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (thought)", "Stalled.... Replanning...", @@ -329,24 +341,24 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: ) # Update our plan. - await self._update_facts_and_plan() + await self._update_facts_and_plan(cancellation_token) # Reset everyone, then rebroadcast the new plan self._chat_history = [self._chat_history[0]] topic_id = TopicId("default", self.id.key) - await self.publish_message(ResetMessage(), topic_id=topic_id) + await self.publish_message(ResetMessage(), topic_id=topic_id, cancellation_token=cancellation_token) # Send everyone the NEW plan synthesized_prompt = self._get_synthesize_prompt( self._task, self._team_description, self._facts, self._plan ) - topic_id = TopicId("default", self.id.key) await self.publish_message( BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])), topic_id=topic_id, + cancellation_token=cancellation_token, ) - logger.info( + self.logger.info( OrchestrationEvent( f"{self.metadata['type']} (thought)", f"New plan:\n{synthesized_prompt}", @@ -357,7 +369,7 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: self._chat_history.append(synthesized_message) # Answer from this synthesized message - return await self._select_next_agent(synthesized_message) + return await self._select_next_agent(synthesized_message, cancellation_token) # If we goit this far, we were not starting, done, or stuck next_agent_name = ledger_dict["next_speaker"]["answer"] @@ -367,12 +379,13 @@ async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]: instruction = ledger_dict["instruction_or_question"]["answer"] user_message = UserMessage(content=instruction, source=self.metadata["type"]) assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"]) - logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction)) + self.logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction)) self._chat_history.append(assistant_message) # My copy topic_id = TopicId("default", self.id.key) await self.publish_message( BroadcastMessage(content=user_message, request_halt=False), topic_id=topic_id, + cancellation_token=cancellation_token, ) # Send to everyone else return agent