Skip to content

Commit 6711ea6

Browse files
committed
fix: work with async agents
1 parent 34d37d7 commit 6711ea6

9 files changed

+141
-24
lines changed

agentserve/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
# agentserve/__init__.py
22
from .agent_server import AgentServer as app
3+
from .logging_config import setup_logger
4+
5+
logger = setup_logger()

agentserve/agent_registry.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,45 @@
11
# agentserve/agent_registry.py
22
from typing import Callable, Optional, Type
33
from pydantic import BaseModel
4+
from .logging_config import setup_logger
5+
import asyncio
46

57
class AgentRegistry:
68
def __init__(self):
79
self.agent_function = None
810
self.input_schema: Optional[Type[BaseModel]] = None
9-
11+
self.logger = setup_logger("agentserve.agent_registry")
12+
1013
def register_agent(self, func: Optional[Callable] = None, *, input_schema: Optional[Type[BaseModel]] = None):
1114
if func is None:
12-
# Decorator is called with arguments
1315
def wrapper(func: Callable):
1416
return self.register_agent(func, input_schema=input_schema)
1517
return wrapper
1618

1719
self.input_schema = input_schema
20+
is_async = asyncio.iscoroutinefunction(func)
21+
self.logger.info(f"Registering {'async' if is_async else 'sync'} function")
1822

19-
def validated_func(task_data):
23+
async def async_validated_func(task_data):
24+
if self.input_schema is not None:
25+
validated_data = self.input_schema(**task_data)
26+
return await func(validated_data)
27+
return await func(task_data)
28+
29+
def sync_validated_func(task_data):
2030
if self.input_schema is not None:
2131
validated_data = self.input_schema(**task_data)
2232
return func(validated_data)
23-
else:
24-
return func(task_data)
33+
return func(task_data)
34+
35+
if is_async:
36+
self.agent_function = async_validated_func
37+
setattr(self.agent_function, '_is_async', True)
38+
else:
39+
self.agent_function = sync_validated_func
40+
setattr(self.agent_function, '_is_async', False)
2541

26-
self.agent_function = validated_func
27-
return validated_func
42+
return self.agent_function
2843

2944
def get_agent(self):
3045
if self.agent_function is None:

agentserve/agent_server.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,50 @@
66
from .agent_registry import AgentRegistry
77
from typing import Dict, Any, Optional
88
from .config import Config
9+
from .logging_config import setup_logger
910
import uuid
1011

1112
class AgentServer:
1213
def __init__(self, config: Optional[Config] = None):
13-
self.app = FastAPI()
14+
self.logger = setup_logger("agentserve.server")
15+
self.app = FastAPI(debug=True)
1416
self.agent_registry = AgentRegistry()
1517
self.config = config or Config()
1618
self.task_queue = self._initialize_task_queue()
1719
self.agent = self.agent_registry.register_agent
1820
self._setup_routes()
21+
self.logger.info("AgentServer initialized")
1922

2023
def _initialize_task_queue(self):
2124
task_queue_type = self.config.get('task_queue', 'local').lower()
22-
if task_queue_type == 'celery':
23-
from .celery_task_queue import CeleryTaskQueue
24-
return CeleryTaskQueue(self.config)
25-
elif task_queue_type == 'redis':
26-
from .redis_task_queue import RedisTaskQueue
27-
return RedisTaskQueue(self.config)
28-
else:
29-
from .queues.local_task_queue import LocalTaskQueue
30-
return LocalTaskQueue()
25+
self.logger.info(f"Initializing {task_queue_type} task queue")
26+
27+
try:
28+
if task_queue_type == 'celery':
29+
from .queues.celery_task_queue import CeleryTaskQueue
30+
return CeleryTaskQueue(self.config)
31+
elif task_queue_type == 'redis':
32+
from .queues.redis_task_queue import RedisTaskQueue
33+
return RedisTaskQueue(self.config)
34+
else:
35+
from .queues.local_task_queue import LocalTaskQueue
36+
return LocalTaskQueue()
37+
except Exception as e:
38+
self.logger.error(f"Failed to initialize task queue: {str(e)}")
39+
raise
3140

