Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,12 @@ def run_migrations_online() -> None:
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
is_sqlite = connection.dialect.name == "sqlite"
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True
render_as_batch=is_sqlite,
)
with context.begin_transaction():
context.run_migrations()
Expand Down
12 changes: 4 additions & 8 deletions python/packages/autogen-studio/autogenstudio/datamodel/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from uuid import UUID, uuid4

from autogen_core import ComponentModel
from pydantic import ConfigDict
from sqlalchemy import UUID as SQLAlchemyUUID
from sqlalchemy import ForeignKey, Integer, String
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func

Expand Down Expand Up @@ -47,9 +45,7 @@ class Message(SQLModel, table=True):
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"))
)
run_id: Optional[UUID] = Field(
default=None, sa_column=Column(SQLAlchemyUUID, ForeignKey("run.id", ondelete="CASCADE"))
)
run_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("run.id", ondelete="CASCADE")))

message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))

Expand Down Expand Up @@ -84,7 +80,7 @@ class Run(SQLModel, table=True):

__table_args__ = {"sqlite_autoincrement": True}

id: UUID = Field(default_factory=uuid4, sa_column=Column(SQLAlchemyUUID, primary_key=True, index=True, unique=True))
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
Expand All @@ -106,7 +102,7 @@ class Run(SQLModel, table=True):
version: Optional[str] = "0.0.1"
messages: Union[List[Message], List[dict]] = Field(default_factory=list, sa_column=Column(JSON))

model_config = ConfigDict(json_encoders={UUID: str, datetime: lambda v: v.isoformat()})
model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat()})
user_id: Optional[str] = None


Expand All @@ -125,7 +121,7 @@ class Gallery(SQLModel, table=True):
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(default_factory=GalleryConfig, sa_column=Column(JSON))

model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat(), UUID: str})
model_config = ConfigDict(json_encoders={datetime: lambda v: v.isoformat()})


