Skip to content

Commit

Permalink
TeamOne cancellation token support and making logger a member variable (
Browse files Browse the repository at this point in the history
#622)

* Added cancellation token support for team_one; made logger a member variable of each agent.

formatting

fix error

fix error

formatting

* No need to create a new cancellation token
  • Loading branch information
ZHANG-EH authored Sep 24, 2024
1 parent 6dcbf86 commit afdfb4e
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 57 deletions.
18 changes: 15 additions & 3 deletions python/packages/team-one/src/team_one/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
TeamOneMessages,
)

logger = logging.getLogger(EVENT_LOGGER_NAME + ".agent")


class TeamOneBaseAgent(RoutedAgent):
"""An agent that optionally ensures messages are handled non-concurrently in the order they arrive."""
Expand All @@ -29,6 +27,7 @@ def __init__(
super().__init__(description)
self._handle_messages_concurrently = handle_messages_concurrently
self._enabled = True
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.agent")

if not self._handle_messages_concurrently:
# TODO: make it possible to stop
Expand All @@ -40,6 +39,7 @@ async def _process(self) -> None:
message, ctx, future = await self._message_queue.get()
if ctx.cancellation_token.is_cancelled():
# TODO: Do we need to resolve the future here?
future.cancel()
continue

try:
Expand All @@ -54,6 +54,8 @@ async def _process(self) -> None:
else:
raise ValueError("Unknown message type.")
future.set_result(None)
except asyncio.CancelledError:
future.cancel()
except Exception as e:
future.set_exception(e)

Expand Down Expand Up @@ -92,9 +94,19 @@ async def _handle_request_reply(self, message: RequestReplyMessage, ctx: Message
async def _handle_deactivate(self, message: DeactivateMessage, ctx: MessageContext) -> None:
"""Handle a deactivate message."""
self._enabled = False
logger.info(
self.logger.info(
AgentEvent(
f"{self.metadata['type']} (deactivated)",
"",
)
)

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
"""Drop the message, with a log."""
# self.logger.info(
# AgentEvent(
# f"{self.metadata['type']} (unhandled message)",
# f"Unhandled message type: {type(message)}",
# )
# )
pass
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from ..utils import message_content_to_str
from .base_agent import TeamOneBaseAgent

logger = logging.getLogger(EVENT_LOGGER_NAME + ".orchestrator")


class BaseOrchestrator(TeamOneBaseAgent):
def __init__(
Expand All @@ -28,6 +26,7 @@ def __init__(
self._max_time = max_time
self._num_rounds = 0
self._start_time: float = -1.0
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.orchestrator")

async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext) -> None:
"""Handle an incoming message."""
Expand All @@ -42,11 +41,11 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext

content = message_content_to_str(message.content.content)

logger.info(OrchestrationEvent(source, content))
self.logger.info(OrchestrationEvent(source, content))

# Termination conditions
if self._num_rounds >= self._max_rounds:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
f"Max rounds ({self._max_rounds}) reached.",
Expand All @@ -55,7 +54,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext
return

if time.time() - self._start_time >= self._max_time:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
f"Max time ({self._max_time}s) reached.",
Expand All @@ -64,7 +63,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext
return

if message.request_halt:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
f"{source} requested halt.",
Expand All @@ -74,7 +73,7 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext

next_agent = await self._select_next_agent(message.content)
if next_agent is None:
logger.info(
self.logger.info(
OrchestrationEvent(
f"{self.metadata['type']} (termination condition)",
"No agent selected.",
Expand All @@ -84,15 +83,15 @@ async def _handle_broadcast(self, message: BroadcastMessage, ctx: MessageContext
request_reply_message = RequestReplyMessage()
# emit an event

logger.info(
self.logger.info(
OrchestrationEvent(
source=f"{self.metadata['type']} (thought)",
message=f"Next speaker {(await next_agent.metadata)['type']}" "",
)
)

self._num_rounds += 1 # Call before sending the message
await self.send_message(request_reply_message, next_agent.id)
await self.send_message(request_reply_message, next_agent.id, cancellation_token=ctx.cancellation_token)

async def _select_next_agent(self, message: LLMMessage) -> Optional[AgentProxy]:
raise NotImplementedError()
Expand Down
6 changes: 5 additions & 1 deletion python/packages/team-one/src/team_one/agents/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ async def _handle_request_reply(self, message: RequestReplyMessage, ctx: Message

user_message = UserMessage(content=response, source=self.metadata["type"])
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id)
await self.publish_message(
BroadcastMessage(content=user_message, request_halt=request_halt),
topic_id=topic_id,
cancellation_token=ctx.cancellation_token,
)

async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""
Expand Down
4 changes: 3 additions & 1 deletion python/packages/team-one/src/team_one/agents/coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[
"""Respond to a reply request."""

# Make an inference to the model.
response = await self._model_client.create(self._system_messages + self._chat_history)
response = await self._model_client.create(
self._system_messages + self._chat_history, cancellation_token=cancellation_token
)
assert isinstance(response.content, str)
return "TERMINATE" in response.content, response.content

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[
)

create_result = await self._model_client.create(
messages=history + [context_message, task_message], tools=self._tools
messages=history + [context_message, task_message], tools=self._tools, cancellation_token=cancellation_token
)

response = create_result.content
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pathlib
import re
import traceback
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast # Any, Callable, Dict, List, Literal, Tuple
from typing import Any, BinaryIO, Dict, List, Tuple, Union, cast, Optional # Any, Callable, Dict, List, Literal, Tuple
from urllib.parse import quote_plus # parse_qs, quote, unquote, urlparse, urlunparse

import aiofiles
Expand Down Expand Up @@ -67,8 +67,6 @@

SCREENSHOT_TOKENS = 1105

logger = logging.getLogger(EVENT_LOGGER_NAME + ".MultimodalWebSurfer")


# Sentinels
class DEFAULT_CHANNEL(metaclass=SentinelMeta):
Expand Down Expand Up @@ -96,6 +94,7 @@ def __init__(
self._page: Page | None = None
self._last_download: Download | None = None
self._prior_metadata_hash: str | None = None
self.logger = logging.getLogger(EVENT_LOGGER_NAME + f".{self.id.key}.MultimodalWebSurfer")

# Read page_script
self._page_script: str = ""
Expand Down Expand Up @@ -196,7 +195,7 @@ async def _set_debug_dir(self, debug_dir: str | None) -> None:
""".strip(),
)
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")
self.logger.info(f"Multimodal Web Surfer debug screens: {pathlib.Path(os.path.abspath(debug_html)).as_uri()}\n")

async def _reset(self, cancellation_token: CancellationToken) -> None:
assert self._page is not None
Expand All @@ -205,7 +204,7 @@ async def _reset(self, cancellation_token: CancellationToken) -> None:
await self._visit_page(self.start_page)
if self.debug_dir:
await self._page.screenshot(path=os.path.join(self.debug_dir, "screenshot.png"))
logger.info(
self.logger.info(
WebSurferEvent(
source=self.metadata["type"],
url=self._page.url,
Expand Down Expand Up @@ -250,13 +249,18 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[
return False, f"Web surfing error:\n\n{traceback.format_exc()}"

async def _execute_tool(
self, message: List[FunctionCall], rects: Dict[str, InteractiveRegion], tool_names: str, use_ocr: bool = True
self,
message: List[FunctionCall],
rects: Dict[str, InteractiveRegion],
tool_names: str,
use_ocr: bool = True,
cancellation_token: Optional[CancellationToken] = None,
) -> Tuple[bool, UserContent]:
name = message[0].name
args = json.loads(message[0].arguments)
action_description = ""
assert self._page is not None
logger.info(
self.logger.info(
WebSurferEvent(
source=self.metadata["type"],
url=self._page.url,
Expand Down Expand Up @@ -340,11 +344,11 @@ async def _execute_tool(
elif name == "answer_question":
question = str(args.get("question"))
# Do Q&A on the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page(question=question)
return False, await self._summarize_page(question=question, cancellation_token=cancellation_token)

elif name == "summarize_page":
# Summarize the DOM. No need to take further action. Browser state does not change.
return False, await self._summarize_page()
return False, await self._summarize_page(cancellation_token=cancellation_token)

elif name == "sleep":
action_description = "I am waiting a short period of time before taking further action."
Expand Down Expand Up @@ -394,7 +398,9 @@ async def _execute_tool(
async with aiofiles.open(os.path.join(self.debug_dir, "screenshot.png"), "wb") as file:
await file.write(new_screenshot)

ocr_text = await self._get_ocr_text(new_screenshot) if use_ocr is True else ""
ocr_text = (
await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) if use_ocr is True else ""
)

# Return the complete observation
message_content = "" # message.content or ""
Expand Down Expand Up @@ -518,7 +524,7 @@ async def __generate_reply(self, cancellation_token: CancellationToken) -> Tuple
UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.metadata["type"])
)
response = await self._model_client.create(
history, tools=tools, extra_create_args={"tool_choice": "auto"}
history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token
) # , "parallel_tool_calls": False})
message = response.content

Expand All @@ -529,7 +535,7 @@ async def __generate_reply(self, cancellation_token: CancellationToken) -> Tuple
return False, message
elif isinstance(message, list):
# Take an action
return await self._execute_tool(message, rects, tool_names)
return await self._execute_tool(message, rects, tool_names, cancellation_token=cancellation_token)
else:
# Not sure what happened here
raise AssertionError(f"Unknown response format '{message}'")
Expand Down Expand Up @@ -668,7 +674,7 @@ async def _click_id(self, identifier: str) -> None:
assert isinstance(new_page, Page)
await self._on_new_page(new_page)

logger.info(
self.logger.info(
WebSurferEvent(
source=self.metadata["type"],
url=self._page.url,
Expand Down Expand Up @@ -716,7 +722,12 @@ async def _scroll_id(self, identifier: str, direction: str) -> None:
"""
)

async def _summarize_page(self, question: str | None = None, token_limit: int = 100000) -> str:
async def _summarize_page(
self,
question: str | None = None,
token_limit: int = 100000,
cancellation_token: Optional[CancellationToken] = None,
) -> str:
assert self._page is not None

page_markdown: str = await self._get_page_markdown()
Expand Down Expand Up @@ -780,12 +791,14 @@ async def _summarize_page(self, question: str | None = None, token_limit: int =
)

# Generate the response
response = await self._model_client.create(messages)
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content

async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) -> str:
async def _get_ocr_text(
self, image: bytes | io.BufferedIOBase | Image.Image, cancellation_token: Optional[CancellationToken] = None
) -> str:
scaled_screenshot = None
if isinstance(image, Image.Image):
scaled_screenshot = image.resize((MLM_WIDTH, MLM_HEIGHT))
Expand All @@ -810,7 +823,7 @@ async def _get_ocr_text(self, image: bytes | io.BufferedIOBase | Image.Image) ->
source=self.metadata["type"],
)
)
response = await self._model_client.create(messages)
response = await self._model_client.create(messages, cancellation_token=cancellation_token)
scaled_screenshot.close()
assert isinstance(response.content, str)
return response.content
Loading

0 comments on commit afdfb4e

Please sign in to comment.