3241
def _setup_routes(self):
3342
@self.app.post("/task/sync")
3443
async def sync_task(task_data: Dict[str, Any]):
44+
self.logger.debug(f"sync_task called with data: {task_data}")
3545
try:
3646
agent_function = self.agent_registry.get_agent()
37-
result = agent_function(task_data)
47+
if getattr(agent_function, '_is_async', False):
48+
self.logger.info("Function is async, running in event loop")
49+
result = await agent_function(task_data)
50+
else:
51+
self.logger.info("Function is sync, running directly")
52+
result = agent_function(task_data)
3853
return {"result": result}
3954
except ValidationError as ve:
4055
if hasattr(ve, 'errors'):

agentserve/logging_config.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import logging
2+
import sys
3+
from typing import Optional
4+
5+
def setup_logger(name: str = "agentserve", level: Optional[str] = None) -> logging.Logger:
6+
logger = logging.getLogger(name)
7+
8+
if not logger.handlers:
9+
handler = logging.StreamHandler(sys.stdout)
10+
formatter = logging.Formatter(
11+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12+
)
13+
handler.setFormatter(formatter)
14+
logger.addHandler(handler)
15+
16+
log_level = getattr(logging, (level or "DEBUG").upper())
17+
logger.setLevel(log_level)
18+
19+
return logger

agentserve/queues/calery_task_queue.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,44 @@
11
# agentserve/celery_task_queue.py
22

3+
import asyncio
34
from typing import Any, Dict
45
from .task_queue import TaskQueue
56
from ..config import Config
7+
from ..logging_config import setup_logger
8+
69
class CeleryTaskQueue(TaskQueue):
710
def __init__(self, config: Config):
811
try:
912
from celery import Celery
1013
except ImportError:
1114
raise ImportError("CeleryTaskQueue requires the 'celery' package. Please install it.")
1215

16+
self.logger = setup_logger("agentserve.queue.celery")
1317
broker_url = config.get('celery', {}).get('broker_url', 'pyamqp://guest@localhost//')
1418
self.celery_app = Celery('agent_server', broker=broker_url)
19+
self.loop = asyncio.new_event_loop()
1520
self._register_tasks()
21+
self.logger.info("CeleryTaskQueue initialized")
1622

1723
def _register_tasks(self):
1824
@self.celery_app.task(name='agent_task')
19-
def agent_task(task_data):
20-
from .agent_registry import AgentRegistry
25+
def agent_task(task_data, is_async=False):
26+
from ..agent_registry import AgentRegistry
2127
agent_registry = AgentRegistry()
2228
agent_function = agent_registry.get_agent()
29+
30+
if is_async:
31+
asyncio.set_event_loop(self.loop)
32+
return self.loop.run_until_complete(agent_function(task_data))
2333
return agent_function(task_data)
2434

2535
def enqueue(self, agent_function, task_data: Dict[str, Any], task_id: str):
26-
# Since the agent task is registered with Celery, we just send the task name
27-
self.celery_app.send_task('agent_task', args=[task_data], task_id=task_id)
36+
self.logger.debug(f"Enqueueing task {task_id}")
37+
is_async = getattr(agent_function, '_is_async', False)
38+
self.celery_app.send_task('agent_task',
39+
args=[task_data],
40+
kwargs={'is_async': is_async},
41+
task_id=task_id)
2842

2943
def get_status(self, task_id: str) -> str:
3044
result = self.celery_app.AsyncResult(task_id)

agentserve/queues/local_task_queue.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,35 @@
44
from typing import Any, Dict
55
from .task_queue import TaskQueue
66
import threading
7+
from ..logging_config import setup_logger
78

89
class LocalTaskQueue(TaskQueue):
910
def __init__(self):
11+
self.logger = setup_logger("agentserve.queue.local")
1012
self.results = {}
1113
self.statuses = {}
14+
self.loop = asyncio.new_event_loop()
15+
self.logger.info("LocalTaskQueue initialized")
1216

