Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable human interaction in AutoGenStudio #3445

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion samples/apps/autogen-studio/autogenstudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .chatmanager import *
from .datamodel import *
from .version import __version__
from .workflowmanager import *
61 changes: 26 additions & 35 deletions samples/apps/autogen-studio/autogenstudio/web/app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import asyncio
import os
import queue
import threading
import traceback
from contextlib import asynccontextmanager
from typing import Any, Union

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.concurrency import run_in_threadpool

from loguru import logger
from openai import OpenAIError

from ..chatmanager import AutoGenChatManager, WebSocketConnectionManager
from .chatmanager import AutoGenChatManager, WebSocketConnectionManager
from ..database import workflow_from_id
from ..database.dbmanager import DBManager
from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow
Expand All @@ -22,41 +21,13 @@

profiler = Profiler()
managers = {"chat": None} # manage calls to autogen
# Create thread-safe queue for messages between api thread and autogen threads
message_queue = queue.Queue()
active_connections = []
active_connections_lock = asyncio.Lock()
websocket_manager = WebSocketConnectionManager(
active_connections=active_connections,
active_connections_lock=active_connections_lock,
)


def message_handler():
while True:
message = message_queue.get()
logger.info(
"** Processing Agent Message on Queue: Active Connections: "
+ str([client_id for _, client_id in websocket_manager.active_connections])
+ " **"
)
for connection, socket_client_id in websocket_manager.active_connections:
if message["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)
asyncio.run(websocket_manager.send_message(message, connection))
else:
logger.info(
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)
message_queue.task_done()


message_handler_thread = threading.Thread(target=message_handler, daemon=True)
message_handler_thread.start()


app_file_path = os.path.dirname(os.path.abspath(__file__))
folders = init_app_folders(app_file_path)
ui_folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui")
Expand All @@ -68,7 +39,7 @@ def message_handler():
@asynccontextmanager
async def lifespan(app: FastAPI):
print("***** App started *****")
managers["chat"] = AutoGenChatManager(message_queue=message_queue)
managers["chat"] = AutoGenChatManager(websocket_manager=websocket_manager)
dbmanager.create_db_and_tables()

yield
Expand Down Expand Up @@ -449,12 +420,17 @@ async def run_session_workflow(message: Message, session_id: int, workflow_id: i
user_dir = os.path.join(folders["files_static_root"], "user", md5_hash(message.user_id))
os.makedirs(user_dir, exist_ok=True)
workflow = workflow_from_id(workflow_id, dbmanager=dbmanager)
agent_response: Message = managers["chat"].chat(
# So this is where the issue begins. In order to ensure we are not blocking the event loop
# Ww need to wrap the synchronous call in run_in_threadpool. This way, no message thread
# is needed and FastAPI's usage of asyncio is sufficient.
agent_response: Message = await run_in_threadpool(
managers["chat"].chat,
message=message,
history=user_message_history,
user_dir=user_dir,
workflow=workflow,
connection_id=message.connection_id,
human_input_function=get_human_input
)

response: Response = dbmanager.upsert(agent_response)
Expand All @@ -475,9 +451,24 @@ async def get_version():
}


# websockets
def get_human_input(prompt, timeout=120):
"""
Sends a prompt to the frontend to request human input and blocks until input is received or timeout occurs.
"""
connection_id = prompt.get("connection_id")
socket_msg = {
"type": "user_input_request",
"data": prompt,
"connection_id": connection_id,
}

# Send the prompt to the frontend
response = managers["chat"].get_user_input(socket_msg, timeout)

return response


# websockets
async def process_socket_message(data: dict, websocket: WebSocket, client_id: str):
print(f"Client says: {data['type']}")
if data["type"] == "user_message":
Expand Down
134 changes: 134 additions & 0 deletions samples/apps/autogen-studio/autogenstudio/web/chatmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import asyncio
import os
from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional
from loguru import logger
import websockets
from fastapi import WebSocket, WebSocketDisconnect

from ..datamodel import Message
from ..workflowmanager import WorkflowManager
from .websocketmanager import WebSocketConnectionManager

#temp, for troubleshooting
import traceback

class AutoGenChatManager:
"""
This class handles the automated generation and management of chat interactions
using an automated workflow configuration and message queue.
"""

def __init__(self, message_queue: Queue = None, websocket_manager:WebSocketConnectionManager = None) -> None:
"""
Initializes the AutoGenChatManager with a message queue.

:param message_queue: A queue to use for sending messages asynchronously.
"""
self.message_queue = message_queue
self.websocket_manager = websocket_manager

def send(self, message: dict) -> None:
"""
Sends a message by putting it into the message queue.

:param message: The message string to be sent.
"""
# Since we are no longer blocking the event loop in the main app.py,
# we can safely avoid using the other thread, which increases complexity and
# reduces certainty about the order in which messages will be sent.
# if self.message_queue is not None:
# self.message_queue.put_nowait(message)
for connection, socket_client_id in self.websocket_manager.active_connections:
if message["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}, Message: {message}"
)
asyncio.run(self.websocket_manager.send_message(message, connection))
else:
logger.info(
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)


def get_user_input(self, user_prompt: dict, timeout: int) -> str:
"""
waits on the websocket for a response from the user.

:param prompt: the string to prompt the user with
:param timeout: The amount of seconds to wait before considering the user inactive.
:returns the user's response, or a default message to terminate the chat if the user is inactive.
"""
response = ""
for connection, socket_client_id in self.websocket_manager.active_connections:
if user_prompt["connection_id"] == socket_client_id:
logger.info(
f"Sending user prompt to connection_id: {user_prompt['connection_id']}. Connection ID: {socket_client_id}, Prompt: {user_prompt}"
)
response = asyncio.run(self.websocket_manager.get_user_input(user_prompt, timeout, connection))
else:
logger.info(
f"Skipping message for connection_id: {user_prompt['connection_id']}. Connection ID: {socket_client_id}"
)

return response


def chat(
self,
message: Message,
history: List[Dict[str, Any]],
workflow: Any = None,
connection_id: Optional[str] = None,
user_dir: Optional[str] = None,
human_input_function: Optional[callable] = None,
**kwargs,
) -> Message:
"""
Processes an incoming message according to the agent's workflow configuration
and generates a response.

:param message: An instance of `Message` representing an incoming message.
:param history: A list of dictionaries, each representing a past interaction.
:param flow_config: An instance of `AgentWorkFlowConfig`. If None, defaults to a standard configuration.
:param connection_id: An optional connection identifier.
:param kwargs: Additional keyword arguments.
:param user_dir: An optional base path to use as the temporary working folder.
:param human_input_function: an optional callable to enable human input during workflows.
:return: An instance of `Message` representing a response.
"""

# create a working director for workflow based on user_dir/session_id/time_hash
work_dir = os.path.join(
user_dir,
str(message.session_id),
datetime.now().strftime("%Y%m%d_%H-%M-%S"),
)
os.makedirs(work_dir, exist_ok=True)

# if no flow config is provided, use the default
if workflow is None:
raise ValueError("Workflow must be specified")

workflow_manager = WorkflowManager(
workflow=workflow,
history=history,
work_dir=work_dir,
send_message_function=self.send,
human_input_function=human_input_function,
connection_id=connection_id,
)

message_text = message.content.strip()
# Temporary, for troubleshooting
try:
result_message: Message = workflow_manager.run(message=f"{message_text}", clear_history=False, history=history)
except Exception as e:
traceback.print_exc()
raise

result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message

Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,15 @@
import os
from datetime import datetime
from queue import Queue
from loguru import logger

from typing import Any, Dict, List, Optional, Tuple, Union

import websockets
from fastapi import WebSocket, WebSocketDisconnect

from .datamodel import Message
from .workflowmanager import WorkflowManager


class AutoGenChatManager:
"""
This class handles the automated generation and management of chat interactions
using an automated workflow configuration and message queue.
"""

def __init__(self, message_queue: Queue) -> None:
"""
Initializes the AutoGenChatManager with a message queue.

:param message_queue: A queue to use for sending messages asynchronously.
"""
self.message_queue = message_queue

def send(self, message: str) -> None:
"""
Sends a message by putting it into the message queue.

:param message: The message string to be sent.
"""
if self.message_queue is not None:
self.message_queue.put_nowait(message)

def chat(
self,
message: Message,
history: List[Dict[str, Any]],
workflow: Any = None,
connection_id: Optional[str] = None,
user_dir: Optional[str] = None,
**kwargs,
) -> Message:
"""
Processes an incoming message according to the agent's workflow configuration
and generates a response.

:param message: An instance of `Message` representing an incoming message.
:param history: A list of dictionaries, each representing a past interaction.
:param flow_config: An instance of `AgentWorkFlowConfig`. If None, defaults to a standard configuration.
:param connection_id: An optional connection identifier.
:param kwargs: Additional keyword arguments.
:return: An instance of `Message` representing a response.
"""

# create a working director for workflow based on user_dir/session_id/time_hash
work_dir = os.path.join(
user_dir,
str(message.session_id),
datetime.now().strftime("%Y%m%d_%H-%M-%S"),
)
os.makedirs(work_dir, exist_ok=True)

# if no flow config is provided, use the default
if workflow is None:
raise ValueError("Workflow must be specified")

workflow_manager = WorkflowManager(
workflow=workflow,
history=history,
work_dir=work_dir,
send_message_function=self.send,
connection_id=connection_id,
)

message_text = message.content.strip()
result_message: Message = workflow_manager.run(message=f"{message_text}", clear_history=False, history=history)

result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message

from ..datamodel import Message
from ..workflowmanager import WorkflowManager

class WebSocketConnectionManager:
"""
Expand Down Expand Up @@ -135,6 +64,35 @@ async def disconnect_all(self) -> None:
for connection, _ in self.active_connections[:]:
await self.disconnect(connection)

async def get_user_input(self, user_prompt: Dict, timeout: int, websocket: WebSocket) -> str:
await self.send_message(user_prompt, websocket)
# there's a bug I can't figure out.
# for some reason, the async with asyncio.timeout(..) context
# often forces you to wait until the timeout even when
# data = await websocket.receive_json() should have returned.
# so, as a janky workaround....
message = f"exit"
for i in range(timeout*10):
try:
data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1)
except asyncio.TimeoutError:
continue
except WebSocketDisconnect:
print("Error: Tried to send a message to a closed WebSocket")
await self.disconnect(websocket)
except websockets.exceptions.ConnectionClosedOK:
print("Error: WebSocket connection closed normally")
await self.disconnect(websocket)
except Exception as e:
print(f"Error in sending message: {str(e)}", user_prompt)
await self.disconnect(websocket)

message = data.get("data").get("content")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SailorJoe6

Thanks for the rewrite and 3.10 support.

I am seeing some errors while testing ... .

With that fixed on my local version, I am still seeing a few other socket connection errors.
Tried to address this using a asynccontext manager but still seeing some of the timeout/delay issues where receive_json() does not return.

@asynccontextmanager
async def timeout_manager(timeout):
    try:
        yield await asyncio.wait_for(asyncio.sleep(timeout), timeout)
    except asyncio.TimeoutError:
        raise asyncio.TimeoutError

... 
async def get_user_input(self, user_prompt: Dict, timeout: int, websocket: WebSocket) -> str:
        await self.send_message(user_prompt, websocket)
        message = f"exit" 
        try:
            async with timeout_manager(20):
                try:
                    data = await websocket.receive_json()
                except WebSocketDisconnect:
                    print("Error: Tried to send a message to a closed WebSocket")
                    await self.disconnect(websocket)
                except websockets.exceptions.ConnectionClosedOK:
                    print("Error: WebSocket connection closed normally")
                    await self.disconnect(websocket)
                except Exception as e:
                    print(
                        f"Error in sending message: {str(e)}, {user_prompt}")
                    await self.disconnect(websocket)

                message = data.get("data").get("content")
                # break
        except asyncio.TimeoutError:
            print(">> Timeout")

        return message

Overall, once the core errors are addressed (even without the delay), I am happy to mark this feature as experimental (always/terminate human input), merge in and then improve iteratively.

Also, if you can attach a quick recording of how you are testing and results, that would be useful. Thanks.

Copy link
Collaborator Author

@SailorJoe6 SailorJoe6 Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll address these issues to the best of my abilities tomorrow. Thanks again for the review!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @victordibia, I'm working on these issues now. I see there's a new commit on autogenstudio branch. I can merge or rebase. Personally, I prefer rebase, but it's your call. Do you have a preference?

break

return message


async def send_message(self, message: Union[Dict, str], websocket: WebSocket) -> None:
"""
Sends a JSON message to a single WebSocket connection.
Expand Down Expand Up @@ -175,3 +133,4 @@ async def broadcast(self, message: Dict) -> None:
except (WebSocketDisconnect, websockets.exceptions.ConnectionClosedOK) as e:
print(f"Error: WebSocket disconnected or closed({str(e)})")
await self.disconnect(connection)

Loading
Loading