class Settings(SQLModel, table=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ async def run_stream(
yield event
finally:
# Cleanup - remove our handler
logger.handlers.remove(llm_event_logger)
if llm_event_logger in logger.handlers:
logger.handlers.remove(llm_event_logger)

# Ensure cleanup happens
if team and hasattr(team, "_participants"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import traceback
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional, Union
from uuid import UUID

from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import (
Expand Down Expand Up @@ -42,11 +41,11 @@ class WebSocketManager:

def __init__(self, db_manager: DatabaseManager):
self.db_manager = db_manager
self._connections: Dict[UUID, WebSocket] = {}
self._cancellation_tokens: Dict[UUID, CancellationToken] = {}
self._connections: Dict[int, WebSocket] = {}
self._cancellation_tokens: Dict[int, CancellationToken] = {}
# Track explicitly closed connections
self._closed_connections: set[UUID] = set()
self._input_responses: Dict[UUID, asyncio.Queue] = {}
self._closed_connections: set[int] = set()
self._input_responses: Dict[int, asyncio.Queue] = {}

self._cancel_message = TeamResult(
task_result=TaskResult(
Expand All @@ -63,7 +62,7 @@ def _get_stop_message(self, reason: str) -> dict:
duration=0,
).model_dump()

async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
async def connect(self, websocket: WebSocket, run_id: int) -> bool:
try:
await websocket.accept()
self._connections[run_id] = websocket
Expand All @@ -80,7 +79,7 @@ async def connect(self, websocket: WebSocket, run_id: UUID) -> bool:
logger.error(f"Connection error for run {run_id}: {e}")
return False

async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None:
async def start_stream(self, run_id: int, task: str, team_config: dict) -> None:
"""Start streaming task execution with proper run management"""
if run_id not in self._connections or run_id in self._closed_connections:
raise ValueError(f"No active connection for run {run_id}")
Expand Down Expand Up @@ -161,7 +160,7 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None
finally:
self._cancellation_tokens.pop(run_id, None)

async def _save_message(self, run_id: UUID, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None:
async def _save_message(self, run_id: int, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None:
"""Save a message to the database"""

run = await self._get_run(run_id)
Expand All @@ -175,7 +174,7 @@ async def _save_message(self, run_id: UUID, message: Union[AgentEvent | ChatMess
self.db_manager.upsert(db_message)

async def _update_run(
self, run_id: UUID, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
self, run_id: int, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None
) -> None:
"""Update run status and result"""
run = await self._get_run(run_id)
Expand All @@ -187,7 +186,7 @@ async def _update_run(
run.error_message = error
self.db_manager.upsert(run)

def create_input_func(self, run_id: UUID) -> Callable:
def create_input_func(self, run_id: int) -> Callable:
"""Creates an input function for a specific run"""

async def input_handler(prompt: str = "", cancellation_token: Optional[CancellationToken] = None) -> str:
Expand Down Expand Up @@ -216,14 +215,14 @@ async def input_handler(prompt: str = "", cancellation_token: Optional[Cancellat

return input_handler

async def handle_input_response(self, run_id: UUID, response: str) -> None:
async def handle_input_response(self, run_id: int, response: str) -> None:
"""Handle input response from client"""
if run_id in self._input_responses:
await self._input_responses[run_id].put(response)
else:
logger.warning(f"Received input response for inactive run {run_id}")

async def stop_run(self, run_id: UUID, reason: str) -> None:
async def stop_run(self, run_id: int, reason: str) -> None:
if run_id in self._cancellation_tokens:
logger.info(f"Stopping run {run_id}")

Expand Down Expand Up @@ -253,7 +252,7 @@ async def stop_run(self, run_id: UUID, reason: str) -> None:
# We might want to force disconnect here if db update failed
# await self.disconnect(run_id) # Optional

async def disconnect(self, run_id: UUID) -> None:
async def disconnect(self, run_id: int) -> None:
"""Clean up connection and associated resources"""
logger.info(f"Disconnecting run {run_id}")

Expand All @@ -268,11 +267,11 @@ async def disconnect(self, run_id: UUID) -> None:
self._cancellation_tokens.pop(run_id, None)
self._input_responses.pop(run_id, None)

async def _send_message(self, run_id: UUID, message: dict) -> None:
async def _send_message(self, run_id: int, message: dict) -> None:
"""Send a message through the WebSocket with connection state checking

Args:
run_id: UUID of the run
run_id: id of the run
message: Message dictionary to send
"""
if run_id in self._closed_connections:
Expand All @@ -292,7 +291,7 @@ async def _send_message(self, run_id: UUID, message: dict) -> None:
await self._update_run_status(run_id, RunStatus.ERROR, str(e))
await self.disconnect(run_id)

async def _handle_stream_error(self, run_id: UUID, error: Exception) -> None:
async def _handle_stream_error(self, run_id: int, error: Exception) -> None:
"""Handle stream errors with proper run updates"""
if run_id not in self._closed_connections:
error_result = TeamResult(
Expand Down Expand Up @@ -366,11 +365,11 @@ def _format_message(self, message: Any) -> Optional[dict]:
logger.error(f"Message formatting error: {e}")
return None

async def _get_run(self, run_id: UUID) -> Optional[Run]:
async def _get_run(self, run_id: int) -> Optional[Run]:
"""Get run from database

Args:
run_id: UUID of the run to retrieve
run_id: id of the run to retrieve

Returns:
Optional[Run]: Run object if found, None otherwise
Expand All @@ -388,11 +387,11 @@ async def _get_settings(self, user_id: str) -> Optional[Settings]:
response = self.db_manager.get(filters={"user_id": user_id}, model_class=Settings, return_json=False)
return response.data[0] if response.status and response.data else None

async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optional[str] = None) -> None:
async def _update_run_status(self, run_id: int, status: RunStatus, error: Optional[str] = None) -> None:
"""Update run status in database

Args:
run_id: UUID of the run to update
run_id: id of the run to update
status: New status to set
error: Optional error message
"""
Expand Down Expand Up @@ -451,11 +450,11 @@ async def cleanup(self) -> None:
self._input_responses.clear()

@property
def active_connections(self) -> set[UUID]:
def active_connections(self) -> set[int]:
"""Get set of active run IDs"""
return set(self._connections.keys()) - self._closed_connections

@property
def active_runs(self) -> set[UUID]:
def active_runs(self) -> set[int]:
"""Get set of runs with active cancellation tokens"""
return set(self._cancellation_tokens.keys())
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# /api/runs routes
from typing import Dict
from uuid import UUID

from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel
Expand Down Expand Up @@ -40,7 +39,7 @@ async def create_run(
),
return_json=False,
)
return {"status": run.status, "data": {"run_id": str(run.data.id)}}
return {"status": run.status, "data": {"run_id": run.data.id}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

Expand All @@ -49,7 +48,7 @@ async def create_run(


@router.get("/{run_id}")
async def get_run(run_id: UUID, db=Depends(get_db)) -> Dict:
async def get_run(run_id: int, db=Depends(get_db)) -> Dict:
"""Get run details including task and result"""
run = db.get(Run, filters={"id": run_id}, return_json=False)
if not run.status or not run.data:
Expand All @@ -59,7 +58,7 @@ async def get_run(run_id: UUID, db=Depends(get_db)) -> Dict:


@router.get("/{run_id}/messages")
async def get_run_messages(run_id: UUID, db=Depends(get_db)) -> Dict:
async def get_run_messages(run_id: int, db=Depends(get_db)) -> Dict:
"""Get all messages for a run"""
messages = db.get(Message, filters={"run_id": run_id}, order="created_at asc", return_json=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi import APIRouter, Depends, HTTPException

from ...datamodel import Team
from ...gallery.builder import create_default_gallery
from ..deps import get_db

router = APIRouter()
Expand All @@ -13,6 +14,13 @@
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
"""List all teams for a user"""
response = db.get(Team, filters={"user_id": user_id})
if not response.data or len(response.data) == 0:
default_gallery = create_default_gallery()
default_team = Team(user_id=user_id, component=default_gallery.components.teams[0].model_dump())

db.upsert(default_team)
response = db.get(Team, filters={"user_id": user_id})

return {"status": True, "data": response.data}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import asyncio
import json
from datetime import datetime
from uuid import UUID

from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from loguru import logger
Expand All @@ -17,7 +16,7 @@
@router.websocket("/runs/{run_id}")
async def run_websocket(
websocket: WebSocket,
run_id: UUID,
run_id: int,
ws_manager: WebSocketManager = Depends(get_websocket_manager),
db=Depends(get_db),
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ export interface DBModel {
export interface Message extends DBModel {
config: AgentMessageConfig;
session_id: number;
run_id: string;
run_id: number;
}

export interface Team extends DBModel {
Expand Down Expand Up @@ -321,7 +321,7 @@ export interface TeamResult {
}

export interface Run {
id: string;
id: number;
created_at: string;
updated_at?: string;
status: RunStatus;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export default function ChatView({ session }: ChatViewProps) {

const chatContainerRef = React.useRef<HTMLDivElement | null>(null);
const [streamingContent, setStreamingContent] = React.useState<{
runId: string;
runId: number;
content: string;
source: string;
} | null>(null);
Expand All @@ -62,7 +62,7 @@ export default function ChatView({ session }: ChatViewProps) {
// Create a Message object from AgentMessageConfig
const createMessage = (
config: AgentMessageConfig,
runId: string,
runId: number,
sessionId: number
): Message => ({
created_at: new Date().toISOString(),
Expand Down Expand Up @@ -134,7 +134,7 @@ export default function ChatView({ session }: ChatViewProps) {
};
}, [activeSocket]);

const createRun = async (sessionId: number): Promise<string> => {
const createRun = async (sessionId: number): Promise<number> => {
const payload = { session_id: sessionId, user_id: user?.email || "" };
const response = await fetch(`${serverUrl}/runs/`, {
method: "POST",
Expand Down Expand Up @@ -423,7 +423,7 @@ export default function ChatView({ session }: ChatViewProps) {
}
};

const setupWebSocket = (runId: string, query: string): WebSocket => {
const setupWebSocket = (runId: number, query: string): WebSocket => {
if (!session || !session.id) {
throw new Error("Invalid session configuration");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ interface RunViewProps {
onCancel?: () => void;
isFirstRun?: boolean;
streamingContent?: {
runId: string;
runId: number;
content: string;
source: string;
} | null;
Expand Down Expand Up @@ -211,8 +211,7 @@ const RunView: React.FC<RunViewProps> = ({
}
>
<span className="cursor-help">
Run ...{run.id.slice(-6)} |{" "}
{getRelativeTimeString(run?.created_at || "")}{" "}
Run ...{run.id} | {getRelativeTimeString(run?.created_at || "")}{" "}
</span>
</Tooltip>
{!isFirstRun && (
Expand Down
Loading
Loading