Skip to content

Commit

Permalink
Support for Human Input Mode in AutoGen Studio (#3484)
Browse files Browse the repository at this point in the history
* bump version, add claude default model

* Move WebSocketConnectionManager into its own file.

* Update the AutoGenStudio to use Async code throughout the call stack

Update *WorkflowManager* classes:
- Add async `a_send_message_function` parameter to mirror `send_message_function` param.
- Add async `a_process_message` coroutine to mirror the synchronous `process_message` function.
- Add async `a_run` coroutine to mirror the `run` function
- Add async `_a_run_workflow` coroutine to mirror the synchronous `_run_workflow` function.

Update *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes:
- Override the async `a_receive` coroutines

Update *AutoGenChatManager*:
- Add async `a_send` and `a_chat` coroutines to mirror their sync counterparts.
- Accept the `WebSocketManager` instance as a parameter, so it can do Async comms directly.

Update *app.py*
- Provide the `WebSocketManager` instance to the *AutoGenChatManager* constructor
- Await the manager's `a_chat` coroutine, rather than calling the synchronous `chat` function.

* Add Human Input Support

Updates to *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes
- override the `get_human_input` function and async `a_get_human_input` coroutine

Updates to *WorkflowManager* classes:
- add parameters `a_human_input_function` and `a_human_input_timeout` and pass along on to the ExtendedConversableAgent and ExtendedGroupChatManager
- fix for invalid configuration passed from UI when human input mode is not NEVER and no model is attached

Updates to *AutoGenChatManager* class:
- add parameter `human_input_timeout` and pass it along to *WorkflowManager* classes
- add async `a_prompt_for_input` coroutine that relies on `websocket_manager.get_input` coroutine (which snuck into last commit)

Updates to *App.py*
- global var HUMAN_INPUT_TIMEOUT_SECONDS = 180, we can replace this with a configurable value in the future

* add formatting/precommit fixes

* version bump

---------

Co-authored-by: Joe Landers <[email protected]>
  • Loading branch information
victordibia and SailorJoe6 authored Sep 8, 2024
1 parent 70a1791 commit 084a54d
Show file tree
Hide file tree
Showing 11 changed files with 780 additions and 143 deletions.
180 changes: 91 additions & 89 deletions samples/apps/autogen-studio/autogenstudio/chatmanager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import os
from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import websockets
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger

from .datamodel import Message
from .websocket_connection_manager import WebSocketConnectionManager
from .workflowmanager import WorkflowManager


Expand All @@ -17,15 +16,19 @@ class AutoGenChatManager:
using an automated workflow configuration and message queue.
"""

def __init__(self, message_queue: Queue) -> None:
def __init__(
self, message_queue: Queue, websocket_manager: WebSocketConnectionManager = None, human_input_timeout: int = 180
) -> None:
"""
Initializes the AutoGenChatManager with a message queue.
:param message_queue: A queue to use for sending messages asynchronously.
"""
self.message_queue = message_queue
self.websocket_manager = websocket_manager
self.a_human_input_timeout = human_input_timeout

def send(self, message: str) -> None:
def send(self, message: dict) -> None:
"""
Sends a message by putting it into the message queue.
Expand All @@ -34,6 +37,45 @@ def send(self, message: str) -> None:
if self.message_queue is not None:
self.message_queue.put_nowait(message)

async def a_send(self, message: dict) -> None:
"""
Asynchronously sends a message via the WebSocketManager class
:param message: The message string to be sent.
"""
for connection, socket_client_id in self.websocket_manager.active_connections:
if message["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)
await self.websocket_manager.send_message(message, connection)
else:
logger.info(
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)

async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str:
"""
Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class
:param message: The message string to be sent.
"""

for connection, socket_client_id in self.websocket_manager.active_connections:
if prompt["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)
try:
result = await self.websocket_manager.get_input(prompt, connection, timeout)
return result
except Exception as e:
return f"Error: {e}\nTERMINATE"
else:
logger.info(
f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)

def chat(
self,
message: Message,
Expand Down Expand Up @@ -72,6 +114,7 @@ def chat(
history=history,
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
connection_id=connection_id,
)

Expand All @@ -82,96 +125,55 @@ def chat(
result_message.session_id = message.session_id
return result_message


class WebSocketConnectionManager:
"""
Manages WebSocket connections including sending, broadcasting, and managing the lifecycle of connections.
"""

def __init__(
async def a_chat(
self,
active_connections: List[Tuple[WebSocket, str]] = None,
active_connections_lock: asyncio.Lock = None,
) -> None:
"""
Initializes WebSocketConnectionManager with an optional list of active WebSocket connections.
:param active_connections: A list of tuples, each containing a WebSocket object and its corresponding client_id.
"""
if active_connections is None:
active_connections = []
self.active_connections_lock = active_connections_lock
self.active_connections: List[Tuple[WebSocket, str]] = active_connections

async def connect(self, websocket: WebSocket, client_id: str) -> None:
message: Message,
history: List[Dict[str, Any]],
workflow: Any = None,
connection_id: Optional[str] = None,
user_dir: Optional[str] = None,
**kwargs,
) -> Message:
"""
Accepts a new WebSocket connection and appends it to the active connections list.
Processes an incoming message according to the agent's workflow configuration
and generates a response.
:param websocket: The WebSocket instance representing a client connection.
:param client_id: A string representing the unique identifier of the client.
:param message: An instance of `Message` representing an incoming message.
:param history: A list of dictionaries, each representing a past interaction.
:param flow_config: An instance of `AgentWorkFlowConfig`. If None, defaults to a standard configuration.
:param connection_id: An optional connection identifier.
:param kwargs: Additional keyword arguments.
:return: An instance of `Message` representing a response.
"""
await websocket.accept()
async with self.active_connections_lock:
self.active_connections.append((websocket, client_id))
print(f"New Connection: {client_id}, Total: {len(self.active_connections)}")

async def disconnect(self, websocket: WebSocket) -> None:
"""
Disconnects and removes a WebSocket connection from the active connections list.
# create a working director for workflow based on user_dir/session_id/time_hash
work_dir = os.path.join(
user_dir,
str(message.session_id),
datetime.now().strftime("%Y%m%d_%H-%M-%S"),
)
os.makedirs(work_dir, exist_ok=True)

:param websocket: The WebSocket instance to remove.
"""
async with self.active_connections_lock:
try:
self.active_connections = [conn for conn in self.active_connections if conn[0] != websocket]
print(f"Connection Closed. Total: {len(self.active_connections)}")
except ValueError:
print("Error: WebSocket connection not found")

async def disconnect_all(self) -> None:
"""
Disconnects all active WebSocket connections.
"""
for connection, _ in self.active_connections[:]:
await self.disconnect(connection)
# if no flow config is provided, use the default
if workflow is None:
raise ValueError("Workflow must be specified")

async def send_message(self, message: Union[Dict, str], websocket: WebSocket) -> None:
"""
Sends a JSON message to a single WebSocket connection.
workflow_manager = WorkflowManager(
workflow=workflow,
history=history,
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
a_human_input_function=self.a_prompt_for_input,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=connection_id,
)

:param message: A JSON serializable dictionary containing the message to send.
:param websocket: The WebSocket instance through which to send the message.
"""
try:
async with self.active_connections_lock:
await websocket.send_json(message)
except WebSocketDisconnect:
print("Error: Tried to send a message to a closed WebSocket")
await self.disconnect(websocket)
except websockets.exceptions.ConnectionClosedOK:
print("Error: WebSocket connection closed normally")
await self.disconnect(websocket)
except Exception as e:
print(f"Error in sending message: {str(e)}", message)
await self.disconnect(websocket)

async def broadcast(self, message: Dict) -> None:
"""
Broadcasts a JSON message to all active WebSocket connections.
message_text = message.content.strip()
result_message: Message = await workflow_manager.a_run(
message=f"{message_text}", clear_history=False, history=history
)

:param message: A JSON serializable dictionary containing the message to broadcast.
"""
# Create a message dictionary with the desired format
message_dict = {"message": message}

for connection, _ in self.active_connections[:]:
try:
if connection.client_state == websockets.protocol.State.OPEN:
# Call send_message method with the message dictionary and current WebSocket connection
await self.send_message(message_dict, connection)
else:
print("Error: WebSocket connection is closed")
await self.disconnect(connection)
except (WebSocketDisconnect, websockets.exceptions.ConnectionClosedOK) as e:
print(f"Error: WebSocket disconnected or closed({str(e)})")
await self.disconnect(connection)
result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message
8 changes: 8 additions & 0 deletions samples/apps/autogen-studio/autogenstudio/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def init_db_samples(dbmanager: Any):
model="gpt-4-1106-preview", description="OpenAI GPT-4 model", user_id="[email protected]", api_type="open_ai"
)

anthropic_sonnet_model = Model(
model="claude-3-5-sonnet-20240620",
description="Anthropic's Claude 3.5 Sonnet model",
api_type="anthropic",
user_id="[email protected]",
)

# skills
generate_pdf_skill = Skill(
name="generate_and_save_pdf",
Expand Down Expand Up @@ -303,6 +310,7 @@ def init_db_samples(dbmanager: Any):
session.add(google_gemini_model)
session.add(azure_model)
session.add(gpt_4_model)
session.add(anthropic_sonnet_model)
session.add(generate_image_skill)
session.add(generate_pdf_skill)
session.add(user_proxy)
Expand Down
2 changes: 1 addition & 1 deletion samples/apps/autogen-studio/autogenstudio/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = "0.1.4"
VERSION = "0.1.6"
__version__ = VERSION
APP_NAME = "autogenstudio"
13 changes: 10 additions & 3 deletions samples/apps/autogen-studio/autogenstudio/web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from loguru import logger
from openai import OpenAIError

from ..chatmanager import AutoGenChatManager, WebSocketConnectionManager
from ..chatmanager import AutoGenChatManager
from ..database import workflow_from_id
from ..database.dbmanager import DBManager
from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow
from ..profiler import Profiler
from ..utils import check_and_cast_datetime_fields, init_app_folders, md5_hash, test_model
from ..version import VERSION
from ..websocket_connection_manager import WebSocketConnectionManager

profiler = Profiler()
managers = {"chat": None} # manage calls to autogen
Expand Down Expand Up @@ -64,11 +65,17 @@ def message_handler():
database_engine_uri = folders["database_engine_uri"]
dbmanager = DBManager(engine_uri=database_engine_uri)

HUMAN_INPUT_TIMEOUT_SECONDS = 180


@asynccontextmanager
async def lifespan(app: FastAPI):
print("***** App started *****")
managers["chat"] = AutoGenChatManager(message_queue=message_queue)
managers["chat"] = AutoGenChatManager(
message_queue=message_queue,
websocket_manager=websocket_manager,
human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS,
)
dbmanager.create_db_and_tables()

yield
Expand Down Expand Up @@ -449,7 +456,7 @@ async def run_session_workflow(message: Message, session_id: int, workflow_id: i
user_dir = os.path.join(folders["files_static_root"], "user", md5_hash(message.user_id))
os.makedirs(user_dir, exist_ok=True)
workflow = workflow_from_id(workflow_id, dbmanager=dbmanager)
agent_response: Message = managers["chat"].chat(
agent_response: Message = await managers["chat"].a_chat(
message=message,
history=user_message_history,
user_dir=user_dir,
Expand Down
Loading

0 comments on commit 084a54d

Please sign in to comment.