Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RAG Agent to AutoGen Studio #2881

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions samples/apps/autogen-studio/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ autogenstudio/web/workdir/*
autogenstudio/web/ui/*
autogenstudio/web/skills/user/*
.release.sh
.nightly.sh

notebooks/work_dir/*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
81 changes: 5 additions & 76 deletions samples/apps/autogen-studio/autogenstudio/chatmanager.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import asyncio
import json
import os
import time
from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional, Tuple, Union

import websockets
from fastapi import WebSocket, WebSocketDisconnect

from .datamodel import Message, SocketMessage, Workflow
from .utils import (
extract_successful_code_blocks,
get_modified_files,
summarize_chat_history,
)
from .datamodel import Message
from .workflowmanager import WorkflowManager


Expand Down Expand Up @@ -82,76 +75,12 @@ def chat(
connection_id=connection_id,
)

workflow = Workflow.model_validate(workflow)

message_text = message.content.strip()
result_message: Message = workflow_manager.run(message=f"{message_text}", clear_history=False, history=history)

start_time = time.time()
workflow_manager.run(message=f"{message_text}", clear_history=False)
end_time = time.time()

metadata = {
"messages": workflow_manager.agent_history,
"summary_method": workflow.summary_method,
"time": end_time - start_time,
"files": get_modified_files(start_time, end_time, source_dir=work_dir),
}

output = self._generate_output(message_text, workflow_manager, workflow)

output_message = Message(
user_id=message.user_id,
role="assistant",
content=output,
meta=json.dumps(metadata),
session_id=message.session_id,
)

return output_message

def _generate_output(
self,
message_text: str,
workflow_manager: WorkflowManager,
workflow: Workflow,
) -> str:
"""
Generates the output response based on the workflow configuration and agent history.

:param message_text: The text of the incoming message.
:param flow: An instance of `WorkflowManager`.
:param flow_config: An instance of `AgentWorkFlowConfig`.
:return: The output response as a string.
"""

output = ""
if workflow.summary_method == "last":
successful_code_blocks = extract_successful_code_blocks(workflow_manager.agent_history)
last_message = (
workflow_manager.agent_history[-1]["message"]["content"] if workflow_manager.agent_history else ""
)
successful_code_blocks = "\n\n".join(successful_code_blocks)
output = (last_message + "\n" + successful_code_blocks) if successful_code_blocks else last_message
elif workflow.summary_method == "llm":
client = workflow_manager.receiver.client
status_message = SocketMessage(
type="agent_status",
data={
"status": "summarizing",
"message": "Summarizing agent dialogue",
},
connection_id=workflow_manager.connection_id,
)
self.send(status_message.dict())
output = summarize_chat_history(
task=message_text,
messages=workflow_manager.agent_history,
client=client,
)

elif workflow.summary_method == "none":
output = ""
return output
result_message.user_id = message.user_id
result_message.session_id = message.session_id
return result_message


class WebSocketConnectionManager:
Expand Down
35 changes: 34 additions & 1 deletion samples/apps/autogen-studio/autogenstudio/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def ui(
port: int = 8081,
workers: int = 1,
reload: Annotated[bool, typer.Option("--reload")] = False,
docs: bool = False,
docs: bool = True,
appdir: str = None,
database_uri: Optional[str] = None,
):
Expand Down Expand Up @@ -48,6 +48,39 @@ def ui(
)


@app.command()
def serve(
workflow: str = "",
host: str = "127.0.0.1",
port: int = 8084,
workers: int = 1,
docs: bool = False,
):
"""
Serve an API Endpoint based on an AutoGen Studio workflow json file.

Args:
workflow (str): Path to the workflow json file.
host (str, optional): Host to run the UI on. Defaults to 127.0.0.1 (localhost).
port (int, optional): Port to run the UI on. Defaults to 8081.
workers (int, optional): Number of workers to run the UI with. Defaults to 1.
reload (bool, optional): Whether to reload the UI on code changes. Defaults to False.
docs (bool, optional): Whether to generate API docs. Defaults to False.

"""

os.environ["AUTOGENSTUDIO_API_DOCS"] = str(docs)
os.environ["AUTOGENSTUDIO_WORKFLOW_FILE"] = workflow

uvicorn.run(
"autogenstudio.web.serve:app",
host=host,
port=port,
workers=workers,
reload=False,
)


@app.command()
def version():
"""
Expand Down
55 changes: 37 additions & 18 deletions samples/apps/autogen-studio/autogenstudio/database/dbmanager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
from datetime import datetime
from typing import Optional

Expand All @@ -15,30 +16,39 @@
Skill,
Workflow,
WorkflowAgentLink,
WorkflowAgentType,
)
from .utils import init_db_samples

valid_link_types = ["agent_model", "agent_skill", "agent_agent", "workflow_agent"]


class WorkflowAgentMap(SQLModel):
agent: Agent
link: WorkflowAgentLink


class DBManager:
"""A class to manage database operations"""

_init_lock = threading.Lock() # Class-level lock

def __init__(self, engine_uri: str):
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
self.engine = create_engine(engine_uri, connect_args=connection_args)
# run_migration(engine_uri=engine_uri)

def create_db_and_tables(self):
"""Create a new database and tables"""
try:
SQLModel.metadata.create_all(self.engine)
with self._init_lock: # Use the lock
try:
init_db_samples(self)
SQLModel.metadata.create_all(self.engine)
try:
init_db_samples(self)
except Exception as e:
logger.info("Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while initializing database samples: " + str(e))
except Exception as e:
logger.info("Error while creating database tables:" + str(e))
logger.info("Error while creating database tables:" + str(e))

def upsert(self, model: SQLModel):
"""Create a new entity"""
Expand All @@ -62,7 +72,7 @@ def upsert(self, model: SQLModel):
session.refresh(model)
except Exception as e:
session.rollback()
logger.error("Error while upserting %s", e)
logger.error("Error while updating " + str(model_class.__name__) + ": " + str(e))
status = False

response = Response(
Expand Down Expand Up @@ -115,7 +125,7 @@ def get_items(
session.rollback()
status = False
status_message = f"Error while fetching {model_class.__name__}"
logger.error("Error while getting %s: %s", model_class.__name__, e)
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))

response: Response = Response(
message=status_message,
Expand Down Expand Up @@ -157,16 +167,16 @@ def delete(self, model_class: SQLModel, filters: dict = None):
status_message = f"{model_class.__name__} Deleted Successfully"
else:
print(f"Row with filters {filters} not found")
logger.info("Row with filters %s not found", filters)
logger.info("Row with filters + filters + not found")
status_message = "Row not found"
except exc.IntegrityError as e:
session.rollback()
logger.error("Integrity ... Error while deleting: %s", e)
logger.error("Integrity ... Error while deleting: " + str(e))
status_message = f"The {model_class.__name__} is linked to another entity and cannot be deleted."
status = False
except Exception as e:
session.rollback()
logger.error("Error while deleting: %s", e)
logger.error("Error while deleting: " + str(e))
status_message = f"Error while deleting: {e}"
status = False
response = Response(
Expand All @@ -182,6 +192,7 @@ def get_linked_entities(
primary_id: int,
return_json: bool = False,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
):
"""
Get all entities linked to the primary entity.
Expand Down Expand Up @@ -217,19 +228,21 @@ def get_linked_entities(
linked_entities = agent.agents
elif link_type == "workflow_agent":
linked_entities = session.exec(
select(Agent)
.join(WorkflowAgentLink)
select(WorkflowAgentLink, Agent)
.join(Agent, WorkflowAgentLink.agent_id == Agent.id)
.where(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_type == agent_type,
)
).all()

linked_entities = [WorkflowAgentMap(agent=agent, link=link) for link, agent in linked_entities]
linked_entities = sorted(linked_entities, key=lambda x: x.link.sequence_id) # type: ignore
except Exception as e:
logger.error("Error while getting linked entities: %s", e)
logger.error("Error while getting linked entities: " + str(e))
status_message = f"Error while getting linked entities: {e}"
status = False
if return_json:
linked_entities = [self._model_to_dict(row) for row in linked_entities]
linked_entities = [row.model_dump() for row in linked_entities]

response = Response(
message=status_message,
Expand All @@ -245,6 +258,7 @@ def link(
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = None,
) -> Response:
"""
Link two entities together.
Expand Down Expand Up @@ -357,6 +371,7 @@ def link(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()
if existing_link:
Expand All @@ -373,6 +388,7 @@ def link(
workflow_id=primary_id,
agent_id=secondary_id,
agent_type=agent_type,
sequence_id=sequence_id,
)
session.add(workflow_agent_link)
# add and commit the link
Expand All @@ -385,7 +401,7 @@ def link(

except Exception as e:
session.rollback()
logger.error("Error while linking: %s", e)
logger.error("Error while linking: " + str(e))
status = False
status_message = f"Error while linking due to an exception: {e}"

Expand All @@ -402,6 +418,7 @@ def unlink(
primary_id: int,
secondary_id: int,
agent_type: Optional[str] = None,
sequence_id: Optional[int] = 0,
) -> Response:
"""
Unlink two entities.
Expand All @@ -417,6 +434,7 @@ def unlink(
"""
status = True
status_message = ""
print("primary", primary_id, "secondary", secondary_id, "sequence", sequence_id, "agent_type", agent_type)

if link_type not in valid_link_types:
status = False
Expand Down Expand Up @@ -452,6 +470,7 @@ def unlink(
WorkflowAgentLink.workflow_id == primary_id,
WorkflowAgentLink.agent_id == secondary_id,
WorkflowAgentLink.agent_type == agent_type,
WorkflowAgentLink.sequence_id == sequence_id,
)
).first()

Expand All @@ -465,7 +484,7 @@ def unlink(

except Exception as e:
session.rollback()
logger.error("Error while unlinking: %s", e)
logger.error("Error while unlinking: " + str(e))
status = False
status_message = f"Error while unlinking due to an exception: {e}"

Expand Down
Loading
Loading