From e1bc1e0ac328195ffb068f501c6568724d042e57 Mon Sep 17 00:00:00 2001 From: Victor Dibia Date: Fri, 30 Aug 2024 10:04:13 -0700 Subject: [PATCH 1/6] bump version, add claude default model --- .../apps/autogen-studio/autogenstudio/database/utils.py | 8 ++++++++ samples/apps/autogen-studio/autogenstudio/version.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/samples/apps/autogen-studio/autogenstudio/database/utils.py b/samples/apps/autogen-studio/autogenstudio/database/utils.py index 189fa1baf8d1..ac77a9161498 100644 --- a/samples/apps/autogen-studio/autogenstudio/database/utils.py +++ b/samples/apps/autogen-studio/autogenstudio/database/utils.py @@ -175,6 +175,13 @@ def init_db_samples(dbmanager: Any): model="gpt-4-1106-preview", description="OpenAI GPT-4 model", user_id="guestuser@gmail.com", api_type="open_ai" ) + anthropic_sonnet_model = Model( + model="claude-3-5-sonnet-20240620", + description="Anthropic's Claude 3.5 Sonnet model", + api_type="anthropic", + user_id="guestuser@gmail.com", + ) + # skills generate_pdf_skill = Skill( name="generate_and_save_pdf", @@ -303,6 +310,7 @@ def init_db_samples(dbmanager: Any): session.add(google_gemini_model) session.add(azure_model) session.add(gpt_4_model) + session.add(anthropic_sonnet_model) session.add(generate_image_skill) session.add(generate_pdf_skill) session.add(user_proxy) diff --git a/samples/apps/autogen-studio/autogenstudio/version.py b/samples/apps/autogen-studio/autogenstudio/version.py index 3d83da06d441..9ec42a697660 100644 --- a/samples/apps/autogen-studio/autogenstudio/version.py +++ b/samples/apps/autogen-studio/autogenstudio/version.py @@ -1,3 +1,3 @@ -VERSION = "0.1.4" +VERSION = "0.1.5" __version__ = VERSION APP_NAME = "autogenstudio" From 75b8054edd4ed5117c5d78af64076db7306b9bdf Mon Sep 17 00:00:00 2001 From: Joe Landers Date: Wed, 4 Sep 2024 16:22:01 -0700 Subject: [PATCH 2/6] Move WebSocketConnectionManager into its own file. --- .../autogenstudio/chatmanager.py | 94 ------------ .../autogen-studio/autogenstudio/web/app.py | 3 +- .../websocket_connection_manager.py | 134 ++++++++++++++++++ 3 files changed, 136 insertions(+), 95 deletions(-) create mode 100644 samples/apps/autogen-studio/autogenstudio/websocket_connection_manager.py diff --git a/samples/apps/autogen-studio/autogenstudio/chatmanager.py b/samples/apps/autogen-studio/autogenstudio/chatmanager.py index a91401e6663d..88b9b56b70a3 100644 --- a/samples/apps/autogen-studio/autogenstudio/chatmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/chatmanager.py @@ -81,97 +81,3 @@ def chat( result_message.user_id = message.user_id result_message.session_id = message.session_id return result_message - - -class WebSocketConnectionManager: - """ - Manages WebSocket connections including sending, broadcasting, and managing the lifecycle of connections. - """ - - def __init__( - self, - active_connections: List[Tuple[WebSocket, str]] = None, - active_connections_lock: asyncio.Lock = None, - ) -> None: - """ - Initializes WebSocketConnectionManager with an optional list of active WebSocket connections. - - :param active_connections: A list of tuples, each containing a WebSocket object and its corresponding client_id. - """ - if active_connections is None: - active_connections = [] - self.active_connections_lock = active_connections_lock - self.active_connections: List[Tuple[WebSocket, str]] = active_connections - - async def connect(self, websocket: WebSocket, client_id: str) -> None: - """ - Accepts a new WebSocket connection and appends it to the active connections list. - - :param websocket: The WebSocket instance representing a client connection. - :param client_id: A string representing the unique identifier of the client. - """ - await websocket.accept() - async with self.active_connections_lock: - self.active_connections.append((websocket, client_id)) - print(f"New Connection: {client_id}, Total: {len(self.active_connections)}") - - async def disconnect(self, websocket: WebSocket) -> None: - """ - Disconnects and removes a WebSocket connection from the active connections list. - - :param websocket: The WebSocket instance to remove. - """ - async with self.active_connections_lock: - try: - self.active_connections = [conn for conn in self.active_connections if conn[0] != websocket] - print(f"Connection Closed. Total: {len(self.active_connections)}") - except ValueError: - print("Error: WebSocket connection not found") - - async def disconnect_all(self) -> None: - """ - Disconnects all active WebSocket connections. - """ - for connection, _ in self.active_connections[:]: - await self.disconnect(connection) - - async def send_message(self, message: Union[Dict, str], websocket: WebSocket) -> None: - """ - Sends a JSON message to a single WebSocket connection. - - :param message: A JSON serializable dictionary containing the message to send. - :param websocket: The WebSocket instance through which to send the message. - """ - try: - async with self.active_connections_lock: - await websocket.send_json(message) - except WebSocketDisconnect: - print("Error: Tried to send a message to a closed WebSocket") - await self.disconnect(websocket) - except websockets.exceptions.ConnectionClosedOK: - print("Error: WebSocket connection closed normally") - await self.disconnect(websocket) - except Exception as e: - print(f"Error in sending message: {str(e)}", message) - await self.disconnect(websocket) - - async def broadcast(self, message: Dict) -> None: - """ - Broadcasts a JSON message to all active WebSocket connections. - - :param message: A JSON serializable dictionary containing the message to broadcast. - """ - # Create a message dictionary with the desired format - message_dict = {"message": message} - - for connection, _ in self.active_connections[:]: - try: - if connection.client_state == websockets.protocol.State.OPEN: - # Call send_message method with the message dictionary and current WebSocket connection - await self.send_message(message_dict, connection) - else: - print("Error: WebSocket connection is closed") - await self.disconnect(connection) - except (WebSocketDisconnect, websockets.exceptions.ConnectionClosedOK) as e: - print(f"Error: WebSocket disconnected or closed({str(e)})") - await self.disconnect(connection) diff --git a/samples/apps/autogen-studio/autogenstudio/web/app.py b/samples/apps/autogen-studio/autogenstudio/web/app.py index 5926f6c64a14..8774b29f4a6d 100644 --- a/samples/apps/autogen-studio/autogenstudio/web/app.py +++ b/samples/apps/autogen-studio/autogenstudio/web/app.py @@ -12,7 +12,8 @@ from loguru import logger from openai import OpenAIError -from ..chatmanager import AutoGenChatManager, WebSocketConnectionManager +from ..chatmanager import AutoGenChatManager +from ..websocket_connection_manager import WebSocketConnectionManager from ..database import workflow_from_id from ..database.dbmanager import DBManager from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow diff --git a/samples/apps/autogen-studio/autogenstudio/websocket_connection_manager.py b/samples/apps/autogen-studio/autogenstudio/websocket_connection_manager.py new file mode 100644 index 000000000000..8f8691a43590 --- /dev/null +++ b/samples/apps/autogen-studio/autogenstudio/websocket_connection_manager.py @@ -0,0 +1,134 @@ +import asyncio +from typing import Any, Dict, List, Optional, Tuple, Union + +import websockets +from fastapi import WebSocket, WebSocketDisconnect + +class WebSocketConnectionManager: + """ + Manages WebSocket connections including sending, broadcasting, and managing the lifecycle of connections. + """ + + def __init__( + self, + active_connections: List[Tuple[WebSocket, str]] = None, + active_connections_lock: asyncio.Lock = None, + ) -> None: + """ + Initializes WebSocketConnectionManager with an optional list of active WebSocket connections. + + :param active_connections: A list of tuples, each containing a WebSocket object and its corresponding client_id. + """ + if active_connections is None: + active_connections = [] + self.active_connections_lock = active_connections_lock + self.active_connections: List[Tuple[WebSocket, str]] = active_connections + + async def connect(self, websocket: WebSocket, client_id: str) -> None: + """ + Accepts a new WebSocket connection and appends it to the active connections list. + + :param websocket: The WebSocket instance representing a client connection. + :param client_id: A string representing the unique identifier of the client. + """ + await websocket.accept() + async with self.active_connections_lock: + self.active_connections.append((websocket, client_id)) + print(f"New Connection: {client_id}, Total: {len(self.active_connections)}") + + async def disconnect(self, websocket: WebSocket) -> None: + """ + Disconnects and removes a WebSocket connection from the active connections list. + + :param websocket: The WebSocket instance to remove. + """ + async with self.active_connections_lock: + try: + self.active_connections = [conn for conn in self.active_connections if conn[0] != websocket] + print(f"Connection Closed. Total: {len(self.active_connections)}") + except ValueError: + print("Error: WebSocket connection not found") + + async def disconnect_all(self) -> None: + """ + Disconnects all active WebSocket connections. + """ + for connection, _ in self.active_connections[:]: + await self.disconnect(connection) + + async def send_message(self, message: Union[Dict, str], websocket: WebSocket) -> None: + """ + Sends a JSON message to a single WebSocket connection. + + :param message: A JSON serializable dictionary containing the message to send. + :param websocket: The WebSocket instance through which to send the message. + """ + try: + async with self.active_connections_lock: + await websocket.send_json(message) + except WebSocketDisconnect: + print("Error: Tried to send a message to a closed WebSocket") + await self.disconnect(websocket) + except websockets.exceptions.ConnectionClosedOK: + print("Error: WebSocket connection closed normally") + await self.disconnect(websocket) + except Exception as e: + print(f"Error in sending message: {str(e)}", message) + await self.disconnect(websocket) + + async def get_input(self, prompt: Union[Dict, str], websocket: WebSocket, timeout: int=60) -> str: + """ + Sends a JSON message to a single WebSocket connection as a prompt for user input. + Waits on a user response or until the given timeout elapses. + + :param prompt: A JSON serializable dictionary containing the message to send. + :param websocket: The WebSocket instance through which to send the message. + """ + response = "Error: Unexpected response.\nTERMINATE" + try: + async with self.active_connections_lock: + await websocket.send_json(prompt) + result = await asyncio.wait_for(websocket.receive_json(), timeout=timeout) + data = result.get("data") + if data: + response = data.get("content", "Error: Unexpected response format\nTERMINATE") + else: + response = "Error: Unexpected response format\nTERMINATE" + + except asyncio.TimeoutError: + response = f"The user was timed out after {timeout} seconds of inactivity.\nTERMINATE" + except WebSocketDisconnect: + print("Error: Tried to send a message to a closed WebSocket") + await self.disconnect(websocket) + response = "The user was disconnected\nTERMINATE" + except websockets.exceptions.ConnectionClosedOK: + print("Error: WebSocket connection closed normally") + await self.disconnect(websocket) + response = "The user was disconnected\nTERMINATE" + except Exception as e: + print(f"Error in sending message: {str(e)}", prompt) + await self.disconnect(websocket) + response = f"Error: {e}\nTERMINATE" + + return response + + async def broadcast(self, message: Dict) -> None: + """ + Broadcasts a JSON message to all active WebSocket connections. + + :param message: A JSON serializable dictionary containing the message to broadcast. + """ + # Create a message dictionary with the desired format + message_dict = {"message": message} + + for connection, _ in self.active_connections[:]: + try: + if connection.client_state == websockets.protocol.State.OPEN: + # Call send_message method with the message dictionary and current WebSocket connection + await self.send_message(message_dict, connection) + else: + print("Error: WebSocket connection is closed") + await self.disconnect(connection) + except (WebSocketDisconnect, websockets.exceptions.ConnectionClosedOK) as e: + print(f"Error: WebSocket disconnected or closed({str(e)})") + await self.disconnect(connection) From a6624d8d0457afabd4769b8d1eec418d82c791f8 Mon Sep 17 00:00:00 2001 From: Joe Landers Date: Wed, 4 Sep 2024 17:27:10 -0700 Subject: [PATCH 3/6] Update the AutoGenStudio to use Async code throughout the call stack Update *WorkflowManager* classes: - Add async `a_send_message_function` parameter to mirror `send_message_function` param. - Add async `a_process_message` coroutine to mirror the synchronous `process_message` function. - Add async `a_run` coroutine to mirror the `run` function - Add async `_a_run_workflow` coroutine to mirror the synchronous `_run_workflow` function. Update *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes: - Override the async `a_receive` coroutines Update *AutoGenChatManager*: - Add async `a_send` and `a_chat` coroutines to mirror their sync counterparts. - Accept the `WebSocketManager` instance as a parameter, so it can do Async comms directly. Update *app.py* - Provide the `WebSocketManager` instance to the *AutoGenChatManager* constructor - Await the manager's `a_chat` coroutine, rather than calling the synchronous `chat` function. --- .../autogenstudio/chatmanager.py | 79 +++++- .../autogen-studio/autogenstudio/web/app.py | 7 +- .../autogenstudio/workflowmanager.py | 258 +++++++++++++++++- 3 files changed, 330 insertions(+), 14 deletions(-) diff --git a/samples/apps/autogen-studio/autogenstudio/chatmanager.py b/samples/apps/autogen-studio/autogenstudio/chatmanager.py index 88b9b56b70a3..723e11d637d2 100644 --- a/samples/apps/autogen-studio/autogenstudio/chatmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/chatmanager.py @@ -1,15 +1,14 @@ -import asyncio import os from datetime import datetime from queue import Queue from typing import Any, Dict, List, Optional, Tuple, Union - +from loguru import logger import websockets from fastapi import WebSocket, WebSocketDisconnect from .datamodel import Message from .workflowmanager import WorkflowManager - +from .websocket_connection_manager import WebSocketConnectionManager class AutoGenChatManager: """ @@ -17,15 +16,18 @@ class AutoGenChatManager: using an automated workflow configuration and message queue. """ - def __init__(self, message_queue: Queue) -> None: + def __init__(self, + message_queue: Queue, + websocket_manager: WebSocketConnectionManager = None): """ Initializes the AutoGenChatManager with a message queue. :param message_queue: A queue to use for sending messages asynchronously. """ self.message_queue = message_queue + self.websocket_manager = websocket_manager - def send(self, message: str) -> None: + def send(self, message: dict) -> None: """ Sends a message by putting it into the message queue. @@ -34,6 +36,23 @@ def send(self, message: str) -> None: if self.message_queue is not None: self.message_queue.put_nowait(message) + async def a_send(self, message: dict) -> None: + """ + Asynchronously sends a message via the WebSocketManager class + + :param message: The message string to be sent. + """ + for connection, socket_client_id in self.websocket_manager.active_connections: + if message["connection_id"] == socket_client_id: + logger.info( + f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}" + ) + await self.websocket_manager.send_message(message, connection) + else: + logger.info( + f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}" + ) + def chat( self, message: Message, @@ -72,6 +91,7 @@ def chat( history=history, work_dir=work_dir, send_message_function=self.send, + a_send_message_function=self.a_send, connection_id=connection_id, ) @@ -81,3 +101,52 @@ def chat( result_message.user_id = message.user_id result_message.session_id = message.session_id return result_message + + async def a_chat( + self, + message: Message, + history: List[Dict[str, Any]], + workflow: Any = None, + connection_id: Optional[str] = None, + user_dir: Optional[str] = None, + **kwargs, + ) -> Message: + """ + Processes an incoming message according to the agent's workflow configuration + and generates a response. + + :param message: An instance of `Message` representing an incoming message. + :param history: A list of dictionaries, each representing a past interaction. + :param flow_config: An instance of `AgentWorkFlowConfig`. If None, defaults to a standard configuration. + :param connection_id: An optional connection identifier. + :param kwargs: Additional keyword arguments. + :return: An instance of `Message` representing a response. + """ + + # create a working director for workflow based on user_dir/session_id/time_hash + work_dir = os.path.join( + user_dir, + str(message.session_id), + datetime.now().strftime("%Y%m%d_%H-%M-%S"), + ) + os.makedirs(work_dir, exist_ok=True) + + # if no flow config is provided, use the default + if workflow is None: + raise ValueError("Workflow must be specified") + + workflow_manager = WorkflowManager( + workflow=workflow, + history=history, + work_dir=work_dir, + send_message_function=self.send, + a_send_message_function=self.a_send, + connection_id=connection_id, + ) + + message_text = message.content.strip() + result_message: Message = await workflow_manager.a_run(message=f"{message_text}", clear_history=False, history=history) + + result_message.user_id = message.user_id + result_message.session_id = message.session_id + return result_message diff --git a/samples/apps/autogen-studio/autogenstudio/web/app.py b/samples/apps/autogen-studio/autogenstudio/web/app.py index 8774b29f4a6d..9db32bb360fa 100644 --- a/samples/apps/autogen-studio/autogenstudio/web/app.py +++ b/samples/apps/autogen-studio/autogenstudio/web/app.py @@ -4,7 +4,7 @@ import threading import traceback from contextlib import asynccontextmanager -from typing import Any, Union +from typing import Any, Coroutine from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -69,7 +69,8 @@ def message_handler(): @asynccontextmanager async def lifespan(app: FastAPI): print("***** App started *****") - managers["chat"] = AutoGenChatManager(message_queue=message_queue) + managers["chat"] = AutoGenChatManager(message_queue=message_queue, + websocket_manager=websocket_manager) dbmanager.create_db_and_tables() yield @@ -450,7 +451,7 @@ async def run_session_workflow(message: Message, session_id: int, workflow_id: i user_dir = os.path.join(folders["files_static_root"], "user", md5_hash(message.user_id)) os.makedirs(user_dir, exist_ok=True) workflow = workflow_from_id(workflow_id, dbmanager=dbmanager) - agent_response: Message = managers["chat"].chat( + agent_response: Message = await managers["chat"].a_chat( message=message, history=user_message_history, user_dir=user_dir, diff --git a/samples/apps/autogen-studio/autogenstudio/workflowmanager.py b/samples/apps/autogen-studio/autogenstudio/workflowmanager.py index f5065e85e5c8..3c76fa8b361c 100644 --- a/samples/apps/autogen-studio/autogenstudio/workflowmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/workflowmanager.py @@ -2,7 +2,7 @@ import os import time from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Coroutine import autogen @@ -40,6 +40,7 @@ def __init__( work_dir: str = None, clear_work_dir: bool = True, send_message_function: Optional[callable] = None, + a_send_message_function: Optional[Coroutine] = None, connection_id: Optional[str] = None, ) -> None: """ @@ -51,6 +52,7 @@ def __init__( work_dir (str): The working directory. clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. + a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -67,6 +69,7 @@ def __init__( # TODO - improved typing for workflow self.workflow_skills = [] self.send_message_function = send_message_function + self.a_send_message_function = a_send_message_function self.connection_id = connection_id self.work_dir = work_dir or "work_dir" self.code_executor_pool = { @@ -112,6 +115,34 @@ def _run_workflow(self, message: str, history: Optional[List[Message]] = None, c else: raise ValueError("Sender and receiver agents are not defined in the workflow configuration.") + async def _a_run_workflow(self, message: str, history: Optional[List[Message]] = None, clear_history: bool = False) -> None: + """ + Asynchronously runs the workflow based on the provided configuration. + + Args: + message: The initial message to start the chat. + history: A list of messages to populate the agents' history. + clear_history: If set to True, clears the chat history before initiating. + + """ + for agent in self.workflow.get("agents", []): + if agent.get("link").get("agent_type") == "sender": + self.sender = self.load(agent.get("agent")) + elif agent.get("link").get("agent_type") == "receiver": + self.receiver = self.load(agent.get("agent")) + if self.sender and self.receiver: + # save all agent skills to skills.py + save_skills_to_file(self.workflow_skills, self.work_dir) + if history: + self._populate_history(history) + await self.sender.a_initiate_chat( + self.receiver, + message=message, + clear_history=clear_history, + ) + else: + raise ValueError("Sender and receiver agents are not defined in the workflow configuration.") + def _serialize_agent( self, agent: Agent, @@ -182,7 +213,9 @@ def process_message( "connection_id": self.connection_id, "message_type": "agent_message", } - # if the agent will respond to the message, or the message is sent by a groupchat agent. This avoids adding groupchat broadcast messages to the history (which are sent with request_reply=False), or when agent populated from history + # if the agent will respond to the message, or the message is sent by a groupchat agent. + # This avoids adding groupchat broadcast messages to the history (which are sent with request_reply=False), + # or when agent populated from history if request_reply is not False or sender_type == "groupchat": self.agent_history.append(message_payload) # add to history if self.send_message_function: # send over the message queue @@ -193,6 +226,54 @@ def process_message( ) self.send_message_function(socket_msg.dict()) + async def a_process_message( + self, + sender: autogen.Agent, + receiver: autogen.Agent, + message: Dict, + request_reply: bool = False, + silent: bool = False, + sender_type: str = "agent", + ) -> None: + """ + Asynchronously processes the message and adds it to the agent history. + + Args: + + sender: The sender of the message. + receiver: The receiver of the message. + message: The message content. + request_reply: If set to True, the message will be added to agent history. + silent: determining verbosity. + sender_type: The type of the sender of the message. + """ + + message = message if isinstance(message, dict) else {"content": message, "role": "user"} + message_payload = { + "recipient": receiver.name, + "sender": sender.name, + "message": message, + "timestamp": datetime.now().isoformat(), + "sender_type": sender_type, + "connection_id": self.connection_id, + "message_type": "agent_message", + } + # if the agent will respond to the message, or the message is sent by a groupchat agent. + # This avoids adding groupchat broadcast messages to the history (which are sent with request_reply=False), + # or when agent populated from history + if request_reply is not False or sender_type == "groupchat": + self.agent_history.append(message_payload) # add to history + socket_msg = SocketMessage( + type="agent_message", + data=message_payload, + connection_id=self.connection_id, + ) + if self.a_send_message_function: # send over the message queue + await self.a_send_message_function(socket_msg.dict()) + elif self.send_message_function: # send over the message queue + self.send_message_function(socket_msg.dict()) + + def _populate_history(self, history: List[Message]) -> None: """ Populates the agent message history from the provided list of messages. @@ -284,6 +365,7 @@ def load(self, agent: Any) -> autogen.Agent: agent = ExtendedGroupChatManager( groupchat=groupchat, message_processor=self.process_message, + a_message_processor=self.a_process_message, llm_config=agent.config.llm_config.model_dump(), ) return agent @@ -293,11 +375,13 @@ def load(self, agent: Any) -> autogen.Agent: agent = ExtendedConversableAgent( **self._serialize_agent(agent), message_processor=self.process_message, + a_message_processor=self.a_process_message, ) elif agent.type == "userproxy": agent = ExtendedConversableAgent( **self._serialize_agent(agent), message_processor=self.process_message, + a_message_processor=self.a_process_message, ) else: raise ValueError(f"Unknown agent type: {agent.type}") @@ -409,6 +493,37 @@ def run(self, message: str, history: Optional[List[Message]] = None, clear_histo ) return result_message + async def a_run(self, message: str, history: Optional[List[Message]] = None, clear_history: bool = False) -> Message: + """ + Asynchronously initiates a chat between the sender and receiver agents with an initial message + and an option to clear the history. + + Args: + message: The initial message to start the chat. + clear_history: If set to True, clears the chat history before initiating. + """ + + start_time = time.time() + await self._a_run_workflow(message=message, history=history, clear_history=clear_history) + end_time = time.time() + + output = self._generate_output(message, self.workflow.get("summary_method", "last")) + + usage = self._get_usage_summary() + # print("usage", usage) + + result_message = Message( + content=output, + role="assistant", + meta={ + "messages": self.agent_history, + "summary_method": self.workflow.get("summary_method", "last"), + "time": end_time - start_time, + "files": get_modified_files(start_time, end_time, source_dir=self.work_dir), + "usage": usage, + }, + ) + return result_message class SequentialWorkflowManager: """ @@ -422,6 +537,7 @@ def __init__( work_dir: str = None, clear_work_dir: bool = True, send_message_function: Optional[callable] = None, + a_send_message_function: Optional[Coroutine] = None, connection_id: Optional[str] = None, ) -> None: """ @@ -433,6 +549,7 @@ def __init__( work_dir (str): The working directory. clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. + a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -448,6 +565,7 @@ def __init__( # TODO - improved typing for workflow self.send_message_function = send_message_function + self.a_send_message_function = a_send_message_function self.connection_id = connection_id self.work_dir = work_dir or "work_dir" if clear_work_dir: @@ -498,6 +616,7 @@ def _run_workflow(self, message: str, history: Optional[List[Message]] = None, c work_dir=self.work_dir, clear_work_dir=True, send_message_function=self.send_message_function, + a_send_message_function=self.a_send_message_function, connection_id=self.connection_id, ) task_prompt = ( @@ -519,6 +638,68 @@ def _run_workflow(self, message: str, history: Optional[List[Message]] = None, c print(f"======== end of sequence === {i}============") self.agent_history.extend(result.meta.get("messages", [])) + async def _a_run_workflow(self, message: str, history: Optional[List[Message]] = None, clear_history: bool = False) -> None: + """ + Asynchronously runs the workflow based on the provided configuration. + + Args: + message: The initial message to start the chat. + history: A list of messages to populate the agents' history. + clear_history: If set to True, clears the chat history before initiating. + + """ + user_proxy = { + "config": { + "name": "user_proxy", + "human_input_mode": "NEVER", + "max_consecutive_auto_reply": 25, + "code_execution_config": "local", + "default_auto_reply": "TERMINATE", + "description": "User Proxy Agent Configuration", + "llm_config": False, + "type": "userproxy", + } + } + sequential_history = [] + for i, agent in enumerate(self.workflow.get("agents", [])): + workflow = Workflow( + name="agent workflow", type=WorkFlowType.autonomous, summary_method=WorkFlowSummaryMethod.llm + ) + workflow = workflow.model_dump(mode="json") + agent = agent.get("agent") + workflow["agents"] = [ + {"agent": user_proxy, "link": {"agent_type": "sender"}}, + {"agent": agent, "link": {"agent_type": "receiver"}}, + ] + + auto_workflow = AutoWorkflowManager( + workflow=workflow, + history=history, + work_dir=self.work_dir, + clear_work_dir=True, + send_message_function=self.send_message_function, + a_send_message_function=self.a_send_message_function, + connection_id=self.connection_id, + ) + task_prompt = ( + f""" + Your primary instructions are as follows: + {agent.get("task_instruction")} + Context for addressing your task is below: + ======= + {str(sequential_history)} + ======= + Now address your task: + """ + if i > 0 + else message + ) + result = await auto_workflow.a_run(message=task_prompt, clear_history=clear_history) + sequential_history.append(result.content) + self.model_client = auto_workflow.receiver.client + print(f"======== end of sequence === {i}============") + self.agent_history.extend(result.meta.get("messages", [])) + def _generate_output( self, message_text: str, @@ -587,6 +768,34 @@ def run(self, message: str, history: Optional[List[Message]] = None, clear_histo ) return result_message + async def a_run(self, message: str, history: Optional[List[Message]] = None, clear_history: bool = False) -> Message: + """ + Asynchronously initiates a chat between the sender and receiver agents with an initial message + and an option to clear the history. + + Args: + message: The initial message to start the chat. + clear_history: If set to True, clears the chat history before initiating. + """ + + start_time = time.time() + await self._a_run_workflow(message=message, history=history, clear_history=clear_history) + end_time = time.time() + output = self._generate_output(message, self.workflow.get("summary_method", "last")) + + result_message = Message( + content=output, + role="assistant", + meta={ + "messages": self.agent_history, + "summary_method": self.workflow.get("summary_method", "last"), + "time": end_time - start_time, + "files": get_modified_files(start_time, end_time, source_dir=self.work_dir), + "task": message, + }, + ) + return result_message + class WorkflowManager: """ @@ -600,6 +809,7 @@ def __new__( work_dir: str = None, clear_work_dir: bool = True, send_message_function: Optional[callable] = None, + a_send_message_function: Optional[Coroutine] = None, connection_id: Optional[str] = None, ) -> None: """ @@ -611,6 +821,7 @@ def __new__( work_dir (str): The working directory. clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. + a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -631,6 +842,7 @@ def __new__( work_dir=work_dir, clear_work_dir=clear_work_dir, send_message_function=send_message_function, + a_send_message_function=a_send_message_function, connection_id=connection_id, ) elif self.workflow.get("type") == WorkFlowType.sequential.value: @@ -645,9 +857,14 @@ def __new__( class ExtendedConversableAgent(autogen.ConversableAgent): - def __init__(self, message_processor=None, *args, **kwargs): + def __init__(self, + message_processor=None, + a_message_processor=None, + *args, **kwargs): + super().__init__(*args, **kwargs) self.message_processor = message_processor + self.a_message_processor = a_message_processor def receive( self, @@ -660,14 +877,28 @@ def receive( self.message_processor(sender, self, message, request_reply, silent, sender_type="agent") super().receive(message, sender, request_reply, silent) - -"" + async def a_receive( + self, + message: Union[Dict, str], + sender: autogen.Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ) -> None: + if self.a_message_processor: + await self.a_message_processor(sender, self, message, request_reply, silent, sender_type="agent") + elif self.message_processor: + self.message_processor(sender, self, message, request_reply, silent, sender_type="agent") + await super().a_receive(message, sender, request_reply, silent) class ExtendedGroupChatManager(autogen.GroupChatManager): - def __init__(self, message_processor=None, *args, **kwargs): + def __init__(self, + message_processor=None, + a_message_processor=None, + *args, **kwargs): super().__init__(*args, **kwargs) self.message_processor = message_processor + self.a_message_processor = a_message_processor def receive( self, @@ -679,3 +910,18 @@ def receive( if self.message_processor: self.message_processor(sender, self, message, request_reply, silent, sender_type="groupchat") super().receive(message, sender, request_reply, silent) + + async def a_receive( + self, + message: Union[Dict, str], + sender: autogen.Agent, + request_reply: Optional[bool] = None, + silent: Optional[bool] = False, + ) -> None: + if self.a_message_processor: + await self.a_message_processor(sender, self, message, request_reply, silent, sender_type="agent") + elif self.message_processor: + self.message_processor(sender, self, message, request_reply, silent, sender_type="agent") + await super().a_receive(message, sender, request_reply, silent) + + From 330262b1b36cf464ec7d570eb66460a00bc2585c Mon Sep 17 00:00:00 2001 From: Joe Landers Date: Wed, 4 Sep 2024 17:55:03 -0700 Subject: [PATCH 4/6] Add Human Input Support Updates to *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes - override the `get_human_input` function and async `a_get_human_input` coroutine Updates to *WorkflowManager* classes: - add parameters `a_human_input_function` and `a_human_input_timeout` and pass along on to the ExtendedConversableAgent and ExtendedGroupChatManager - fix for invalid configuration passed from UI when human input mode is not NEVER and no model is attached Updates to *AutoGenChatManager* class: - add parameter `human_input_timeout` and pass it along to *WorkflowManager* classes - add async `a_prompt_for_input` coroutine that relies on `websocket_manager.get_input` coroutine (which snuck into last commit) Updates to *App.py* - global var HUMAN_INPUT_TIMEOUT_SECONDS = 180, we can replace this with a configurable value in the future --- .../autogenstudio/chatmanager.py | 31 +++- .../autogen-studio/autogenstudio/web/app.py | 6 +- .../autogenstudio/workflowmanager.py | 133 ++++++++++++++++++ .../frontend/src/components/atoms.tsx | 46 +++--- .../views/builder/utils/agentconfig.tsx | 4 +- .../components/views/playground/chatbox.tsx | 122 ++++++++++++---- .../components/views/playground/sessions.tsx | 14 +- .../frontend/src/hooks/store.tsx | 4 + 8 files changed, 310 insertions(+), 50 deletions(-) diff --git a/samples/apps/autogen-studio/autogenstudio/chatmanager.py b/samples/apps/autogen-studio/autogenstudio/chatmanager.py index 723e11d637d2..dd433ae576e8 100644 --- a/samples/apps/autogen-studio/autogenstudio/chatmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/chatmanager.py @@ -3,8 +3,6 @@ from queue import Queue from typing import Any, Dict, List, Optional, Tuple, Union from loguru import logger -import websockets -from fastapi import WebSocket, WebSocketDisconnect from .datamodel import Message from .workflowmanager import WorkflowManager @@ -18,7 +16,8 @@ class AutoGenChatManager: def __init__(self, message_queue: Queue, - websocket_manager: WebSocketConnectionManager = None): + websocket_manager: WebSocketConnectionManager = None, + human_input_timeout: int = 180) -> None: """ Initializes the AutoGenChatManager with a message queue. @@ -26,6 +25,7 @@ def __init__(self, """ self.message_queue = message_queue self.websocket_manager = websocket_manager + self.a_human_input_timeout = human_input_timeout def send(self, message: dict) -> None: """ @@ -53,6 +53,29 @@ async def a_send(self, message: dict) -> None: f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}" ) + async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str: + """ + Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class + + :param message: The message string to be sent. + """ + + for connection, socket_client_id in self.websocket_manager.active_connections: + if prompt["connection_id"] == socket_client_id: + logger.info( + f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}" + ) + try: + result = await self.websocket_manager.get_input(prompt, connection, timeout) + return result + except Exception as e: + traceback.print_exc() + return f"Error: {e}\nTERMINATE" + else: + logger.info( + f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}" + ) + def chat( self, message: Message, @@ -141,6 +164,8 @@ async def a_chat( work_dir=work_dir, send_message_function=self.send, a_send_message_function=self.a_send, + a_human_input_function=self.a_prompt_for_input, + a_human_input_timeout=self.a_human_input_timeout, connection_id=connection_id, ) diff --git a/samples/apps/autogen-studio/autogenstudio/web/app.py b/samples/apps/autogen-studio/autogenstudio/web/app.py index 9db32bb360fa..d7db1b85a9b0 100644 --- a/samples/apps/autogen-studio/autogenstudio/web/app.py +++ b/samples/apps/autogen-studio/autogenstudio/web/app.py @@ -4,7 +4,7 @@ import threading import traceback from contextlib import asynccontextmanager -from typing import Any, Coroutine +from typing import Any, Union from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -65,12 +65,14 @@ def message_handler(): database_engine_uri = folders["database_engine_uri"] dbmanager = DBManager(engine_uri=database_engine_uri) +HUMAN_INPUT_TIMEOUT_SECONDS = 180 @asynccontextmanager async def lifespan(app: FastAPI): print("***** App started *****") managers["chat"] = AutoGenChatManager(message_queue=message_queue, - websocket_manager=websocket_manager) + websocket_manager=websocket_manager, + human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS) dbmanager.create_db_and_tables() yield diff --git a/samples/apps/autogen-studio/autogenstudio/workflowmanager.py b/samples/apps/autogen-studio/autogenstudio/workflowmanager.py index 3c76fa8b361c..fe3d698de5a9 100644 --- a/samples/apps/autogen-studio/autogenstudio/workflowmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/workflowmanager.py @@ -41,6 +41,8 @@ def __init__( clear_work_dir: bool = True, send_message_function: Optional[callable] = None, a_send_message_function: Optional[Coroutine] = None, + a_human_input_function: Optional[callable] = None, + a_human_input_timeout: Optional[int] = 60, connection_id: Optional[str] = None, ) -> None: """ @@ -53,6 +55,8 @@ def __init__( clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. + a_human_input_function (Optional[callable]): Async coroutine to prompt the user for input. + a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -70,6 +74,8 @@ def __init__( self.workflow_skills = [] self.send_message_function = send_message_function self.a_send_message_function = a_send_message_function + self.a_human_input_function = a_human_input_function + self.a_human_input_timeout = a_human_input_timeout self.connection_id = connection_id self.work_dir = work_dir or "work_dir" self.code_executor_pool = { @@ -303,6 +309,12 @@ def sanitize_agent(self, agent: Dict) -> Agent: """ """ skills = agent.get("skills", []) + + # When human input mode is not NEVER and no model is attached, the ui is passing bogus llm_config. + configured_models = agent.get("models") + if not configured_models or len(configured_models) == 0: + agent["config"]["llm_config"] = False + agent = Agent.model_validate(agent) agent.config.is_termination_msg = agent.config.is_termination_msg or ( lambda x: "TERMINATE" in x.get("content", "").rstrip()[-20:] @@ -366,6 +378,9 @@ def load(self, agent: Any) -> autogen.Agent: groupchat=groupchat, message_processor=self.process_message, a_message_processor=self.a_process_message, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, + connection_id=self.connection_id, llm_config=agent.config.llm_config.model_dump(), ) return agent @@ -376,12 +391,18 @@ def load(self, agent: Any) -> autogen.Agent: **self._serialize_agent(agent), message_processor=self.process_message, a_message_processor=self.a_process_message, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, + connection_id=self.connection_id, ) elif agent.type == "userproxy": agent = ExtendedConversableAgent( **self._serialize_agent(agent), message_processor=self.process_message, a_message_processor=self.a_process_message, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, + connection_id=self.connection_id, ) else: raise ValueError(f"Unknown agent type: {agent.type}") @@ -538,6 +559,8 @@ def __init__( clear_work_dir: bool = True, send_message_function: Optional[callable] = None, a_send_message_function: Optional[Coroutine] = None, + a_human_input_function: Optional[callable] = None, + a_human_input_timeout: Optional[int] = 60, connection_id: Optional[str] = None, ) -> None: """ @@ -550,6 +573,8 @@ def __init__( clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. + a_human_input_function (Optional[callable]): Async coroutine to prompt for human input. + a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -566,6 +591,8 @@ def __init__( # TODO - improved typing for workflow self.send_message_function = send_message_function self.a_send_message_function = a_send_message_function + self.a_human_input_function = a_human_input_function + self.a_human_input_timeout = a_human_input_timeout self.connection_id = connection_id self.work_dir = work_dir or "work_dir" if clear_work_dir: @@ -617,6 +644,7 @@ def _run_workflow(self, message: str, history: Optional[List[Message]] = None, c clear_work_dir=True, send_message_function=self.send_message_function, a_send_message_function=self.a_send_message_function, + a_human_input_timeout=self.a_human_input_timeout, connection_id=self.connection_id, ) task_prompt = ( @@ -679,6 +707,8 @@ async def _a_run_workflow(self, message: str, history: Optional[List[Message]] = clear_work_dir=True, send_message_function=self.send_message_function, a_send_message_function=self.a_send_message_function, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, connection_id=self.connection_id, ) task_prompt = ( @@ -810,6 +840,8 @@ def __new__( clear_work_dir: bool = True, send_message_function: Optional[callable] = None, a_send_message_function: Optional[Coroutine] = None, + a_human_input_function: Optional[callable] = None, + a_human_input_timeout: Optional[int] = 60, connection_id: Optional[str] = None, ) -> None: """ @@ -822,6 +854,8 @@ def __new__( clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. + a_human_input_function (Optional[callable]): Async coroutine to prompt for user input. + a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -843,6 +877,8 @@ def __new__( clear_work_dir=clear_work_dir, send_message_function=send_message_function, a_send_message_function=a_send_message_function, + a_human_input_function=a_human_input_function, + a_human_input_timeout=a_human_input_timeout, connection_id=connection_id, ) elif self.workflow.get("type") == WorkFlowType.sequential.value: @@ -852,6 +888,9 @@ def __new__( work_dir=work_dir, clear_work_dir=clear_work_dir, send_message_function=send_message_function, + a_send_message_function=a_send_message_function, + a_human_input_function=a_human_input_function, + a_human_input_timeout=a_human_input_timeout, connection_id=connection_id, ) @@ -860,11 +899,18 @@ class ExtendedConversableAgent(autogen.ConversableAgent): def __init__(self, message_processor=None, a_message_processor=None, + a_human_input_function=None, + a_human_input_timeout: Optional[int] = 60, + connection_id=None, *args, **kwargs): super().__init__(*args, **kwargs) self.message_processor = message_processor self.a_message_processor = a_message_processor + self.a_human_input_function = a_human_input_function + self.a_human_input_response = None + self.a_human_input_timeout = a_human_input_timeout + self.connection_id = connection_id def receive( self, @@ -891,14 +937,65 @@ async def a_receive( await super().a_receive(message, sender, request_reply, silent) + # Strangely, when the response from a_get_human_input == "" (empty string) the libs call into the + # sync version. I guess that's "just in case", but it's odd because replying with an empty string + # is the intended way for the user to signal the underlying libs that they want to system to go forward + # with whatever funciton call, tool call or AI genrated response the request calls for. Oh well, + # Que Sera Sera. + def get_human_input(self, prompt: str) -> str: + if self.a_human_input_response == None: + return super().get_human_input(prompt) + else: + response = self.a_human_input_response + self.a_human_input_response = None + return response + + async def a_get_human_input(self, prompt: str) -> str: + if self.message_processor and self.a_human_input_function: + message_dict = { + "content": prompt, + "role": "system", + "type": "user-input-request" + } + + message_payload = { + "recipient": self.name, + "sender": "system", + "message": message_dict, + "timestamp": datetime.now().isoformat(), + "sender_type": "system", + "connection_id": self.connection_id, + "message_type": "agent_message" + } + + socket_msg = SocketMessage( + type="user_input_request", + data=message_payload, + connection_id=self.connection_id, + ) + self.a_human_input_response = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout) + return self.a_human_input_response + + else: + result = await super().a_get_human_input(prompt) + return result + + class ExtendedGroupChatManager(autogen.GroupChatManager): def __init__(self, message_processor=None, a_message_processor=None, + a_human_input_function=None, + a_human_input_timeout: Optional[int] = 60, + connection_id=None, *args, **kwargs): super().__init__(*args, **kwargs) self.message_processor = message_processor self.a_message_processor = a_message_processor + self.a_human_input_function = a_human_input_function + self.a_human_input_response = None + self.a_human_input_timeout = a_human_input_timeout + self.connection_id = connection_id def receive( self, @@ -925,3 +1022,39 @@ async def a_receive( await super().a_receive(message, sender, request_reply, silent) + def get_human_input(self, prompt: str) -> str: + if self.a_human_input_response == None: + return super().get_human_input(prompt) + else: + response = self.a_human_input_response + self.a_human_input_response = None + return response + + async def a_get_human_input(self, prompt: str) -> str: + if self.message_processor and self.a_human_input_function: + message_dict = { + "content": prompt, + "role": "system", + "type": "user-input-request" + } + + message_payload = { + "recipient": self.name, + "sender": "system", + "message": message_dict, + "timestamp": datetime.now().isoformat(), + "sender_type": "system", + "connection_id": self.connection_id, + "message_type": "agent_message" + } + socket_msg = SocketMessage( + type="user_input_request", + data=message_payload, + connection_id=self.connection_id, + ) + result = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout) + return result + + else: + result = await super().a_get_human_input(prompt) + return result diff --git a/samples/apps/autogen-studio/frontend/src/components/atoms.tsx b/samples/apps/autogen-studio/frontend/src/components/atoms.tsx index a0864153f5ac..8f52e60281b7 100644 --- a/samples/apps/autogen-studio/frontend/src/components/atoms.tsx +++ b/samples/apps/autogen-studio/frontend/src/components/atoms.tsx @@ -49,7 +49,7 @@ export const SectionHeader = ({ icon, }: IProps) => { return ( -
+

{/* {count !== null && {count}} */} {icon && <>{icon}} @@ -72,6 +72,7 @@ export const IconButton = ({ }: IProps) => { return ( { return (