Skip to content

Commit 5e3e078

Browse files
committed
fix: threadpool for local task queue supports 10 workers
1 parent 45b1836 commit 5e3e078

File tree

7 files changed

+51
-20
lines changed

7 files changed

+51
-20
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ redis:
199199
server:
200200
host: 0.0.0.0
201201
port: 8000
202+
203+
queue: # if using local task queue
204+
max_workers: 10 # default
202205
```
203206
204207
#### Using Environment Variables

agentserve/agent_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _initialize_task_queue(self):
3333
return RedisTaskQueue(self.config)
3434
else:
3535
from .queues.local_task_queue import LocalTaskQueue
36-
return LocalTaskQueue()
36+
return LocalTaskQueue(self.config)
3737
except Exception as e:
3838
self.logger.error(f"Failed to initialize task queue: {str(e)}")
3939
raise

agentserve/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def _load_config(self):
4242
server_config['host'] = server_host
4343
if server_port:
4444
server_config['port'] = int(server_port)
45+
46+
queue_config = config.setdefault('queue', {})
47+
queue_config['max_workers'] = int(os.getenv('AGENTSERVE_QUEUE_MAX_WORKERS', queue_config.get('max_workers', 10)))
4548

4649
return config
4750

agentserve/logging_config.py

+7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import sys
33
from typing import Optional
44

5+
_loggers = {}
6+
57
def setup_logger(name: str = "agentserve", level: Optional[str] = None) -> logging.Logger:
8+
if name in _loggers:
9+
return _loggers[name]
10+
611
logger = logging.getLogger(name)
712

813
if not logger.handlers:
@@ -12,8 +17,10 @@ def setup_logger(name: str = "agentserve", level: Optional[str] = None) -> loggi
1217
)
1318
handler.setFormatter(formatter)
1419
logger.addHandler(handler)
20+
logger.propagate = False # Prevent duplicate logging
1521

1622
log_level = getattr(logging, (level or "DEBUG").upper())
1723
logger.setLevel(log_level)
1824

25+
_loggers[name] = logger
1926
return logger

agentserve/queues/local_task_queue.py

+35-17
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,62 @@
55
from .task_queue import TaskQueue
66
import threading
77
from ..logging_config import setup_logger
8+
import concurrent.futures
89

910
class LocalTaskQueue(TaskQueue):
1011
def __init__(self):
1112
self.logger = setup_logger("agentserve.queue.local")
1213
self.results = {}
1314
self.statuses = {}
14-
self.loop = asyncio.new_event_loop()
15+
max_workers = 10 # default
16+
if config:
17+
max_workers = config.get('queue', {}).get('max_workers', 10)
18+
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
19+
self.lock = threading.Lock()
1520
self.logger.info("LocalTaskQueue initialized")
1621

1722
def enqueue(self, agent_function, task_data: Dict[str, Any], task_id: str):
1823
self.logger.debug(f"Enqueueing task {task_id}")
19-
self.statuses[task_id] = 'queued'
20-
threading.Thread(target=self._run_task, args=(agent_function, task_data, task_id)).start()
24+
with self.lock:
25+
self.statuses[task_id] = 'queued'
26+
self.thread_pool.submit(self._run_task, agent_function, task_data, task_id)
2127

2228
def _run_task(self, agent_function, task_data: Dict[str, Any], task_id: str):
2329
self.logger.debug(f"Starting task {task_id}")
24-
self.statuses[task_id] = 'in_progress'
30+
with self.lock:
31+
self.statuses[task_id] = 'in_progress'
32+
2533
try:
2634
if getattr(agent_function, '_is_async', False):
27-
asyncio.set_event_loop(self.loop)
28-
result = self.loop.run_until_complete(agent_function(task_data))
35+
loop = asyncio.new_event_loop()
36+
asyncio.set_event_loop(loop)
37+
try:
38+
result = loop.run_until_complete(agent_function(task_data))
39+
finally:
40+
loop.close()
2941
else:
3042
result = agent_function(task_data)
31-
self.results[task_id] = result
32-
self.statuses[task_id] = 'completed'
43+
44+
with self.lock:
45+
self.results[task_id] = result
46+
self.statuses[task_id] = 'completed'
3347
self.logger.info(f"Task {task_id} completed successfully")
48+
3449
except Exception as e:
3550
self.logger.error(f"Task {task_id} failed: {str(e)}")
36-
self.results[task_id] = e
37-
self.statuses[task_id] = 'failed'
51+
with self.lock:
52+
self.results[task_id] = e
53+
self.statuses[task_id] = 'failed'
3854

3955
def get_status(self, task_id: str) -> str:
40-
return self.statuses.get(task_id, 'not_found')
56+
with self.lock:
57+
return self.statuses.get(task_id, 'not_found')
4158

4259
def get_result(self, task_id: str) -> Any:
43-
if task_id not in self.results:
44-
return None
45-
result = self.results[task_id]
46-
if isinstance(result, Exception):
47-
raise result
48-
return result
60+
with self.lock:
61+
if task_id not in self.results:
62+
return None
63+
result = self.results[task_id]
64+
if isinstance(result, Exception):
65+
raise result
66+
return result

async_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class MyInputSchema(BaseModel):
1212

1313
@app.agent(input_schema=MyInputSchema)
1414
async def my_agent(task_data):
15-
await asyncio.sleep(1)
15+
await asyncio.sleep(20)
1616
return task_data
1717

1818
app.run()

example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydantic import BaseModel
33

44
# Configure logging level
5-
agentserve.setup_logger(level="DEBUG") # or "INFO", "WARNING", "ERROR"
5+
agentserve.setup_logger(level="INFO") # or "INFO", "WARNING", "ERROR"
66

77
app = agentserve.app()
88

0 commit comments

Comments
 (0)