1317
def enqueue(self, agent_function, task_data: Dict[str, Any], task_id: str):
18+
self.logger.debug(f"Enqueueing task {task_id}")
1419
self.statuses[task_id] = 'queued'
1520
threading.Thread(target=self._run_task, args=(agent_function, task_data, task_id)).start()
1621

1722
def _run_task(self, agent_function, task_data: Dict[str, Any], task_id: str):
23+
self.logger.debug(f"Starting task {task_id}")
1824
self.statuses[task_id] = 'in_progress'
1925
try:
20-
result = agent_function(task_data)
26+
if getattr(agent_function, '_is_async', False):
27+
asyncio.set_event_loop(self.loop)
28+
result = self.loop.run_until_complete(agent_function(task_data))
29+
else:
30+
result = agent_function(task_data)
2131
self.results[task_id] = result
2232
self.statuses[task_id] = 'completed'
33+
self.logger.info(f"Task {task_id} completed successfully")
2334
except Exception as e:
35+
self.logger.error(f"Task {task_id} failed: {str(e)}")
2436
self.results[task_id] = e
2537
self.statuses[task_id] = 'failed'
2638

agentserve/queues/redis_task_queue.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# agentserve/redis_task_queue.py
22

3+
import asyncio
34
from typing import Any, Dict
45
from .task_queue import TaskQueue
6+
from ..logging_config import setup_logger
57

68
class RedisTaskQueue(TaskQueue):
79
def __init__(self, config: Config):
@@ -11,14 +13,28 @@ def __init__(self, config: Config):
1113
except ImportError:
1214
raise ImportError("RedisTaskQueue requires 'redis' and 'rq' packages. Please install them.")
1315

16+
self.logger = setup_logger("agentserve.queue.redis")
1417
redis_config = config.get('redis', {})
1518
redis_host = redis_config.get('host', 'localhost')
1619
redis_port = redis_config.get('port', 6379)
1720
self.redis_conn = Redis(host=redis_host, port=redis_port)
1821
self.task_queue = Queue(connection=self.redis_conn)
22+
self.loop = asyncio.new_event_loop()
23+
self.logger.info("RedisTaskQueue initialized")
1924

2025
def enqueue(self, agent_function, task_data: Dict[str, Any], task_id: str):
21-
self.task_queue.enqueue_call(func=agent_function, args=(task_data,), job_id=task_id)
26+
self.logger.debug(f"Enqueueing task {task_id}")
27+
if getattr(agent_function, '_is_async', False):
28+
wrapped_func = self._wrap_async_function(agent_function)
29+
self.task_queue.enqueue_call(func=wrapped_func, args=(task_data,), job_id=task_id)
30+
else:
31+
self.task_queue.enqueue_call(func=agent_function, args=(task_data,), job_id=task_id)
32+
33+
def _wrap_async_function(self, func):
34+
def wrapper(task_data):
35+
asyncio.set_event_loop(self.loop)
36+
return self.loop.run_until_complete(func(task_data))
37+
return wrapper
2238

2339
def get_status(self, task_id: str) -> str:
2440
job = self.task_queue.fetch_job(task_id)

async_example.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import agentserve
2+
from pydantic import BaseModel
3+
import asyncio
4+
5+
# Configure logging level
6+
agentserve.setup_logger(level="DEBUG") # or "INFO", "WARNING", "ERROR"
7+
8+
app = agentserve.app()
9+
10+
class MyInputSchema(BaseModel):
11+
prompt: str
12+
13+
@app.agent(input_schema=MyInputSchema)
14+
async def my_agent(task_data):
15+
await asyncio.sleep(1)
16+
return task_data
17+
18+
app.run()

example.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import agentserve
22
from pydantic import BaseModel
3+
4+
# Configure logging level
5+
agentserve.setup_logger(level="DEBUG") # or "INFO", "WARNING", "ERROR"
6+
37
app = agentserve.app()
48

9+
510
class MyInputSchema(BaseModel):
611
prompt: str
712

0 commit comments

Comments
 (0)